candle_transformers/models/
vit.rs

1//! Vision Transformer (ViT) implementation.
2//!
3//! Vision Transformer applies transformer architecture to image classification
4//! by splitting images into patches and processing them as a sequence.
5//!
6//! Key characteristics:
7//! - Image patches as sequence tokens
8//! - Self-attention between patches
9//! - Position embeddings
10//! - CLS token for classification
11//! - Layer normalization
12//!
13//! References:
14//! - [ViT Paper](https://arxiv.org/abs/2010.11929)
15//! - [Model Card](https://huggingface.co/google/vit-base-patch16-224)
16//!
17
18use crate::models::with_tracing::{conv2d, linear, linear_no_bias, Conv2d, Linear};
19use candle::{IndexOp, Module, Result, Tensor, D};
20use candle_nn::{layer_norm, LayerNorm, VarBuilder};
21
22// https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/configuration_vit.py
23#[derive(Debug, Clone, serde::Deserialize)]
24pub struct Config {
25    pub hidden_size: usize,
26    pub num_hidden_layers: usize,
27    pub num_attention_heads: usize,
28    pub intermediate_size: usize,
29    pub hidden_act: candle_nn::Activation,
30    pub layer_norm_eps: f64,
31    pub image_size: usize,
32    pub patch_size: usize,
33    pub num_channels: usize,
34    pub qkv_bias: bool,
35}
36
37impl Config {
38    // https://huggingface.co/google/vit-base-patch16-224/blob/main/config.json
39    pub fn vit_base_patch16_224() -> Self {
40        Self {
41            hidden_size: 768,
42            num_hidden_layers: 12,
43            num_attention_heads: 12,
44            intermediate_size: 3072,
45            hidden_act: candle_nn::Activation::Gelu,
46            layer_norm_eps: 1e-12,
47            image_size: 224,
48            patch_size: 16,
49            num_channels: 3,
50            qkv_bias: true,
51        }
52    }
53
54    pub fn microsoft_trocr_base_handwritten() -> Self {
55        Self {
56            hidden_size: 768,
57            num_hidden_layers: 12,
58            num_attention_heads: 12,
59            intermediate_size: 3072,
60            hidden_act: candle_nn::Activation::Gelu,
61            layer_norm_eps: 1e-12,
62            image_size: 384,
63            patch_size: 16,
64            num_channels: 3,
65            qkv_bias: false,
66        }
67    }
68}
69
70#[derive(Debug, Clone)]
71struct PatchEmbeddings {
72    num_patches: usize,
73    projection: Conv2d,
74}
75
76impl PatchEmbeddings {
77    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
78        let image_size = cfg.image_size;
79        let patch_size = cfg.patch_size;
80        let num_patches = (image_size / patch_size) * (image_size / patch_size);
81        let conv_cfg = candle_nn::Conv2dConfig {
82            stride: patch_size,
83            ..Default::default()
84        };
85        let projection = conv2d(
86            cfg.num_channels,
87            cfg.hidden_size,
88            patch_size,
89            conv_cfg,
90            vb.pp("projection"),
91        )?;
92        Ok(Self {
93            num_patches,
94            projection,
95        })
96    }
97}
98
99impl Module for PatchEmbeddings {
100    fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
101        let (_b_size, _num_channels, _height, _width) = pixel_values.dims4()?;
102        self.projection
103            .forward(pixel_values)?
104            .flatten_from(2)?
105            .transpose(1, 2)
106    }
107}
108
109#[derive(Debug, Clone)]
110pub struct Embeddings {
111    cls_token: Tensor,
112    mask_token: Option<Tensor>,
113    patch_embeddings: PatchEmbeddings,
114    position_embeddings: Tensor,
115    hidden_size: usize,
116}
117
118impl Embeddings {
119    pub fn new(cfg: &Config, use_mask_token: bool, vb: VarBuilder) -> Result<Self> {
120        let hidden_size = cfg.hidden_size;
121        let cls_token = vb.get((1, 1, hidden_size), "cls_token")?;
122        let mask_token = if use_mask_token {
123            Some(vb.get((1, 1, hidden_size), "mask_token")?)
124        } else {
125            None
126        };
127        let patch_embeddings = PatchEmbeddings::new(cfg, vb.pp("patch_embeddings"))?;
128        let num_patches = patch_embeddings.num_patches;
129        let position_embeddings =
130            vb.get((1, num_patches + 1, hidden_size), "position_embeddings")?;
131        Ok(Self {
132            cls_token,
133            mask_token,
134            patch_embeddings,
135            position_embeddings,
136            hidden_size,
137        })
138    }
139
140    fn interpolate_pos_encoding(
141        &self,
142        _embeddings: &Tensor,
143        _height: usize,
144        _width: usize,
145    ) -> Result<Tensor> {
146        todo!()
147    }
148
149    pub fn forward(
150        &self,
151        pixel_values: &Tensor,
152        bool_masked_pos: Option<&Tensor>,
153        interpolate_pos_encoding: bool,
154    ) -> Result<Tensor> {
155        let (b_size, _num_channels, height, width) = pixel_values.dims4()?;
156        let embeddings = self.patch_embeddings.forward(pixel_values)?;
157        let embeddings = match (bool_masked_pos, &self.mask_token) {
158            (None, _) => embeddings,
159            (Some(_), None) => candle::bail!("bool_masked_pos set without mask_token"),
160            (Some(bool_masked_pos), Some(mask_tokens)) => {
161                let seq_len = embeddings.dim(1)?;
162                let mask_tokens = mask_tokens.broadcast_as((b_size, seq_len, self.hidden_size))?;
163                let mask = bool_masked_pos
164                    .unsqueeze(D::Minus1)?
165                    .to_dtype(mask_tokens.dtype())?;
166                ((mask_tokens * &mask)? - (embeddings * (mask - 1.)?)?)?
167            }
168        };
169        let cls_tokens = self.cls_token.broadcast_as((b_size, 1, self.hidden_size))?;
170        let embeddings = Tensor::cat(&[&cls_tokens, &embeddings], 1)?;
171        if interpolate_pos_encoding {
172            let pos = self.interpolate_pos_encoding(&embeddings, height, width)?;
173            embeddings.broadcast_add(&pos)
174        } else {
175            embeddings.broadcast_add(&self.position_embeddings)
176        }
177    }
178}
179
180#[derive(Debug, Clone)]
181struct SelfAttention {
182    query: Linear,
183    key: Linear,
184    value: Linear,
185    num_attention_heads: usize,
186    attention_head_size: usize,
187}
188
189impl SelfAttention {
190    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
191        let attention_head_size = cfg.hidden_size / cfg.num_attention_heads;
192        let num_attention_heads = cfg.num_attention_heads;
193        let all_head_size = num_attention_heads * attention_head_size;
194        let linear = |name| {
195            if cfg.qkv_bias {
196                linear(cfg.hidden_size, all_head_size, vb.pp(name))
197            } else {
198                linear_no_bias(cfg.hidden_size, all_head_size, vb.pp(name))
199            }
200        };
201        let query = linear("query")?;
202        let key = linear("key")?;
203        let value = linear("value")?;
204        Ok(Self {
205            query,
206            key,
207            value,
208            num_attention_heads,
209            attention_head_size,
210        })
211    }
212
213    fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
214        let (b_size, seq_len, _) = xs.dims3()?;
215        xs.reshape((
216            b_size,
217            seq_len,
218            self.num_attention_heads,
219            self.attention_head_size,
220        ))?
221        .permute((0, 2, 1, 3))
222    }
223}
224
225impl Module for SelfAttention {
226    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
227        let query = self.query.forward(xs)?;
228        let key = self.key.forward(xs)?;
229        let value = self.value.forward(xs)?;
230
231        let query = self.transpose_for_scores(&query)?.contiguous()?;
232        let key = self.transpose_for_scores(&key)?.contiguous()?;
233        let value = self.transpose_for_scores(&value)?.contiguous()?;
234
235        let attention_scores =
236            (query.matmul(&key.t()?)? / f64::sqrt(self.attention_head_size as f64))?;
237        let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
238        attention_probs
239            .matmul(&value)?
240            .permute((0, 2, 1, 3))?
241            .contiguous()?
242            .flatten_from(D::Minus2)
243    }
244}
245
246#[derive(Debug, Clone)]
247struct SelfOutput {
248    dense: Linear,
249}
250
251impl SelfOutput {
252    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
253        let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
254        Ok(Self { dense })
255    }
256}
257
258impl Module for SelfOutput {
259    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
260        xs.apply(&self.dense)
261    }
262}
263
264#[derive(Debug, Clone)]
265struct Attention {
266    attention: SelfAttention,
267    output: SelfOutput,
268}
269
270impl Attention {
271    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
272        let attention = SelfAttention::new(cfg, vb.pp("attention"))?;
273        let output = SelfOutput::new(cfg, vb.pp("output"))?;
274        Ok(Self { attention, output })
275    }
276}
277
278impl Module for Attention {
279    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
280        xs.apply(&self.attention)?.apply(&self.output)
281    }
282}
283
284#[derive(Debug, Clone)]
285struct Intermediate {
286    dense: Linear,
287    intermediate_act_fn: candle_nn::Activation,
288}
289
290impl Intermediate {
291    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
292        let dense = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("dense"))?;
293        Ok(Self {
294            dense,
295            intermediate_act_fn: cfg.hidden_act,
296        })
297    }
298}
299
300impl Module for Intermediate {
301    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
302        xs.apply(&self.dense)?.apply(&self.intermediate_act_fn)
303    }
304}
305
306#[derive(Debug, Clone)]
307struct Output {
308    dense: Linear,
309}
310
311impl Output {
312    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
313        let dense = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("dense"))?;
314        Ok(Self { dense })
315    }
316
317    fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
318        xs.apply(&self.dense)? + input_tensor
319    }
320}
321
322#[derive(Debug, Clone)]
323struct Layer {
324    attention: Attention,
325    intermediate: Intermediate,
326    output: Output,
327    layernorm_before: LayerNorm,
328    layernorm_after: LayerNorm,
329}
330
331impl Layer {
332    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
333        let attention = Attention::new(cfg, vb.pp("attention"))?;
334        let intermediate = Intermediate::new(cfg, vb.pp("intermediate"))?;
335        let output = Output::new(cfg, vb.pp("output"))?;
336        let h_sz = cfg.hidden_size;
337        let layernorm_before = layer_norm(h_sz, cfg.layer_norm_eps, vb.pp("layernorm_before"))?;
338        let layernorm_after = layer_norm(h_sz, cfg.layer_norm_eps, vb.pp("layernorm_after"))?;
339        Ok(Self {
340            attention,
341            intermediate,
342            output,
343            layernorm_after,
344            layernorm_before,
345        })
346    }
347}
348
349impl Module for Layer {
350    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
351        let xs = (xs.apply(&self.layernorm_before)?.apply(&self.attention)? + xs)?;
352        let ys = xs.apply(&self.layernorm_after)?.apply(&self.intermediate)?;
353        self.output.forward(&ys, &xs)
354    }
355}
356
357#[derive(Debug, Clone)]
358pub struct Encoder {
359    layers: Vec<Layer>,
360}
361
362impl Encoder {
363    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
364        let vb = vb.pp("layer");
365        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
366        for i in 0..cfg.num_hidden_layers {
367            let layer = Layer::new(cfg, vb.pp(i))?;
368            layers.push(layer)
369        }
370        Ok(Self { layers })
371    }
372}
373
374impl Module for Encoder {
375    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
376        let mut xs = xs.clone();
377        for layer in self.layers.iter() {
378            xs = xs.apply(layer)?
379        }
380        Ok(xs)
381    }
382}
383
384#[derive(Debug, Clone)]
385pub struct Model {
386    embeddings: Embeddings,
387    encoder: Encoder,
388    layernorm: LayerNorm,
389    // no need for pooling layer for image classification
390    classifier: Linear,
391}
392
393impl Model {
394    pub fn new(cfg: &Config, num_labels: usize, vb: VarBuilder) -> Result<Self> {
395        let vb_v = vb.pp("vit");
396        let embeddings = Embeddings::new(cfg, false, vb_v.pp("embeddings"))?;
397        let encoder = Encoder::new(cfg, vb_v.pp("encoder"))?;
398        let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_v.pp("layernorm"))?;
399        let classifier = linear(cfg.hidden_size, num_labels, vb.pp("classifier"))?;
400        Ok(Self {
401            embeddings,
402            encoder,
403            layernorm,
404            classifier,
405        })
406    }
407
408    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
409        let embedding_output = self.embeddings.forward(xs, None, false)?;
410        let encoder_outputs = self.encoder.forward(&embedding_output)?;
411        encoder_outputs
412            .i((.., 0, ..))?
413            .apply(&self.layernorm)?
414            .apply(&self.classifier)
415    }
416}