candle_transformers/models/
eva2.rs

1//! EVA-2 inference implementation.
2//!
3//! EVA-02 is a computer vision model that can be used as an ImageNet classifier.
4//! The model returns the probability for an image to belong to each of the 1000
5//! ImageNet categories.
6//!
7//! - [Paper](https://arxiv.org/abs/2303.11331). EVA-02: A Visual Representation for Neon Genesis
8//! - [Code](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/eva2.py)
9//!
10//! # Example
11//!
12//! ```bash
13//! cargo run \
14//!   --example eva2 \
15//!   --release -- \
16//!   --image candle-examples/examples/yolo-v8/assets/bike.jpg
17//!
18//! > mountain bike, all-terrain bike, off-roader: 37.09%
19//! > maillot                 : 8.30%
20//! > alp                     : 2.13%
21//! > bicycle-built-for-two, tandem bicycle, tandem: 0.84%
22//! > crash helmet            : 0.73%
23//! ```
24//!
25//! <div align=center>
26//!   <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.jpg" alt="" width=640>
27//! </div>
28//!
29use candle::{IndexOp, Result, Tensor, D};
30use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
31
32const IMG_SIZE: usize = 448;
33const PATCH_SIZE: usize = 14;
34const NUM_CLASSES: usize = 1000;
35
36fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
37    if bias {
38        candle_nn::linear(in_dim, out_dim, vb)
39    } else {
40        candle_nn::linear_no_bias(in_dim, out_dim, vb)
41    }
42}
43
44#[derive(Debug)]
45struct Attention {
46    q: Linear,
47    k: Linear,
48    v: Linear,
49    proj: Linear,
50    rot_pos_embed: Tensor,
51    num_heads: usize,
52    scale: f64,
53}
54
55impl Attention {
56    fn new(
57        vb: VarBuilder,
58        dim: usize,
59        num_heads: usize,
60        qkv_bias: bool,
61        proj_bias: bool,
62        rot_pos_embed: &Tensor,
63    ) -> Result<Self> {
64        let q = linear(vb.pp("q_proj"), dim, dim, qkv_bias)?;
65        let k = linear(vb.pp("k_proj"), dim, dim, false)?; // no bias for Key
66        let v = linear(vb.pp("v_proj"), dim, dim, qkv_bias)?;
67        let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
68        let rot_pos_embed = rot_pos_embed.clone();
69        let scale = 1. / ((dim / num_heads) as f64).sqrt();
70        Ok(Self {
71            q,
72            k,
73            v,
74            proj,
75            rot_pos_embed,
76            num_heads,
77            scale,
78        })
79    }
80}
81
82impl Attention {
83    // See: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/pos_embed_sincos.py#L210
84    fn apply_rot_embed_cat(x: &Tensor, emb: &Tensor) -> Result<Tensor> {
85        let cos_emb = emb.i((0.., 64..128))?; //.transpose(0, 1)?;
86        let sin_emb = emb.i((0.., 0..64))?; //.transpose(0, 1)?;
87        let index_even: [u32; 32] = (0u32..=63)
88            .step_by(2)
89            .collect::<Vec<_>>()
90            .try_into()
91            .expect("wrong size iterator");
92        let index_odd: [u32; 32] = (1u32..=63)
93            .step_by(2)
94            .collect::<Vec<_>>()
95            .try_into()
96            .expect("wrong size iterator");
97        let t_index_even = Tensor::new(&index_even, x.device())?;
98        let t_index_odd = Tensor::new(&index_odd, x.device())?;
99        let x_c = x.contiguous()?;
100        let rot_x_even = x_c.index_select(&t_index_even, D::Minus1)?;
101        let rot_x_odd_minus = (-1.0 * x_c.index_select(&t_index_odd, D::Minus1)?)?;
102        let rot_x =
103            Tensor::stack(&[&rot_x_odd_minus, &rot_x_even], D::Minus1)?.reshape(x.shape())?;
104        x.broadcast_mul(&cos_emb)? + rot_x.broadcast_mul(&sin_emb)?
105    }
106}
107
108impl Module for Attention {
109    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
110        let (b, n, c) = xs.dims3()?;
111        let qkv = Tensor::cat(
112            &[
113                &self.q.forward(xs)?,
114                &self.k.forward(xs)?,
115                &self.v.forward(xs)?,
116            ],
117            2,
118        )?
119        .reshape((b, n, 3, self.num_heads, c / self.num_heads))?
120        .transpose(1, 2)? // 02134
121        .transpose(0, 1)? // 20134
122        .transpose(2, 3)?; // 20314
123        let q = qkv.i(0)?;
124        let k = qkv.i(1)?.contiguous()?;
125        let v = qkv.i(2)?.contiguous()?;
126
127        let npt = 1; // num_prefix_tokens = 1 for CLS token
128        let q = Tensor::cat(
129            &[
130                &q.i((0.., 0.., ..npt, 0..))?,
131                &Self::apply_rot_embed_cat(&q.i((0.., 0.., npt.., 0..))?, &self.rot_pos_embed)?,
132            ],
133            2,
134        )?;
135        let k = Tensor::cat(
136            &[
137                &k.i((0.., 0.., ..npt, 0..))?,
138                &Self::apply_rot_embed_cat(&k.i((0.., 0.., npt.., 0..))?, &self.rot_pos_embed)?,
139            ],
140            2,
141        )?;
142
143        let q = (q * self.scale)?;
144        let attn = &q.matmul(&k.t()?)?;
145        let attn = candle_nn::ops::softmax(attn, D::Minus1)?;
146        let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
147        self.proj.forward(&attn)
148    }
149}
150
151#[derive(Debug)]
152struct Mlp {
153    fc1_g: Linear,
154    fc1_x: Linear,
155    norm: LayerNorm,
156    fc2: Linear,
157}
158
159impl Mlp {
160    fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> {
161        let out_features = in_features;
162        let fc1_g = linear(vb.pp("fc1_g"), in_features, hidden_features, bias)?;
163        let fc1_x = linear(vb.pp("fc1_x"), in_features, hidden_features, bias)?;
164        let norm = layer_norm(hidden_features, 1e-6, vb.pp("norm"))?;
165        let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?;
166        Ok(Self {
167            fc1_g,
168            fc1_x,
169            norm,
170            fc2,
171        })
172    }
173}
174
175impl Module for Mlp {
176    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
177        let xs_g = self.fc1_g.forward(xs)?.silu()?;
178        let xs = self.fc1_x.forward(xs)?;
179        let xs = self.norm.forward(&(xs_g.mul(&xs)?))?;
180        self.fc2.forward(&xs)
181    }
182}
183
184#[derive(Debug)]
185struct Block {
186    norm1: LayerNorm,
187    attn: Attention,
188    norm2: LayerNorm,
189    mlp: Mlp,
190}
191
192impl Block {
193    fn new(vb: VarBuilder, dim: usize, num_heads: usize, rot_pos_embed: &Tensor) -> Result<Self> {
194        let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?;
195        let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true, rot_pos_embed)?;
196        let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?;
197        let hidden_dim = dim * 4 * 2 / 3; // 768 * 4 * 2 / 3 = 3072 * 2 / 3 = 2048
198        let mlp = Mlp::new(vb.pp("mlp"), dim, hidden_dim, true)?;
199        Ok(Self {
200            norm1,
201            attn,
202            norm2,
203            mlp,
204        })
205    }
206}
207
208impl Module for Block {
209    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
210        let residual = xs;
211        let xs = &self.attn.forward(&self.norm1.forward(xs)?)?;
212        let xs = (xs + residual)?;
213        let residual = &xs;
214        let xs = &self.mlp.forward(&self.norm2.forward(&xs)?)?;
215        xs + residual
216    }
217}
218
219#[derive(Debug)]
220struct PatchEmbed {
221    proj: candle_nn::Conv2d,
222    patch_size: (usize, usize),
223    num_patches: usize,
224}
225
226impl PatchEmbed {
227    fn new(
228        vb: VarBuilder,
229        img_size: usize,
230        patch_size: usize,
231        in_chans: usize,
232        embed_dim: usize,
233    ) -> Result<Self> {
234        let config = candle_nn::Conv2dConfig {
235            stride: patch_size,
236            ..Default::default()
237        };
238        let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?;
239        let num_patches = (img_size / patch_size) * (img_size / patch_size);
240        Ok(Self {
241            proj,
242            patch_size: (patch_size, patch_size),
243            num_patches,
244        })
245    }
246}
247
248impl Module for PatchEmbed {
249    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
250        let (_b, _c, h, w) = xs.dims4()?;
251        let (patch_h, patch_w) = self.patch_size;
252        if (h % patch_h) != 0 {
253            candle::bail!("image height {h} is not a multiple of patch height {patch_h}")
254        }
255        if (w % patch_w) != 0 {
256            candle::bail!("image width {w} is not a multiple of patch width {patch_w}")
257        }
258        let xs = self.proj.forward(xs)?;
259        let (b, c, h, w) = xs.dims4()?;
260        // flatten embeddings.
261        xs.reshape((b, c, h * w))?.transpose(1, 2)
262    }
263}
264
265#[derive(Debug)]
266pub struct EVA2VisionTransformer {
267    patch_embed: PatchEmbed,
268    cls_token: Tensor,
269    pos_embed: Tensor,
270    blocks: Vec<Block>,
271    norm: LayerNorm,
272    head: Linear,
273}
274
275impl EVA2VisionTransformer {
276    pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> {
277        let patch_embed =
278            PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?;
279        let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
280        let pos_embed = vb.get((1, patch_embed.num_patches + 1, embed_dim), "pos_embed")?;
281        let rot_pos_embed = vb.get((patch_embed.num_patches, 128), "rot_pos_embed")?;
282        let head = linear(vb.pp("head"), embed_dim, NUM_CLASSES, true)?;
283        let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?;
284        let vb_b = vb.pp("blocks");
285        let blocks = (0..depth)
286            .map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads, &rot_pos_embed))
287            .collect::<Result<Vec<_>>>()?;
288        Ok(Self {
289            patch_embed,
290            cls_token,
291            pos_embed,
292            blocks,
293            norm,
294            head,
295        })
296    }
297
298    fn interpolate_pos_encoding(
299        &self,
300        xs: &Tensor,
301        w: usize,
302        h: usize,
303        num_prefix_tokens: usize,
304    ) -> Result<Tensor> {
305        let npatch = xs.dim(1)? - 1;
306        let n = self.pos_embed.dim(1)? - 1;
307        let sqrt_n = (n as f64).sqrt();
308        if npatch == n && w == h {
309            return Ok(self.pos_embed.clone());
310        }
311        // Interpolate only local tokens, i.e. those after the CLS token
312        let prefix_tokens_pos_embed = self.pos_embed.i((0.., ..num_prefix_tokens, 0..))?.clone();
313        let patch_pos_embed = &self.pos_embed.i((0.., num_prefix_tokens.., 0..))?;
314        let dim = xs.dim(D::Minus1)?;
315        let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
316        let patch_pos_embed = patch_pos_embed
317            .reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
318            .transpose(2, 3)?
319            .transpose(1, 2)?;
320        // This uses bicubic interpolation in the original implementation.
321        let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
322        let el_count = patch_pos_embed.shape().elem_count();
323        let patch_pos_embed =
324            patch_pos_embed
325                .transpose(1, 2)?
326                .transpose(2, 3)?
327                .reshape((1, el_count / dim, dim))?;
328        Tensor::cat(&[&prefix_tokens_pos_embed, &patch_pos_embed], 1)
329    }
330
331    fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
332        let (_b, _nc, w, h) = xs.dims4()?;
333        if (w != IMG_SIZE) || (h != IMG_SIZE) {
334            panic!("Error: The input tensor should have the shape: Bx3x518x518.");
335        }
336        let xs = self.patch_embed.forward(xs)?;
337        let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
338        let xs = (&xs + &self.interpolate_pos_encoding(&xs, w, h, 1)?)?;
339        Ok(xs)
340    }
341
342    fn get_intermediate_layers_not_chunked(
343        &self,
344        xs: &Tensor,
345        blocks_to_take: &[usize],
346    ) -> Result<Vec<Tensor>> {
347        let mut xs = self.prepare_tokens_with_mask(xs)?;
348        let mut output = Vec::new();
349        for (i, blk) in self.blocks.iter().enumerate() {
350            xs = blk.forward(&xs)?;
351            if blocks_to_take.contains(&i) {
352                output.push(xs.clone());
353            }
354        }
355        if output.len() != blocks_to_take.len() {
356            candle::bail!(
357                "only {} / {} blocks found",
358                output.len(),
359                blocks_to_take.len()
360            );
361        }
362        Ok(output)
363    }
364
365    pub fn get_intermediate_layers(
366        &self,
367        xs: &Tensor,
368        blocks_to_take: &[usize],
369        reshape: bool,
370        return_class_token: bool,
371        norm: bool,
372    ) -> Result<Tensor> {
373        let outputs = self.get_intermediate_layers_not_chunked(xs, blocks_to_take)?;
374        let outputs = if norm {
375            outputs
376                .iter()
377                .map(|out| self.norm.forward(out))
378                .collect::<Result<Vec<_>>>()?
379        } else {
380            outputs
381        };
382        let class_tokens = outputs
383            .iter()
384            .map(|out| out.i((.., 0)))
385            .collect::<Result<Vec<_>>>()?;
386        let outputs = outputs
387            .iter()
388            .map(|out| out.i((.., 1..)))
389            .collect::<Result<Vec<_>>>()?;
390
391        let outputs = if reshape {
392            let (b, _c, w, h) = xs.dims4()?;
393            let patch_size = self.patch_embed.patch_size.0;
394            let num_channels = outputs[0].elem_count() / (b * (w / patch_size) * (h / patch_size));
395            outputs
396                .iter()
397                .map(|out| {
398                    out.reshape((b, w / patch_size, h / patch_size, num_channels))?
399                        .transpose(2, 3)?
400                        .transpose(1, 2)
401                })
402                .collect::<Result<Vec<_>>>()?
403        } else {
404            outputs
405        };
406
407        let outputs = if return_class_token {
408            outputs
409                .iter()
410                .zip(class_tokens.iter())
411                .map(|(out, class_token)| Tensor::cat(&[out, class_token], D::Minus1))
412                .collect::<Result<Vec<_>>>()?
413        } else {
414            outputs
415        };
416
417        Tensor::stack(&outputs[..], 0)
418    }
419}
420
421impl Module for EVA2VisionTransformer {
422    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
423        let mut xs = self.prepare_tokens_with_mask(xs)?;
424        for blk in self.blocks.iter() {
425            xs = blk.forward(&xs)?
426        }
427        let xs_moy_local_tokens = xs.i((.., 1..))?.mean(1)?;
428        let xs_norm = self.norm.forward(&xs_moy_local_tokens)?;
429        self.head.forward(&xs_norm)
430    }
431}
432
433pub fn vit_base(vb: VarBuilder) -> Result<EVA2VisionTransformer> {
434    EVA2VisionTransformer::new(vb, 12, 768, 12)
435}
436
437pub fn vit_large(vb: VarBuilder) -> Result<EVA2VisionTransformer> {
438    EVA2VisionTransformer::new(vb, 24, 1024, 16)
439}