candle_transformers/models/
dinov2reg4.rs

1//! Implementation of the DINOv2 revision (4 regularization)
2//!
3//! The DINOv2-reg4 model is a variant of DINOv2 that adds 4 regularization tokens to the
4//! original architecture. This implementation is specifically trained for plant species
5//! classification on the PlantCLEF2024 dataset with 7,806 classes.
6//!
7//! - [Paper](https://arxiv.org/abs/2309.16588). DINOv2: Learning Robust Visual Features without Supervision
8//! - [GH Repo](https://github.com/facebookresearch/dinov2)
9//!
10//! # Example
11//!
12//! ```bash
13//! # Download classes names and a plant picture to identify
14//! # see candle/examples/dinov2reg4 for full code.
15//!
16//! # Perform inference
17//! cargo run \
18//!   --example dinov2reg4 \
19//!   --release -- \
20//!   --image <orchid-file>
21//!
22//! > Orchis simia Lam.       : 45.55%
23//! > Orchis × bergonii Nanteuil: 9.80%
24//! > Orchis italica Poir.    : 9.66%
25//! > Orchis × angusticruris Franch.: 2.76%
26//! > Orchis × bivonae Tod.   : 2.54%
27//! ```
28//!
29//! <div align=center>
30//!   <img src="https://bs.plantnet.org/image/o/bd2d3830ac3270218ba82fd24e2290becd01317c" alt="" width=320>
31//! </div>
32//!
33use candle::{IndexOp, Result, Tensor, D};
34use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
35
36const IMG_SIZE: usize = 518;
37const PATCH_SIZE: usize = 14;
38const NUM_CLASSES: usize = 7806; // PlantCLEF2024 DINOv2 (https://zenodo.org/records/10848263)
39
40fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
41    if bias {
42        candle_nn::linear(in_dim, out_dim, vb)
43    } else {
44        candle_nn::linear_no_bias(in_dim, out_dim, vb)
45    }
46}
47
48#[derive(Debug)]
49struct Attention {
50    qkv: Linear,
51    proj: Linear,
52    num_heads: usize,
53    scale: f64,
54}
55
56impl Attention {
57    fn new(
58        vb: VarBuilder,
59        dim: usize,
60        num_heads: usize,
61        qkv_bias: bool,
62        proj_bias: bool,
63    ) -> Result<Self> {
64        let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
65        let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
66        let scale = 1. / ((dim / num_heads) as f64).sqrt();
67        Ok(Self {
68            qkv,
69            proj,
70            num_heads,
71            scale,
72        })
73    }
74}
75
76impl Module for Attention {
77    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
78        let (b, n, c) = xs.dims3()?;
79        let qkv = self
80            .qkv
81            .forward(xs)?
82            .reshape((b, n, 3, self.num_heads, c / self.num_heads))?
83            .transpose(1, 2)? // 02134
84            .transpose(0, 1)? // 20134
85            .transpose(2, 3)?; // 20314
86        let q = (qkv.i(0)? * self.scale)?;
87        let k = qkv.i(1)?.contiguous()?;
88        let v = qkv.i(2)?.contiguous()?;
89        let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?;
90        let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
91        self.proj.forward(&attn)
92    }
93}
94
95#[derive(Debug)]
96struct LayerScale {
97    gamma: Tensor,
98}
99
100impl LayerScale {
101    fn new(vb: VarBuilder, dim: usize) -> Result<Self> {
102        let gamma = vb.get(dim, "gamma")?;
103        Ok(Self { gamma })
104    }
105}
106
107impl Module for LayerScale {
108    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
109        xs.broadcast_mul(&self.gamma)
110    }
111}
112
113#[derive(Debug)]
114struct Mlp {
115    fc1: Linear,
116    fc2: Linear,
117}
118
119impl Mlp {
120    fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> {
121        let out_features = in_features;
122        let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?;
123        let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?;
124        Ok(Self { fc1, fc2 })
125    }
126}
127
128impl Module for Mlp {
129    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
130        let xs = self.fc1.forward(xs)?.gelu()?;
131        self.fc2.forward(&xs)
132    }
133}
134
135#[derive(Debug)]
136struct Block {
137    norm1: LayerNorm,
138    attn: Attention,
139    ls1: LayerScale,
140    norm2: LayerNorm,
141    mlp: Mlp,
142    ls2: LayerScale,
143}
144
145impl Block {
146    fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
147        let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?;
148        let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?;
149        let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
150        let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?;
151        let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
152        let ls2 = LayerScale::new(vb.pp("ls2"), dim)?;
153        Ok(Self {
154            norm1,
155            attn,
156            ls1,
157            norm2,
158            mlp,
159            ls2,
160        })
161    }
162}
163
164impl Module for Block {
165    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
166        let residual = xs;
167        let xs = self
168            .ls1
169            .forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?;
170        let xs = (xs + residual)?;
171        let residual = &xs;
172        let xs = self
173            .ls2
174            .forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?;
175        xs + residual
176    }
177}
178
179#[derive(Debug)]
180struct PatchEmbed {
181    proj: candle_nn::Conv2d,
182    patch_size: (usize, usize),
183    num_patches: usize,
184}
185
186impl PatchEmbed {
187    fn new(
188        vb: VarBuilder,
189        img_size: usize,
190        patch_size: usize,
191        in_chans: usize,
192        embed_dim: usize,
193    ) -> Result<Self> {
194        let config = candle_nn::Conv2dConfig {
195            stride: patch_size,
196            ..Default::default()
197        };
198        let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?;
199        let num_patches = (img_size / patch_size) * (img_size / patch_size);
200        Ok(Self {
201            proj,
202            patch_size: (patch_size, patch_size),
203            num_patches,
204        })
205    }
206}
207
208impl Module for PatchEmbed {
209    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
210        let (_b, _c, h, w) = xs.dims4()?;
211        let (patch_h, patch_w) = self.patch_size;
212        if (h % patch_h) != 0 {
213            candle::bail!("image height {h} is not a multiple of patch height {patch_h}")
214        }
215        if (w % patch_w) != 0 {
216            candle::bail!("image width {w} is not a multiple of patch width {patch_w}")
217        }
218        let xs = self.proj.forward(xs)?;
219        let (b, c, h, w) = xs.dims4()?;
220        // flatten embeddings.
221        xs.reshape((b, c, h * w))?.transpose(1, 2)
222    }
223}
224
225#[derive(Debug)]
226pub struct DinoVisionTransformer {
227    patch_embed: PatchEmbed,
228    cls_token: Tensor,
229    reg_token: Tensor,
230    pos_embed: Tensor,
231    blocks: Vec<Block>,
232    norm: LayerNorm,
233    head: Linear,
234}
235
236impl DinoVisionTransformer {
237    pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> {
238        let patch_embed =
239            PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?;
240        let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
241        let reg_token = vb.get((1, 4, embed_dim), "reg_token")?;
242        let pos_embed = vb.get((1, patch_embed.num_patches, embed_dim), "pos_embed")?;
243        let head = linear(vb.pp("head"), embed_dim, NUM_CLASSES, true)?;
244        let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?;
245        let vb_b = vb.pp("blocks");
246        let blocks = (0..depth)
247            .map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads))
248            .collect::<Result<Vec<_>>>()?;
249        Ok(Self {
250            patch_embed,
251            cls_token,
252            reg_token,
253            pos_embed,
254            blocks,
255            norm,
256            head,
257        })
258    }
259
260    fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
261        let npatch = xs.dim(1)? - 1;
262        let n = self.pos_embed.dim(1)? - 1;
263        let sqrt_n = (n as f64).sqrt();
264        if npatch == n && w == h {
265            return Ok(self.pos_embed.clone());
266        }
267        let patch_pos_embed = &self.pos_embed;
268        let dim = xs.dim(D::Minus1)?;
269        let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
270        let patch_pos_embed = patch_pos_embed
271            .reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
272            .transpose(2, 3)?
273            .transpose(1, 2)?;
274        // This uses bicubic interpolation in the original implementation.
275        let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
276        let el_count = patch_pos_embed.shape().elem_count();
277        patch_pos_embed
278            .transpose(1, 2)?
279            .transpose(2, 3)?
280            .reshape((1, el_count / dim, dim))
281    }
282
283    fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
284        let (_b, _nc, w, h) = xs.dims4()?;
285        if (w != IMG_SIZE) || (h != IMG_SIZE) {
286            panic!("Error: The input tensor should have the shape: Bx3x518x518.");
287        }
288        let xs = self.patch_embed.forward(xs)?;
289        let xs = (&xs + &self.interpolate_pos_encoding(&xs, w, h)?)?;
290        let xs = Tensor::cat(&[&self.cls_token, &self.reg_token, &xs], 1)?;
291        Ok(xs)
292    }
293}
294
295impl Module for DinoVisionTransformer {
296    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
297        let mut xs = self.prepare_tokens_with_mask(xs)?;
298        for blk in self.blocks.iter() {
299            xs = blk.forward(&xs)?
300        }
301        let xs = self.norm.forward(&xs)?;
302        let xs_norm_clstoken = xs.i((.., 0))?;
303        self.head.forward(&xs_norm_clstoken)
304    }
305}
306
307pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> {
308    DinoVisionTransformer::new(vb, 12, 384, 6)
309}
310
311pub fn vit_base(vb: VarBuilder) -> Result<DinoVisionTransformer> {
312    DinoVisionTransformer::new(vb, 12, 768, 12)
313}