1use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
21use candle_nn::{linear_b as linear, Linear, VarBuilder};
22use std::sync::Arc;
23
24#[derive(serde::Deserialize, Debug, Clone, Copy)]
25#[serde(rename_all = "snake_case")]
26pub enum TemporalBlockType {
27 Attention,
28 Recurrent,
29}
30
31#[derive(serde::Deserialize, Debug, Clone)]
32pub struct Config {
33 pub num_hidden_layers: usize,
34 pub vocab_size: usize,
35 pub hidden_size: usize,
36 pub intermediate_size: usize,
37 pub num_attention_heads: usize,
38 pub num_key_value_heads: usize,
39 pub head_dim: usize,
40 pub lru_width: Option<usize>,
41 pub attention_window_size: usize,
42 pub conv1d_width: usize,
43 pub logits_soft_cap: f64,
44 pub hidden_activation: candle_nn::Activation,
45 pub partial_rotary_factor: f64,
46 pub rms_norm_eps: f64,
47 pub rope_theta: f64,
48 #[serde(alias = "_block_types")]
49 pub block_types: Vec<TemporalBlockType>,
50 pub attention_bias: bool,
51 #[serde(default = "default_max_seq_len")]
52 pub max_seq_len: usize,
53}
54
55fn default_max_seq_len() -> usize {
56 8192
57}
58
59#[derive(Debug, Clone)]
60pub(crate) struct RmsNorm {
61 weight: Tensor,
62 eps: f64,
63}
64
65impl RmsNorm {
66 pub(crate) fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
67 let weight = vb.get(dim, "weight")?;
68 Ok(Self { weight, eps })
69 }
70
71 pub(crate) fn from_weight(weight: Tensor, eps: f64) -> Self {
72 Self { weight, eps }
73 }
74}
75
76impl Module for RmsNorm {
77 fn forward(&self, x: &Tensor) -> Result<Tensor> {
78 let x_dtype = x.dtype();
79 let internal_dtype = match x_dtype {
80 DType::F16 | DType::BF16 => DType::F32,
81 d => d,
82 };
83 let hidden_size = x.dim(D::Minus1)?;
84 let x = x.to_dtype(internal_dtype)?;
85 let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
86 let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
87 x_normed
88 .to_dtype(x_dtype)?
89 .broadcast_mul(&(&self.weight + 1.0)?)
90 }
91}
92
93#[derive(Debug, Clone)]
94pub(crate) struct RotaryEmbedding {
95 sin: Tensor,
96 cos: Tensor,
97}
98
99fn rotate_half(xs: &Tensor) -> Result<Tensor> {
100 let last_dim = xs.dim(D::Minus1)?;
101 let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
102 let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
103 Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
104}
105
106impl RotaryEmbedding {
107 pub(crate) fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
108 if cfg.partial_rotary_factor != 0.5 {
109 candle::bail!("partial-rotary-factor {} <> 0.5", cfg.partial_rotary_factor)
110 }
111 let dim = cfg.head_dim / 2;
112 let max_seq_len = cfg.max_seq_len;
113 let inv_freq: Vec<_> = (0..dim)
114 .step_by(2)
115 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
116 .collect();
117 let inv_freq_len = inv_freq.len();
118 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
119 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
120 .to_dtype(dtype)?
121 .reshape((max_seq_len, 1))?;
122 let freqs = t.matmul(&inv_freq)?;
123 let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
124 Ok(Self {
125 sin: freqs.sin()?,
126 cos: freqs.cos()?,
127 })
128 }
129
130 pub(crate) fn apply_rotary_emb_qkv(
131 &self,
132 q: &Tensor,
133 k: &Tensor,
134 seqlen_offset: usize,
135 ) -> Result<(Tensor, Tensor)> {
136 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
137 let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
138 let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
139 let cos = cos.unsqueeze(0)?.unsqueeze(0)?; let sin = sin.unsqueeze(0)?.unsqueeze(0)?; let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
142 let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
143 Ok((q_embed, k_embed))
144 }
145}
146
147#[derive(Debug, Clone)]
148struct Mlp {
149 gate_proj: Linear,
150 up_proj: Linear,
151 down_proj: Linear,
152 act_fn: candle_nn::Activation,
153}
154
155impl Mlp {
156 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
157 let h = cfg.hidden_size;
158 let intermediate_size = cfg.intermediate_size / 2;
159 let gate_proj = linear(h, intermediate_size, true, vb.pp("gate_proj"))?;
160 let up_proj = linear(h, intermediate_size, true, vb.pp("up_proj"))?;
161 let down_proj = linear(intermediate_size, h, true, vb.pp("down_proj"))?;
162 Ok(Self {
163 gate_proj,
164 up_proj,
165 down_proj,
166 act_fn: cfg.hidden_activation,
167 })
168 }
169}
170
171impl Module for Mlp {
172 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
173 let gate = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
174 (gate * xs.apply(&self.up_proj))?.apply(&self.down_proj)
175 }
176}
177
178#[derive(Debug, Clone)]
180pub(crate) struct Rglru {
181 pub(crate) recurrent_param: Tensor,
182 pub(crate) input_gate_weight: Tensor,
183 pub(crate) input_gate_bias: Tensor,
184 pub(crate) recurrent_gate_weight: Tensor,
185 pub(crate) recurrent_gate_bias: Tensor,
186 pub(crate) block_width: usize,
187 pub(crate) n_heads: usize,
188 pub(crate) recurrent_states: Option<Tensor>,
189}
190
191fn baddbmm(a: &Tensor, b: &Tensor, c: &Tensor) -> Result<Tensor> {
192 a.broadcast_add(&b.matmul(c)?)
193}
194
195fn softplus(xs: &Tensor) -> Result<Tensor> {
196 (xs.exp()? + 1.0)?.log()
197}
198
199impl Rglru {
200 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
201 let h = cfg.hidden_size;
202 let lru_width = cfg.lru_width.unwrap_or(h);
203 let n_heads = cfg.num_attention_heads;
204 let block_width = lru_width / n_heads;
205 let recurrent_param = vb.get((lru_width,), "recurrent_param")?;
206 let input_gate_weight = vb.get((n_heads, block_width, block_width), "input_gate_weight")?;
207 let input_gate_bias = vb.get((n_heads, block_width), "input_gate_bias")?;
208 let recurrent_gate_weight =
209 vb.get((n_heads, block_width, block_width), "recurrent_gate_weight")?;
210 let recurrent_gate_bias = vb.get((n_heads, block_width), "recurrent_gate_bias")?;
211 Ok(Self {
212 recurrent_param,
213 input_gate_bias,
214 input_gate_weight,
215 recurrent_gate_bias,
216 recurrent_gate_weight,
217 block_width,
218 n_heads,
219 recurrent_states: None,
220 })
221 }
222
223 pub(crate) fn forward(&mut self, xs: &Tensor, pos: usize) -> Result<Tensor> {
225 let (b_sz, seq_len, lru_width) = xs.dims3()?;
226 let pos = Tensor::arange(pos as u32, (pos + seq_len) as u32, xs.device())?;
227 let reset = pos.eq(0u32)?.unsqueeze(1)?.unsqueeze(0)?;
228 let reshape_act = xs
229 .reshape((b_sz * seq_len, self.n_heads, self.block_width))?
230 .permute((1, 0, 2))?
231 .contiguous()?;
232
233 let res = baddbmm(
234 &self.input_gate_bias.unsqueeze(1)?,
235 &reshape_act,
236 &self.input_gate_weight,
237 )?;
238 let input_gate = res.transpose(0, 1)?.reshape((b_sz, seq_len, lru_width))?;
239 let input_gate = candle_nn::ops::sigmoid(&input_gate)?;
240 let res = baddbmm(
241 &self.recurrent_gate_bias.unsqueeze(1)?,
242 &reshape_act,
243 &self.recurrent_gate_weight,
244 )?;
245 let recurrent_gate = res.transpose(0, 1)?.reshape((b_sz, seq_len, lru_width))?;
246 let recurrent_gate = candle_nn::ops::sigmoid(&recurrent_gate)?;
247
248 let log_recurrent_gate =
249 (recurrent_gate * (-8.0))?.broadcast_mul(&softplus(&self.recurrent_param)?)?;
250 let recurrent_gate = log_recurrent_gate.exp()?;
251 let a_square = (log_recurrent_gate * 2.)?.exp()?;
252
253 let gated_inputs = (xs * input_gate)?;
255
256 let reset = reset.to_dtype(a_square.dtype())?;
257 let multiplier =
258 reset.broadcast_add(&((1.0 - &reset)?.broadcast_mul(&(1.0 - a_square)?.sqrt()?))?)?;
259 let normalized_x = (gated_inputs * multiplier.to_dtype(xs.dtype()))?;
260
261 let (hidden_states, recurrent_states) = rnn_scan(
262 &normalized_x,
263 &recurrent_gate,
264 &reset,
265 self.recurrent_states.as_ref(),
266 )?;
267 self.recurrent_states = Some(recurrent_states);
268 Ok(hidden_states)
269 }
270}
271
272fn rnn_scan(
273 hidden_states: &Tensor,
274 recurrent_gate: &Tensor,
275 reset: &Tensor,
276 recurrent_states: Option<&Tensor>,
277) -> Result<(Tensor, Tensor)> {
278 let acc_dtype = DType::F32;
279 let dev = hidden_states.device();
280 let in_dtype = hidden_states.dtype();
281 let inv_reset = (1.0 - reset)?.to_dtype(recurrent_gate.dtype())?;
282 let recurrent_gate = recurrent_gate.broadcast_mul(&inv_reset)?;
283 let (c, r) = if hidden_states.dim(1)? == 1 {
284 match recurrent_states {
285 None => {
286 let next_state = hidden_states.i((.., 0))?.to_dtype(acc_dtype)?;
287 (hidden_states.clone(), next_state)
288 }
289 Some(recurrent_states) => {
290 let contextualized_states =
291 recurrent_gate.to_dtype(acc_dtype)? * recurrent_states.unsqueeze(1)?;
292 let contextualized_states =
293 (contextualized_states + hidden_states.to_dtype(acc_dtype)?)?;
294 let c = contextualized_states.to_dtype(in_dtype)?;
295 let l = contextualized_states.dim(1)?;
296 let r = contextualized_states.i((.., l - 1))?;
297 (c, r)
298 }
299 }
300 } else {
301 let mut recurrent_states = match recurrent_states {
302 None => Tensor::zeros(hidden_states.i((.., 0))?.shape(), acc_dtype, dev)?,
303 Some(r) => r.clone(),
304 };
305 let mut contextualized_states = vec![];
306 for t in 0..hidden_states.dim(1)? {
307 recurrent_states =
308 (recurrent_gate.i((.., t))?.to_dtype(acc_dtype)? * recurrent_states)?;
309 recurrent_states =
310 (recurrent_states + hidden_states.i((.., t))?.to_dtype(acc_dtype)?)?;
311 contextualized_states.push(recurrent_states.to_dtype(in_dtype)?)
312 }
313 let contextualized_states = Tensor::stack(&contextualized_states, 1)?;
314 (contextualized_states, recurrent_states)
315 };
316 Ok((c, r))
317}
318
319#[derive(Debug, Clone)]
320struct RecurrentBlock {
321 linear_y: Linear,
322 linear_x: Linear,
323 linear_out: Linear,
324 conv_1d: candle_nn::Conv1d,
325 conv1d_state: Option<Tensor>,
326 conv1d_width: usize,
327 rg_lru: Rglru,
328 act_fn: candle_nn::Activation,
329}
330
331impl RecurrentBlock {
332 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
333 let h = cfg.hidden_size;
334 let lru_width = cfg.lru_width.unwrap_or(h);
335 let linear_y = linear(h, lru_width, true, vb.pp("linear_y"))?;
336 let linear_x = linear(h, lru_width, true, vb.pp("linear_x"))?;
337 let linear_out = linear(lru_width, h, true, vb.pp("linear_out"))?;
338 let conv_1d = candle_nn::conv1d(
339 lru_width,
340 lru_width,
341 cfg.conv1d_width,
342 candle_nn::Conv1dConfig {
343 groups: lru_width,
344 padding: cfg.conv1d_width - 1,
345 ..Default::default()
346 },
347 vb.pp("conv_1d"),
348 )?;
349 let rg_lru = Rglru::new(cfg, vb.pp("rg_lru"))?;
350 Ok(Self {
351 linear_y,
352 linear_x,
353 linear_out,
354 conv_1d,
355 conv1d_state: None,
356 conv1d_width: cfg.conv1d_width,
357 rg_lru,
358 act_fn: cfg.hidden_activation,
359 })
360 }
361
362 pub fn forward(&mut self, xs: &Tensor, pos: usize) -> Result<Tensor> {
363 let (_b_sz, seq_len, _) = xs.dims3()?;
364
365 let y_branch = xs.apply(&self.linear_y)?.apply(&self.act_fn)?;
366 let x_branch = xs.apply(&self.linear_x)?.transpose(1, 2)?;
367 let x_branch = if pos == 0 {
368 let x_len = x_branch.dim(D::Minus1)?;
369 let pad = self.conv1d_width as i64 - x_len as i64 - 1;
370 let padded = match pad.cmp(&0) {
371 std::cmp::Ordering::Equal => x_branch.clone(),
372 std::cmp::Ordering::Less => {
373 let rev_pad = (-pad) as usize;
374 x_branch.narrow(D::Minus1, rev_pad, x_len - rev_pad)?
375 }
376 std::cmp::Ordering::Greater => {
377 x_branch.pad_with_zeros(D::Minus1, pad as usize, 0)?
378 }
379 };
380 self.conv1d_state = Some(padded);
381 x_branch
382 .apply(&self.conv_1d)?
383 .narrow(D::Minus1, 0, seq_len)?
384 } else {
385 let conv_state = match self.conv1d_state.as_ref() {
386 None => candle::bail!("empty cache despite pos > 0"),
387 Some(s) => Tensor::cat(&[s, &x_branch], D::Minus1)?,
388 };
389 let w = self.conv_1d.weight().i((.., 0, ..))?;
390 let x_branch = conv_state.broadcast_mul(&w)?.sum(D::Minus1)?;
391 let x_branch = match self.conv_1d.bias() {
392 None => x_branch,
393 Some(b) => x_branch.broadcast_add(b)?,
394 };
395 let x_branch = x_branch.unsqueeze(D::Minus1)?;
396 self.conv1d_state = Some(conv_state.i((.., .., 1..))?);
397 x_branch
398 };
399 let x_branch = x_branch.transpose(1, 2)?;
400 let x_branch = self.rg_lru.forward(&x_branch, pos)?;
401 (x_branch * y_branch)?.apply(&self.linear_out)
402 }
403}
404
405#[derive(Debug, Clone)]
406struct SdpaAttention {
407 q_proj: Linear,
408 k_proj: Linear,
409 v_proj: Linear,
410 o_proj: Linear,
411 n_heads: usize,
412 n_kv_heads: usize,
413 head_dim: usize,
414 hidden_size: usize,
415 kv_cache: Option<(Tensor, Tensor)>,
416 rotary_emb: Arc<RotaryEmbedding>,
417}
418
419impl SdpaAttention {
420 fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
421 let h = cfg.hidden_size;
422 let n_heads = cfg.num_attention_heads;
423 let n_kv_heads = cfg.num_key_value_heads;
424 let hd = cfg.head_dim;
425 let q_proj = linear(h, n_heads * hd, cfg.attention_bias, vb.pp("q_proj"))?;
426 let k_proj = linear(h, n_kv_heads * hd, cfg.attention_bias, vb.pp("k_proj"))?;
427 let v_proj = linear(h, n_kv_heads * hd, cfg.attention_bias, vb.pp("v_proj"))?;
428 let o_proj = linear(n_heads * hd, h, true, vb.pp("o_proj"))?;
429 Ok(Self {
430 q_proj,
431 k_proj,
432 v_proj,
433 o_proj,
434 n_heads,
435 n_kv_heads,
436 head_dim: hd,
437 hidden_size: h,
438 kv_cache: None,
439 rotary_emb,
440 })
441 }
442
443 fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
444 let n_rep = self.n_heads / self.n_kv_heads;
445 crate::utils::repeat_kv(x, n_rep)
446 }
447
448 fn forward(
449 &mut self,
450 xs: &Tensor,
451 attention_mask: Option<&Tensor>,
452 pos: usize,
453 ) -> Result<Tensor> {
454 let (bsz, q_len, _) = xs.dims3()?;
455
456 let query_states = xs.apply(&self.q_proj)?;
457 let key_states = xs.apply(&self.k_proj)?;
458 let value_states = xs.apply(&self.v_proj)?;
459
460 let query_states = query_states
461 .reshape((bsz, q_len, self.n_heads, self.head_dim))?
462 .transpose(1, 2)?;
463 let key_states = key_states
464 .reshape((bsz, q_len, self.n_kv_heads, self.head_dim))?
465 .transpose(1, 2)?;
466 let value_states = value_states
467 .reshape((bsz, q_len, self.n_kv_heads, self.head_dim))?
468 .transpose(1, 2)?;
469 let query_states = query_states.chunk(2, D::Minus1)?;
470 let key_states = key_states.chunk(2, D::Minus1)?;
471 let (query_rot, key_rot) =
472 self.rotary_emb
473 .apply_rotary_emb_qkv(&query_states[0], &key_states[0], pos)?;
474 let query_states = Tensor::cat(&[&query_rot, &query_states[1]], D::Minus1)?.contiguous()?;
475 let key_states = Tensor::cat(&[&key_rot, &key_states[1]], D::Minus1)?.contiguous()?;
476
477 let (key_states, value_states) = match &self.kv_cache {
478 None => (key_states, value_states),
479 Some((prev_k, prev_v)) => {
480 let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
481 let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
482 (key_states, value_states)
483 }
484 };
485 self.kv_cache = Some((key_states.clone(), value_states.clone()));
486
487 let key_states = self.repeat_kv(key_states)?;
488 let value_states = self.repeat_kv(value_states)?;
489 let xs = {
490 let att = (query_states.matmul(&key_states.t()?)? / (self.head_dim as f64).sqrt())?;
491 let att = if q_len == 1 {
492 att
493 } else {
494 match attention_mask {
495 None => att,
496 Some(mask) => att.broadcast_add(mask)?,
497 }
498 };
499 let att = candle_nn::ops::softmax_last_dim(&att)?;
500 att.matmul(&value_states.contiguous()?)?
501 };
502
503 let xs = xs
504 .transpose(1, 2)?
505 .reshape((bsz, q_len, self.hidden_size))?;
506 self.o_proj.forward(&xs)
507 }
508}
509
510#[derive(Debug, Clone)]
511enum TemporalBlock {
512 Recurrent(RecurrentBlock),
513 Attention(SdpaAttention),
514}
515
516impl TemporalBlock {
517 fn forward(
518 &mut self,
519 xs: &Tensor,
520 attention_mask: Option<&Tensor>,
521 pos: usize,
522 ) -> Result<Tensor> {
523 match self {
524 Self::Recurrent(b) => b.forward(xs, pos),
525 Self::Attention(b) => b.forward(xs, attention_mask, pos),
526 }
527 }
528}
529
530#[derive(Debug, Clone)]
531struct DecoderLayer {
532 temporal_pre_norm: RmsNorm,
533 channel_pre_norm: RmsNorm,
534 temporal_block: TemporalBlock,
535 mlp_block: Mlp,
536}
537
538impl DecoderLayer {
539 fn new(
540 block_idx: usize,
541 rotary_emb: Arc<RotaryEmbedding>,
542 cfg: &Config,
543 vb: VarBuilder,
544 ) -> Result<Self> {
545 let h = cfg.hidden_size;
546 let temporal_pre_norm = RmsNorm::new(h, cfg.rms_norm_eps, vb.pp("temporal_pre_norm"))?;
547 let channel_pre_norm = RmsNorm::new(h, cfg.rms_norm_eps, vb.pp("channel_pre_norm"))?;
548 let temporal_block = match cfg.block_types[block_idx % cfg.block_types.len()] {
549 TemporalBlockType::Recurrent => {
550 let block = RecurrentBlock::new(cfg, vb.pp("temporal_block"))?;
551 TemporalBlock::Recurrent(block)
552 }
553 TemporalBlockType::Attention => {
554 let block = SdpaAttention::new(rotary_emb, cfg, vb.pp("temporal_block"))?;
555 TemporalBlock::Attention(block)
556 }
557 };
558 let mlp_block = Mlp::new(cfg, vb.pp("mlp_block"))?;
559 Ok(Self {
560 temporal_pre_norm,
561 channel_pre_norm,
562 temporal_block,
563 mlp_block,
564 })
565 }
566
567 fn forward(
568 &mut self,
569 xs: &Tensor,
570 attention_mask: Option<&Tensor>,
571 pos: usize,
572 ) -> Result<Tensor> {
573 let residual = xs;
574 let xs = xs.apply(&self.temporal_pre_norm)?;
575 let xs = self.temporal_block.forward(&xs, attention_mask, pos)?;
576 let xs = (xs + residual)?;
577 let residual = &xs;
578 let xs = xs.apply(&self.channel_pre_norm)?.apply(&self.mlp_block)?;
579 xs + residual
580 }
581}
582
583#[derive(Debug, Clone)]
584pub struct Model {
585 embed_tokens: candle_nn::Embedding,
586 layers: Vec<DecoderLayer>,
587 final_norm: RmsNorm,
588 lm_head: Linear,
589 hidden_size: usize,
590 logits_soft_cap: f64,
591 dtype: DType,
592 device: Device,
593}
594
595impl Model {
596 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
597 let embed_tokens =
598 candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?;
599 let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?);
600 let vb_b = vb.pp("layers");
601 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
602 for idx in 0..cfg.num_hidden_layers {
603 let layer = DecoderLayer::new(idx, rotary_emb.clone(), cfg, vb_b.pp(idx))?;
604 layers.push(layer)
605 }
606 let final_norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("final_norm"))?;
607 let lm_head = Linear::new(embed_tokens.embeddings().clone(), None);
608 Ok(Self {
609 embed_tokens,
610 layers,
611 final_norm,
612 lm_head,
613 hidden_size: cfg.hidden_size,
614 logits_soft_cap: cfg.logits_soft_cap,
615 dtype: vb.dtype(),
616 device: vb.device().clone(),
617 })
618 }
619
620 fn prepare_decoder_attention_mask(
621 &self,
622 b_size: usize,
623 tgt_len: usize,
624 seqlen_offset: usize,
625 ) -> Result<Tensor> {
626 let mask: Vec<_> = (0..tgt_len)
627 .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
628 .collect();
629 let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
630 let mask = if seqlen_offset > 0 {
631 let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
632 Tensor::cat(&[&mask0, &mask], D::Minus1)?
633 } else {
634 mask
635 };
636 mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
637 .to_dtype(self.dtype)
638 }
639
640 pub fn forward(&mut self, xs: &Tensor, pos: usize) -> Result<Tensor> {
641 let (b_size, seq_len) = xs.dims2()?;
642 let attention_mask = if seq_len <= 1 {
643 None
644 } else {
645 let mask = self.prepare_decoder_attention_mask(b_size, seq_len, pos)?;
646 Some(mask)
647 };
648 let xs = xs.apply(&self.embed_tokens)?;
649 let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
650 for layer in self.layers.iter_mut() {
651 xs = layer.forward(&xs, attention_mask.as_ref(), pos)?;
652 }
653 let logits = xs
654 .narrow(1, seq_len - 1, 1)?
655 .apply(&self.final_norm)?
656 .apply(&self.lm_head)?;
657 let logits = ((logits / self.logits_soft_cap)?.tanh()? * self.logits_soft_cap)?;
658 Ok(logits)
659 }
660}