candle_transformers/models/
trocr.rs

1//! TrOCR model implementation.
2//!
3//! TrOCR is a Transformer-based OCR model that uses a Vision Transformer encoder
4//! and a BART-like decoder for optical character recognition.
5//!
6//! Key characteristics:
7//! - Vision Transformer encoder for image processing
8//! - BART-style decoder for text generation
9//! - Learned positional embeddings
10//! - Layer normalization and self-attention
11//!
12//! References:
13//! - [Paper](https://arxiv.org/abs/2109.10282)
14//! - [Model Card](https://huggingface.co/microsoft/trocr-base-handwritten)
15//!
16
17use crate::models::vit::{Config, Embeddings, Encoder};
18use candle::{DType, Result, Tensor};
19use candle_nn::{
20    embedding, layer_norm, linear_no_bias, Embedding, LayerNorm, Linear, Module, VarBuilder,
21};
22
23fn default_tie_word_embeddings() -> bool {
24    true
25}
26fn default_use_learned_position_embeddings() -> bool {
27    true
28}
29
30#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
31pub struct TrOCRConfig {
32    pub vocab_size: usize,
33    pub d_model: usize,
34    pub cross_attention_hidden_size: usize,
35    pub decoder_layers: usize,
36    pub decoder_attention_heads: usize,
37    pub decoder_ffn_dim: usize,
38    pub activation_function: candle_nn::Activation,
39    pub max_position_embeddings: usize,
40    pub dropout: f64,
41    pub attention_dropout: f64,
42    pub activation_dropout: f64,
43    pub decoder_start_token_id: u32,
44    pub init_std: f64,
45    pub decoder_layerdrop: f64,
46    pub use_cache: bool,
47    pub scale_embedding: bool,
48    pub pad_token_id: usize,
49    pub bos_token_id: usize,
50    pub eos_token_id: u32,
51    pub decoder_vocab_size: Option<usize>,
52    #[serde(default = "default_use_learned_position_embeddings")]
53    pub use_learned_position_embeddings: bool,
54    #[serde(default = "default_tie_word_embeddings")]
55    pub tie_word_embeddings: bool,
56}
57
58impl Default for TrOCRConfig {
59    fn default() -> Self {
60        Self {
61            vocab_size: 50265,
62            d_model: 1024,
63            cross_attention_hidden_size: 768,
64            decoder_layers: 12,
65            decoder_attention_heads: 16,
66            decoder_ffn_dim: 4096,
67            activation_function: candle_nn::Activation::Gelu,
68            max_position_embeddings: 512,
69            dropout: 0.1,
70            attention_dropout: 0.0,
71            activation_dropout: 0.0,
72            decoder_start_token_id: 2,
73            init_std: 0.02,
74            decoder_layerdrop: 0.0,
75            use_cache: true,
76            scale_embedding: false,
77            pad_token_id: 1,
78            bos_token_id: 0,
79            eos_token_id: 2,
80            decoder_vocab_size: Some(50265),
81            use_learned_position_embeddings: true,
82            tie_word_embeddings: true,
83        }
84    }
85}
86
87#[derive(Debug, Clone)]
88struct TrOCRLearnedPositionalEmbedding {
89    offset: usize,
90    weights: Embedding,
91}
92
93impl TrOCRLearnedPositionalEmbedding {
94    fn load(vb: VarBuilder, cfg: &TrOCRConfig) -> Result<Self> {
95        let offset: usize = 2;
96        let num_embeddings = cfg.max_position_embeddings;
97        let embedding_dim = cfg.d_model;
98        let weights = embedding(num_embeddings + offset, embedding_dim, vb)?;
99
100        Ok(Self { offset, weights })
101    }
102
103    fn new_sinusoidal(vb: VarBuilder, cfg: &TrOCRConfig) -> Result<Self> {
104        // https://github.com/huggingface/transformers/blob/58e3d23e97078f361a533b9ec4a6a2de674ea52a/src/transformers/models/trocr/modeling_trocr.py#L81
105        let embedding_dim = cfg.d_model;
106        let half_dim = embedding_dim / 2;
107        let num_positions = cfg.max_position_embeddings + cfg.pad_token_id + 1;
108        let dev = vb.device();
109        let inv_freq: Vec<_> = (0..half_dim)
110            .map(|i| 1f32 / 10000f32.powf(i as f32 / (half_dim - 1) as f32))
111            .collect();
112        let inv_freq_len = inv_freq.len();
113        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
114        let t = Tensor::arange(0u32, num_positions as u32, dev)?
115            .to_dtype(DType::F32)?
116            .reshape((num_positions, 1))?;
117        let freqs = t.matmul(&inv_freq)?;
118        let emb = Tensor::cat(&[freqs.sin()?, freqs.cos()?], 1)?;
119        let emb = Tensor::cat(
120            &[
121                emb.narrow(0, 0, cfg.pad_token_id)?,
122                Tensor::zeros((1, embedding_dim), DType::F32, dev)?,
123                emb.narrow(0, cfg.pad_token_id + 1, cfg.max_position_embeddings)?,
124            ],
125            0,
126        )?
127        .contiguous()?;
128        let emb = Embedding::new(emb, embedding_dim);
129        Ok(Self {
130            offset: cfg.pad_token_id + 1,
131            weights: emb,
132        })
133    }
134
135    fn forward(&mut self, input_ids: &Tensor, past_key_values_length: u32) -> Result<Tensor> {
136        let (b_sz, seq_len) = input_ids.dims2()?;
137
138        let positions = Tensor::arange(
139            past_key_values_length,
140            seq_len as u32 + past_key_values_length,
141            input_ids.device(),
142        )?
143        .expand((b_sz, seq_len))?;
144
145        let positions =
146            positions.broadcast_add(&Tensor::new(self.offset as u32, input_ids.device())?)?;
147        self.weights.forward(&positions)
148    }
149}
150
151#[derive(Debug, Clone)]
152struct TrOCRAttention {
153    head_dim: usize,
154    num_heads: usize,
155    is_decoder: bool,
156    scaling: f64,
157    k_proj: Linear,
158    v_proj: Linear,
159    q_proj: Linear,
160    out_proj: Linear,
161    kv_cache: Option<(Tensor, Tensor)>,
162}
163
164impl TrOCRAttention {
165    fn load(
166        vb: VarBuilder,
167        cfg: &TrOCRConfig,
168        kdim: Option<usize>,
169        vdim: Option<usize>,
170    ) -> Result<Self> {
171        let embed_dim = cfg.d_model;
172        let num_heads = cfg.decoder_attention_heads;
173        let head_dim = embed_dim / num_heads;
174        let kdim = kdim.unwrap_or(embed_dim);
175        let vdim = vdim.unwrap_or(embed_dim);
176
177        let k_proj = linear_no_bias(kdim, embed_dim, vb.pp("k_proj"))?;
178        let v_proj = linear_no_bias(vdim, embed_dim, vb.pp("v_proj"))?;
179        let q_proj = linear_no_bias(embed_dim, embed_dim, vb.pp("q_proj"))?;
180
181        let out_proj = linear_no_bias(embed_dim, embed_dim, vb.pp("out_proj"))?;
182        Ok(Self {
183            head_dim,
184            num_heads,
185            is_decoder: true,
186            scaling: 1. / (head_dim as f64).sqrt(),
187            k_proj,
188            v_proj,
189            q_proj,
190            out_proj,
191            kv_cache: None,
192        })
193    }
194
195    fn reset_kv_cache(&mut self) {
196        self.kv_cache = None
197    }
198
199    fn _shape(&self, tensor: &Tensor, bsz: usize) -> Result<Tensor> {
200        tensor
201            .reshape((bsz, (), self.num_heads, self.head_dim))?
202            .transpose(1, 2)?
203            .contiguous()
204    }
205
206    fn forward(
207        &mut self,
208        xs: &Tensor,
209        kv_states: Option<&Tensor>,
210        attn_mask: Option<&Tensor>,
211    ) -> Result<Tensor> {
212        let (b_sz, tgt_len, _) = xs.dims3()?;
213        let query_states = (xs.apply(&self.q_proj)? * self.scaling)?;
214        let (key_states, value_states) = match kv_states {
215            None => {
216                let key_states = self._shape(&xs.apply(&self.k_proj)?, b_sz)?;
217                let value_states = self._shape(&xs.apply(&self.v_proj)?, b_sz)?;
218                if self.is_decoder {
219                    let kv_states = match &self.kv_cache {
220                        None => (key_states, value_states),
221                        Some((p_key_states, p_value_states)) => {
222                            let key_states = Tensor::cat(&[p_key_states, &key_states], 2)?;
223                            let value_states = Tensor::cat(&[p_value_states, &value_states], 2)?;
224                            (key_states, value_states)
225                        }
226                    };
227                    self.kv_cache = Some(kv_states.clone());
228                    kv_states
229                } else {
230                    (key_states, value_states)
231                }
232            }
233            Some(kv_states) => {
234                let key_states = self._shape(&kv_states.apply(&self.k_proj)?, b_sz)?;
235                let value_states = self._shape(&kv_states.apply(&self.v_proj)?, b_sz)?;
236                (key_states, value_states)
237            }
238        };
239        let proj_shape = (b_sz * self.num_heads, (), self.head_dim);
240        let query_states = self._shape(&query_states, b_sz)?.reshape(proj_shape)?;
241        let key_states = key_states.reshape(proj_shape)?;
242        let value_states = value_states.reshape(proj_shape)?;
243        let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
244        let attn_weights = match attn_mask {
245            None => attn_weights,
246            Some(attn_mask) => attn_weights.broadcast_add(attn_mask)?,
247        };
248        let attn_probs = candle_nn::ops::softmax_last_dim(&attn_weights)?;
249        let attn_output = attn_probs.matmul(&value_states)?;
250        attn_output
251            .reshape((b_sz, self.num_heads, tgt_len, self.head_dim))?
252            .transpose(1, 2)?
253            .reshape((b_sz, tgt_len, self.head_dim * self.num_heads))?
254            .apply(&self.out_proj)
255    }
256}
257
258#[derive(Debug, Clone)]
259struct TrOCRDecoderLayer {
260    self_attn: TrOCRAttention,
261    activation_fn: candle_nn::Activation,
262    self_attn_layer_norm: LayerNorm,
263    encoder_attn: TrOCRAttention,
264    encoder_attn_layer_norm: LayerNorm,
265    fc1: Linear,
266    fc2: Linear,
267    final_layer_norm: LayerNorm,
268}
269
270impl TrOCRDecoderLayer {
271    fn load(vb: VarBuilder, cfg: &TrOCRConfig) -> Result<Self> {
272        let embed_dim = cfg.d_model;
273        let self_attn = TrOCRAttention::load(vb.pp("self_attn"), cfg, None, None)?;
274        let self_attn_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("self_attn_layer_norm"))?;
275        let encoder_attn = TrOCRAttention::load(
276            vb.pp("encoder_attn"),
277            cfg,
278            Some(cfg.cross_attention_hidden_size),
279            Some(cfg.cross_attention_hidden_size),
280        )?;
281        let encoder_attn_layer_norm =
282            layer_norm(embed_dim, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
283        let fc1 = linear_no_bias(embed_dim, cfg.decoder_ffn_dim, vb.pp("fc1"))?;
284        let fc2 = linear_no_bias(cfg.decoder_ffn_dim, embed_dim, vb.pp("fc2"))?;
285        let final_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("final_layer_norm"))?;
286        Ok(Self {
287            self_attn,
288            activation_fn: cfg.activation_function,
289            self_attn_layer_norm,
290            encoder_attn,
291            encoder_attn_layer_norm,
292            fc1,
293            fc2,
294            final_layer_norm,
295        })
296    }
297
298    fn reset_kv_cache(&mut self) {
299        self.self_attn.reset_kv_cache();
300    }
301
302    fn forward(
303        &mut self,
304        xs: &Tensor,
305        attention_mask: &Tensor,
306        encoder_hidden_states: Option<&Tensor>,
307    ) -> Result<Tensor> {
308        let residual = xs.clone();
309        let xs = self.self_attn.forward(xs, None, Some(attention_mask))?;
310        let xs = (xs + residual)?;
311        let mut xs = self.self_attn_layer_norm.forward(&xs)?;
312
313        if let Some(encoder_hidden_states) = &encoder_hidden_states {
314            let residual = xs.clone();
315            let encoder_attention_mask = attention_mask.clone(); // TODO
316            xs = self.encoder_attn.forward(
317                &xs,
318                Some(encoder_hidden_states),
319                Some(&encoder_attention_mask),
320            )?;
321            xs = (xs + residual)?;
322            xs = self.encoder_attn_layer_norm.forward(&xs)?
323        }
324
325        let residual = xs.clone();
326        let xs = self.fc1.forward(&xs)?;
327        let xs = self.activation_fn.forward(&xs)?;
328        let xs = self.fc2.forward(&xs)?;
329        let xs = (xs + residual)?;
330        let xs = self.final_layer_norm.forward(&xs)?;
331
332        Ok(xs)
333    }
334}
335
336#[derive(Debug, Clone)]
337pub struct TrOCRDecoder {
338    layers: Vec<TrOCRDecoderLayer>,
339    embed_scale: Option<f64>,
340    embed_tokens: Embedding,
341    embed_positions: TrOCRLearnedPositionalEmbedding,
342}
343
344impl TrOCRDecoder {
345    fn new(cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
346        let vb = vb.pp("decoder.model.decoder");
347
348        let embed_tokens = embedding(cfg.vocab_size, cfg.d_model, vb.pp("embed_tokens"))?;
349        let embed_positions = if cfg.use_learned_position_embeddings {
350            TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?
351        } else {
352            TrOCRLearnedPositionalEmbedding::new_sinusoidal(vb.pp("embed_positions"), cfg)?
353        };
354        let mut layers = Vec::with_capacity(cfg.decoder_layers);
355        let vb_l = vb.pp("layers");
356        for idx in 0..cfg.decoder_layers {
357            let layer = TrOCRDecoderLayer::load(vb_l.pp(idx), cfg)?;
358            layers.push(layer)
359        }
360        let embed_scale = if cfg.scale_embedding {
361            Some((cfg.d_model as f64).sqrt())
362        } else {
363            None
364        };
365
366        Ok(Self {
367            layers,
368            embed_scale,
369            embed_tokens,
370            embed_positions,
371        })
372    }
373
374    fn reset_kv_cache(&mut self) {
375        self.layers.iter_mut().for_each(|l| l.reset_kv_cache())
376    }
377
378    pub fn forward(
379        &mut self,
380        xs: &Tensor,
381        encoder_xs: Option<&Tensor>,
382        past_kv_len: usize,
383        attn_mask: &Tensor,
384    ) -> Result<Tensor> {
385        let embed_pos = self.embed_positions.forward(xs, past_kv_len as u32)?;
386        let xs = xs.apply(&self.embed_tokens)?;
387
388        let xs = match self.embed_scale {
389            None => xs,
390            Some(scale) => (xs * scale)?,
391        };
392
393        let mut xs = xs.broadcast_add(&embed_pos)?;
394
395        for layer in self.layers.iter_mut() {
396            xs = layer.forward(&xs, attn_mask, encoder_xs)?;
397        }
398        Ok(xs)
399    }
400}
401
402#[derive(Debug, Clone)]
403pub struct TrOCREncoder {
404    embeddings: Embeddings,
405    encoder: Encoder,
406    layernorm: LayerNorm,
407}
408
409impl TrOCREncoder {
410    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
411        let vb_v = vb.pp("encoder");
412
413        let embeddings = Embeddings::new(cfg, false, vb_v.pp("embeddings"))?;
414
415        let encoder = Encoder::new(cfg, vb_v.pp("encoder"))?;
416        let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_v.pp("layernorm"))?;
417
418        Ok(Self {
419            embeddings,
420            encoder,
421            layernorm,
422        })
423    }
424
425    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
426        let embedding_output = self.embeddings.forward(xs, None, false)?;
427        let encoder_outputs = self.encoder.forward(&embedding_output)?;
428
429        self.layernorm.forward(&encoder_outputs)
430    }
431}
432
433#[derive(Debug, Clone)]
434pub struct TrOCRForCausalLM {
435    decoder: TrOCRDecoder,
436    output_projection: Linear,
437}
438
439impl TrOCRForCausalLM {
440    pub fn new(decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
441        let decoder = TrOCRDecoder::new(decoder_cfg, vb.clone())?;
442        let output_projection = if decoder_cfg.tie_word_embeddings {
443            candle_nn::Linear::new(decoder.embed_tokens.embeddings().clone(), None)
444        } else {
445            candle_nn::linear_no_bias(
446                decoder_cfg.d_model,
447                decoder_cfg.vocab_size,
448                vb.pp("decoder.output_projection"),
449            )?
450        };
451        Ok(Self {
452            decoder,
453            output_projection,
454        })
455    }
456
457    pub fn forward(
458        &mut self,
459        xs: &Tensor,
460        encoder_xs: Option<&Tensor>,
461        past_kv_len: usize,
462        attn_mask: &Tensor,
463    ) -> Result<Tensor> {
464        let xs = self
465            .decoder
466            .forward(xs, encoder_xs, past_kv_len, attn_mask)?;
467        let xs = xs.apply(&self.output_projection)?;
468
469        Ok(xs)
470    }
471
472    fn reset_kv_cache(&mut self) {
473        self.decoder.reset_kv_cache();
474    }
475}
476
477#[derive(Debug, Clone)]
478pub struct TrOCRModel {
479    encoder: TrOCREncoder,
480    decoder: TrOCRForCausalLM,
481}
482
483impl TrOCRModel {
484    pub fn new(encoder_cfg: &Config, decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
485        let encoder = TrOCREncoder::new(encoder_cfg, vb.clone())?;
486        let decoder = TrOCRForCausalLM::new(decoder_cfg, vb)?;
487        Ok(Self { encoder, decoder })
488    }
489
490    pub fn encoder(&mut self) -> &mut TrOCREncoder {
491        &mut self.encoder
492    }
493
494    pub fn decoder(&mut self) -> &mut TrOCRForCausalLM {
495        &mut self.decoder
496    }
497
498    pub fn decode(
499        &mut self,
500        xs: &Tensor,
501        encoder_xs: &Tensor,
502        past_kv_len: usize,
503    ) -> Result<Tensor> {
504        let seq_len = xs.dim(1)?;
505        let mask: Vec<_> = (0..seq_len)
506            .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
507            .collect();
508        let mask = Tensor::from_vec(mask, (seq_len, seq_len), xs.device())?;
509
510        self.decoder
511            .forward(xs, Some(encoder_xs), past_kv_len, &mask)
512    }
513
514    pub fn reset_kv_cache(&mut self) {
515        self.decoder.reset_kv_cache();
516    }
517}