1use crate::models::vit::{Config, Embeddings, Encoder};
18use candle::{DType, Result, Tensor};
19use candle_nn::{
20 embedding, layer_norm, linear_no_bias, Embedding, LayerNorm, Linear, Module, VarBuilder,
21};
22
23fn default_tie_word_embeddings() -> bool {
24 true
25}
26fn default_use_learned_position_embeddings() -> bool {
27 true
28}
29
30#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
31pub struct TrOCRConfig {
32 pub vocab_size: usize,
33 pub d_model: usize,
34 pub cross_attention_hidden_size: usize,
35 pub decoder_layers: usize,
36 pub decoder_attention_heads: usize,
37 pub decoder_ffn_dim: usize,
38 pub activation_function: candle_nn::Activation,
39 pub max_position_embeddings: usize,
40 pub dropout: f64,
41 pub attention_dropout: f64,
42 pub activation_dropout: f64,
43 pub decoder_start_token_id: u32,
44 pub init_std: f64,
45 pub decoder_layerdrop: f64,
46 pub use_cache: bool,
47 pub scale_embedding: bool,
48 pub pad_token_id: usize,
49 pub bos_token_id: usize,
50 pub eos_token_id: u32,
51 pub decoder_vocab_size: Option<usize>,
52 #[serde(default = "default_use_learned_position_embeddings")]
53 pub use_learned_position_embeddings: bool,
54 #[serde(default = "default_tie_word_embeddings")]
55 pub tie_word_embeddings: bool,
56}
57
58impl Default for TrOCRConfig {
59 fn default() -> Self {
60 Self {
61 vocab_size: 50265,
62 d_model: 1024,
63 cross_attention_hidden_size: 768,
64 decoder_layers: 12,
65 decoder_attention_heads: 16,
66 decoder_ffn_dim: 4096,
67 activation_function: candle_nn::Activation::Gelu,
68 max_position_embeddings: 512,
69 dropout: 0.1,
70 attention_dropout: 0.0,
71 activation_dropout: 0.0,
72 decoder_start_token_id: 2,
73 init_std: 0.02,
74 decoder_layerdrop: 0.0,
75 use_cache: true,
76 scale_embedding: false,
77 pad_token_id: 1,
78 bos_token_id: 0,
79 eos_token_id: 2,
80 decoder_vocab_size: Some(50265),
81 use_learned_position_embeddings: true,
82 tie_word_embeddings: true,
83 }
84 }
85}
86
87#[derive(Debug, Clone)]
88struct TrOCRLearnedPositionalEmbedding {
89 offset: usize,
90 weights: Embedding,
91}
92
93impl TrOCRLearnedPositionalEmbedding {
94 fn load(vb: VarBuilder, cfg: &TrOCRConfig) -> Result<Self> {
95 let offset: usize = 2;
96 let num_embeddings = cfg.max_position_embeddings;
97 let embedding_dim = cfg.d_model;
98 let weights = embedding(num_embeddings + offset, embedding_dim, vb)?;
99
100 Ok(Self { offset, weights })
101 }
102
103 fn new_sinusoidal(vb: VarBuilder, cfg: &TrOCRConfig) -> Result<Self> {
104 let embedding_dim = cfg.d_model;
106 let half_dim = embedding_dim / 2;
107 let num_positions = cfg.max_position_embeddings + cfg.pad_token_id + 1;
108 let dev = vb.device();
109 let inv_freq: Vec<_> = (0..half_dim)
110 .map(|i| 1f32 / 10000f32.powf(i as f32 / (half_dim - 1) as f32))
111 .collect();
112 let inv_freq_len = inv_freq.len();
113 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
114 let t = Tensor::arange(0u32, num_positions as u32, dev)?
115 .to_dtype(DType::F32)?
116 .reshape((num_positions, 1))?;
117 let freqs = t.matmul(&inv_freq)?;
118 let emb = Tensor::cat(&[freqs.sin()?, freqs.cos()?], 1)?;
119 let emb = Tensor::cat(
120 &[
121 emb.narrow(0, 0, cfg.pad_token_id)?,
122 Tensor::zeros((1, embedding_dim), DType::F32, dev)?,
123 emb.narrow(0, cfg.pad_token_id + 1, cfg.max_position_embeddings)?,
124 ],
125 0,
126 )?
127 .contiguous()?;
128 let emb = Embedding::new(emb, embedding_dim);
129 Ok(Self {
130 offset: cfg.pad_token_id + 1,
131 weights: emb,
132 })
133 }
134
135 fn forward(&mut self, input_ids: &Tensor, past_key_values_length: u32) -> Result<Tensor> {
136 let (b_sz, seq_len) = input_ids.dims2()?;
137
138 let positions = Tensor::arange(
139 past_key_values_length,
140 seq_len as u32 + past_key_values_length,
141 input_ids.device(),
142 )?
143 .expand((b_sz, seq_len))?;
144
145 let positions =
146 positions.broadcast_add(&Tensor::new(self.offset as u32, input_ids.device())?)?;
147 self.weights.forward(&positions)
148 }
149}
150
151#[derive(Debug, Clone)]
152struct TrOCRAttention {
153 head_dim: usize,
154 num_heads: usize,
155 is_decoder: bool,
156 scaling: f64,
157 k_proj: Linear,
158 v_proj: Linear,
159 q_proj: Linear,
160 out_proj: Linear,
161 kv_cache: Option<(Tensor, Tensor)>,
162}
163
164impl TrOCRAttention {
165 fn load(
166 vb: VarBuilder,
167 cfg: &TrOCRConfig,
168 kdim: Option<usize>,
169 vdim: Option<usize>,
170 ) -> Result<Self> {
171 let embed_dim = cfg.d_model;
172 let num_heads = cfg.decoder_attention_heads;
173 let head_dim = embed_dim / num_heads;
174 let kdim = kdim.unwrap_or(embed_dim);
175 let vdim = vdim.unwrap_or(embed_dim);
176
177 let k_proj = linear_no_bias(kdim, embed_dim, vb.pp("k_proj"))?;
178 let v_proj = linear_no_bias(vdim, embed_dim, vb.pp("v_proj"))?;
179 let q_proj = linear_no_bias(embed_dim, embed_dim, vb.pp("q_proj"))?;
180
181 let out_proj = linear_no_bias(embed_dim, embed_dim, vb.pp("out_proj"))?;
182 Ok(Self {
183 head_dim,
184 num_heads,
185 is_decoder: true,
186 scaling: 1. / (head_dim as f64).sqrt(),
187 k_proj,
188 v_proj,
189 q_proj,
190 out_proj,
191 kv_cache: None,
192 })
193 }
194
195 fn reset_kv_cache(&mut self) {
196 self.kv_cache = None
197 }
198
199 fn _shape(&self, tensor: &Tensor, bsz: usize) -> Result<Tensor> {
200 tensor
201 .reshape((bsz, (), self.num_heads, self.head_dim))?
202 .transpose(1, 2)?
203 .contiguous()
204 }
205
206 fn forward(
207 &mut self,
208 xs: &Tensor,
209 kv_states: Option<&Tensor>,
210 attn_mask: Option<&Tensor>,
211 ) -> Result<Tensor> {
212 let (b_sz, tgt_len, _) = xs.dims3()?;
213 let query_states = (xs.apply(&self.q_proj)? * self.scaling)?;
214 let (key_states, value_states) = match kv_states {
215 None => {
216 let key_states = self._shape(&xs.apply(&self.k_proj)?, b_sz)?;
217 let value_states = self._shape(&xs.apply(&self.v_proj)?, b_sz)?;
218 if self.is_decoder {
219 let kv_states = match &self.kv_cache {
220 None => (key_states, value_states),
221 Some((p_key_states, p_value_states)) => {
222 let key_states = Tensor::cat(&[p_key_states, &key_states], 2)?;
223 let value_states = Tensor::cat(&[p_value_states, &value_states], 2)?;
224 (key_states, value_states)
225 }
226 };
227 self.kv_cache = Some(kv_states.clone());
228 kv_states
229 } else {
230 (key_states, value_states)
231 }
232 }
233 Some(kv_states) => {
234 let key_states = self._shape(&kv_states.apply(&self.k_proj)?, b_sz)?;
235 let value_states = self._shape(&kv_states.apply(&self.v_proj)?, b_sz)?;
236 (key_states, value_states)
237 }
238 };
239 let proj_shape = (b_sz * self.num_heads, (), self.head_dim);
240 let query_states = self._shape(&query_states, b_sz)?.reshape(proj_shape)?;
241 let key_states = key_states.reshape(proj_shape)?;
242 let value_states = value_states.reshape(proj_shape)?;
243 let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
244 let attn_weights = match attn_mask {
245 None => attn_weights,
246 Some(attn_mask) => attn_weights.broadcast_add(attn_mask)?,
247 };
248 let attn_probs = candle_nn::ops::softmax_last_dim(&attn_weights)?;
249 let attn_output = attn_probs.matmul(&value_states)?;
250 attn_output
251 .reshape((b_sz, self.num_heads, tgt_len, self.head_dim))?
252 .transpose(1, 2)?
253 .reshape((b_sz, tgt_len, self.head_dim * self.num_heads))?
254 .apply(&self.out_proj)
255 }
256}
257
258#[derive(Debug, Clone)]
259struct TrOCRDecoderLayer {
260 self_attn: TrOCRAttention,
261 activation_fn: candle_nn::Activation,
262 self_attn_layer_norm: LayerNorm,
263 encoder_attn: TrOCRAttention,
264 encoder_attn_layer_norm: LayerNorm,
265 fc1: Linear,
266 fc2: Linear,
267 final_layer_norm: LayerNorm,
268}
269
270impl TrOCRDecoderLayer {
271 fn load(vb: VarBuilder, cfg: &TrOCRConfig) -> Result<Self> {
272 let embed_dim = cfg.d_model;
273 let self_attn = TrOCRAttention::load(vb.pp("self_attn"), cfg, None, None)?;
274 let self_attn_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("self_attn_layer_norm"))?;
275 let encoder_attn = TrOCRAttention::load(
276 vb.pp("encoder_attn"),
277 cfg,
278 Some(cfg.cross_attention_hidden_size),
279 Some(cfg.cross_attention_hidden_size),
280 )?;
281 let encoder_attn_layer_norm =
282 layer_norm(embed_dim, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
283 let fc1 = linear_no_bias(embed_dim, cfg.decoder_ffn_dim, vb.pp("fc1"))?;
284 let fc2 = linear_no_bias(cfg.decoder_ffn_dim, embed_dim, vb.pp("fc2"))?;
285 let final_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("final_layer_norm"))?;
286 Ok(Self {
287 self_attn,
288 activation_fn: cfg.activation_function,
289 self_attn_layer_norm,
290 encoder_attn,
291 encoder_attn_layer_norm,
292 fc1,
293 fc2,
294 final_layer_norm,
295 })
296 }
297
298 fn reset_kv_cache(&mut self) {
299 self.self_attn.reset_kv_cache();
300 }
301
302 fn forward(
303 &mut self,
304 xs: &Tensor,
305 attention_mask: &Tensor,
306 encoder_hidden_states: Option<&Tensor>,
307 ) -> Result<Tensor> {
308 let residual = xs.clone();
309 let xs = self.self_attn.forward(xs, None, Some(attention_mask))?;
310 let xs = (xs + residual)?;
311 let mut xs = self.self_attn_layer_norm.forward(&xs)?;
312
313 if let Some(encoder_hidden_states) = &encoder_hidden_states {
314 let residual = xs.clone();
315 let encoder_attention_mask = attention_mask.clone(); xs = self.encoder_attn.forward(
317 &xs,
318 Some(encoder_hidden_states),
319 Some(&encoder_attention_mask),
320 )?;
321 xs = (xs + residual)?;
322 xs = self.encoder_attn_layer_norm.forward(&xs)?
323 }
324
325 let residual = xs.clone();
326 let xs = self.fc1.forward(&xs)?;
327 let xs = self.activation_fn.forward(&xs)?;
328 let xs = self.fc2.forward(&xs)?;
329 let xs = (xs + residual)?;
330 let xs = self.final_layer_norm.forward(&xs)?;
331
332 Ok(xs)
333 }
334}
335
336#[derive(Debug, Clone)]
337pub struct TrOCRDecoder {
338 layers: Vec<TrOCRDecoderLayer>,
339 embed_scale: Option<f64>,
340 embed_tokens: Embedding,
341 embed_positions: TrOCRLearnedPositionalEmbedding,
342}
343
344impl TrOCRDecoder {
345 fn new(cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
346 let vb = vb.pp("decoder.model.decoder");
347
348 let embed_tokens = embedding(cfg.vocab_size, cfg.d_model, vb.pp("embed_tokens"))?;
349 let embed_positions = if cfg.use_learned_position_embeddings {
350 TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?
351 } else {
352 TrOCRLearnedPositionalEmbedding::new_sinusoidal(vb.pp("embed_positions"), cfg)?
353 };
354 let mut layers = Vec::with_capacity(cfg.decoder_layers);
355 let vb_l = vb.pp("layers");
356 for idx in 0..cfg.decoder_layers {
357 let layer = TrOCRDecoderLayer::load(vb_l.pp(idx), cfg)?;
358 layers.push(layer)
359 }
360 let embed_scale = if cfg.scale_embedding {
361 Some((cfg.d_model as f64).sqrt())
362 } else {
363 None
364 };
365
366 Ok(Self {
367 layers,
368 embed_scale,
369 embed_tokens,
370 embed_positions,
371 })
372 }
373
374 fn reset_kv_cache(&mut self) {
375 self.layers.iter_mut().for_each(|l| l.reset_kv_cache())
376 }
377
378 pub fn forward(
379 &mut self,
380 xs: &Tensor,
381 encoder_xs: Option<&Tensor>,
382 past_kv_len: usize,
383 attn_mask: &Tensor,
384 ) -> Result<Tensor> {
385 let embed_pos = self.embed_positions.forward(xs, past_kv_len as u32)?;
386 let xs = xs.apply(&self.embed_tokens)?;
387
388 let xs = match self.embed_scale {
389 None => xs,
390 Some(scale) => (xs * scale)?,
391 };
392
393 let mut xs = xs.broadcast_add(&embed_pos)?;
394
395 for layer in self.layers.iter_mut() {
396 xs = layer.forward(&xs, attn_mask, encoder_xs)?;
397 }
398 Ok(xs)
399 }
400}
401
402#[derive(Debug, Clone)]
403pub struct TrOCREncoder {
404 embeddings: Embeddings,
405 encoder: Encoder,
406 layernorm: LayerNorm,
407}
408
409impl TrOCREncoder {
410 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
411 let vb_v = vb.pp("encoder");
412
413 let embeddings = Embeddings::new(cfg, false, vb_v.pp("embeddings"))?;
414
415 let encoder = Encoder::new(cfg, vb_v.pp("encoder"))?;
416 let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_v.pp("layernorm"))?;
417
418 Ok(Self {
419 embeddings,
420 encoder,
421 layernorm,
422 })
423 }
424
425 pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
426 let embedding_output = self.embeddings.forward(xs, None, false)?;
427 let encoder_outputs = self.encoder.forward(&embedding_output)?;
428
429 self.layernorm.forward(&encoder_outputs)
430 }
431}
432
433#[derive(Debug, Clone)]
434pub struct TrOCRForCausalLM {
435 decoder: TrOCRDecoder,
436 output_projection: Linear,
437}
438
439impl TrOCRForCausalLM {
440 pub fn new(decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
441 let decoder = TrOCRDecoder::new(decoder_cfg, vb.clone())?;
442 let output_projection = if decoder_cfg.tie_word_embeddings {
443 candle_nn::Linear::new(decoder.embed_tokens.embeddings().clone(), None)
444 } else {
445 candle_nn::linear_no_bias(
446 decoder_cfg.d_model,
447 decoder_cfg.vocab_size,
448 vb.pp("decoder.output_projection"),
449 )?
450 };
451 Ok(Self {
452 decoder,
453 output_projection,
454 })
455 }
456
457 pub fn forward(
458 &mut self,
459 xs: &Tensor,
460 encoder_xs: Option<&Tensor>,
461 past_kv_len: usize,
462 attn_mask: &Tensor,
463 ) -> Result<Tensor> {
464 let xs = self
465 .decoder
466 .forward(xs, encoder_xs, past_kv_len, attn_mask)?;
467 let xs = xs.apply(&self.output_projection)?;
468
469 Ok(xs)
470 }
471
472 fn reset_kv_cache(&mut self) {
473 self.decoder.reset_kv_cache();
474 }
475}
476
477#[derive(Debug, Clone)]
478pub struct TrOCRModel {
479 encoder: TrOCREncoder,
480 decoder: TrOCRForCausalLM,
481}
482
483impl TrOCRModel {
484 pub fn new(encoder_cfg: &Config, decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
485 let encoder = TrOCREncoder::new(encoder_cfg, vb.clone())?;
486 let decoder = TrOCRForCausalLM::new(decoder_cfg, vb)?;
487 Ok(Self { encoder, decoder })
488 }
489
490 pub fn encoder(&mut self) -> &mut TrOCREncoder {
491 &mut self.encoder
492 }
493
494 pub fn decoder(&mut self) -> &mut TrOCRForCausalLM {
495 &mut self.decoder
496 }
497
498 pub fn decode(
499 &mut self,
500 xs: &Tensor,
501 encoder_xs: &Tensor,
502 past_kv_len: usize,
503 ) -> Result<Tensor> {
504 let seq_len = xs.dim(1)?;
505 let mask: Vec<_> = (0..seq_len)
506 .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
507 .collect();
508 let mask = Tensor::from_vec(mask, (seq_len, seq_len), xs.device())?;
509
510 self.decoder
511 .forward(xs, Some(encoder_xs), past_kv_len, &mask)
512 }
513
514 pub fn reset_kv_cache(&mut self) {
515 self.decoder.reset_kv_cache();
516 }
517}