candle_transformers/models/mimi/
transformer.rs

1// Copyright (c) Kyutai, all rights reserved.
2// This source code is licensed under the license found in the
3// LICENSE file in the root directory of this source tree.
4
5use candle::{DType, Device, IndexOp, Module, Result, StreamTensor, StreamingModule, Tensor, D};
6use candle_nn::{linear_no_bias, Linear, VarBuilder};
7use std::sync::Arc;
8
9fn linear(in_d: usize, out_d: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
10    if bias {
11        candle_nn::linear(in_d, out_d, vb)
12    } else {
13        linear_no_bias(in_d, out_d, vb)
14    }
15}
16
17#[derive(Debug, Copy, Clone, PartialEq, Eq)]
18pub enum PositionalEmbedding {
19    Rope,
20    Sin,
21    None,
22}
23
24#[derive(Debug, Clone)]
25pub struct Config {
26    pub d_model: usize,
27    pub num_heads: usize,
28    pub num_layers: usize,
29    pub causal: bool,
30    pub norm_first: bool,
31    pub bias_ff: bool,
32    pub bias_attn: bool,
33    pub layer_scale: Option<f64>,
34    pub positional_embedding: PositionalEmbedding,
35    pub use_conv_block: bool,
36    pub cross_attention: bool,
37    pub conv_kernel_size: usize,
38    pub use_conv_bias: bool,
39    pub gating: Option<candle_nn::Activation>,
40    pub norm: super::NormType,
41    pub context: usize,
42    pub max_period: usize,
43    pub max_seq_len: usize,
44
45    pub kv_repeat: usize,
46    pub dim_feedforward: usize,
47    pub conv_layout: bool,
48}
49
50#[derive(Debug, Clone)]
51pub struct RotaryEmbedding {
52    sin: Tensor,
53    cos: Tensor,
54    span: tracing::Span,
55}
56
57impl RotaryEmbedding {
58    pub fn new(dim: usize, max_seq_len: usize, theta: f32, dev: &Device) -> Result<Self> {
59        let inv_freq: Vec<_> = (0..dim)
60            .step_by(2)
61            .map(|i| 1f32 / theta.powf(i as f32 / dim as f32))
62            .collect();
63        let inv_freq_len = inv_freq.len();
64        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
65        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
66            .to_dtype(DType::F32)?
67            .reshape((max_seq_len, 1))?;
68        let freqs = t.matmul(&inv_freq)?;
69        Ok(Self {
70            sin: freqs.sin()?,
71            cos: freqs.cos()?,
72            span: tracing::span!(tracing::Level::TRACE, "rot"),
73        })
74    }
75
76    pub fn apply_rotary_emb(&self, qk: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
77        let _enter = self.span.enter();
78        let (_b_size, _nheads, seqlen, _headdim) = qk.dims4()?;
79        let qk_dtype = qk.dtype();
80        let c = self.cos.narrow(0, seqlen_offset, seqlen)?;
81        let s = self.sin.narrow(0, seqlen_offset, seqlen)?;
82        candle_nn::rotary_emb::rope_i(&qk.to_dtype(DType::F32)?, &c, &s)?.to_dtype(qk_dtype)
83    }
84}
85
86#[derive(Debug, Clone)]
87pub struct LayerScale {
88    scale: Tensor,
89}
90
91impl LayerScale {
92    pub fn new(d_model: usize, _init: f64, vb: VarBuilder) -> Result<Self> {
93        let scale = vb.get(d_model, "scale")?;
94        Ok(Self { scale })
95    }
96}
97
98impl Module for LayerScale {
99    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
100        xs.broadcast_mul(&self.scale)
101    }
102}
103
104#[derive(Debug, Clone)]
105pub struct StreamingMultiheadAttention {
106    q_proj: Linear,
107    k_proj: Linear,
108    v_proj: Linear,
109    out_proj: Linear,
110    kv_repeat: usize,
111    num_heads: usize,
112    context: usize,
113    neg_inf: Tensor,
114    rope: Option<Arc<RotaryEmbedding>>,
115    kv_cache: candle_nn::kv_cache::RotatingKvCache,
116    pos: usize,
117    use_flash_attn: bool,
118    span: tracing::Span,
119}
120
121impl StreamingMultiheadAttention {
122    pub fn new(rope: &Option<Arc<RotaryEmbedding>>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
123        let embed_dim = cfg.d_model;
124        let num_kv = cfg.num_heads / cfg.kv_repeat;
125        let kv_dim = num_kv * (embed_dim / cfg.num_heads);
126        let q_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("q_proj"))?;
127        let k_proj = linear(embed_dim, kv_dim, cfg.bias_attn, vb.pp("k_proj"))?;
128        let v_proj = linear(embed_dim, kv_dim, cfg.bias_attn, vb.pp("v_proj"))?;
129        let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("o_proj"))?;
130        let neg_inf = Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?;
131        Ok(Self {
132            q_proj,
133            k_proj,
134            v_proj,
135            out_proj,
136            rope: rope.clone(),
137            kv_repeat: cfg.kv_repeat,
138            num_heads: cfg.num_heads,
139            context: cfg.context,
140            neg_inf,
141            kv_cache: candle_nn::kv_cache::RotatingKvCache::new(2, cfg.context),
142            pos: 0,
143            use_flash_attn: false,
144            span: tracing::span!(tracing::Level::TRACE, "mha"),
145        })
146    }
147
148    pub fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
149        let _enter = self.span.enter();
150        if self.kv_repeat != 1 {
151            candle::bail!("only kv-repeat = 1 is supported")
152        }
153        let (b, t, hd) = xs.dims3()?;
154        let head_dim = hd / self.num_heads;
155        let q = xs
156            .apply(&self.q_proj)?
157            .reshape((b, t, self.num_heads, head_dim))?;
158        let k = xs
159            .apply(&self.k_proj)?
160            .reshape((b, t, self.num_heads, head_dim))?;
161        let v = xs
162            .apply(&self.v_proj)?
163            .reshape((b, t, self.num_heads, head_dim))?;
164        // qk_layer_norm = None
165        // kv_repeat = 1, otherwise we would need repeat_kv
166        let mut q = q.transpose(1, 2)?.contiguous()?; // b,h,t,d
167        let mut k = k.transpose(1, 2)?.contiguous()?; // b,h,k,d
168        let v = v.transpose(1, 2)?.contiguous()?; // b,h,k,d
169        if let Some(rope) = &self.rope {
170            q = rope.apply_rotary_emb(&q, self.pos)?;
171            k = rope.apply_rotary_emb(&k, self.pos)?;
172        }
173
174        let (k, v) = {
175            self.pos += k.dim(2)?;
176            self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?
177        };
178        // The KV cache keeps all the data at the moment, we want to trim
179        // down the part that comes from the cache to at most context to
180        // be coherent with the mask shape we provide.
181        let k_len = k.dim(2)?;
182        let k_target_len = t + usize::min(self.context, k_len - t);
183        let (k, v) = if k_target_len < k_len {
184            let k = k.narrow(2, k_len - k_target_len, k_target_len)?;
185            let v = v.narrow(2, k_len - k_target_len, k_target_len)?;
186            (k, v)
187        } else {
188            (k.clone(), v.clone())
189        };
190
191        let xs = if q.dtype() == DType::BF16 && self.use_flash_attn {
192            let q = q.transpose(1, 2)?;
193            let k = k.transpose(1, 2)?;
194            let v = v.transpose(1, 2)?;
195            let softmax_scale = 1f32 / (head_dim as f32).sqrt();
196            flash_attn(&q, &k, &v, softmax_scale, t > 1)?.transpose(1, 2)?
197        } else {
198            let pre_ws = q.matmul(&k.t()?)?; // b,h,t,k
199            let pre_ws = (pre_ws * (head_dim as f64).powf(-0.5))?;
200
201            let pre_ws = match mask {
202                None => pre_ws,
203                Some(mask) => {
204                    let mask = mask.broadcast_left((b, self.num_heads))?;
205                    let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?;
206                    mask.where_cond(&neg_inf, &pre_ws)?
207                }
208            };
209
210            let ws = candle_nn::ops::softmax_last_dim(&pre_ws)?; // b,h,t,k
211            ws.matmul(&v)? // b,h,t,d
212        };
213        let xs = xs
214            .transpose(1, 2)? // b,t,h,d
215            .reshape((b, t, hd))?
216            .apply(&self.out_proj)?;
217        Ok(xs)
218    }
219
220    pub fn reset_kv_cache(&mut self) {
221        self.kv_cache.reset()
222    }
223
224    pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::RotatingKvCache) {
225        self.kv_cache = kv_cache
226    }
227}
228
229#[derive(Debug, Clone)]
230pub struct StreamingMultiheadCrossAttention {
231    in_proj_q: Linear,
232    in_proj_k: Linear,
233    in_proj_v: Linear,
234    out_proj: Linear,
235    kv_repeat: usize,
236    num_heads: usize,
237    neg_inf: Tensor,
238    span: tracing::Span,
239}
240
241impl StreamingMultiheadCrossAttention {
242    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
243        let embed_dim = cfg.d_model;
244        let num_kv = cfg.num_heads / cfg.kv_repeat;
245        let kv_dim = num_kv * (embed_dim / cfg.num_heads);
246        let out_dim = embed_dim + 2 * kv_dim;
247        let in_proj_weight = vb.get((out_dim, embed_dim), "in_proj_weight")?;
248        let in_proj_weight_q = in_proj_weight.narrow(0, 0, embed_dim)?;
249        let in_proj_weight_k = in_proj_weight.narrow(0, embed_dim, kv_dim)?;
250        let in_proj_weight_v = in_proj_weight.narrow(0, embed_dim + kv_dim, kv_dim)?;
251        let (in_proj_bias_q, in_proj_bias_k, in_proj_bias_v) = if cfg.bias_attn {
252            let b = vb.get(out_dim, "in_proj_bias")?;
253            let q = b.narrow(0, 0, embed_dim)?;
254            let k = b.narrow(0, embed_dim, kv_dim)?;
255            let v = b.narrow(0, embed_dim + kv_dim, kv_dim)?;
256            (Some(q), Some(k), Some(v))
257        } else {
258            (None, None, None)
259        };
260        let in_proj_q = Linear::new(in_proj_weight_q, in_proj_bias_q);
261        let in_proj_k = Linear::new(in_proj_weight_k, in_proj_bias_k);
262        let in_proj_v = Linear::new(in_proj_weight_v, in_proj_bias_v);
263        let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("out_proj"))?;
264        let neg_inf = Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?;
265        Ok(Self {
266            in_proj_q,
267            in_proj_k,
268            in_proj_v,
269            out_proj,
270            kv_repeat: cfg.kv_repeat,
271            num_heads: cfg.num_heads,
272            neg_inf,
273            span: tracing::span!(tracing::Level::TRACE, "mhca"),
274        })
275    }
276
277    pub fn forward(&self, xs: &Tensor, ca_src: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
278        let _enter = self.span.enter();
279        if self.kv_repeat != 1 {
280            candle::bail!("only kv-repeat = 1 is supported")
281        }
282        let (b, t, hd) = xs.dims3()?;
283        let head_dim = hd / self.num_heads;
284        // time_dim = 1, layout: b,t,h,d
285        let q = xs.apply(&self.in_proj_q)?;
286        let k = ca_src.apply(&self.in_proj_k)?;
287        let v = ca_src.apply(&self.in_proj_v)?;
288        let (ca_b, ca_t, ca_dim) = k.dims3()?;
289        let q = q.reshape((b, t, self.num_heads, head_dim))?;
290        let k = k.reshape((ca_b, ca_t, ca_dim / head_dim, head_dim))?;
291        let v = v.reshape((ca_b, ca_t, ca_dim / head_dim, head_dim))?;
292        // qk_layer_norm = None
293        // kv_repeat = 1, otherwise we would need repeat_kv
294        let q = q.transpose(1, 2)?.contiguous()?; // b,h,t,d
295        let k = k.transpose(1, 2)?.contiguous()?; // b,h,k,d
296        let v = v.transpose(1, 2)?.contiguous()?; // b,h,k,d
297
298        let pre_ws = q.matmul(&k.t()?)?; // b,h,t,k
299        let pre_ws = (pre_ws * (head_dim as f64).powf(-0.5))?;
300
301        let pre_ws = match mask {
302            None => pre_ws,
303            Some(mask) => {
304                let mask = mask.broadcast_left((b, self.num_heads))?;
305                let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?;
306                mask.where_cond(&neg_inf, &pre_ws)?
307            }
308        };
309
310        let ws = candle_nn::ops::softmax_last_dim(&pre_ws)?; // b,h,t,k
311        let xs = ws.matmul(&v)?; // b,h,t,d
312        let xs = xs
313            .transpose(1, 2)? // b,t,h,d
314            .reshape((b, t, hd))?
315            .apply(&self.out_proj)?;
316        Ok(xs)
317    }
318}
319
320#[derive(Debug, Clone)]
321pub enum Mlp {
322    NoGating {
323        span1: tracing::Span,
324        linear1: Linear,
325        span2: tracing::Span,
326        linear2: Linear,
327        span: tracing::Span,
328    },
329    Gating {
330        linear_in: Linear,
331        linear_out: Linear,
332        activation: candle_nn::Activation,
333        span: tracing::Span,
334    },
335}
336
337impl Mlp {
338    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
339        let d_model = cfg.d_model;
340        let span = tracing::span!(tracing::Level::TRACE, "mlp");
341
342        match cfg.gating {
343            None => {
344                let span1 = tracing::span!(tracing::Level::TRACE, "lin1");
345                let span2 = tracing::span!(tracing::Level::TRACE, "lin2");
346                let linear1 = linear(d_model, cfg.dim_feedforward, cfg.bias_ff, vb.pp("mlp.fc1"))?;
347                let linear2 = linear(cfg.dim_feedforward, d_model, cfg.bias_ff, vb.pp("mlp.fc2"))?;
348                Ok(Self::NoGating {
349                    linear1,
350                    linear2,
351                    span,
352                    span1,
353                    span2,
354                })
355            }
356            Some(activation) => {
357                let vb = vb.pp("gating");
358                let hidden = if cfg.dim_feedforward == 4 * d_model {
359                    11 * d_model / 4
360                } else {
361                    2 * cfg.dim_feedforward / 3
362                };
363                // TODO: Maybe use bias_ff here?
364                let linear_in = linear(d_model, 2 * hidden, false, vb.pp("linear_in"))?;
365                let linear_out = linear(hidden, d_model, false, vb.pp("linear_out"))?;
366                Ok(Self::Gating {
367                    linear_in,
368                    linear_out,
369                    activation,
370                    span,
371                })
372            }
373        }
374    }
375}
376
377impl Module for Mlp {
378    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
379        match self {
380            Self::NoGating {
381                linear1,
382                linear2,
383                span,
384                span1,
385                span2,
386            } => {
387                let _enter = span.enter();
388                let xs = {
389                    let _enter = span1.enter();
390                    xs.apply(linear1)?
391                };
392                let xs = xs.gelu_erf()?;
393                {
394                    let _enter = span2.enter();
395                    xs.apply(linear2)
396                }
397            }
398            Self::Gating {
399                linear_in,
400                linear_out,
401                activation,
402                span,
403            } => {
404                let _enter = span.enter();
405                let xs = xs.apply(linear_in)?;
406                let (b, t, _) = xs.dims3()?;
407                let xs = xs.reshape((b, t, 2, ()))?;
408                let xs = (xs.i((.., .., 0))?.apply(activation)? * xs.i((.., .., 1))?)?;
409                xs.apply(linear_out)
410            }
411        }
412    }
413}
414
415#[derive(Debug, Clone)]
416pub struct RmsNorm {
417    pub(crate) alpha: Tensor,
418    pub(crate) eps: f32,
419}
420
421impl RmsNorm {
422    pub fn new(d_model: usize, eps: f32, vb: VarBuilder) -> Result<Self> {
423        let alpha = vb.get((1, 1, d_model), "alpha")?.reshape(d_model)?;
424        Ok(Self { alpha, eps })
425    }
426}
427
428impl Module for RmsNorm {
429    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
430        candle_nn::ops::rms_norm(xs, &self.alpha, self.eps)
431    }
432}
433
434#[derive(Debug, Clone)]
435pub enum Norm {
436    LayerNorm(candle_nn::LayerNorm),
437    RmsNorm(RmsNorm),
438}
439
440impl Norm {
441    pub fn new(d_model: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
442        let norm = match cfg.norm {
443            super::NormType::LayerNorm => {
444                let norm = candle_nn::layer_norm(d_model, 1e-5, vb)?;
445                Self::LayerNorm(norm)
446            }
447            super::NormType::RmsNorm => {
448                let norm = RmsNorm::new(d_model, 1e-8, vb)?;
449                Self::RmsNorm(norm)
450            }
451        };
452        Ok(norm)
453    }
454}
455
456impl Module for Norm {
457    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
458        match self {
459            Self::LayerNorm(m) => m.forward(xs),
460            Self::RmsNorm(m) => m.forward(xs),
461        }
462    }
463}
464
465#[derive(Debug, Clone)]
466pub struct StreamingTransformerLayer {
467    self_attn: StreamingMultiheadAttention,
468    mlp: Mlp,
469    norm1: Norm,
470    norm2: Norm,
471    layer_scale_1: Option<LayerScale>,
472    layer_scale_2: Option<LayerScale>,
473    cross_attn: Option<(candle_nn::LayerNorm, StreamingMultiheadCrossAttention)>,
474    norm_first: bool,
475    span: tracing::Span,
476}
477
478impl StreamingTransformerLayer {
479    pub fn new(rope: &Option<Arc<RotaryEmbedding>>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
480        if cfg.use_conv_block {
481            candle::bail!("conv-block is not supported")
482        }
483        let d_model = cfg.d_model;
484        let mlp = Mlp::new(cfg, vb.clone())?;
485        let (norm1, norm2) = match cfg.norm {
486            super::NormType::LayerNorm => {
487                let norm1 = candle_nn::layer_norm(d_model, 1e-5, vb.pp("input_layernorm"))?;
488                let norm2 =
489                    candle_nn::layer_norm(d_model, 1e-5, vb.pp("post_attention_layernorm"))?;
490                (Norm::LayerNorm(norm1), Norm::LayerNorm(norm2))
491            }
492            super::NormType::RmsNorm => {
493                let norm1 = RmsNorm::new(d_model, 1e-8, vb.pp("input_rmsnorm"))?;
494                let norm2 = RmsNorm::new(d_model, 1e-8, vb.pp("post_attention_rmsnorm"))?;
495                (Norm::RmsNorm(norm1), Norm::RmsNorm(norm2))
496            }
497        };
498        let layer_scale_1 = match cfg.layer_scale {
499            None => None,
500            Some(ls) => {
501                let ls = LayerScale::new(d_model, ls, vb.pp("self_attn_layer_scale"))?;
502                Some(ls)
503            }
504        };
505        let layer_scale_2 = match cfg.layer_scale {
506            None => None,
507            Some(ls) => {
508                let ls = LayerScale::new(d_model, ls, vb.pp("mlp_layer_scale"))?;
509                Some(ls)
510            }
511        };
512        let self_attn = StreamingMultiheadAttention::new(rope, cfg, vb.pp("self_attn"))?;
513        let cross_attn = if cfg.cross_attention {
514            let norm_cross = candle_nn::layer_norm(cfg.d_model, 1e-5, vb.pp("norm_cross"))?;
515            let cross_attn = StreamingMultiheadCrossAttention::new(cfg, vb.pp("cross_attention"))?;
516            Some((norm_cross, cross_attn))
517        } else {
518            None
519        };
520        Ok(Self {
521            self_attn,
522            mlp,
523            norm1,
524            norm2,
525            layer_scale_1,
526            layer_scale_2,
527            cross_attn,
528            norm_first: cfg.norm_first,
529            span: tracing::span!(tracing::Level::TRACE, "transformer-layer"),
530        })
531    }
532
533    pub fn forward(
534        &mut self,
535        xs: &Tensor,
536        ca_src: Option<&Tensor>,
537        mask: Option<&Tensor>,
538    ) -> Result<Tensor> {
539        let _enter = self.span.enter();
540        if !self.norm_first {
541            candle::bail!("only norm_first = true is supported")
542        }
543        let norm1 = xs.apply(&self.norm1)?;
544        let xs = (xs
545            + self
546                .self_attn
547                .forward(&norm1, mask)?
548                .apply(&self.layer_scale_1.as_ref())?)?;
549
550        let xs = match (&self.cross_attn, ca_src) {
551            (Some((norm_cross, cross_attn)), Some(ca_src)) => {
552                let residual = &xs;
553                let xs = xs.apply(norm_cross)?;
554                (residual + cross_attn.forward(&xs, ca_src, None)?)?
555            }
556            _ => xs,
557        };
558
559        let xs = (&xs
560            + xs.apply(&self.norm2)?
561                .apply(&self.mlp)?
562                .apply(&self.layer_scale_2.as_ref()))?;
563        Ok(xs)
564    }
565
566    pub fn reset_kv_cache(&mut self) {
567        self.self_attn.reset_kv_cache()
568    }
569
570    pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::RotatingKvCache) {
571        self.self_attn.set_kv_cache(kv_cache)
572    }
573}
574
575#[derive(Debug, Clone)]
576pub struct StreamingTransformer {
577    layers: Vec<StreamingTransformerLayer>,
578    positional_embedding: PositionalEmbedding,
579    max_period: usize,
580}
581
582impl StreamingTransformer {
583    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
584        let vb_l = vb.pp("layers");
585        let rope = match cfg.positional_embedding {
586            PositionalEmbedding::Rope => {
587                let rope = RotaryEmbedding::new(
588                    cfg.d_model / cfg.num_heads,
589                    cfg.max_seq_len,
590                    cfg.max_period as f32,
591                    vb.device(),
592                )?;
593                Some(Arc::new(rope))
594            }
595            PositionalEmbedding::Sin | PositionalEmbedding::None => None,
596        };
597        let mut layers = Vec::with_capacity(cfg.num_layers);
598        for layer_idx in 0..cfg.num_layers {
599            let layer = StreamingTransformerLayer::new(&rope, cfg, vb_l.pp(layer_idx))?;
600            layers.push(layer)
601        }
602        Ok(Self {
603            layers,
604            positional_embedding: cfg.positional_embedding,
605            max_period: cfg.max_period,
606        })
607    }
608
609    pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
610        self.forward_ca(xs, None)
611    }
612
613    pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&Tensor>) -> Result<Tensor> {
614        let (_b, t, c) = xs.dims3()?;
615        let pos = self.layers[0].self_attn.kv_cache.current_seq_len();
616        let mask = self.layers[0]
617            .self_attn
618            .kv_cache
619            .attn_mask(t, xs.device())?;
620        let mut xs = match self.positional_embedding {
621            PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(),
622            PositionalEmbedding::Sin => {
623                let dev = xs.device();
624                let theta = self.max_period as f32;
625                let half_dim = c / 2;
626                let positions = Tensor::arange(pos as u32, (pos + t) as u32, dev)?
627                    .unsqueeze(1)?
628                    .to_dtype(DType::F32)?;
629                let inv_freq: Vec<_> = (0..half_dim)
630                    .map(|i| 1f32 / theta.powf(i as f32 / (half_dim - 1) as f32))
631                    .collect();
632                let inv_freq_len = inv_freq.len();
633                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
634                let freqs = positions.broadcast_mul(&inv_freq)?;
635                let pos_emb =
636                    Tensor::cat(&[freqs.cos()?, freqs.sin()?], D::Minus1)?.to_dtype(xs.dtype())?;
637                xs.broadcast_add(&pos_emb)?
638            }
639        };
640        for layer in self.layers.iter_mut() {
641            xs = layer.forward(&xs, ca_src, mask.as_ref())?;
642        }
643        Ok(xs)
644    }
645
646    pub fn copy_state(&mut self, from: &Self) -> Result<()> {
647        if self.layers.len() != from.layers.len() {
648            candle::bail!("cannot copy kv-caches as the transformers have different depths")
649        }
650        self.layers
651            .iter_mut()
652            .zip(from.layers.iter())
653            .for_each(|(v, w)| v.set_kv_cache(w.self_attn.kv_cache.clone()));
654        Ok(())
655    }
656}
657
658impl StreamingModule for StreamingTransformer {
659    fn reset_state(&mut self) {
660        self.layers.iter_mut().for_each(|v| v.reset_kv_cache())
661    }
662
663    fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
664        match xs.as_option() {
665            None => Ok(StreamTensor::empty()),
666            Some(xs) => Ok(StreamTensor::from_tensor(self.forward(xs)?)),
667        }
668    }
669}
670
671#[derive(Debug, Clone)]
672pub struct ProjectedTransformer {
673    transformer: StreamingTransformer,
674    input_proj: Option<Linear>,
675    output_projs: Vec<Option<Linear>>,
676    conv_layout: bool,
677    span: tracing::Span,
678}
679
680impl ProjectedTransformer {
681    pub fn new(
682        input_dim: usize,
683        output_dims: &[usize],
684        cfg: &Config,
685        vb: VarBuilder,
686    ) -> Result<Self> {
687        let transformer = StreamingTransformer::new(cfg, vb.clone())?;
688        let input_proj = if input_dim == cfg.d_model {
689            None
690        } else {
691            let l = linear_no_bias(input_dim, cfg.d_model, vb.pp("input_proj"))?;
692            Some(l)
693        };
694        let mut output_projs = Vec::with_capacity(output_dims.len());
695        let vb_o = vb.pp("output_projs");
696        for (i, &output_dim) in output_dims.iter().enumerate() {
697            let output_proj = if output_dim == cfg.d_model {
698                None
699            } else {
700                let l = linear_no_bias(cfg.d_model, output_dim, vb_o.pp(i))?;
701                Some(l)
702            };
703            output_projs.push(output_proj)
704        }
705        Ok(Self {
706            transformer,
707            input_proj,
708            output_projs,
709            conv_layout: cfg.conv_layout,
710            span: tracing::span!(tracing::Level::TRACE, "proj-transformer"),
711        })
712    }
713
714    pub fn forward(&mut self, xs: &Tensor) -> Result<Vec<Tensor>> {
715        let _enter = self.span.enter();
716        let xs = if self.conv_layout {
717            xs.transpose(1, 2)?
718        } else {
719            xs.clone()
720        };
721        let xs = xs.apply(&self.input_proj.as_ref())?;
722        let xs = self.transformer.forward(&xs)?;
723        let mut ys = Vec::with_capacity(self.output_projs.len());
724        for output_proj in self.output_projs.iter() {
725            let ys_ = xs.apply(&output_proj.as_ref())?;
726            let ys_ = if self.conv_layout {
727                ys_.transpose(1, 2)?
728            } else {
729                ys_
730            };
731            ys.push(ys_)
732        }
733        Ok(ys)
734    }
735}
736
737impl StreamingModule for ProjectedTransformer {
738    fn reset_state(&mut self) {
739        self.transformer.reset_state()
740    }
741
742    fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
743        let xs = xs.apply(&|x: &Tensor| {
744            if self.conv_layout {
745                x.transpose(1, 2)
746            } else {
747                Ok(x.clone())
748            }
749        })?;
750        let xs = xs.apply(&self.input_proj.as_ref())?;
751        let xs = self.transformer.step(&xs)?;
752        let ys = xs.apply(&self.output_projs[0].as_ref())?;
753        ys.apply(&|y: &Tensor| {
754            if self.conv_layout {
755                y.transpose(1, 2)
756            } else {
757                Ok(y.clone())
758            }
759        })
760    }
761}
762
763#[cfg(feature = "flash-attn")]
764fn flash_attn(
765    q: &Tensor,
766    k: &Tensor,
767    v: &Tensor,
768    softmax_scale: f32,
769    causal: bool,
770) -> Result<Tensor> {
771    candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
772}
773
774#[cfg(not(feature = "flash-attn"))]
775fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
776    unimplemented!("compile with '--features flash-attn'")
777}