candle_transformers/models/
quantized_t5.rs

1//! T5 model implementation with quantization support.
2//!
3//! T5 is an encoder-decoder model pre-trained on a multi-task mixture of supervised
4//! and unsupervised tasks. This implementation provides quantization for reduced
5//! memory and compute requirements.
6//!
7//! Key characteristics:
8//! - Encoder-decoder architecture
9//! - Layer normalization
10//! - Relative positional encodings
11//! - Support for 8-bit quantization
12//!
13//! References:
14//! - 📝 [T5 Paper](https://arxiv.org/abs/1910.10683)
15//! - 🤗 [Model Card](https://huggingface.co/t5-base)
16//! - 🤗 Original model from [T5](https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py)
17
18use crate::models::t5::{deserialize_feed_forward_proj_activation, ActivationWithOptionalGating};
19use crate::models::with_tracing::QMatMul;
20use crate::quantized_nn::Embedding;
21pub use crate::quantized_var_builder::VarBuilder;
22use candle::{DType, Device, Module, Result, Tensor, D};
23use candle_nn::Activation;
24use serde::Deserialize;
25use std::sync::Arc;
26
27fn default_relative_attention_max_distance() -> usize {
28    128
29}
30
31fn default_is_decoder() -> bool {
32    false
33}
34
35fn default_use_cache() -> bool {
36    true
37}
38
39fn default_tie_word_embeddings() -> bool {
40    true
41}
42
43fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
44    let mask: Vec<_> = (0..size)
45        .flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
46        .collect();
47    Tensor::from_slice(&mask, (size, size), device)
48}
49
50fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
51    let shape = mask.shape();
52    let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
53    let m = mask.where_cond(&on_true, on_false)?;
54    Ok(m)
55}
56
57#[derive(Debug, Clone, PartialEq, Deserialize)]
58pub struct Config {
59    vocab_size: usize,
60    d_model: usize,
61    d_kv: usize,
62    d_ff: usize,
63    num_layers: usize,
64    num_decoder_layers: Option<usize>,
65    num_heads: usize,
66    relative_attention_num_buckets: usize,
67    #[serde(default = "default_relative_attention_max_distance")]
68    relative_attention_max_distance: usize,
69    dropout_rate: f64,
70    layer_norm_epsilon: f64,
71    initializer_factor: f64,
72    #[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")]
73    pub feed_forward_proj: ActivationWithOptionalGating,
74    #[serde(default = "default_tie_word_embeddings")]
75    tie_word_embeddings: bool,
76    #[serde(default = "default_is_decoder")]
77    is_decoder: bool,
78    is_encoder_decoder: bool,
79    #[serde(default = "default_use_cache")]
80    pub use_cache: bool,
81    pub pad_token_id: usize,
82    pub eos_token_id: usize,
83    pub decoder_start_token_id: Option<usize>,
84}
85
86impl Default for Config {
87    fn default() -> Self {
88        Self {
89            vocab_size: 32128,
90            d_model: 512,
91            d_kv: 64,
92            d_ff: 2048,
93            num_layers: 6,
94            num_decoder_layers: None,
95            num_heads: 8,
96            relative_attention_num_buckets: 32,
97            relative_attention_max_distance: 128,
98            dropout_rate: 0.1,
99            layer_norm_epsilon: 1e-6,
100            initializer_factor: 1.0,
101            feed_forward_proj: ActivationWithOptionalGating {
102                gated: false,
103                activation: Activation::Relu,
104            },
105            tie_word_embeddings: true,
106            is_decoder: false,
107            is_encoder_decoder: true,
108            use_cache: true,
109            pad_token_id: 0,
110            eos_token_id: 1,
111            decoder_start_token_id: Some(0),
112        }
113    }
114}
115
116#[derive(Debug, Clone)]
117struct T5LayerNorm {
118    weight: Tensor,
119    variance_epsilon: f64,
120    span: tracing::Span,
121}
122
123impl T5LayerNorm {
124    fn load(h: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
125        let weight = vb.get(h, "weight")?.dequantize(vb.device())?;
126        Ok(Self {
127            weight,
128            variance_epsilon: eps,
129            span: tracing::span!(tracing::Level::TRACE, "layer-norm"),
130        })
131    }
132}
133
134impl Module for T5LayerNorm {
135    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
136        let _enter = self.span.enter();
137        let dtype = xs.dtype();
138        let xs_f32 = xs.to_dtype(DType::F32)?;
139        // variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
140        let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
141        let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
142        let xs = xs.to_dtype(dtype)?;
143        let xs = xs.broadcast_mul(&self.weight)?;
144        Ok(xs)
145    }
146}
147
148#[derive(Debug, Clone)]
149struct T5DenseActDense {
150    wi: QMatMul,
151    wo: QMatMul,
152    act: Activation,
153    span: tracing::Span,
154}
155
156impl T5DenseActDense {
157    fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
158        let wi = QMatMul::new(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
159        let wo = QMatMul::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
160        Ok(Self {
161            wi,
162            wo,
163            act: Activation::Relu,
164            span: tracing::span!(tracing::Level::TRACE, "dense-act-dense"),
165        })
166    }
167}
168
169impl Module for T5DenseActDense {
170    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
171        let _enter = self.span.enter();
172        let xs = self.wi.forward(xs)?;
173        let xs = self.act.forward(&xs)?;
174        let xs = self.wo.forward(&xs)?;
175        Ok(xs)
176    }
177}
178
179#[derive(Debug, Clone)]
180struct T5DenseGatedActDense {
181    wi_0: QMatMul,
182    wi_1: QMatMul,
183    wo: QMatMul,
184    act: Activation,
185    span: tracing::Span,
186}
187
188impl T5DenseGatedActDense {
189    fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
190        let wi_0 = QMatMul::new(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?;
191        let wi_1 = QMatMul::new(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?;
192        let wo = QMatMul::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
193        Ok(Self {
194            wi_0,
195            wi_1,
196            wo,
197            act: cfg.feed_forward_proj.activation,
198            span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
199        })
200    }
201}
202
203impl Module for T5DenseGatedActDense {
204    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
205        let _enter = self.span.enter();
206        let hidden_gelu = self.act.forward(&self.wi_0.forward(xs)?)?;
207        let hidden_linear = self.wi_1.forward(xs)?;
208        let xs = hidden_gelu.broadcast_mul(&hidden_linear)?;
209        let xs = self.wo.forward(&xs)?;
210        Ok(xs)
211    }
212}
213
214#[derive(Debug, Clone)]
215struct T5LayerFF {
216    dense_act: Option<T5DenseActDense>,
217    gated_dense_act: Option<T5DenseGatedActDense>,
218    layer_norm: T5LayerNorm,
219    span: tracing::Span,
220}
221
222impl T5LayerFF {
223    fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
224        let layer_norm =
225            T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
226        let (dense_act, gated_dense_act) = if cfg.feed_forward_proj.gated {
227            (
228                None,
229                Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
230            )
231        } else {
232            (
233                Some(T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?),
234                None,
235            )
236        };
237        Ok(Self {
238            dense_act,
239            gated_dense_act,
240            layer_norm,
241            span: tracing::span!(tracing::Level::TRACE, "layer-ff"),
242        })
243    }
244}
245
246impl Module for T5LayerFF {
247    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
248        let _enter = self.span.enter();
249        let ys = self.layer_norm.forward(xs)?;
250        let ys = match &self.dense_act {
251            Some(dense_act) => dense_act.forward(&ys)?,
252            None => self.gated_dense_act.as_ref().unwrap().forward(&ys)?,
253        };
254        let xs = (xs + ys)?;
255        Ok(xs)
256    }
257}
258
259#[derive(Debug, Clone)]
260struct T5Attention {
261    q: QMatMul,
262    k: QMatMul,
263    v: QMatMul,
264    o: QMatMul,
265    n_heads: usize,
266    d_kv: usize,
267    relative_attention_bias: Option<Embedding>,
268    relative_attention_num_buckets: usize,
269    relative_attention_max_distance: usize,
270    inner_dim: usize,
271    use_cache: bool,
272    kv_cache: Option<(Tensor, Tensor)>,
273    span: tracing::Span,
274    span_cache: tracing::Span,
275    span_mm: tracing::Span,
276    span_sm: tracing::Span,
277}
278
279impl T5Attention {
280    fn load(
281        has_relative_attention_bias: bool,
282        decoder: bool,
283        vb: VarBuilder,
284        cfg: &Config,
285    ) -> Result<Self> {
286        let inner_dim = cfg.num_heads * cfg.d_kv;
287        let q = QMatMul::new(cfg.d_model, inner_dim, vb.pp("q"))?;
288        let k = QMatMul::new(cfg.d_model, inner_dim, vb.pp("k"))?;
289        let v = QMatMul::new(cfg.d_model, inner_dim, vb.pp("v"))?;
290        let o = QMatMul::new(inner_dim, cfg.d_model, vb.pp("o"))?;
291        let relative_attention_bias = if has_relative_attention_bias {
292            let emb = Embedding::new(
293                cfg.relative_attention_num_buckets,
294                cfg.num_heads,
295                vb.pp("relative_attention_bias"),
296            )?;
297            Some(emb)
298        } else {
299            None
300        };
301        Ok(Self {
302            q,
303            k,
304            v,
305            o,
306            n_heads: cfg.num_heads,
307            d_kv: cfg.d_kv,
308            relative_attention_bias,
309            relative_attention_num_buckets: cfg.relative_attention_num_buckets,
310            relative_attention_max_distance: cfg.relative_attention_max_distance,
311            inner_dim,
312            use_cache: cfg.use_cache && decoder,
313            kv_cache: None,
314            span: tracing::span!(tracing::Level::TRACE, "attention"),
315            span_cache: tracing::span!(tracing::Level::TRACE, "attention-cache"),
316            span_mm: tracing::span!(tracing::Level::TRACE, "attention-mm"),
317            span_sm: tracing::span!(tracing::Level::TRACE, "attention-sm"),
318        })
319    }
320
321    fn forward(
322        &mut self,
323        xs: &Tensor,
324        position_bias: Option<&Tensor>,
325        key_value_states: Option<&Tensor>,
326        mask: Option<&Tensor>,
327    ) -> Result<(Tensor, Option<Tensor>)> {
328        // Performs Self-attention (if key_value_states is None) or attention
329        // over source sentence (provided by key_value_states).
330        let _enter = self.span.enter();
331        let kv_input = match key_value_states {
332            None => xs,
333            Some(key_value_states) => key_value_states,
334        };
335        let (b_sz, q_len) = (xs.dim(0)?, xs.dim(1)?);
336        let kv_len = kv_input.dim(1)?;
337        let q = self.q.forward(xs)?;
338        let k = self.k.forward(kv_input)?;
339        let v = self.v.forward(kv_input)?;
340        let q = q
341            .reshape((b_sz, q_len, self.n_heads, self.d_kv))?
342            .transpose(1, 2)?
343            .contiguous()?;
344        let mut k = k
345            .reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
346            .transpose(1, 2)?;
347        let mut v = v
348            .reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
349            .transpose(1, 2)?;
350
351        if self.use_cache && key_value_states.is_none() {
352            let _enter = self.span_cache.enter();
353            if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache {
354                k = Tensor::cat(&[kv_cache_k, &k], 2)?;
355                v = Tensor::cat(&[kv_cache_v, &v], 2)?;
356            };
357            self.kv_cache = Some((k.clone(), v.clone()));
358        };
359        let k = k.contiguous()?;
360        let v = v.contiguous()?;
361        // TODO: Use flash_attn.
362        let scores = {
363            let _enter = self.span_mm.enter();
364            q.matmul(&k.t()?)?
365        };
366        let scores = match mask {
367            None => scores,
368            Some(mask) => masked_fill(
369                &scores,
370                &mask
371                    .unsqueeze(0)?
372                    .unsqueeze(0)?
373                    .repeat((b_sz, self.n_heads))?,
374                f32::NEG_INFINITY,
375            )?,
376        };
377
378        let (scores, position_bias) = match position_bias {
379            Some(position_bias) => (
380                scores.broadcast_add(position_bias)?,
381                Some(position_bias.clone()),
382            ),
383            None => match &self.relative_attention_bias {
384                None => (scores, None),
385                Some(relative_attention_bias) => {
386                    // This only handles the bidirectional case.
387                    let kv_len = k.dim(2)?;
388                    let (q_start, q_end) = match self.use_cache {
389                        true => ((kv_len - q_len) as u32, kv_len as u32),
390                        false => (0_u32, kv_len as u32),
391                    };
392                    let num_buckets = self.relative_attention_num_buckets as u32 / 2;
393                    let max_exact = num_buckets / 2;
394                    let relative_position = (q_start..q_end)
395                        .map(|i| {
396                            (0..kv_len as u32)
397                                .map(|j| {
398                                    if i < j {
399                                        if j - i < max_exact {
400                                            j - i + num_buckets
401                                        } else {
402                                            let b = f32::log(
403                                                (j - i) as f32 / max_exact as f32,
404                                                self.relative_attention_max_distance as f32
405                                                    / max_exact as f32,
406                                            ) * (num_buckets - max_exact) as f32;
407                                            u32::min(
408                                                max_exact + num_buckets + b as u32,
409                                                self.relative_attention_num_buckets as u32 - 1,
410                                            )
411                                        }
412                                    } else if i - j < max_exact {
413                                        i - j
414                                    } else {
415                                        let b = f32::log(
416                                            (i - j) as f32 / max_exact as f32,
417                                            self.relative_attention_max_distance as f32
418                                                / max_exact as f32,
419                                        ) * (num_buckets - max_exact) as f32;
420                                        max_exact + b as u32
421                                    }
422                                })
423                                .collect::<Vec<u32>>()
424                        })
425                        .collect::<Vec<Vec<_>>>();
426                    let relative_buckets = Tensor::new(relative_position, q.device())?;
427                    let position_bias = relative_attention_bias
428                        .forward(&relative_buckets)?
429                        .permute((2, 0, 1))?
430                        .unsqueeze(0)?;
431                    (scores.broadcast_add(&position_bias)?, Some(position_bias))
432                    // TODO: position_bias_masked?
433                }
434            },
435        };
436
437        let attn_weights = {
438            let _enter = self.span_sm.enter();
439            candle_nn::ops::softmax_last_dim(&scores)?
440        };
441        let attn_output = attn_weights.matmul(&v)?;
442        let attn_output = attn_output
443            .transpose(1, 2)?
444            .reshape((b_sz, q_len, self.inner_dim))?;
445        let attn_output = self.o.forward(&attn_output)?;
446        Ok((attn_output, position_bias))
447    }
448
449    fn clear_kv_cache(&mut self) {
450        self.kv_cache = None
451    }
452}
453
454#[derive(Debug, Clone)]
455struct T5LayerSelfAttention {
456    self_attention: T5Attention,
457    layer_norm: T5LayerNorm,
458    span: tracing::Span,
459}
460
461impl T5LayerSelfAttention {
462    fn load(h: bool, d: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
463        let self_attention = T5Attention::load(h, d, vb.pp("SelfAttention"), cfg)?;
464        let layer_norm =
465            T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
466        Ok(Self {
467            self_attention,
468            layer_norm,
469            span: tracing::span!(tracing::Level::TRACE, "self-attn"),
470        })
471    }
472
473    fn forward(
474        &mut self,
475        xs: &Tensor,
476        position_bias: Option<&Tensor>,
477        mask: Option<&Tensor>,
478    ) -> Result<(Tensor, Option<Tensor>)> {
479        let _enter = self.span.enter();
480        let normed_xs = self.layer_norm.forward(xs)?;
481        let (ys, position_bias) =
482            self.self_attention
483                .forward(&normed_xs, position_bias, None, mask)?;
484        let ys = (xs + ys)?;
485        Ok((ys, position_bias))
486    }
487
488    fn clear_kv_cache(&mut self) {
489        self.self_attention.clear_kv_cache()
490    }
491}
492
493#[derive(Debug, Clone)]
494struct T5LayerCrossAttention {
495    cross_attention: T5Attention,
496    layer_norm: T5LayerNorm,
497    span: tracing::Span,
498}
499
500impl T5LayerCrossAttention {
501    fn load(decoder: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
502        let cross_attention = T5Attention::load(false, decoder, vb.pp("EncDecAttention"), cfg)?;
503        let layer_norm =
504            T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
505        Ok(Self {
506            cross_attention,
507            layer_norm,
508            span: tracing::span!(tracing::Level::TRACE, "cross-attn"),
509        })
510    }
511
512    fn forward(
513        &mut self,
514        hidden_states: &Tensor,
515        position_bias: Option<&Tensor>,
516        key_value_states: &Tensor,
517    ) -> Result<(Tensor, Option<Tensor>)> {
518        let _enter = self.span.enter();
519        let normed_hidden_states = self.layer_norm.forward(hidden_states)?;
520        let (ys, position_bias) = self.cross_attention.forward(
521            &normed_hidden_states,
522            position_bias,
523            Some(key_value_states),
524            None,
525        )?;
526        let ys = (hidden_states + ys)?;
527        Ok((ys, position_bias))
528    }
529
530    fn clear_kv_cache(&mut self) {
531        self.cross_attention.clear_kv_cache()
532    }
533}
534
535#[derive(Debug, Clone)]
536struct T5Block {
537    self_attn: T5LayerSelfAttention,
538    cross_attn: Option<T5LayerCrossAttention>,
539    ff: T5LayerFF,
540    span: tracing::Span,
541}
542
543impl T5Block {
544    fn load(
545        has_relative_attention_bias: bool,
546        decoder: bool,
547        vb: VarBuilder,
548        cfg: &Config,
549    ) -> Result<Self> {
550        let vb = vb.pp("layer");
551        let self_attn =
552            T5LayerSelfAttention::load(has_relative_attention_bias, decoder, vb.pp("0"), cfg)?;
553        let cross_attn = if cfg.is_decoder {
554            Some(T5LayerCrossAttention::load(decoder, vb.pp("1"), cfg)?)
555        } else {
556            None
557        };
558        let ff_i = if cross_attn.is_some() { 2 } else { 1 };
559        let ff = T5LayerFF::load(vb.pp(ff_i), cfg)?;
560        Ok(Self {
561            self_attn,
562            cross_attn,
563            ff,
564            span: tracing::span!(tracing::Level::TRACE, "block"),
565        })
566    }
567
568    fn forward(
569        &mut self,
570        xs: &Tensor,
571        position_bias: Option<&Tensor>,
572        encoder_hidden_states: Option<&Tensor>,
573    ) -> Result<(Tensor, Option<Tensor>)> {
574        let _enter = self.span.enter();
575        // TODO: Cache masks
576        let mask = match self.cross_attn.is_some() {
577            true => {
578                let mask_len = xs.dim(1)?;
579                // If the input seq length is 1, no need for a mask, this is also helpful to avoid shape
580                // issues when using the KV cache in the decoder.
581                if mask_len <= 1 {
582                    None
583                } else {
584                    Some(get_mask(mask_len, xs.device())?)
585                }
586            }
587            false => None,
588        };
589        let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?;
590        // TODO: clamp for f16?
591        if let Some(cross_attn) = &mut self.cross_attn {
592            (xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?;
593            // TODO: clamp for f16?
594        }
595        let xs = self.ff.forward(&xs)?;
596        // TODO: clamp for f16?
597        Ok((xs, position_bias))
598    }
599
600    fn clear_kv_cache(&mut self) {
601        self.self_attn.clear_kv_cache();
602        self.cross_attn.iter_mut().for_each(|c| c.clear_kv_cache());
603    }
604}
605
606#[derive(Debug, Clone)]
607struct T5Stack {
608    block: Vec<T5Block>,
609    shared: Arc<Embedding>,
610    final_layer_norm: T5LayerNorm,
611    span: tracing::Span,
612}
613
614impl T5Stack {
615    fn load(decoder: bool, vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> {
616        let block = (0..cfg.num_layers)
617            .map(|i| T5Block::load(i == 0, decoder, vb.pp(format!("block.{i}")), cfg))
618            .collect::<Result<Vec<_>>>()?;
619        let final_layer_norm = T5LayerNorm::load(
620            cfg.d_model,
621            cfg.layer_norm_epsilon,
622            vb.pp("final_layer_norm"),
623        )?;
624        Ok(Self {
625            block,
626            shared: shared.clone(),
627            final_layer_norm,
628            span: tracing::span!(tracing::Level::TRACE, "stack"),
629        })
630    }
631
632    fn forward(
633        &mut self,
634        input_ids: &Tensor,
635        encoder_hidden_states: Option<&Tensor>,
636    ) -> Result<Tensor> {
637        let _enter = self.span.enter();
638        let input_embeds = self.shared.as_ref().forward(input_ids)?;
639        let mut hidden_states = input_embeds;
640        let mut position_bias = None;
641        for block in self.block.iter_mut() {
642            (hidden_states, position_bias) = block.forward(
643                &hidden_states,
644                position_bias.as_ref(),
645                encoder_hidden_states,
646            )?
647        }
648        self.final_layer_norm.forward(&hidden_states)
649    }
650
651    fn clear_kv_cache(&mut self) {
652        self.block.iter_mut().for_each(|b| b.clear_kv_cache())
653    }
654}
655
656#[derive(Debug, Clone)]
657pub struct T5EncoderModel {
658    encoder: T5Stack,
659    device: Device,
660    span: tracing::Span,
661}
662
663impl T5EncoderModel {
664    pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
665        let shared_vb = if vb.contains_key("shared.weight") {
666            vb.pp("shared")
667        } else {
668            vb.pp("decoder").pp("embed_tokens")
669        };
670        let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
671        let shared = Arc::new(shared);
672        let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
673        Ok(Self {
674            encoder,
675            device: vb.device().clone(),
676            span: tracing::span!(tracing::Level::TRACE, "encoder"),
677        })
678    }
679
680    pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
681        let _enter = self.span.enter();
682        self.encoder.forward(input_ids, None)
683    }
684
685    pub fn device(&self) -> &Device {
686        &self.device
687    }
688
689    pub fn clear_kv_cache(&mut self) {
690        self.encoder.clear_kv_cache()
691    }
692}
693
694#[derive(Debug, Clone)]
695pub struct T5ForConditionalGeneration {
696    encoder: T5Stack,
697    decoder: T5Stack,
698    d_model: usize,
699    tie_word_embeddings: bool,
700    lm_head: Option<QMatMul>,
701    shared: Arc<Embedding>,
702    device: Device,
703    span_decode: tracing::Span,
704    span_decode_head: tracing::Span,
705}
706
707impl T5ForConditionalGeneration {
708    pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
709        assert!(cfg.is_encoder_decoder);
710        let d_model = cfg.d_model;
711        let shared_vb = if vb.contains_key("shared.weight") {
712            vb.pp("shared")
713        } else {
714            vb.pp("decoder").pp("embed_tokens")
715        };
716        let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
717        let shared = Arc::new(shared);
718
719        let mut encoder_cfg = cfg.clone();
720        encoder_cfg.is_decoder = false;
721        encoder_cfg.use_cache = false;
722        encoder_cfg.is_encoder_decoder = false;
723        let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, &encoder_cfg)?;
724
725        let mut decoder_cfg = cfg.clone();
726        decoder_cfg.is_decoder = true;
727        decoder_cfg.is_encoder_decoder = false;
728        decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers);
729        let decoder = T5Stack::load(true, vb.pp("decoder"), &shared, &decoder_cfg)?;
730
731        let tie_word_embeddings = cfg.tie_word_embeddings;
732        let lm_head = if tie_word_embeddings {
733            None
734        } else {
735            Some(QMatMul::new(cfg.d_model, cfg.vocab_size, vb.pp("lm_head"))?)
736        };
737
738        Ok(Self {
739            encoder,
740            decoder,
741            d_model,
742            tie_word_embeddings,
743            lm_head,
744            shared,
745            device: vb.device().clone(),
746            span_decode: tracing::span!(tracing::Level::TRACE, "decode"),
747            span_decode_head: tracing::span!(tracing::Level::TRACE, "decode-head"),
748        })
749    }
750
751    pub fn encode(&mut self, input_ids: &Tensor) -> Result<Tensor> {
752        self.encoder.forward(input_ids, None)
753    }
754
755    pub fn decode(
756        &mut self,
757        decoder_input_ids: &Tensor,
758        encoder_output: &Tensor,
759    ) -> Result<Tensor> {
760        let _enter = self.span_decode.enter();
761        let decoder_output = self
762            .decoder
763            .forward(decoder_input_ids, Some(encoder_output))?;
764
765        let scaling_factor = if self.tie_word_embeddings {
766            // Rescale output before projecting on vocab
767            // See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
768            (self.d_model as f64).sqrt()
769        } else {
770            1.0
771        };
772        let sequence_output = ((decoder_output
773            .narrow(1, decoder_output.dim(1)? - 1, 1)?
774            .squeeze(1)?)
775            * scaling_factor)?;
776        let output = {
777            let _enter = self.span_decode_head.enter();
778            match self.lm_head {
779                None => sequence_output.matmul(&self.shared.embeddings().t()?)?,
780                Some(ref lm_head) => lm_head.forward(&sequence_output)?,
781            }
782        };
783        Ok(output)
784    }
785
786    pub fn forward(&mut self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> {
787        let encoder_output = self.encode(input_ids)?;
788        self.decode(decoder_input_ids, &encoder_output)
789    }
790
791    pub fn device(&self) -> &Device {
792        &self.device
793    }
794
795    pub fn clear_kv_cache(&mut self) {
796        self.encoder.clear_kv_cache();
797        self.decoder.clear_kv_cache();
798    }
799}