1use crate::models::t5::{deserialize_feed_forward_proj_activation, ActivationWithOptionalGating};
19use crate::models::with_tracing::QMatMul;
20use crate::quantized_nn::Embedding;
21pub use crate::quantized_var_builder::VarBuilder;
22use candle::{DType, Device, Module, Result, Tensor, D};
23use candle_nn::Activation;
24use serde::Deserialize;
25use std::sync::Arc;
26
27fn default_relative_attention_max_distance() -> usize {
28 128
29}
30
31fn default_is_decoder() -> bool {
32 false
33}
34
35fn default_use_cache() -> bool {
36 true
37}
38
39fn default_tie_word_embeddings() -> bool {
40 true
41}
42
43fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
44 let mask: Vec<_> = (0..size)
45 .flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
46 .collect();
47 Tensor::from_slice(&mask, (size, size), device)
48}
49
50fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
51 let shape = mask.shape();
52 let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
53 let m = mask.where_cond(&on_true, on_false)?;
54 Ok(m)
55}
56
57#[derive(Debug, Clone, PartialEq, Deserialize)]
58pub struct Config {
59 vocab_size: usize,
60 d_model: usize,
61 d_kv: usize,
62 d_ff: usize,
63 num_layers: usize,
64 num_decoder_layers: Option<usize>,
65 num_heads: usize,
66 relative_attention_num_buckets: usize,
67 #[serde(default = "default_relative_attention_max_distance")]
68 relative_attention_max_distance: usize,
69 dropout_rate: f64,
70 layer_norm_epsilon: f64,
71 initializer_factor: f64,
72 #[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")]
73 pub feed_forward_proj: ActivationWithOptionalGating,
74 #[serde(default = "default_tie_word_embeddings")]
75 tie_word_embeddings: bool,
76 #[serde(default = "default_is_decoder")]
77 is_decoder: bool,
78 is_encoder_decoder: bool,
79 #[serde(default = "default_use_cache")]
80 pub use_cache: bool,
81 pub pad_token_id: usize,
82 pub eos_token_id: usize,
83 pub decoder_start_token_id: Option<usize>,
84}
85
86impl Default for Config {
87 fn default() -> Self {
88 Self {
89 vocab_size: 32128,
90 d_model: 512,
91 d_kv: 64,
92 d_ff: 2048,
93 num_layers: 6,
94 num_decoder_layers: None,
95 num_heads: 8,
96 relative_attention_num_buckets: 32,
97 relative_attention_max_distance: 128,
98 dropout_rate: 0.1,
99 layer_norm_epsilon: 1e-6,
100 initializer_factor: 1.0,
101 feed_forward_proj: ActivationWithOptionalGating {
102 gated: false,
103 activation: Activation::Relu,
104 },
105 tie_word_embeddings: true,
106 is_decoder: false,
107 is_encoder_decoder: true,
108 use_cache: true,
109 pad_token_id: 0,
110 eos_token_id: 1,
111 decoder_start_token_id: Some(0),
112 }
113 }
114}
115
116#[derive(Debug, Clone)]
117struct T5LayerNorm {
118 weight: Tensor,
119 variance_epsilon: f64,
120 span: tracing::Span,
121}
122
123impl T5LayerNorm {
124 fn load(h: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
125 let weight = vb.get(h, "weight")?.dequantize(vb.device())?;
126 Ok(Self {
127 weight,
128 variance_epsilon: eps,
129 span: tracing::span!(tracing::Level::TRACE, "layer-norm"),
130 })
131 }
132}
133
134impl Module for T5LayerNorm {
135 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
136 let _enter = self.span.enter();
137 let dtype = xs.dtype();
138 let xs_f32 = xs.to_dtype(DType::F32)?;
139 let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
141 let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
142 let xs = xs.to_dtype(dtype)?;
143 let xs = xs.broadcast_mul(&self.weight)?;
144 Ok(xs)
145 }
146}
147
148#[derive(Debug, Clone)]
149struct T5DenseActDense {
150 wi: QMatMul,
151 wo: QMatMul,
152 act: Activation,
153 span: tracing::Span,
154}
155
156impl T5DenseActDense {
157 fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
158 let wi = QMatMul::new(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
159 let wo = QMatMul::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
160 Ok(Self {
161 wi,
162 wo,
163 act: Activation::Relu,
164 span: tracing::span!(tracing::Level::TRACE, "dense-act-dense"),
165 })
166 }
167}
168
169impl Module for T5DenseActDense {
170 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
171 let _enter = self.span.enter();
172 let xs = self.wi.forward(xs)?;
173 let xs = self.act.forward(&xs)?;
174 let xs = self.wo.forward(&xs)?;
175 Ok(xs)
176 }
177}
178
179#[derive(Debug, Clone)]
180struct T5DenseGatedActDense {
181 wi_0: QMatMul,
182 wi_1: QMatMul,
183 wo: QMatMul,
184 act: Activation,
185 span: tracing::Span,
186}
187
188impl T5DenseGatedActDense {
189 fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
190 let wi_0 = QMatMul::new(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?;
191 let wi_1 = QMatMul::new(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?;
192 let wo = QMatMul::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
193 Ok(Self {
194 wi_0,
195 wi_1,
196 wo,
197 act: cfg.feed_forward_proj.activation,
198 span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
199 })
200 }
201}
202
203impl Module for T5DenseGatedActDense {
204 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
205 let _enter = self.span.enter();
206 let hidden_gelu = self.act.forward(&self.wi_0.forward(xs)?)?;
207 let hidden_linear = self.wi_1.forward(xs)?;
208 let xs = hidden_gelu.broadcast_mul(&hidden_linear)?;
209 let xs = self.wo.forward(&xs)?;
210 Ok(xs)
211 }
212}
213
214#[derive(Debug, Clone)]
215struct T5LayerFF {
216 dense_act: Option<T5DenseActDense>,
217 gated_dense_act: Option<T5DenseGatedActDense>,
218 layer_norm: T5LayerNorm,
219 span: tracing::Span,
220}
221
222impl T5LayerFF {
223 fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
224 let layer_norm =
225 T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
226 let (dense_act, gated_dense_act) = if cfg.feed_forward_proj.gated {
227 (
228 None,
229 Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
230 )
231 } else {
232 (
233 Some(T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?),
234 None,
235 )
236 };
237 Ok(Self {
238 dense_act,
239 gated_dense_act,
240 layer_norm,
241 span: tracing::span!(tracing::Level::TRACE, "layer-ff"),
242 })
243 }
244}
245
246impl Module for T5LayerFF {
247 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
248 let _enter = self.span.enter();
249 let ys = self.layer_norm.forward(xs)?;
250 let ys = match &self.dense_act {
251 Some(dense_act) => dense_act.forward(&ys)?,
252 None => self.gated_dense_act.as_ref().unwrap().forward(&ys)?,
253 };
254 let xs = (xs + ys)?;
255 Ok(xs)
256 }
257}
258
259#[derive(Debug, Clone)]
260struct T5Attention {
261 q: QMatMul,
262 k: QMatMul,
263 v: QMatMul,
264 o: QMatMul,
265 n_heads: usize,
266 d_kv: usize,
267 relative_attention_bias: Option<Embedding>,
268 relative_attention_num_buckets: usize,
269 relative_attention_max_distance: usize,
270 inner_dim: usize,
271 use_cache: bool,
272 kv_cache: Option<(Tensor, Tensor)>,
273 span: tracing::Span,
274 span_cache: tracing::Span,
275 span_mm: tracing::Span,
276 span_sm: tracing::Span,
277}
278
279impl T5Attention {
280 fn load(
281 has_relative_attention_bias: bool,
282 decoder: bool,
283 vb: VarBuilder,
284 cfg: &Config,
285 ) -> Result<Self> {
286 let inner_dim = cfg.num_heads * cfg.d_kv;
287 let q = QMatMul::new(cfg.d_model, inner_dim, vb.pp("q"))?;
288 let k = QMatMul::new(cfg.d_model, inner_dim, vb.pp("k"))?;
289 let v = QMatMul::new(cfg.d_model, inner_dim, vb.pp("v"))?;
290 let o = QMatMul::new(inner_dim, cfg.d_model, vb.pp("o"))?;
291 let relative_attention_bias = if has_relative_attention_bias {
292 let emb = Embedding::new(
293 cfg.relative_attention_num_buckets,
294 cfg.num_heads,
295 vb.pp("relative_attention_bias"),
296 )?;
297 Some(emb)
298 } else {
299 None
300 };
301 Ok(Self {
302 q,
303 k,
304 v,
305 o,
306 n_heads: cfg.num_heads,
307 d_kv: cfg.d_kv,
308 relative_attention_bias,
309 relative_attention_num_buckets: cfg.relative_attention_num_buckets,
310 relative_attention_max_distance: cfg.relative_attention_max_distance,
311 inner_dim,
312 use_cache: cfg.use_cache && decoder,
313 kv_cache: None,
314 span: tracing::span!(tracing::Level::TRACE, "attention"),
315 span_cache: tracing::span!(tracing::Level::TRACE, "attention-cache"),
316 span_mm: tracing::span!(tracing::Level::TRACE, "attention-mm"),
317 span_sm: tracing::span!(tracing::Level::TRACE, "attention-sm"),
318 })
319 }
320
321 fn forward(
322 &mut self,
323 xs: &Tensor,
324 position_bias: Option<&Tensor>,
325 key_value_states: Option<&Tensor>,
326 mask: Option<&Tensor>,
327 ) -> Result<(Tensor, Option<Tensor>)> {
328 let _enter = self.span.enter();
331 let kv_input = match key_value_states {
332 None => xs,
333 Some(key_value_states) => key_value_states,
334 };
335 let (b_sz, q_len) = (xs.dim(0)?, xs.dim(1)?);
336 let kv_len = kv_input.dim(1)?;
337 let q = self.q.forward(xs)?;
338 let k = self.k.forward(kv_input)?;
339 let v = self.v.forward(kv_input)?;
340 let q = q
341 .reshape((b_sz, q_len, self.n_heads, self.d_kv))?
342 .transpose(1, 2)?
343 .contiguous()?;
344 let mut k = k
345 .reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
346 .transpose(1, 2)?;
347 let mut v = v
348 .reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
349 .transpose(1, 2)?;
350
351 if self.use_cache && key_value_states.is_none() {
352 let _enter = self.span_cache.enter();
353 if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache {
354 k = Tensor::cat(&[kv_cache_k, &k], 2)?;
355 v = Tensor::cat(&[kv_cache_v, &v], 2)?;
356 };
357 self.kv_cache = Some((k.clone(), v.clone()));
358 };
359 let k = k.contiguous()?;
360 let v = v.contiguous()?;
361 let scores = {
363 let _enter = self.span_mm.enter();
364 q.matmul(&k.t()?)?
365 };
366 let scores = match mask {
367 None => scores,
368 Some(mask) => masked_fill(
369 &scores,
370 &mask
371 .unsqueeze(0)?
372 .unsqueeze(0)?
373 .repeat((b_sz, self.n_heads))?,
374 f32::NEG_INFINITY,
375 )?,
376 };
377
378 let (scores, position_bias) = match position_bias {
379 Some(position_bias) => (
380 scores.broadcast_add(position_bias)?,
381 Some(position_bias.clone()),
382 ),
383 None => match &self.relative_attention_bias {
384 None => (scores, None),
385 Some(relative_attention_bias) => {
386 let kv_len = k.dim(2)?;
388 let (q_start, q_end) = match self.use_cache {
389 true => ((kv_len - q_len) as u32, kv_len as u32),
390 false => (0_u32, kv_len as u32),
391 };
392 let num_buckets = self.relative_attention_num_buckets as u32 / 2;
393 let max_exact = num_buckets / 2;
394 let relative_position = (q_start..q_end)
395 .map(|i| {
396 (0..kv_len as u32)
397 .map(|j| {
398 if i < j {
399 if j - i < max_exact {
400 j - i + num_buckets
401 } else {
402 let b = f32::log(
403 (j - i) as f32 / max_exact as f32,
404 self.relative_attention_max_distance as f32
405 / max_exact as f32,
406 ) * (num_buckets - max_exact) as f32;
407 u32::min(
408 max_exact + num_buckets + b as u32,
409 self.relative_attention_num_buckets as u32 - 1,
410 )
411 }
412 } else if i - j < max_exact {
413 i - j
414 } else {
415 let b = f32::log(
416 (i - j) as f32 / max_exact as f32,
417 self.relative_attention_max_distance as f32
418 / max_exact as f32,
419 ) * (num_buckets - max_exact) as f32;
420 max_exact + b as u32
421 }
422 })
423 .collect::<Vec<u32>>()
424 })
425 .collect::<Vec<Vec<_>>>();
426 let relative_buckets = Tensor::new(relative_position, q.device())?;
427 let position_bias = relative_attention_bias
428 .forward(&relative_buckets)?
429 .permute((2, 0, 1))?
430 .unsqueeze(0)?;
431 (scores.broadcast_add(&position_bias)?, Some(position_bias))
432 }
434 },
435 };
436
437 let attn_weights = {
438 let _enter = self.span_sm.enter();
439 candle_nn::ops::softmax_last_dim(&scores)?
440 };
441 let attn_output = attn_weights.matmul(&v)?;
442 let attn_output = attn_output
443 .transpose(1, 2)?
444 .reshape((b_sz, q_len, self.inner_dim))?;
445 let attn_output = self.o.forward(&attn_output)?;
446 Ok((attn_output, position_bias))
447 }
448
449 fn clear_kv_cache(&mut self) {
450 self.kv_cache = None
451 }
452}
453
454#[derive(Debug, Clone)]
455struct T5LayerSelfAttention {
456 self_attention: T5Attention,
457 layer_norm: T5LayerNorm,
458 span: tracing::Span,
459}
460
461impl T5LayerSelfAttention {
462 fn load(h: bool, d: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
463 let self_attention = T5Attention::load(h, d, vb.pp("SelfAttention"), cfg)?;
464 let layer_norm =
465 T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
466 Ok(Self {
467 self_attention,
468 layer_norm,
469 span: tracing::span!(tracing::Level::TRACE, "self-attn"),
470 })
471 }
472
473 fn forward(
474 &mut self,
475 xs: &Tensor,
476 position_bias: Option<&Tensor>,
477 mask: Option<&Tensor>,
478 ) -> Result<(Tensor, Option<Tensor>)> {
479 let _enter = self.span.enter();
480 let normed_xs = self.layer_norm.forward(xs)?;
481 let (ys, position_bias) =
482 self.self_attention
483 .forward(&normed_xs, position_bias, None, mask)?;
484 let ys = (xs + ys)?;
485 Ok((ys, position_bias))
486 }
487
488 fn clear_kv_cache(&mut self) {
489 self.self_attention.clear_kv_cache()
490 }
491}
492
493#[derive(Debug, Clone)]
494struct T5LayerCrossAttention {
495 cross_attention: T5Attention,
496 layer_norm: T5LayerNorm,
497 span: tracing::Span,
498}
499
500impl T5LayerCrossAttention {
501 fn load(decoder: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
502 let cross_attention = T5Attention::load(false, decoder, vb.pp("EncDecAttention"), cfg)?;
503 let layer_norm =
504 T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
505 Ok(Self {
506 cross_attention,
507 layer_norm,
508 span: tracing::span!(tracing::Level::TRACE, "cross-attn"),
509 })
510 }
511
512 fn forward(
513 &mut self,
514 hidden_states: &Tensor,
515 position_bias: Option<&Tensor>,
516 key_value_states: &Tensor,
517 ) -> Result<(Tensor, Option<Tensor>)> {
518 let _enter = self.span.enter();
519 let normed_hidden_states = self.layer_norm.forward(hidden_states)?;
520 let (ys, position_bias) = self.cross_attention.forward(
521 &normed_hidden_states,
522 position_bias,
523 Some(key_value_states),
524 None,
525 )?;
526 let ys = (hidden_states + ys)?;
527 Ok((ys, position_bias))
528 }
529
530 fn clear_kv_cache(&mut self) {
531 self.cross_attention.clear_kv_cache()
532 }
533}
534
535#[derive(Debug, Clone)]
536struct T5Block {
537 self_attn: T5LayerSelfAttention,
538 cross_attn: Option<T5LayerCrossAttention>,
539 ff: T5LayerFF,
540 span: tracing::Span,
541}
542
543impl T5Block {
544 fn load(
545 has_relative_attention_bias: bool,
546 decoder: bool,
547 vb: VarBuilder,
548 cfg: &Config,
549 ) -> Result<Self> {
550 let vb = vb.pp("layer");
551 let self_attn =
552 T5LayerSelfAttention::load(has_relative_attention_bias, decoder, vb.pp("0"), cfg)?;
553 let cross_attn = if cfg.is_decoder {
554 Some(T5LayerCrossAttention::load(decoder, vb.pp("1"), cfg)?)
555 } else {
556 None
557 };
558 let ff_i = if cross_attn.is_some() { 2 } else { 1 };
559 let ff = T5LayerFF::load(vb.pp(ff_i), cfg)?;
560 Ok(Self {
561 self_attn,
562 cross_attn,
563 ff,
564 span: tracing::span!(tracing::Level::TRACE, "block"),
565 })
566 }
567
568 fn forward(
569 &mut self,
570 xs: &Tensor,
571 position_bias: Option<&Tensor>,
572 encoder_hidden_states: Option<&Tensor>,
573 ) -> Result<(Tensor, Option<Tensor>)> {
574 let _enter = self.span.enter();
575 let mask = match self.cross_attn.is_some() {
577 true => {
578 let mask_len = xs.dim(1)?;
579 if mask_len <= 1 {
582 None
583 } else {
584 Some(get_mask(mask_len, xs.device())?)
585 }
586 }
587 false => None,
588 };
589 let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?;
590 if let Some(cross_attn) = &mut self.cross_attn {
592 (xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?;
593 }
595 let xs = self.ff.forward(&xs)?;
596 Ok((xs, position_bias))
598 }
599
600 fn clear_kv_cache(&mut self) {
601 self.self_attn.clear_kv_cache();
602 self.cross_attn.iter_mut().for_each(|c| c.clear_kv_cache());
603 }
604}
605
606#[derive(Debug, Clone)]
607struct T5Stack {
608 block: Vec<T5Block>,
609 shared: Arc<Embedding>,
610 final_layer_norm: T5LayerNorm,
611 span: tracing::Span,
612}
613
614impl T5Stack {
615 fn load(decoder: bool, vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> {
616 let block = (0..cfg.num_layers)
617 .map(|i| T5Block::load(i == 0, decoder, vb.pp(format!("block.{i}")), cfg))
618 .collect::<Result<Vec<_>>>()?;
619 let final_layer_norm = T5LayerNorm::load(
620 cfg.d_model,
621 cfg.layer_norm_epsilon,
622 vb.pp("final_layer_norm"),
623 )?;
624 Ok(Self {
625 block,
626 shared: shared.clone(),
627 final_layer_norm,
628 span: tracing::span!(tracing::Level::TRACE, "stack"),
629 })
630 }
631
632 fn forward(
633 &mut self,
634 input_ids: &Tensor,
635 encoder_hidden_states: Option<&Tensor>,
636 ) -> Result<Tensor> {
637 let _enter = self.span.enter();
638 let input_embeds = self.shared.as_ref().forward(input_ids)?;
639 let mut hidden_states = input_embeds;
640 let mut position_bias = None;
641 for block in self.block.iter_mut() {
642 (hidden_states, position_bias) = block.forward(
643 &hidden_states,
644 position_bias.as_ref(),
645 encoder_hidden_states,
646 )?
647 }
648 self.final_layer_norm.forward(&hidden_states)
649 }
650
651 fn clear_kv_cache(&mut self) {
652 self.block.iter_mut().for_each(|b| b.clear_kv_cache())
653 }
654}
655
656#[derive(Debug, Clone)]
657pub struct T5EncoderModel {
658 encoder: T5Stack,
659 device: Device,
660 span: tracing::Span,
661}
662
663impl T5EncoderModel {
664 pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
665 let shared_vb = if vb.contains_key("shared.weight") {
666 vb.pp("shared")
667 } else {
668 vb.pp("decoder").pp("embed_tokens")
669 };
670 let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
671 let shared = Arc::new(shared);
672 let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
673 Ok(Self {
674 encoder,
675 device: vb.device().clone(),
676 span: tracing::span!(tracing::Level::TRACE, "encoder"),
677 })
678 }
679
680 pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
681 let _enter = self.span.enter();
682 self.encoder.forward(input_ids, None)
683 }
684
685 pub fn device(&self) -> &Device {
686 &self.device
687 }
688
689 pub fn clear_kv_cache(&mut self) {
690 self.encoder.clear_kv_cache()
691 }
692}
693
694#[derive(Debug, Clone)]
695pub struct T5ForConditionalGeneration {
696 encoder: T5Stack,
697 decoder: T5Stack,
698 d_model: usize,
699 tie_word_embeddings: bool,
700 lm_head: Option<QMatMul>,
701 shared: Arc<Embedding>,
702 device: Device,
703 span_decode: tracing::Span,
704 span_decode_head: tracing::Span,
705}
706
707impl T5ForConditionalGeneration {
708 pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
709 assert!(cfg.is_encoder_decoder);
710 let d_model = cfg.d_model;
711 let shared_vb = if vb.contains_key("shared.weight") {
712 vb.pp("shared")
713 } else {
714 vb.pp("decoder").pp("embed_tokens")
715 };
716 let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
717 let shared = Arc::new(shared);
718
719 let mut encoder_cfg = cfg.clone();
720 encoder_cfg.is_decoder = false;
721 encoder_cfg.use_cache = false;
722 encoder_cfg.is_encoder_decoder = false;
723 let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, &encoder_cfg)?;
724
725 let mut decoder_cfg = cfg.clone();
726 decoder_cfg.is_decoder = true;
727 decoder_cfg.is_encoder_decoder = false;
728 decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers);
729 let decoder = T5Stack::load(true, vb.pp("decoder"), &shared, &decoder_cfg)?;
730
731 let tie_word_embeddings = cfg.tie_word_embeddings;
732 let lm_head = if tie_word_embeddings {
733 None
734 } else {
735 Some(QMatMul::new(cfg.d_model, cfg.vocab_size, vb.pp("lm_head"))?)
736 };
737
738 Ok(Self {
739 encoder,
740 decoder,
741 d_model,
742 tie_word_embeddings,
743 lm_head,
744 shared,
745 device: vb.device().clone(),
746 span_decode: tracing::span!(tracing::Level::TRACE, "decode"),
747 span_decode_head: tracing::span!(tracing::Level::TRACE, "decode-head"),
748 })
749 }
750
751 pub fn encode(&mut self, input_ids: &Tensor) -> Result<Tensor> {
752 self.encoder.forward(input_ids, None)
753 }
754
755 pub fn decode(
756 &mut self,
757 decoder_input_ids: &Tensor,
758 encoder_output: &Tensor,
759 ) -> Result<Tensor> {
760 let _enter = self.span_decode.enter();
761 let decoder_output = self
762 .decoder
763 .forward(decoder_input_ids, Some(encoder_output))?;
764
765 let scaling_factor = if self.tie_word_embeddings {
766 (self.d_model as f64).sqrt()
769 } else {
770 1.0
771 };
772 let sequence_output = ((decoder_output
773 .narrow(1, decoder_output.dim(1)? - 1, 1)?
774 .squeeze(1)?)
775 * scaling_factor)?;
776 let output = {
777 let _enter = self.span_decode_head.enter();
778 match self.lm_head {
779 None => sequence_output.matmul(&self.shared.embeddings().t()?)?,
780 Some(ref lm_head) => lm_head.forward(&sequence_output)?,
781 }
782 };
783 Ok(output)
784 }
785
786 pub fn forward(&mut self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> {
787 let encoder_output = self.encode(input_ids)?;
788 self.decode(decoder_input_ids, &encoder_output)
789 }
790
791 pub fn device(&self) -> &Device {
792 &self.device
793 }
794
795 pub fn clear_kv_cache(&mut self) {
796 self.encoder.clear_kv_cache();
797 self.decoder.clear_kv_cache();
798 }
799}