1use super::with_tracing::{linear, Embedding, Linear};
8use candle::{Result, Tensor};
9use candle_nn::{layer_norm, LayerNorm, VarBuilder};
10
11#[derive(Debug, Clone, serde::Deserialize)]
12pub struct Config {
13 pub vocab_size: usize,
14 pub decoder_vocab_size: Option<usize>,
15 pub max_position_embeddings: usize,
16 pub encoder_layers: usize,
17 pub encoder_ffn_dim: usize,
18 pub encoder_attention_heads: usize,
19 pub decoder_layers: usize,
20 pub decoder_ffn_dim: usize,
21 pub decoder_attention_heads: usize,
22 pub use_cache: bool,
23 pub is_encoder_decoder: bool,
24 pub activation_function: candle_nn::Activation,
25 pub d_model: usize,
26 pub decoder_start_token_id: u32,
27 pub scale_embedding: bool,
28 pub pad_token_id: u32,
29 pub eos_token_id: u32,
30 pub forced_eos_token_id: u32,
31 pub share_encoder_decoder_embeddings: bool,
32}
33
34impl Config {
35 pub fn opus_mt_tc_big_fr_en() -> Self {
37 Self {
38 activation_function: candle_nn::Activation::Relu,
39 d_model: 1024,
40 decoder_attention_heads: 16,
41 decoder_ffn_dim: 4096,
42 decoder_layers: 6,
43 decoder_start_token_id: 53016,
44 decoder_vocab_size: Some(53017),
45 encoder_attention_heads: 16,
46 encoder_ffn_dim: 4096,
47 encoder_layers: 6,
48 eos_token_id: 43311,
49 forced_eos_token_id: 43311,
50 is_encoder_decoder: true,
51 max_position_embeddings: 1024,
52 pad_token_id: 53016,
53 scale_embedding: true,
54 share_encoder_decoder_embeddings: true,
55 use_cache: true,
56 vocab_size: 53017,
57 }
58 }
59
60 pub fn opus_mt_fr_en() -> Self {
62 Self {
63 activation_function: candle_nn::Activation::Swish,
64 d_model: 512,
65 decoder_attention_heads: 8,
66 decoder_ffn_dim: 2048,
67 decoder_layers: 6,
68 decoder_start_token_id: 59513,
69 decoder_vocab_size: Some(59514),
70 encoder_attention_heads: 8,
71 encoder_ffn_dim: 2048,
72 encoder_layers: 6,
73 eos_token_id: 0,
74 forced_eos_token_id: 0,
75 is_encoder_decoder: true,
76 max_position_embeddings: 512,
77 pad_token_id: 59513,
78 scale_embedding: true,
79 share_encoder_decoder_embeddings: true,
80 use_cache: true,
81 vocab_size: 59514,
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
87struct SinusoidalPositionalEmbedding {
88 emb: Embedding,
89}
90
91impl SinusoidalPositionalEmbedding {
92 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
93 let dev = vb.device();
94 let dtype = vb.dtype();
95 let num_positions = cfg.max_position_embeddings;
96 let dim = cfg.d_model;
97 let inv_freq: Vec<_> = (0..dim)
98 .step_by(2)
99 .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
100 .collect();
101 let inv_freq_len = inv_freq.len();
102 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
103 let t = Tensor::arange(0u32, num_positions as u32, dev)?
104 .to_dtype(dtype)?
105 .reshape((num_positions, 1))?;
106 let freqs = t.matmul(&inv_freq)?;
107 let sin = freqs.sin()?;
108 let cos = freqs.cos()?;
109 let weights = Tensor::cat(&[&sin, &cos], 1)?.contiguous()?;
110 let emb = Embedding::from_weights(weights)?;
111 Ok(Self { emb })
112 }
113
114 fn forward(&self, input_ids: &Tensor, past_kv_len: usize) -> Result<Tensor> {
115 let seq_len = input_ids.dim(1)?;
116 Tensor::arange(
117 past_kv_len as u32,
118 (past_kv_len + seq_len) as u32,
119 input_ids.device(),
120 )?
121 .apply(&self.emb)
122 }
123}
124
125#[derive(Debug, Clone)]
126struct Attention {
127 q_proj: Linear,
128 k_proj: Linear,
129 v_proj: Linear,
130 out_proj: Linear,
131 scaling: f64,
132 num_heads: usize,
133 head_dim: usize,
134 kv_cache: Option<(Tensor, Tensor)>,
135 is_decoder: bool,
136}
137
138impl Attention {
139 fn new(cfg: &Config, is_decoder: bool, vb: VarBuilder) -> Result<Self> {
140 let num_heads = if is_decoder {
141 cfg.decoder_attention_heads
142 } else {
143 cfg.encoder_attention_heads
144 };
145 let embed_dim = cfg.d_model;
146 let head_dim = embed_dim / num_heads;
147 let scaling = (head_dim as f64).powf(-0.5);
148 let q_proj = linear(embed_dim, embed_dim, vb.pp("q_proj"))?;
149 let k_proj = linear(embed_dim, embed_dim, vb.pp("k_proj"))?;
150 let v_proj = linear(embed_dim, embed_dim, vb.pp("v_proj"))?;
151 let out_proj = linear(embed_dim, embed_dim, vb.pp("out_proj"))?;
152 Ok(Self {
153 q_proj,
154 k_proj,
155 v_proj,
156 out_proj,
157 scaling,
158 num_heads,
159 head_dim,
160 kv_cache: None,
161 is_decoder,
162 })
163 }
164
165 fn _shape(&self, tensor: &Tensor, bsz: usize) -> Result<Tensor> {
166 tensor
167 .reshape((bsz, (), self.num_heads, self.head_dim))?
168 .transpose(1, 2)?
169 .contiguous()
170 }
171
172 fn forward(
173 &mut self,
174 xs: &Tensor,
175 kv_states: Option<&Tensor>,
176 attn_mask: Option<&Tensor>,
177 ) -> Result<Tensor> {
178 let (b_sz, tgt_len, _) = xs.dims3()?;
179 let query_states = (xs.apply(&self.q_proj)? * self.scaling)?;
180 let (key_states, value_states) = match kv_states {
181 None => {
182 let key_states = self._shape(&xs.apply(&self.k_proj)?, b_sz)?;
183 let value_states = self._shape(&xs.apply(&self.v_proj)?, b_sz)?;
184 if self.is_decoder {
185 let kv_states = match &self.kv_cache {
186 None => (key_states, value_states),
187 Some((p_key_states, p_value_states)) => {
188 let key_states = Tensor::cat(&[p_key_states, &key_states], 2)?;
189 let value_states = Tensor::cat(&[p_value_states, &value_states], 2)?;
190 (key_states, value_states)
191 }
192 };
193 self.kv_cache = Some(kv_states.clone());
194 kv_states
195 } else {
196 (key_states, value_states)
197 }
198 }
199 Some(kv_states) => {
200 let key_states = self._shape(&kv_states.apply(&self.k_proj)?, b_sz)?;
201 let value_states = self._shape(&kv_states.apply(&self.v_proj)?, b_sz)?;
202 (key_states, value_states)
203 }
204 };
205 let proj_shape = (b_sz * self.num_heads, (), self.head_dim);
206 let query_states = self._shape(&query_states, b_sz)?.reshape(proj_shape)?;
207 let key_states = key_states.reshape(proj_shape)?;
208 let value_states = value_states.reshape(proj_shape)?;
209 let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
210 let attn_weights = match attn_mask {
211 None => attn_weights,
212 Some(attn_mask) => attn_weights.broadcast_add(attn_mask)?,
213 };
214 let attn_probs = candle_nn::ops::softmax_last_dim(&attn_weights)?;
215 let attn_output = attn_probs.matmul(&value_states)?;
216 attn_output
217 .reshape((b_sz, self.num_heads, tgt_len, self.head_dim))?
218 .transpose(1, 2)?
219 .reshape((b_sz, tgt_len, self.head_dim * self.num_heads))?
220 .apply(&self.out_proj)
221 }
222
223 fn reset_kv_cache(&mut self) {
224 self.kv_cache = None
225 }
226}
227
228#[derive(Debug, Clone)]
229struct EncoderLayer {
230 self_attn: Attention,
231 self_attn_layer_norm: LayerNorm,
232 activation_fn: candle_nn::Activation,
233 fc1: Linear,
234 fc2: Linear,
235 final_layer_norm: LayerNorm,
236}
237
238impl EncoderLayer {
239 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
240 let self_attn = Attention::new(cfg, true, vb.pp("self_attn"))?;
241 let self_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?;
242 let fc1 = linear(cfg.d_model, cfg.encoder_ffn_dim, vb.pp("fc1"))?;
243 let fc2 = linear(cfg.encoder_ffn_dim, cfg.d_model, vb.pp("fc2"))?;
244 let final_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("final_layer_norm"))?;
245 Ok(Self {
246 self_attn,
247 self_attn_layer_norm,
248 activation_fn: cfg.activation_function,
249 fc1,
250 fc2,
251 final_layer_norm,
252 })
253 }
254
255 fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
256 let residual = xs;
257 let xs = (self.self_attn.forward(xs, None, None)? + residual)?
258 .apply(&self.self_attn_layer_norm)?;
259 let residual = &xs;
260 let xs = xs
261 .apply(&self.fc1)?
262 .apply(&self.activation_fn)?
263 .apply(&self.fc2)?;
264 (xs + residual)?.apply(&self.final_layer_norm)
265 }
266
267 fn reset_kv_cache(&mut self) {
268 self.self_attn.reset_kv_cache()
269 }
270}
271
272#[derive(Debug, Clone)]
273struct DecoderLayer {
274 self_attn: Attention,
275 self_attn_layer_norm: LayerNorm,
276 activation_fn: candle_nn::Activation,
277 encoder_attn: Attention,
278 encoder_attn_layer_norm: LayerNorm,
279 fc1: Linear,
280 fc2: Linear,
281 final_layer_norm: LayerNorm,
282}
283
284impl DecoderLayer {
285 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
286 let self_attn = Attention::new(cfg, true, vb.pp("self_attn"))?;
287 let self_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?;
288 let encoder_attn = Attention::new(cfg, true, vb.pp("encoder_attn"))?;
289 let encoder_attn_layer_norm =
290 layer_norm(cfg.d_model, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
291 let fc1 = linear(cfg.d_model, cfg.decoder_ffn_dim, vb.pp("fc1"))?;
292 let fc2 = linear(cfg.decoder_ffn_dim, cfg.d_model, vb.pp("fc2"))?;
293 let final_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("final_layer_norm"))?;
294 Ok(Self {
295 self_attn,
296 self_attn_layer_norm,
297 activation_fn: cfg.activation_function,
298 encoder_attn,
299 encoder_attn_layer_norm,
300 fc1,
301 fc2,
302 final_layer_norm,
303 })
304 }
305
306 fn forward(
307 &mut self,
308 xs: &Tensor,
309 encoder_xs: Option<&Tensor>,
310 attn_mask: &Tensor,
311 ) -> Result<Tensor> {
312 let residual = xs;
313 let xs = (self.self_attn.forward(xs, None, Some(attn_mask))? + residual)?
314 .apply(&self.self_attn_layer_norm)?;
315 let xs = match encoder_xs {
316 None => xs,
317 Some(encoder_xs) => {
318 let residual = &xs;
319 let xs = self.encoder_attn.forward(&xs, Some(encoder_xs), None)?;
320 (residual + xs)?.apply(&self.encoder_attn_layer_norm)?
321 }
322 };
323 let residual = &xs;
324 let xs = xs
325 .apply(&self.fc1)?
326 .apply(&self.activation_fn)?
327 .apply(&self.fc2)?;
328 let xs = (xs + residual)?.apply(&self.final_layer_norm)?;
329 Ok(xs)
330 }
331
332 fn reset_kv_cache(&mut self) {
333 self.self_attn.reset_kv_cache();
334 self.encoder_attn.reset_kv_cache()
335 }
336}
337
338#[derive(Debug, Clone)]
339pub struct Encoder {
340 embed_tokens: Embedding,
341 embed_positions: SinusoidalPositionalEmbedding,
342 layers: Vec<EncoderLayer>,
343 embed_scale: Option<f64>,
344}
345
346impl Encoder {
347 fn new(cfg: &Config, embed_tokens: &Embedding, vb: VarBuilder) -> Result<Self> {
348 let embed_positions = SinusoidalPositionalEmbedding::new(cfg, vb.pp("embed_positions"))?;
349 let mut layers = Vec::with_capacity(cfg.encoder_layers);
350 let vb_l = vb.pp("layers");
351 for idx in 0..cfg.encoder_layers {
352 let layer = EncoderLayer::new(cfg, vb_l.pp(idx))?;
353 layers.push(layer)
354 }
355 let embed_scale = if cfg.scale_embedding {
356 Some((cfg.d_model as f64).sqrt())
357 } else {
358 None
359 };
360 Ok(Self {
361 embed_tokens: embed_tokens.clone(),
362 embed_positions,
363 layers,
364 embed_scale,
365 })
366 }
367
368 pub fn forward(&mut self, xs: &Tensor, past_kv_len: usize) -> Result<Tensor> {
369 let xs = xs.apply(&self.embed_tokens)?;
370 let xs = match self.embed_scale {
371 None => xs,
372 Some(scale) => (xs * scale)?,
373 };
374 let embed_pos = self
375 .embed_positions
376 .forward(&xs, past_kv_len)?
377 .unsqueeze(0)?;
378 let mut xs = xs.broadcast_add(&embed_pos)?;
379 for layer in self.layers.iter_mut() {
380 xs = layer.forward(&xs)?
381 }
382 Ok(xs)
383 }
384
385 pub fn reset_kv_cache(&mut self) {
386 for layer in self.layers.iter_mut() {
387 layer.reset_kv_cache()
388 }
389 }
390}
391
392#[derive(Debug, Clone)]
393pub struct Decoder {
394 embed_tokens: Embedding,
395 embed_positions: SinusoidalPositionalEmbedding,
396 layers: Vec<DecoderLayer>,
397 embed_scale: Option<f64>,
398}
399
400impl Decoder {
401 fn new(cfg: &Config, embed_tokens: &Embedding, vb: VarBuilder) -> Result<Self> {
402 let embed_positions = SinusoidalPositionalEmbedding::new(cfg, vb.pp("embed_positions"))?;
403 let mut layers = Vec::with_capacity(cfg.decoder_layers);
404 let vb_l = vb.pp("layers");
405 for idx in 0..cfg.decoder_layers {
406 let layer = DecoderLayer::new(cfg, vb_l.pp(idx))?;
407 layers.push(layer)
408 }
409 let embed_scale = if cfg.scale_embedding {
410 Some((cfg.d_model as f64).sqrt())
411 } else {
412 None
413 };
414 Ok(Self {
415 embed_tokens: embed_tokens.clone(),
416 embed_positions,
417 layers,
418 embed_scale,
419 })
420 }
421
422 pub fn forward(
423 &mut self,
424 xs: &Tensor,
425 encoder_xs: Option<&Tensor>,
426 past_kv_len: usize,
427 attn_mask: &Tensor,
428 ) -> Result<Tensor> {
429 let xs = xs.apply(&self.embed_tokens)?;
430 let xs = match self.embed_scale {
431 None => xs,
432 Some(scale) => (xs * scale)?,
433 };
434 let embed_pos = self
435 .embed_positions
436 .forward(&xs, past_kv_len)?
437 .unsqueeze(0)?;
438 let mut xs = xs.broadcast_add(&embed_pos)?;
439 for layer in self.layers.iter_mut() {
440 xs = layer.forward(&xs, encoder_xs, attn_mask)?;
441 }
442 Ok(xs)
443 }
444
445 pub fn reset_kv_cache(&mut self) {
446 for layer in self.layers.iter_mut() {
447 layer.reset_kv_cache()
448 }
449 }
450}
451
452#[derive(Debug, Clone)]
453struct Model {
454 shared: Embedding,
455 encoder: Encoder,
456 decoder: Decoder,
457}
458
459impl Model {
460 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
461 let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
462 let encoder = Encoder::new(cfg, &shared, vb.pp("encoder"))?;
463 let decoder = Decoder::new(cfg, &shared, vb.pp("decoder"))?;
464 Ok(Self {
465 shared,
466 encoder,
467 decoder,
468 })
469 }
470
471 fn reset_kv_cache(&mut self) {
472 self.encoder.reset_kv_cache();
473 self.decoder.reset_kv_cache();
474 }
475}
476
477#[derive(Debug, Clone)]
478pub struct MTModel {
479 model: Model,
480 lm_head: Linear,
481 final_logits_bias: Tensor,
482}
483
484impl MTModel {
485 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
486 let target_vocab_size = cfg.decoder_vocab_size.unwrap_or(cfg.vocab_size);
487 let final_logits_bias = vb.get((1, target_vocab_size), "final_logits_bias")?;
488 let model = Model::new(cfg, vb.pp("model"))?;
489 let lm_head = Linear::from_weights(model.shared.embeddings().clone(), None);
490 Ok(Self {
491 model,
492 lm_head,
493 final_logits_bias,
494 })
495 }
496
497 pub fn encoder(&mut self) -> &mut Encoder {
498 &mut self.model.encoder
499 }
500
501 pub fn decoder(&mut self) -> &mut Decoder {
502 &mut self.model.decoder
503 }
504
505 pub fn decode(
506 &mut self,
507 xs: &Tensor,
508 encoder_xs: &Tensor,
509 past_kv_len: usize,
510 ) -> Result<Tensor> {
511 let seq_len = xs.dim(1)?;
512 let mask: Vec<_> = (0..seq_len)
513 .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
514 .collect();
515 let mask = Tensor::from_vec(mask, (seq_len, seq_len), xs.device())?;
516 self.model
517 .decoder
518 .forward(xs, Some(encoder_xs), past_kv_len, &mask)?
519 .apply(&self.lm_head)?
520 .broadcast_add(&self.final_logits_bias)
521 }
522
523 pub fn reset_kv_cache(&mut self) {
524 self.model.reset_kv_cache();
525 }
526}