candle_transformers/models/
dinov2.rs

1//! Implementation of the DINOv2 models from Meta Research.
2//!
3//! This module implements the DINOv2 vision transformer model from Meta AI Research.
4//! DINOv2 is a self-supervised learning model that can learn visual features
5//! without using any labeled data. See: ["DINOv2: Learning Robust Visual Features without Supervision"](https://github.com/facebookresearch/dinov2)
6//!
7//! ## Running an example with color map and CUDA
8//!
9//! ```bash
10//! cargo run \
11//!   --features cuda,depth_anything_v2 \
12//!   --package candle-examples \
13//!   --example depth_anything_v2 \
14//!   -- --color-map \
15//!   --image candle-examples/examples/yolo-v8/assets/bike.jpg
16//! ```
17//!
18//! ## Running as an ImageNet classifier
19//!
20//! The model returns the probability for the image to belong to each of the 1000 ImageNet categories.
21//!
22//! <div align=center>
23//!   <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.jpg" alt="" width=640>
24//! </div>
25//!
26//! ```bash
27//! cargo run \
28//!   --example dinov2 \
29//!   --release \
30//!   -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
31//!
32//! > mountain bike, all-terrain bike, off-roader: 43.67%
33//! > bicycle-built-for-two, tandem bicycle, tandem: 33.20%
34//! > crash helmet            : 13.23%
35//! > unicycle, monocycle     : 2.44%
36//! > maillot                 : 2.42%
37//! ```
38//!
39
40use candle::{IndexOp, Result, Tensor, D};
41use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
42
43const IMG_SIZE: usize = 518;
44const PATCH_SIZE: usize = 14;
45const NUM_CLASSES: usize = 1000;
46
47fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
48    if bias {
49        candle_nn::linear(in_dim, out_dim, vb)
50    } else {
51        candle_nn::linear_no_bias(in_dim, out_dim, vb)
52    }
53}
54
55#[derive(Debug)]
56struct Attention {
57    qkv: Linear,
58    proj: Linear,
59    num_heads: usize,
60    scale: f64,
61}
62
63impl Attention {
64    fn new(
65        vb: VarBuilder,
66        dim: usize,
67        num_heads: usize,
68        qkv_bias: bool,
69        proj_bias: bool,
70    ) -> Result<Self> {
71        let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
72        let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
73        let scale = 1. / ((dim / num_heads) as f64).sqrt();
74        Ok(Self {
75            qkv,
76            proj,
77            num_heads,
78            scale,
79        })
80    }
81}
82
83impl Module for Attention {
84    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
85        let (b, n, c) = xs.dims3()?;
86        let qkv = self
87            .qkv
88            .forward(xs)?
89            .reshape((b, n, 3, self.num_heads, c / self.num_heads))?
90            .transpose(1, 2)? // 02134
91            .transpose(0, 1)? // 20134
92            .transpose(2, 3)?; // 20314
93        let q = (qkv.i(0)? * self.scale)?;
94        let k = qkv.i(1)?.contiguous()?;
95        let v = qkv.i(2)?.contiguous()?;
96        let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?;
97        let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
98        self.proj.forward(&attn)
99    }
100}
101
102#[derive(Debug)]
103struct LayerScale {
104    gamma: Tensor,
105}
106
107impl LayerScale {
108    fn new(vb: VarBuilder, dim: usize) -> Result<Self> {
109        let gamma = vb.get(dim, "gamma")?;
110        Ok(Self { gamma })
111    }
112}
113
114impl Module for LayerScale {
115    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
116        xs.broadcast_mul(&self.gamma)
117    }
118}
119
120#[derive(Debug)]
121struct Mlp {
122    fc1: Linear,
123    fc2: Linear,
124}
125
126impl Mlp {
127    fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> {
128        let out_features = in_features;
129        let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?;
130        let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?;
131        Ok(Self { fc1, fc2 })
132    }
133}
134
135impl Module for Mlp {
136    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
137        let xs = self.fc1.forward(xs)?.gelu()?;
138        self.fc2.forward(&xs)
139    }
140}
141
142#[derive(Debug)]
143struct Block {
144    norm1: LayerNorm,
145    attn: Attention,
146    ls1: LayerScale,
147    norm2: LayerNorm,
148    mlp: Mlp,
149    ls2: LayerScale,
150}
151
152impl Block {
153    fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
154        let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
155        let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?;
156        let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
157        let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
158        let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
159        let ls2 = LayerScale::new(vb.pp("ls2"), dim)?;
160        Ok(Self {
161            norm1,
162            attn,
163            ls1,
164            norm2,
165            mlp,
166            ls2,
167        })
168    }
169}
170
171impl Module for Block {
172    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
173        let residual = xs;
174        let xs = self
175            .ls1
176            .forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?;
177        let xs = (xs + residual)?;
178        let residual = &xs;
179        let xs = self
180            .ls2
181            .forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?;
182        xs + residual
183    }
184}
185
186#[derive(Debug)]
187struct PatchEmbed {
188    proj: candle_nn::Conv2d,
189    patch_size: (usize, usize),
190    num_patches: usize,
191}
192
193impl PatchEmbed {
194    fn new(
195        vb: VarBuilder,
196        img_size: usize,
197        patch_size: usize,
198        in_chans: usize,
199        embed_dim: usize,
200    ) -> Result<Self> {
201        let config = candle_nn::Conv2dConfig {
202            stride: patch_size,
203            ..Default::default()
204        };
205        let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?;
206        let num_patches = (img_size / patch_size) * (img_size / patch_size);
207        Ok(Self {
208            proj,
209            patch_size: (patch_size, patch_size),
210            num_patches,
211        })
212    }
213}
214
215impl Module for PatchEmbed {
216    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
217        let (_b, _c, h, w) = xs.dims4()?;
218        let (patch_h, patch_w) = self.patch_size;
219        if (h % patch_h) != 0 {
220            candle::bail!("image height {h} is not a multiple of patch height {patch_h}")
221        }
222        if (w % patch_w) != 0 {
223            candle::bail!("image width {w} is not a multiple of patch width {patch_w}")
224        }
225        let xs = self.proj.forward(xs)?;
226        let (b, c, h, w) = xs.dims4()?;
227        // flatten embeddings.
228        xs.reshape((b, c, h * w))?.transpose(1, 2)
229    }
230}
231
232#[derive(Debug)]
233pub struct DinoVisionTransformer {
234    patch_embed: PatchEmbed,
235    cls_token: Tensor,
236    pos_embed: Tensor,
237    blocks: Vec<Block>,
238    norm: LayerNorm,
239    head: Linear,
240}
241
242impl DinoVisionTransformer {
243    pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> {
244        let patch_embed =
245            PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?;
246        let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
247        let num_tokens = 1;
248        let pos_embed = vb.get(
249            (1, patch_embed.num_patches + num_tokens, embed_dim),
250            "pos_embed",
251        )?;
252        let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?;
253        let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?;
254        let vb_b = vb.pp("blocks");
255        let blocks = (0..depth)
256            .map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads))
257            .collect::<Result<Vec<_>>>()?;
258        Ok(Self {
259            patch_embed,
260            cls_token,
261            pos_embed,
262            blocks,
263            norm,
264            head,
265        })
266    }
267
268    fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
269        let npatch = xs.dim(1)? - 1;
270        let n = self.pos_embed.dim(1)? - 1;
271        let sqrt_n = (n as f64).sqrt();
272        if npatch == n && w == h {
273            return Ok(xs.clone());
274        }
275        let class_pos_embed = self.pos_embed.i((.., ..1))?;
276        let patch_pos_embed = self.pos_embed.i((.., 1..))?;
277        let dim = xs.dim(D::Minus1)?;
278        let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
279        let patch_pos_embed = patch_pos_embed
280            .reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
281            .transpose(2, 3)?
282            .transpose(1, 2)?;
283        // This uses bicubic interpolation in the original implementation.
284        let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
285        let el_count = patch_pos_embed.shape().elem_count();
286        let patch_pos_embed =
287            patch_pos_embed
288                .transpose(1, 2)?
289                .transpose(2, 3)?
290                .reshape((1, el_count / dim, dim))?;
291        Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1)
292    }
293
294    fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
295        let (_b, _nc, w, h) = xs.dims4()?;
296        let xs = self.patch_embed.forward(xs)?;
297        let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
298        &xs + &self.interpolate_pos_encoding(&xs, w, h)?
299    }
300
301    fn get_intermediate_layers_not_chunked(
302        &self,
303        xs: &Tensor,
304        blocks_to_take: &[usize],
305    ) -> Result<Vec<Tensor>> {
306        let mut xs = self.prepare_tokens_with_mask(xs)?;
307        let mut output = Vec::new();
308        for (i, blk) in self.blocks.iter().enumerate() {
309            xs = blk.forward(&xs)?;
310            if blocks_to_take.contains(&i) {
311                output.push(xs.clone());
312            }
313        }
314        if output.len() != blocks_to_take.len() {
315            candle::bail!(
316                "only {} / {} blocks found",
317                output.len(),
318                blocks_to_take.len()
319            );
320        }
321        Ok(output)
322    }
323
324    pub fn get_intermediate_layers(
325        &self,
326        xs: &Tensor,
327        blocks_to_take: &[usize],
328        reshape: bool,
329        return_class_token: bool,
330        norm: bool,
331    ) -> Result<Tensor> {
332        let outputs = self.get_intermediate_layers_not_chunked(xs, blocks_to_take)?;
333        let outputs = if norm {
334            outputs
335                .iter()
336                .map(|out| self.norm.forward(out))
337                .collect::<Result<Vec<_>>>()?
338        } else {
339            outputs
340        };
341        let class_tokens = outputs
342            .iter()
343            .map(|out| out.i((.., 0)))
344            .collect::<Result<Vec<_>>>()?;
345        let outputs = outputs
346            .iter()
347            .map(|out| out.i((.., 1..)))
348            .collect::<Result<Vec<_>>>()?;
349
350        let outputs = if reshape {
351            let (b, _c, w, h) = xs.dims4()?;
352            let patch_size = self.patch_embed.patch_size.0;
353            let num_channels = outputs[0].elem_count() / (b * (w / patch_size) * (h / patch_size));
354            outputs
355                .iter()
356                .map(|out| {
357                    out.reshape((b, w / patch_size, h / patch_size, num_channels))?
358                        .transpose(2, 3)?
359                        .transpose(1, 2)
360                })
361                .collect::<Result<Vec<_>>>()?
362        } else {
363            outputs
364        };
365
366        let outputs = if return_class_token {
367            outputs
368                .iter()
369                .zip(class_tokens.iter())
370                .map(|(out, class_token)| Tensor::cat(&[out, class_token], D::Minus1))
371                .collect::<Result<Vec<_>>>()?
372        } else {
373            outputs
374        };
375
376        Tensor::stack(&outputs[..], 0)
377    }
378}
379
380impl Module for DinoVisionTransformer {
381    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
382        let mut xs = self.prepare_tokens_with_mask(xs)?;
383        for blk in self.blocks.iter() {
384            xs = blk.forward(&xs)?
385        }
386        let xs = self.norm.forward(&xs)?;
387        let xs_norm_clstoken = xs.i((.., 0))?;
388        let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?;
389        let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?;
390        self.head.forward(&xs)
391    }
392}
393
394pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> {
395    DinoVisionTransformer::new(vb, 12, 384, 6)
396}