candle_transformers/models/
marian.rs

1//! Marian Neural Machine Translation
2//!
3//! See "Marian: Fast Neural Machine Translation in C++" Junczys-Dowmunt et al. 2018
4//! - [ACL Anthology](https://aclanthology.org/P18-4020/)
5//! - [Github](https://github.com/marian-nmt/marian)
6//!
7use super::with_tracing::{linear, Embedding, Linear};
8use candle::{Result, Tensor};
9use candle_nn::{layer_norm, LayerNorm, VarBuilder};
10
11#[derive(Debug, Clone, serde::Deserialize)]
12pub struct Config {
13    pub vocab_size: usize,
14    pub decoder_vocab_size: Option<usize>,
15    pub max_position_embeddings: usize,
16    pub encoder_layers: usize,
17    pub encoder_ffn_dim: usize,
18    pub encoder_attention_heads: usize,
19    pub decoder_layers: usize,
20    pub decoder_ffn_dim: usize,
21    pub decoder_attention_heads: usize,
22    pub use_cache: bool,
23    pub is_encoder_decoder: bool,
24    pub activation_function: candle_nn::Activation,
25    pub d_model: usize,
26    pub decoder_start_token_id: u32,
27    pub scale_embedding: bool,
28    pub pad_token_id: u32,
29    pub eos_token_id: u32,
30    pub forced_eos_token_id: u32,
31    pub share_encoder_decoder_embeddings: bool,
32}
33
34impl Config {
35    // https://huggingface.co/Helsinki-NLP/opus-mt-tc-big-fr-en/blob/main/config.json
36    pub fn opus_mt_tc_big_fr_en() -> Self {
37        Self {
38            activation_function: candle_nn::Activation::Relu,
39            d_model: 1024,
40            decoder_attention_heads: 16,
41            decoder_ffn_dim: 4096,
42            decoder_layers: 6,
43            decoder_start_token_id: 53016,
44            decoder_vocab_size: Some(53017),
45            encoder_attention_heads: 16,
46            encoder_ffn_dim: 4096,
47            encoder_layers: 6,
48            eos_token_id: 43311,
49            forced_eos_token_id: 43311,
50            is_encoder_decoder: true,
51            max_position_embeddings: 1024,
52            pad_token_id: 53016,
53            scale_embedding: true,
54            share_encoder_decoder_embeddings: true,
55            use_cache: true,
56            vocab_size: 53017,
57        }
58    }
59
60    // https://huggingface.co/Helsinki-NLP/opus-mt-fr-en/blob/main/config.json
61    pub fn opus_mt_fr_en() -> Self {
62        Self {
63            activation_function: candle_nn::Activation::Swish,
64            d_model: 512,
65            decoder_attention_heads: 8,
66            decoder_ffn_dim: 2048,
67            decoder_layers: 6,
68            decoder_start_token_id: 59513,
69            decoder_vocab_size: Some(59514),
70            encoder_attention_heads: 8,
71            encoder_ffn_dim: 2048,
72            encoder_layers: 6,
73            eos_token_id: 0,
74            forced_eos_token_id: 0,
75            is_encoder_decoder: true,
76            max_position_embeddings: 512,
77            pad_token_id: 59513,
78            scale_embedding: true,
79            share_encoder_decoder_embeddings: true,
80            use_cache: true,
81            vocab_size: 59514,
82        }
83    }
84}
85
86#[derive(Debug, Clone)]
87struct SinusoidalPositionalEmbedding {
88    emb: Embedding,
89}
90
91impl SinusoidalPositionalEmbedding {
92    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
93        let dev = vb.device();
94        let dtype = vb.dtype();
95        let num_positions = cfg.max_position_embeddings;
96        let dim = cfg.d_model;
97        let inv_freq: Vec<_> = (0..dim)
98            .step_by(2)
99            .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
100            .collect();
101        let inv_freq_len = inv_freq.len();
102        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
103        let t = Tensor::arange(0u32, num_positions as u32, dev)?
104            .to_dtype(dtype)?
105            .reshape((num_positions, 1))?;
106        let freqs = t.matmul(&inv_freq)?;
107        let sin = freqs.sin()?;
108        let cos = freqs.cos()?;
109        let weights = Tensor::cat(&[&sin, &cos], 1)?.contiguous()?;
110        let emb = Embedding::from_weights(weights)?;
111        Ok(Self { emb })
112    }
113
114    fn forward(&self, input_ids: &Tensor, past_kv_len: usize) -> Result<Tensor> {
115        let seq_len = input_ids.dim(1)?;
116        Tensor::arange(
117            past_kv_len as u32,
118            (past_kv_len + seq_len) as u32,
119            input_ids.device(),
120        )?
121        .apply(&self.emb)
122    }
123}
124
125#[derive(Debug, Clone)]
126struct Attention {
127    q_proj: Linear,
128    k_proj: Linear,
129    v_proj: Linear,
130    out_proj: Linear,
131    scaling: f64,
132    num_heads: usize,
133    head_dim: usize,
134    kv_cache: Option<(Tensor, Tensor)>,
135    is_decoder: bool,
136}
137
138impl Attention {
139    fn new(cfg: &Config, is_decoder: bool, vb: VarBuilder) -> Result<Self> {
140        let num_heads = if is_decoder {
141            cfg.decoder_attention_heads
142        } else {
143            cfg.encoder_attention_heads
144        };
145        let embed_dim = cfg.d_model;
146        let head_dim = embed_dim / num_heads;
147        let scaling = (head_dim as f64).powf(-0.5);
148        let q_proj = linear(embed_dim, embed_dim, vb.pp("q_proj"))?;
149        let k_proj = linear(embed_dim, embed_dim, vb.pp("k_proj"))?;
150        let v_proj = linear(embed_dim, embed_dim, vb.pp("v_proj"))?;
151        let out_proj = linear(embed_dim, embed_dim, vb.pp("out_proj"))?;
152        Ok(Self {
153            q_proj,
154            k_proj,
155            v_proj,
156            out_proj,
157            scaling,
158            num_heads,
159            head_dim,
160            kv_cache: None,
161            is_decoder,
162        })
163    }
164
165    fn _shape(&self, tensor: &Tensor, bsz: usize) -> Result<Tensor> {
166        tensor
167            .reshape((bsz, (), self.num_heads, self.head_dim))?
168            .transpose(1, 2)?
169            .contiguous()
170    }
171
172    fn forward(
173        &mut self,
174        xs: &Tensor,
175        kv_states: Option<&Tensor>,
176        attn_mask: Option<&Tensor>,
177    ) -> Result<Tensor> {
178        let (b_sz, tgt_len, _) = xs.dims3()?;
179        let query_states = (xs.apply(&self.q_proj)? * self.scaling)?;
180        let (key_states, value_states) = match kv_states {
181            None => {
182                let key_states = self._shape(&xs.apply(&self.k_proj)?, b_sz)?;
183                let value_states = self._shape(&xs.apply(&self.v_proj)?, b_sz)?;
184                if self.is_decoder {
185                    let kv_states = match &self.kv_cache {
186                        None => (key_states, value_states),
187                        Some((p_key_states, p_value_states)) => {
188                            let key_states = Tensor::cat(&[p_key_states, &key_states], 2)?;
189                            let value_states = Tensor::cat(&[p_value_states, &value_states], 2)?;
190                            (key_states, value_states)
191                        }
192                    };
193                    self.kv_cache = Some(kv_states.clone());
194                    kv_states
195                } else {
196                    (key_states, value_states)
197                }
198            }
199            Some(kv_states) => {
200                let key_states = self._shape(&kv_states.apply(&self.k_proj)?, b_sz)?;
201                let value_states = self._shape(&kv_states.apply(&self.v_proj)?, b_sz)?;
202                (key_states, value_states)
203            }
204        };
205        let proj_shape = (b_sz * self.num_heads, (), self.head_dim);
206        let query_states = self._shape(&query_states, b_sz)?.reshape(proj_shape)?;
207        let key_states = key_states.reshape(proj_shape)?;
208        let value_states = value_states.reshape(proj_shape)?;
209        let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
210        let attn_weights = match attn_mask {
211            None => attn_weights,
212            Some(attn_mask) => attn_weights.broadcast_add(attn_mask)?,
213        };
214        let attn_probs = candle_nn::ops::softmax_last_dim(&attn_weights)?;
215        let attn_output = attn_probs.matmul(&value_states)?;
216        attn_output
217            .reshape((b_sz, self.num_heads, tgt_len, self.head_dim))?
218            .transpose(1, 2)?
219            .reshape((b_sz, tgt_len, self.head_dim * self.num_heads))?
220            .apply(&self.out_proj)
221    }
222
223    fn reset_kv_cache(&mut self) {
224        self.kv_cache = None
225    }
226}
227
228#[derive(Debug, Clone)]
229struct EncoderLayer {
230    self_attn: Attention,
231    self_attn_layer_norm: LayerNorm,
232    activation_fn: candle_nn::Activation,
233    fc1: Linear,
234    fc2: Linear,
235    final_layer_norm: LayerNorm,
236}
237
238impl EncoderLayer {
239    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
240        let self_attn = Attention::new(cfg, true, vb.pp("self_attn"))?;
241        let self_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?;
242        let fc1 = linear(cfg.d_model, cfg.encoder_ffn_dim, vb.pp("fc1"))?;
243        let fc2 = linear(cfg.encoder_ffn_dim, cfg.d_model, vb.pp("fc2"))?;
244        let final_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("final_layer_norm"))?;
245        Ok(Self {
246            self_attn,
247            self_attn_layer_norm,
248            activation_fn: cfg.activation_function,
249            fc1,
250            fc2,
251            final_layer_norm,
252        })
253    }
254
255    fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
256        let residual = xs;
257        let xs = (self.self_attn.forward(xs, None, None)? + residual)?
258            .apply(&self.self_attn_layer_norm)?;
259        let residual = &xs;
260        let xs = xs
261            .apply(&self.fc1)?
262            .apply(&self.activation_fn)?
263            .apply(&self.fc2)?;
264        (xs + residual)?.apply(&self.final_layer_norm)
265    }
266
267    fn reset_kv_cache(&mut self) {
268        self.self_attn.reset_kv_cache()
269    }
270}
271
272#[derive(Debug, Clone)]
273struct DecoderLayer {
274    self_attn: Attention,
275    self_attn_layer_norm: LayerNorm,
276    activation_fn: candle_nn::Activation,
277    encoder_attn: Attention,
278    encoder_attn_layer_norm: LayerNorm,
279    fc1: Linear,
280    fc2: Linear,
281    final_layer_norm: LayerNorm,
282}
283
284impl DecoderLayer {
285    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
286        let self_attn = Attention::new(cfg, true, vb.pp("self_attn"))?;
287        let self_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?;
288        let encoder_attn = Attention::new(cfg, true, vb.pp("encoder_attn"))?;
289        let encoder_attn_layer_norm =
290            layer_norm(cfg.d_model, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
291        let fc1 = linear(cfg.d_model, cfg.decoder_ffn_dim, vb.pp("fc1"))?;
292        let fc2 = linear(cfg.decoder_ffn_dim, cfg.d_model, vb.pp("fc2"))?;
293        let final_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("final_layer_norm"))?;
294        Ok(Self {
295            self_attn,
296            self_attn_layer_norm,
297            activation_fn: cfg.activation_function,
298            encoder_attn,
299            encoder_attn_layer_norm,
300            fc1,
301            fc2,
302            final_layer_norm,
303        })
304    }
305
306    fn forward(
307        &mut self,
308        xs: &Tensor,
309        encoder_xs: Option<&Tensor>,
310        attn_mask: &Tensor,
311    ) -> Result<Tensor> {
312        let residual = xs;
313        let xs = (self.self_attn.forward(xs, None, Some(attn_mask))? + residual)?
314            .apply(&self.self_attn_layer_norm)?;
315        let xs = match encoder_xs {
316            None => xs,
317            Some(encoder_xs) => {
318                let residual = &xs;
319                let xs = self.encoder_attn.forward(&xs, Some(encoder_xs), None)?;
320                (residual + xs)?.apply(&self.encoder_attn_layer_norm)?
321            }
322        };
323        let residual = &xs;
324        let xs = xs
325            .apply(&self.fc1)?
326            .apply(&self.activation_fn)?
327            .apply(&self.fc2)?;
328        let xs = (xs + residual)?.apply(&self.final_layer_norm)?;
329        Ok(xs)
330    }
331
332    fn reset_kv_cache(&mut self) {
333        self.self_attn.reset_kv_cache();
334        self.encoder_attn.reset_kv_cache()
335    }
336}
337
338#[derive(Debug, Clone)]
339pub struct Encoder {
340    embed_tokens: Embedding,
341    embed_positions: SinusoidalPositionalEmbedding,
342    layers: Vec<EncoderLayer>,
343    embed_scale: Option<f64>,
344}
345
346impl Encoder {
347    fn new(cfg: &Config, embed_tokens: &Embedding, vb: VarBuilder) -> Result<Self> {
348        let embed_positions = SinusoidalPositionalEmbedding::new(cfg, vb.pp("embed_positions"))?;
349        let mut layers = Vec::with_capacity(cfg.encoder_layers);
350        let vb_l = vb.pp("layers");
351        for idx in 0..cfg.encoder_layers {
352            let layer = EncoderLayer::new(cfg, vb_l.pp(idx))?;
353            layers.push(layer)
354        }
355        let embed_scale = if cfg.scale_embedding {
356            Some((cfg.d_model as f64).sqrt())
357        } else {
358            None
359        };
360        Ok(Self {
361            embed_tokens: embed_tokens.clone(),
362            embed_positions,
363            layers,
364            embed_scale,
365        })
366    }
367
368    pub fn forward(&mut self, xs: &Tensor, past_kv_len: usize) -> Result<Tensor> {
369        let xs = xs.apply(&self.embed_tokens)?;
370        let xs = match self.embed_scale {
371            None => xs,
372            Some(scale) => (xs * scale)?,
373        };
374        let embed_pos = self
375            .embed_positions
376            .forward(&xs, past_kv_len)?
377            .unsqueeze(0)?;
378        let mut xs = xs.broadcast_add(&embed_pos)?;
379        for layer in self.layers.iter_mut() {
380            xs = layer.forward(&xs)?
381        }
382        Ok(xs)
383    }
384
385    pub fn reset_kv_cache(&mut self) {
386        for layer in self.layers.iter_mut() {
387            layer.reset_kv_cache()
388        }
389    }
390}
391
392#[derive(Debug, Clone)]
393pub struct Decoder {
394    embed_tokens: Embedding,
395    embed_positions: SinusoidalPositionalEmbedding,
396    layers: Vec<DecoderLayer>,
397    embed_scale: Option<f64>,
398}
399
400impl Decoder {
401    fn new(cfg: &Config, embed_tokens: &Embedding, vb: VarBuilder) -> Result<Self> {
402        let embed_positions = SinusoidalPositionalEmbedding::new(cfg, vb.pp("embed_positions"))?;
403        let mut layers = Vec::with_capacity(cfg.decoder_layers);
404        let vb_l = vb.pp("layers");
405        for idx in 0..cfg.decoder_layers {
406            let layer = DecoderLayer::new(cfg, vb_l.pp(idx))?;
407            layers.push(layer)
408        }
409        let embed_scale = if cfg.scale_embedding {
410            Some((cfg.d_model as f64).sqrt())
411        } else {
412            None
413        };
414        Ok(Self {
415            embed_tokens: embed_tokens.clone(),
416            embed_positions,
417            layers,
418            embed_scale,
419        })
420    }
421
422    pub fn forward(
423        &mut self,
424        xs: &Tensor,
425        encoder_xs: Option<&Tensor>,
426        past_kv_len: usize,
427        attn_mask: &Tensor,
428    ) -> Result<Tensor> {
429        let xs = xs.apply(&self.embed_tokens)?;
430        let xs = match self.embed_scale {
431            None => xs,
432            Some(scale) => (xs * scale)?,
433        };
434        let embed_pos = self
435            .embed_positions
436            .forward(&xs, past_kv_len)?
437            .unsqueeze(0)?;
438        let mut xs = xs.broadcast_add(&embed_pos)?;
439        for layer in self.layers.iter_mut() {
440            xs = layer.forward(&xs, encoder_xs, attn_mask)?;
441        }
442        Ok(xs)
443    }
444
445    pub fn reset_kv_cache(&mut self) {
446        for layer in self.layers.iter_mut() {
447            layer.reset_kv_cache()
448        }
449    }
450}
451
452#[derive(Debug, Clone)]
453struct Model {
454    shared: Embedding,
455    encoder: Encoder,
456    decoder: Decoder,
457}
458
459impl Model {
460    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
461        let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
462        let encoder = Encoder::new(cfg, &shared, vb.pp("encoder"))?;
463        let decoder = Decoder::new(cfg, &shared, vb.pp("decoder"))?;
464        Ok(Self {
465            shared,
466            encoder,
467            decoder,
468        })
469    }
470
471    fn reset_kv_cache(&mut self) {
472        self.encoder.reset_kv_cache();
473        self.decoder.reset_kv_cache();
474    }
475}
476
477#[derive(Debug, Clone)]
478pub struct MTModel {
479    model: Model,
480    lm_head: Linear,
481    final_logits_bias: Tensor,
482}
483
484impl MTModel {
485    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
486        let target_vocab_size = cfg.decoder_vocab_size.unwrap_or(cfg.vocab_size);
487        let final_logits_bias = vb.get((1, target_vocab_size), "final_logits_bias")?;
488        let model = Model::new(cfg, vb.pp("model"))?;
489        let lm_head = Linear::from_weights(model.shared.embeddings().clone(), None);
490        Ok(Self {
491            model,
492            lm_head,
493            final_logits_bias,
494        })
495    }
496
497    pub fn encoder(&mut self) -> &mut Encoder {
498        &mut self.model.encoder
499    }
500
501    pub fn decoder(&mut self) -> &mut Decoder {
502        &mut self.model.decoder
503    }
504
505    pub fn decode(
506        &mut self,
507        xs: &Tensor,
508        encoder_xs: &Tensor,
509        past_kv_len: usize,
510    ) -> Result<Tensor> {
511        let seq_len = xs.dim(1)?;
512        let mask: Vec<_> = (0..seq_len)
513            .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
514            .collect();
515        let mask = Tensor::from_vec(mask, (seq_len, seq_len), xs.device())?;
516        self.model
517            .decoder
518            .forward(xs, Some(encoder_xs), past_kv_len, &mask)?
519            .apply(&self.lm_head)?
520            .broadcast_add(&self.final_logits_bias)
521    }
522
523    pub fn reset_kv_cache(&mut self) {
524        self.model.reset_kv_cache();
525    }
526}