1use candle::{DType, Device, IndexOp, Module, Result, StreamTensor, StreamingModule, Tensor, D};
6use candle_nn::{linear_no_bias, Linear, VarBuilder};
7use std::sync::Arc;
8
9fn linear(in_d: usize, out_d: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
10 if bias {
11 candle_nn::linear(in_d, out_d, vb)
12 } else {
13 linear_no_bias(in_d, out_d, vb)
14 }
15}
16
17#[derive(Debug, Copy, Clone, PartialEq, Eq)]
18pub enum PositionalEmbedding {
19 Rope,
20 Sin,
21 None,
22}
23
24#[derive(Debug, Clone)]
25pub struct Config {
26 pub d_model: usize,
27 pub num_heads: usize,
28 pub num_layers: usize,
29 pub causal: bool,
30 pub norm_first: bool,
31 pub bias_ff: bool,
32 pub bias_attn: bool,
33 pub layer_scale: Option<f64>,
34 pub positional_embedding: PositionalEmbedding,
35 pub use_conv_block: bool,
36 pub cross_attention: bool,
37 pub conv_kernel_size: usize,
38 pub use_conv_bias: bool,
39 pub gating: Option<candle_nn::Activation>,
40 pub norm: super::NormType,
41 pub context: usize,
42 pub max_period: usize,
43 pub max_seq_len: usize,
44
45 pub kv_repeat: usize,
46 pub dim_feedforward: usize,
47 pub conv_layout: bool,
48}
49
50#[derive(Debug, Clone)]
51pub struct RotaryEmbedding {
52 sin: Tensor,
53 cos: Tensor,
54 span: tracing::Span,
55}
56
57impl RotaryEmbedding {
58 pub fn new(dim: usize, max_seq_len: usize, theta: f32, dev: &Device) -> Result<Self> {
59 let inv_freq: Vec<_> = (0..dim)
60 .step_by(2)
61 .map(|i| 1f32 / theta.powf(i as f32 / dim as f32))
62 .collect();
63 let inv_freq_len = inv_freq.len();
64 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
65 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
66 .to_dtype(DType::F32)?
67 .reshape((max_seq_len, 1))?;
68 let freqs = t.matmul(&inv_freq)?;
69 Ok(Self {
70 sin: freqs.sin()?,
71 cos: freqs.cos()?,
72 span: tracing::span!(tracing::Level::TRACE, "rot"),
73 })
74 }
75
76 pub fn apply_rotary_emb(&self, qk: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
77 let _enter = self.span.enter();
78 let (_b_size, _nheads, seqlen, _headdim) = qk.dims4()?;
79 let qk_dtype = qk.dtype();
80 let c = self.cos.narrow(0, seqlen_offset, seqlen)?;
81 let s = self.sin.narrow(0, seqlen_offset, seqlen)?;
82 candle_nn::rotary_emb::rope_i(&qk.to_dtype(DType::F32)?, &c, &s)?.to_dtype(qk_dtype)
83 }
84}
85
86#[derive(Debug, Clone)]
87pub struct LayerScale {
88 scale: Tensor,
89}
90
91impl LayerScale {
92 pub fn new(d_model: usize, _init: f64, vb: VarBuilder) -> Result<Self> {
93 let scale = vb.get(d_model, "scale")?;
94 Ok(Self { scale })
95 }
96}
97
98impl Module for LayerScale {
99 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
100 xs.broadcast_mul(&self.scale)
101 }
102}
103
104#[derive(Debug, Clone)]
105pub struct StreamingMultiheadAttention {
106 q_proj: Linear,
107 k_proj: Linear,
108 v_proj: Linear,
109 out_proj: Linear,
110 kv_repeat: usize,
111 num_heads: usize,
112 context: usize,
113 neg_inf: Tensor,
114 rope: Option<Arc<RotaryEmbedding>>,
115 kv_cache: candle_nn::kv_cache::RotatingKvCache,
116 pos: usize,
117 use_flash_attn: bool,
118 span: tracing::Span,
119}
120
121impl StreamingMultiheadAttention {
122 pub fn new(rope: &Option<Arc<RotaryEmbedding>>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
123 let embed_dim = cfg.d_model;
124 let num_kv = cfg.num_heads / cfg.kv_repeat;
125 let kv_dim = num_kv * (embed_dim / cfg.num_heads);
126 let q_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("q_proj"))?;
127 let k_proj = linear(embed_dim, kv_dim, cfg.bias_attn, vb.pp("k_proj"))?;
128 let v_proj = linear(embed_dim, kv_dim, cfg.bias_attn, vb.pp("v_proj"))?;
129 let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("o_proj"))?;
130 let neg_inf = Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?;
131 Ok(Self {
132 q_proj,
133 k_proj,
134 v_proj,
135 out_proj,
136 rope: rope.clone(),
137 kv_repeat: cfg.kv_repeat,
138 num_heads: cfg.num_heads,
139 context: cfg.context,
140 neg_inf,
141 kv_cache: candle_nn::kv_cache::RotatingKvCache::new(2, cfg.context),
142 pos: 0,
143 use_flash_attn: false,
144 span: tracing::span!(tracing::Level::TRACE, "mha"),
145 })
146 }
147
148 pub fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
149 let _enter = self.span.enter();
150 if self.kv_repeat != 1 {
151 candle::bail!("only kv-repeat = 1 is supported")
152 }
153 let (b, t, hd) = xs.dims3()?;
154 let head_dim = hd / self.num_heads;
155 let q = xs
156 .apply(&self.q_proj)?
157 .reshape((b, t, self.num_heads, head_dim))?;
158 let k = xs
159 .apply(&self.k_proj)?
160 .reshape((b, t, self.num_heads, head_dim))?;
161 let v = xs
162 .apply(&self.v_proj)?
163 .reshape((b, t, self.num_heads, head_dim))?;
164 let mut q = q.transpose(1, 2)?.contiguous()?; let mut k = k.transpose(1, 2)?.contiguous()?; let v = v.transpose(1, 2)?.contiguous()?; if let Some(rope) = &self.rope {
170 q = rope.apply_rotary_emb(&q, self.pos)?;
171 k = rope.apply_rotary_emb(&k, self.pos)?;
172 }
173
174 let (k, v) = {
175 self.pos += k.dim(2)?;
176 self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?
177 };
178 let k_len = k.dim(2)?;
182 let k_target_len = t + usize::min(self.context, k_len - t);
183 let (k, v) = if k_target_len < k_len {
184 let k = k.narrow(2, k_len - k_target_len, k_target_len)?;
185 let v = v.narrow(2, k_len - k_target_len, k_target_len)?;
186 (k, v)
187 } else {
188 (k.clone(), v.clone())
189 };
190
191 let xs = if q.dtype() == DType::BF16 && self.use_flash_attn {
192 let q = q.transpose(1, 2)?;
193 let k = k.transpose(1, 2)?;
194 let v = v.transpose(1, 2)?;
195 let softmax_scale = 1f32 / (head_dim as f32).sqrt();
196 flash_attn(&q, &k, &v, softmax_scale, t > 1)?.transpose(1, 2)?
197 } else {
198 let pre_ws = q.matmul(&k.t()?)?; let pre_ws = (pre_ws * (head_dim as f64).powf(-0.5))?;
200
201 let pre_ws = match mask {
202 None => pre_ws,
203 Some(mask) => {
204 let mask = mask.broadcast_left((b, self.num_heads))?;
205 let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?;
206 mask.where_cond(&neg_inf, &pre_ws)?
207 }
208 };
209
210 let ws = candle_nn::ops::softmax_last_dim(&pre_ws)?; ws.matmul(&v)? };
213 let xs = xs
214 .transpose(1, 2)? .reshape((b, t, hd))?
216 .apply(&self.out_proj)?;
217 Ok(xs)
218 }
219
220 pub fn reset_kv_cache(&mut self) {
221 self.kv_cache.reset()
222 }
223
224 pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::RotatingKvCache) {
225 self.kv_cache = kv_cache
226 }
227}
228
229#[derive(Debug, Clone)]
230pub struct StreamingMultiheadCrossAttention {
231 in_proj_q: Linear,
232 in_proj_k: Linear,
233 in_proj_v: Linear,
234 out_proj: Linear,
235 kv_repeat: usize,
236 num_heads: usize,
237 neg_inf: Tensor,
238 span: tracing::Span,
239}
240
241impl StreamingMultiheadCrossAttention {
242 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
243 let embed_dim = cfg.d_model;
244 let num_kv = cfg.num_heads / cfg.kv_repeat;
245 let kv_dim = num_kv * (embed_dim / cfg.num_heads);
246 let out_dim = embed_dim + 2 * kv_dim;
247 let in_proj_weight = vb.get((out_dim, embed_dim), "in_proj_weight")?;
248 let in_proj_weight_q = in_proj_weight.narrow(0, 0, embed_dim)?;
249 let in_proj_weight_k = in_proj_weight.narrow(0, embed_dim, kv_dim)?;
250 let in_proj_weight_v = in_proj_weight.narrow(0, embed_dim + kv_dim, kv_dim)?;
251 let (in_proj_bias_q, in_proj_bias_k, in_proj_bias_v) = if cfg.bias_attn {
252 let b = vb.get(out_dim, "in_proj_bias")?;
253 let q = b.narrow(0, 0, embed_dim)?;
254 let k = b.narrow(0, embed_dim, kv_dim)?;
255 let v = b.narrow(0, embed_dim + kv_dim, kv_dim)?;
256 (Some(q), Some(k), Some(v))
257 } else {
258 (None, None, None)
259 };
260 let in_proj_q = Linear::new(in_proj_weight_q, in_proj_bias_q);
261 let in_proj_k = Linear::new(in_proj_weight_k, in_proj_bias_k);
262 let in_proj_v = Linear::new(in_proj_weight_v, in_proj_bias_v);
263 let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("out_proj"))?;
264 let neg_inf = Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?;
265 Ok(Self {
266 in_proj_q,
267 in_proj_k,
268 in_proj_v,
269 out_proj,
270 kv_repeat: cfg.kv_repeat,
271 num_heads: cfg.num_heads,
272 neg_inf,
273 span: tracing::span!(tracing::Level::TRACE, "mhca"),
274 })
275 }
276
277 pub fn forward(&self, xs: &Tensor, ca_src: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
278 let _enter = self.span.enter();
279 if self.kv_repeat != 1 {
280 candle::bail!("only kv-repeat = 1 is supported")
281 }
282 let (b, t, hd) = xs.dims3()?;
283 let head_dim = hd / self.num_heads;
284 let q = xs.apply(&self.in_proj_q)?;
286 let k = ca_src.apply(&self.in_proj_k)?;
287 let v = ca_src.apply(&self.in_proj_v)?;
288 let (ca_b, ca_t, ca_dim) = k.dims3()?;
289 let q = q.reshape((b, t, self.num_heads, head_dim))?;
290 let k = k.reshape((ca_b, ca_t, ca_dim / head_dim, head_dim))?;
291 let v = v.reshape((ca_b, ca_t, ca_dim / head_dim, head_dim))?;
292 let q = q.transpose(1, 2)?.contiguous()?; let k = k.transpose(1, 2)?.contiguous()?; let v = v.transpose(1, 2)?.contiguous()?; let pre_ws = q.matmul(&k.t()?)?; let pre_ws = (pre_ws * (head_dim as f64).powf(-0.5))?;
300
301 let pre_ws = match mask {
302 None => pre_ws,
303 Some(mask) => {
304 let mask = mask.broadcast_left((b, self.num_heads))?;
305 let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?;
306 mask.where_cond(&neg_inf, &pre_ws)?
307 }
308 };
309
310 let ws = candle_nn::ops::softmax_last_dim(&pre_ws)?; let xs = ws.matmul(&v)?; let xs = xs
313 .transpose(1, 2)? .reshape((b, t, hd))?
315 .apply(&self.out_proj)?;
316 Ok(xs)
317 }
318}
319
320#[derive(Debug, Clone)]
321pub enum Mlp {
322 NoGating {
323 span1: tracing::Span,
324 linear1: Linear,
325 span2: tracing::Span,
326 linear2: Linear,
327 span: tracing::Span,
328 },
329 Gating {
330 linear_in: Linear,
331 linear_out: Linear,
332 activation: candle_nn::Activation,
333 span: tracing::Span,
334 },
335}
336
337impl Mlp {
338 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
339 let d_model = cfg.d_model;
340 let span = tracing::span!(tracing::Level::TRACE, "mlp");
341
342 match cfg.gating {
343 None => {
344 let span1 = tracing::span!(tracing::Level::TRACE, "lin1");
345 let span2 = tracing::span!(tracing::Level::TRACE, "lin2");
346 let linear1 = linear(d_model, cfg.dim_feedforward, cfg.bias_ff, vb.pp("mlp.fc1"))?;
347 let linear2 = linear(cfg.dim_feedforward, d_model, cfg.bias_ff, vb.pp("mlp.fc2"))?;
348 Ok(Self::NoGating {
349 linear1,
350 linear2,
351 span,
352 span1,
353 span2,
354 })
355 }
356 Some(activation) => {
357 let vb = vb.pp("gating");
358 let hidden = if cfg.dim_feedforward == 4 * d_model {
359 11 * d_model / 4
360 } else {
361 2 * cfg.dim_feedforward / 3
362 };
363 let linear_in = linear(d_model, 2 * hidden, false, vb.pp("linear_in"))?;
365 let linear_out = linear(hidden, d_model, false, vb.pp("linear_out"))?;
366 Ok(Self::Gating {
367 linear_in,
368 linear_out,
369 activation,
370 span,
371 })
372 }
373 }
374 }
375}
376
377impl Module for Mlp {
378 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
379 match self {
380 Self::NoGating {
381 linear1,
382 linear2,
383 span,
384 span1,
385 span2,
386 } => {
387 let _enter = span.enter();
388 let xs = {
389 let _enter = span1.enter();
390 xs.apply(linear1)?
391 };
392 let xs = xs.gelu_erf()?;
393 {
394 let _enter = span2.enter();
395 xs.apply(linear2)
396 }
397 }
398 Self::Gating {
399 linear_in,
400 linear_out,
401 activation,
402 span,
403 } => {
404 let _enter = span.enter();
405 let xs = xs.apply(linear_in)?;
406 let (b, t, _) = xs.dims3()?;
407 let xs = xs.reshape((b, t, 2, ()))?;
408 let xs = (xs.i((.., .., 0))?.apply(activation)? * xs.i((.., .., 1))?)?;
409 xs.apply(linear_out)
410 }
411 }
412 }
413}
414
415#[derive(Debug, Clone)]
416pub struct RmsNorm {
417 pub(crate) alpha: Tensor,
418 pub(crate) eps: f32,
419}
420
421impl RmsNorm {
422 pub fn new(d_model: usize, eps: f32, vb: VarBuilder) -> Result<Self> {
423 let alpha = vb.get((1, 1, d_model), "alpha")?.reshape(d_model)?;
424 Ok(Self { alpha, eps })
425 }
426}
427
428impl Module for RmsNorm {
429 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
430 candle_nn::ops::rms_norm(xs, &self.alpha, self.eps)
431 }
432}
433
434#[derive(Debug, Clone)]
435pub enum Norm {
436 LayerNorm(candle_nn::LayerNorm),
437 RmsNorm(RmsNorm),
438}
439
440impl Norm {
441 pub fn new(d_model: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
442 let norm = match cfg.norm {
443 super::NormType::LayerNorm => {
444 let norm = candle_nn::layer_norm(d_model, 1e-5, vb)?;
445 Self::LayerNorm(norm)
446 }
447 super::NormType::RmsNorm => {
448 let norm = RmsNorm::new(d_model, 1e-8, vb)?;
449 Self::RmsNorm(norm)
450 }
451 };
452 Ok(norm)
453 }
454}
455
456impl Module for Norm {
457 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
458 match self {
459 Self::LayerNorm(m) => m.forward(xs),
460 Self::RmsNorm(m) => m.forward(xs),
461 }
462 }
463}
464
465#[derive(Debug, Clone)]
466pub struct StreamingTransformerLayer {
467 self_attn: StreamingMultiheadAttention,
468 mlp: Mlp,
469 norm1: Norm,
470 norm2: Norm,
471 layer_scale_1: Option<LayerScale>,
472 layer_scale_2: Option<LayerScale>,
473 cross_attn: Option<(candle_nn::LayerNorm, StreamingMultiheadCrossAttention)>,
474 norm_first: bool,
475 span: tracing::Span,
476}
477
478impl StreamingTransformerLayer {
479 pub fn new(rope: &Option<Arc<RotaryEmbedding>>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
480 if cfg.use_conv_block {
481 candle::bail!("conv-block is not supported")
482 }
483 let d_model = cfg.d_model;
484 let mlp = Mlp::new(cfg, vb.clone())?;
485 let (norm1, norm2) = match cfg.norm {
486 super::NormType::LayerNorm => {
487 let norm1 = candle_nn::layer_norm(d_model, 1e-5, vb.pp("input_layernorm"))?;
488 let norm2 =
489 candle_nn::layer_norm(d_model, 1e-5, vb.pp("post_attention_layernorm"))?;
490 (Norm::LayerNorm(norm1), Norm::LayerNorm(norm2))
491 }
492 super::NormType::RmsNorm => {
493 let norm1 = RmsNorm::new(d_model, 1e-8, vb.pp("input_rmsnorm"))?;
494 let norm2 = RmsNorm::new(d_model, 1e-8, vb.pp("post_attention_rmsnorm"))?;
495 (Norm::RmsNorm(norm1), Norm::RmsNorm(norm2))
496 }
497 };
498 let layer_scale_1 = match cfg.layer_scale {
499 None => None,
500 Some(ls) => {
501 let ls = LayerScale::new(d_model, ls, vb.pp("self_attn_layer_scale"))?;
502 Some(ls)
503 }
504 };
505 let layer_scale_2 = match cfg.layer_scale {
506 None => None,
507 Some(ls) => {
508 let ls = LayerScale::new(d_model, ls, vb.pp("mlp_layer_scale"))?;
509 Some(ls)
510 }
511 };
512 let self_attn = StreamingMultiheadAttention::new(rope, cfg, vb.pp("self_attn"))?;
513 let cross_attn = if cfg.cross_attention {
514 let norm_cross = candle_nn::layer_norm(cfg.d_model, 1e-5, vb.pp("norm_cross"))?;
515 let cross_attn = StreamingMultiheadCrossAttention::new(cfg, vb.pp("cross_attention"))?;
516 Some((norm_cross, cross_attn))
517 } else {
518 None
519 };
520 Ok(Self {
521 self_attn,
522 mlp,
523 norm1,
524 norm2,
525 layer_scale_1,
526 layer_scale_2,
527 cross_attn,
528 norm_first: cfg.norm_first,
529 span: tracing::span!(tracing::Level::TRACE, "transformer-layer"),
530 })
531 }
532
533 pub fn forward(
534 &mut self,
535 xs: &Tensor,
536 ca_src: Option<&Tensor>,
537 mask: Option<&Tensor>,
538 ) -> Result<Tensor> {
539 let _enter = self.span.enter();
540 if !self.norm_first {
541 candle::bail!("only norm_first = true is supported")
542 }
543 let norm1 = xs.apply(&self.norm1)?;
544 let xs = (xs
545 + self
546 .self_attn
547 .forward(&norm1, mask)?
548 .apply(&self.layer_scale_1.as_ref())?)?;
549
550 let xs = match (&self.cross_attn, ca_src) {
551 (Some((norm_cross, cross_attn)), Some(ca_src)) => {
552 let residual = &xs;
553 let xs = xs.apply(norm_cross)?;
554 (residual + cross_attn.forward(&xs, ca_src, None)?)?
555 }
556 _ => xs,
557 };
558
559 let xs = (&xs
560 + xs.apply(&self.norm2)?
561 .apply(&self.mlp)?
562 .apply(&self.layer_scale_2.as_ref()))?;
563 Ok(xs)
564 }
565
566 pub fn reset_kv_cache(&mut self) {
567 self.self_attn.reset_kv_cache()
568 }
569
570 pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::RotatingKvCache) {
571 self.self_attn.set_kv_cache(kv_cache)
572 }
573}
574
575#[derive(Debug, Clone)]
576pub struct StreamingTransformer {
577 layers: Vec<StreamingTransformerLayer>,
578 positional_embedding: PositionalEmbedding,
579 max_period: usize,
580}
581
582impl StreamingTransformer {
583 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
584 let vb_l = vb.pp("layers");
585 let rope = match cfg.positional_embedding {
586 PositionalEmbedding::Rope => {
587 let rope = RotaryEmbedding::new(
588 cfg.d_model / cfg.num_heads,
589 cfg.max_seq_len,
590 cfg.max_period as f32,
591 vb.device(),
592 )?;
593 Some(Arc::new(rope))
594 }
595 PositionalEmbedding::Sin | PositionalEmbedding::None => None,
596 };
597 let mut layers = Vec::with_capacity(cfg.num_layers);
598 for layer_idx in 0..cfg.num_layers {
599 let layer = StreamingTransformerLayer::new(&rope, cfg, vb_l.pp(layer_idx))?;
600 layers.push(layer)
601 }
602 Ok(Self {
603 layers,
604 positional_embedding: cfg.positional_embedding,
605 max_period: cfg.max_period,
606 })
607 }
608
609 pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
610 self.forward_ca(xs, None)
611 }
612
613 pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&Tensor>) -> Result<Tensor> {
614 let (_b, t, c) = xs.dims3()?;
615 let pos = self.layers[0].self_attn.kv_cache.current_seq_len();
616 let mask = self.layers[0]
617 .self_attn
618 .kv_cache
619 .attn_mask(t, xs.device())?;
620 let mut xs = match self.positional_embedding {
621 PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(),
622 PositionalEmbedding::Sin => {
623 let dev = xs.device();
624 let theta = self.max_period as f32;
625 let half_dim = c / 2;
626 let positions = Tensor::arange(pos as u32, (pos + t) as u32, dev)?
627 .unsqueeze(1)?
628 .to_dtype(DType::F32)?;
629 let inv_freq: Vec<_> = (0..half_dim)
630 .map(|i| 1f32 / theta.powf(i as f32 / (half_dim - 1) as f32))
631 .collect();
632 let inv_freq_len = inv_freq.len();
633 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
634 let freqs = positions.broadcast_mul(&inv_freq)?;
635 let pos_emb =
636 Tensor::cat(&[freqs.cos()?, freqs.sin()?], D::Minus1)?.to_dtype(xs.dtype())?;
637 xs.broadcast_add(&pos_emb)?
638 }
639 };
640 for layer in self.layers.iter_mut() {
641 xs = layer.forward(&xs, ca_src, mask.as_ref())?;
642 }
643 Ok(xs)
644 }
645
646 pub fn copy_state(&mut self, from: &Self) -> Result<()> {
647 if self.layers.len() != from.layers.len() {
648 candle::bail!("cannot copy kv-caches as the transformers have different depths")
649 }
650 self.layers
651 .iter_mut()
652 .zip(from.layers.iter())
653 .for_each(|(v, w)| v.set_kv_cache(w.self_attn.kv_cache.clone()));
654 Ok(())
655 }
656}
657
658impl StreamingModule for StreamingTransformer {
659 fn reset_state(&mut self) {
660 self.layers.iter_mut().for_each(|v| v.reset_kv_cache())
661 }
662
663 fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
664 match xs.as_option() {
665 None => Ok(StreamTensor::empty()),
666 Some(xs) => Ok(StreamTensor::from_tensor(self.forward(xs)?)),
667 }
668 }
669}
670
671#[derive(Debug, Clone)]
672pub struct ProjectedTransformer {
673 transformer: StreamingTransformer,
674 input_proj: Option<Linear>,
675 output_projs: Vec<Option<Linear>>,
676 conv_layout: bool,
677 span: tracing::Span,
678}
679
680impl ProjectedTransformer {
681 pub fn new(
682 input_dim: usize,
683 output_dims: &[usize],
684 cfg: &Config,
685 vb: VarBuilder,
686 ) -> Result<Self> {
687 let transformer = StreamingTransformer::new(cfg, vb.clone())?;
688 let input_proj = if input_dim == cfg.d_model {
689 None
690 } else {
691 let l = linear_no_bias(input_dim, cfg.d_model, vb.pp("input_proj"))?;
692 Some(l)
693 };
694 let mut output_projs = Vec::with_capacity(output_dims.len());
695 let vb_o = vb.pp("output_projs");
696 for (i, &output_dim) in output_dims.iter().enumerate() {
697 let output_proj = if output_dim == cfg.d_model {
698 None
699 } else {
700 let l = linear_no_bias(cfg.d_model, output_dim, vb_o.pp(i))?;
701 Some(l)
702 };
703 output_projs.push(output_proj)
704 }
705 Ok(Self {
706 transformer,
707 input_proj,
708 output_projs,
709 conv_layout: cfg.conv_layout,
710 span: tracing::span!(tracing::Level::TRACE, "proj-transformer"),
711 })
712 }
713
714 pub fn forward(&mut self, xs: &Tensor) -> Result<Vec<Tensor>> {
715 let _enter = self.span.enter();
716 let xs = if self.conv_layout {
717 xs.transpose(1, 2)?
718 } else {
719 xs.clone()
720 };
721 let xs = xs.apply(&self.input_proj.as_ref())?;
722 let xs = self.transformer.forward(&xs)?;
723 let mut ys = Vec::with_capacity(self.output_projs.len());
724 for output_proj in self.output_projs.iter() {
725 let ys_ = xs.apply(&output_proj.as_ref())?;
726 let ys_ = if self.conv_layout {
727 ys_.transpose(1, 2)?
728 } else {
729 ys_
730 };
731 ys.push(ys_)
732 }
733 Ok(ys)
734 }
735}
736
737impl StreamingModule for ProjectedTransformer {
738 fn reset_state(&mut self) {
739 self.transformer.reset_state()
740 }
741
742 fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
743 let xs = xs.apply(&|x: &Tensor| {
744 if self.conv_layout {
745 x.transpose(1, 2)
746 } else {
747 Ok(x.clone())
748 }
749 })?;
750 let xs = xs.apply(&self.input_proj.as_ref())?;
751 let xs = self.transformer.step(&xs)?;
752 let ys = xs.apply(&self.output_projs[0].as_ref())?;
753 ys.apply(&|y: &Tensor| {
754 if self.conv_layout {
755 y.transpose(1, 2)
756 } else {
757 Ok(y.clone())
758 }
759 })
760 }
761}
762
763#[cfg(feature = "flash-attn")]
764fn flash_attn(
765 q: &Tensor,
766 k: &Tensor,
767 v: &Tensor,
768 softmax_scale: f32,
769 causal: bool,
770) -> Result<Tensor> {
771 candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
772}
773
774#[cfg(not(feature = "flash-attn"))]
775fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
776 unimplemented!("compile with '--features flash-attn'")
777}