1use crate::models::with_tracing::{conv2d, linear, linear_no_bias, Conv2d, Linear};
19use candle::{IndexOp, Module, Result, Tensor, D};
20use candle_nn::{layer_norm, LayerNorm, VarBuilder};
21
22#[derive(Debug, Clone, serde::Deserialize)]
24pub struct Config {
25 pub hidden_size: usize,
26 pub num_hidden_layers: usize,
27 pub num_attention_heads: usize,
28 pub intermediate_size: usize,
29 pub hidden_act: candle_nn::Activation,
30 pub layer_norm_eps: f64,
31 pub image_size: usize,
32 pub patch_size: usize,
33 pub num_channels: usize,
34 pub qkv_bias: bool,
35}
36
37impl Config {
38 pub fn vit_base_patch16_224() -> Self {
40 Self {
41 hidden_size: 768,
42 num_hidden_layers: 12,
43 num_attention_heads: 12,
44 intermediate_size: 3072,
45 hidden_act: candle_nn::Activation::Gelu,
46 layer_norm_eps: 1e-12,
47 image_size: 224,
48 patch_size: 16,
49 num_channels: 3,
50 qkv_bias: true,
51 }
52 }
53
54 pub fn microsoft_trocr_base_handwritten() -> Self {
55 Self {
56 hidden_size: 768,
57 num_hidden_layers: 12,
58 num_attention_heads: 12,
59 intermediate_size: 3072,
60 hidden_act: candle_nn::Activation::Gelu,
61 layer_norm_eps: 1e-12,
62 image_size: 384,
63 patch_size: 16,
64 num_channels: 3,
65 qkv_bias: false,
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
71struct PatchEmbeddings {
72 num_patches: usize,
73 projection: Conv2d,
74}
75
76impl PatchEmbeddings {
77 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
78 let image_size = cfg.image_size;
79 let patch_size = cfg.patch_size;
80 let num_patches = (image_size / patch_size) * (image_size / patch_size);
81 let conv_cfg = candle_nn::Conv2dConfig {
82 stride: patch_size,
83 ..Default::default()
84 };
85 let projection = conv2d(
86 cfg.num_channels,
87 cfg.hidden_size,
88 patch_size,
89 conv_cfg,
90 vb.pp("projection"),
91 )?;
92 Ok(Self {
93 num_patches,
94 projection,
95 })
96 }
97}
98
99impl Module for PatchEmbeddings {
100 fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
101 let (_b_size, _num_channels, _height, _width) = pixel_values.dims4()?;
102 self.projection
103 .forward(pixel_values)?
104 .flatten_from(2)?
105 .transpose(1, 2)
106 }
107}
108
109#[derive(Debug, Clone)]
110pub struct Embeddings {
111 cls_token: Tensor,
112 mask_token: Option<Tensor>,
113 patch_embeddings: PatchEmbeddings,
114 position_embeddings: Tensor,
115 hidden_size: usize,
116}
117
118impl Embeddings {
119 pub fn new(cfg: &Config, use_mask_token: bool, vb: VarBuilder) -> Result<Self> {
120 let hidden_size = cfg.hidden_size;
121 let cls_token = vb.get((1, 1, hidden_size), "cls_token")?;
122 let mask_token = if use_mask_token {
123 Some(vb.get((1, 1, hidden_size), "mask_token")?)
124 } else {
125 None
126 };
127 let patch_embeddings = PatchEmbeddings::new(cfg, vb.pp("patch_embeddings"))?;
128 let num_patches = patch_embeddings.num_patches;
129 let position_embeddings =
130 vb.get((1, num_patches + 1, hidden_size), "position_embeddings")?;
131 Ok(Self {
132 cls_token,
133 mask_token,
134 patch_embeddings,
135 position_embeddings,
136 hidden_size,
137 })
138 }
139
140 fn interpolate_pos_encoding(
141 &self,
142 _embeddings: &Tensor,
143 _height: usize,
144 _width: usize,
145 ) -> Result<Tensor> {
146 todo!()
147 }
148
149 pub fn forward(
150 &self,
151 pixel_values: &Tensor,
152 bool_masked_pos: Option<&Tensor>,
153 interpolate_pos_encoding: bool,
154 ) -> Result<Tensor> {
155 let (b_size, _num_channels, height, width) = pixel_values.dims4()?;
156 let embeddings = self.patch_embeddings.forward(pixel_values)?;
157 let embeddings = match (bool_masked_pos, &self.mask_token) {
158 (None, _) => embeddings,
159 (Some(_), None) => candle::bail!("bool_masked_pos set without mask_token"),
160 (Some(bool_masked_pos), Some(mask_tokens)) => {
161 let seq_len = embeddings.dim(1)?;
162 let mask_tokens = mask_tokens.broadcast_as((b_size, seq_len, self.hidden_size))?;
163 let mask = bool_masked_pos
164 .unsqueeze(D::Minus1)?
165 .to_dtype(mask_tokens.dtype())?;
166 ((mask_tokens * &mask)? - (embeddings * (mask - 1.)?)?)?
167 }
168 };
169 let cls_tokens = self.cls_token.broadcast_as((b_size, 1, self.hidden_size))?;
170 let embeddings = Tensor::cat(&[&cls_tokens, &embeddings], 1)?;
171 if interpolate_pos_encoding {
172 let pos = self.interpolate_pos_encoding(&embeddings, height, width)?;
173 embeddings.broadcast_add(&pos)
174 } else {
175 embeddings.broadcast_add(&self.position_embeddings)
176 }
177 }
178}
179
180#[derive(Debug, Clone)]
181struct SelfAttention {
182 query: Linear,
183 key: Linear,
184 value: Linear,
185 num_attention_heads: usize,
186 attention_head_size: usize,
187}
188
189impl SelfAttention {
190 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
191 let attention_head_size = cfg.hidden_size / cfg.num_attention_heads;
192 let num_attention_heads = cfg.num_attention_heads;
193 let all_head_size = num_attention_heads * attention_head_size;
194 let linear = |name| {
195 if cfg.qkv_bias {
196 linear(cfg.hidden_size, all_head_size, vb.pp(name))
197 } else {
198 linear_no_bias(cfg.hidden_size, all_head_size, vb.pp(name))
199 }
200 };
201 let query = linear("query")?;
202 let key = linear("key")?;
203 let value = linear("value")?;
204 Ok(Self {
205 query,
206 key,
207 value,
208 num_attention_heads,
209 attention_head_size,
210 })
211 }
212
213 fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
214 let (b_size, seq_len, _) = xs.dims3()?;
215 xs.reshape((
216 b_size,
217 seq_len,
218 self.num_attention_heads,
219 self.attention_head_size,
220 ))?
221 .permute((0, 2, 1, 3))
222 }
223}
224
225impl Module for SelfAttention {
226 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
227 let query = self.query.forward(xs)?;
228 let key = self.key.forward(xs)?;
229 let value = self.value.forward(xs)?;
230
231 let query = self.transpose_for_scores(&query)?.contiguous()?;
232 let key = self.transpose_for_scores(&key)?.contiguous()?;
233 let value = self.transpose_for_scores(&value)?.contiguous()?;
234
235 let attention_scores =
236 (query.matmul(&key.t()?)? / f64::sqrt(self.attention_head_size as f64))?;
237 let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
238 attention_probs
239 .matmul(&value)?
240 .permute((0, 2, 1, 3))?
241 .contiguous()?
242 .flatten_from(D::Minus2)
243 }
244}
245
246#[derive(Debug, Clone)]
247struct SelfOutput {
248 dense: Linear,
249}
250
251impl SelfOutput {
252 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
253 let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
254 Ok(Self { dense })
255 }
256}
257
258impl Module for SelfOutput {
259 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
260 xs.apply(&self.dense)
261 }
262}
263
264#[derive(Debug, Clone)]
265struct Attention {
266 attention: SelfAttention,
267 output: SelfOutput,
268}
269
270impl Attention {
271 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
272 let attention = SelfAttention::new(cfg, vb.pp("attention"))?;
273 let output = SelfOutput::new(cfg, vb.pp("output"))?;
274 Ok(Self { attention, output })
275 }
276}
277
278impl Module for Attention {
279 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
280 xs.apply(&self.attention)?.apply(&self.output)
281 }
282}
283
284#[derive(Debug, Clone)]
285struct Intermediate {
286 dense: Linear,
287 intermediate_act_fn: candle_nn::Activation,
288}
289
290impl Intermediate {
291 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
292 let dense = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("dense"))?;
293 Ok(Self {
294 dense,
295 intermediate_act_fn: cfg.hidden_act,
296 })
297 }
298}
299
300impl Module for Intermediate {
301 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
302 xs.apply(&self.dense)?.apply(&self.intermediate_act_fn)
303 }
304}
305
306#[derive(Debug, Clone)]
307struct Output {
308 dense: Linear,
309}
310
311impl Output {
312 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
313 let dense = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("dense"))?;
314 Ok(Self { dense })
315 }
316
317 fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
318 xs.apply(&self.dense)? + input_tensor
319 }
320}
321
322#[derive(Debug, Clone)]
323struct Layer {
324 attention: Attention,
325 intermediate: Intermediate,
326 output: Output,
327 layernorm_before: LayerNorm,
328 layernorm_after: LayerNorm,
329}
330
331impl Layer {
332 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
333 let attention = Attention::new(cfg, vb.pp("attention"))?;
334 let intermediate = Intermediate::new(cfg, vb.pp("intermediate"))?;
335 let output = Output::new(cfg, vb.pp("output"))?;
336 let h_sz = cfg.hidden_size;
337 let layernorm_before = layer_norm(h_sz, cfg.layer_norm_eps, vb.pp("layernorm_before"))?;
338 let layernorm_after = layer_norm(h_sz, cfg.layer_norm_eps, vb.pp("layernorm_after"))?;
339 Ok(Self {
340 attention,
341 intermediate,
342 output,
343 layernorm_after,
344 layernorm_before,
345 })
346 }
347}
348
349impl Module for Layer {
350 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
351 let xs = (xs.apply(&self.layernorm_before)?.apply(&self.attention)? + xs)?;
352 let ys = xs.apply(&self.layernorm_after)?.apply(&self.intermediate)?;
353 self.output.forward(&ys, &xs)
354 }
355}
356
357#[derive(Debug, Clone)]
358pub struct Encoder {
359 layers: Vec<Layer>,
360}
361
362impl Encoder {
363 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
364 let vb = vb.pp("layer");
365 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
366 for i in 0..cfg.num_hidden_layers {
367 let layer = Layer::new(cfg, vb.pp(i))?;
368 layers.push(layer)
369 }
370 Ok(Self { layers })
371 }
372}
373
374impl Module for Encoder {
375 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
376 let mut xs = xs.clone();
377 for layer in self.layers.iter() {
378 xs = xs.apply(layer)?
379 }
380 Ok(xs)
381 }
382}
383
384#[derive(Debug, Clone)]
385pub struct Model {
386 embeddings: Embeddings,
387 encoder: Encoder,
388 layernorm: LayerNorm,
389 classifier: Linear,
391}
392
393impl Model {
394 pub fn new(cfg: &Config, num_labels: usize, vb: VarBuilder) -> Result<Self> {
395 let vb_v = vb.pp("vit");
396 let embeddings = Embeddings::new(cfg, false, vb_v.pp("embeddings"))?;
397 let encoder = Encoder::new(cfg, vb_v.pp("encoder"))?;
398 let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_v.pp("layernorm"))?;
399 let classifier = linear(cfg.hidden_size, num_labels, vb.pp("classifier"))?;
400 Ok(Self {
401 embeddings,
402 encoder,
403 layernorm,
404 classifier,
405 })
406 }
407
408 pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
409 let embedding_output = self.embeddings.forward(xs, None, false)?;
410 let encoder_outputs = self.encoder.forward(&embedding_output)?;
411 encoder_outputs
412 .i((.., 0, ..))?
413 .apply(&self.layernorm)?
414 .apply(&self.classifier)
415 }
416}