candle_transformers/models/
recurrent_gemma.rs

1//! Recurrent Gemma model implementation
2//!
3//! Recurrent Gemma is a version of the Gemma language model that incorporates recurrent memory.
4//! This allows the model to maintain state between predictions and have longer-range memory.
5//!
6//! Key characteristics:
7//! - Real-gated linear recurrent units (RGLRU)
8//! - 1D convolution for local context
9//! - RMSNorm for layer normalization
10//! - Rotary positional embeddings (RoPE)
11//! - Grouped query attention
12//!
13//! References:
14//! - [Gemma: Open Models Based on Gemini Technology](https://blog.google/technology/developers/gemma-open-models/)
15//! - [Recurrent Memory model architecture](https://arxiv.org/abs/2402.00441)
16//!
17//! This implementation is based on the python version from huggingface/transformers.
18//! https://github.com/huggingface/transformers/blob/b109257f4fb8b1166e7c53cc5418632014ed53a5/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py#L2
19//!
20use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
21use candle_nn::{linear_b as linear, Linear, VarBuilder};
22use std::sync::Arc;
23
24#[derive(serde::Deserialize, Debug, Clone, Copy)]
25#[serde(rename_all = "snake_case")]
26pub enum TemporalBlockType {
27    Attention,
28    Recurrent,
29}
30
31#[derive(serde::Deserialize, Debug, Clone)]
32pub struct Config {
33    pub num_hidden_layers: usize,
34    pub vocab_size: usize,
35    pub hidden_size: usize,
36    pub intermediate_size: usize,
37    pub num_attention_heads: usize,
38    pub num_key_value_heads: usize,
39    pub head_dim: usize,
40    pub lru_width: Option<usize>,
41    pub attention_window_size: usize,
42    pub conv1d_width: usize,
43    pub logits_soft_cap: f64,
44    pub hidden_activation: candle_nn::Activation,
45    pub partial_rotary_factor: f64,
46    pub rms_norm_eps: f64,
47    pub rope_theta: f64,
48    #[serde(alias = "_block_types")]
49    pub block_types: Vec<TemporalBlockType>,
50    pub attention_bias: bool,
51    #[serde(default = "default_max_seq_len")]
52    pub max_seq_len: usize,
53}
54
55fn default_max_seq_len() -> usize {
56    8192
57}
58
59#[derive(Debug, Clone)]
60pub(crate) struct RmsNorm {
61    weight: Tensor,
62    eps: f64,
63}
64
65impl RmsNorm {
66    pub(crate) fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
67        let weight = vb.get(dim, "weight")?;
68        Ok(Self { weight, eps })
69    }
70
71    pub(crate) fn from_weight(weight: Tensor, eps: f64) -> Self {
72        Self { weight, eps }
73    }
74}
75
76impl Module for RmsNorm {
77    fn forward(&self, x: &Tensor) -> Result<Tensor> {
78        let x_dtype = x.dtype();
79        let internal_dtype = match x_dtype {
80            DType::F16 | DType::BF16 => DType::F32,
81            d => d,
82        };
83        let hidden_size = x.dim(D::Minus1)?;
84        let x = x.to_dtype(internal_dtype)?;
85        let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
86        let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
87        x_normed
88            .to_dtype(x_dtype)?
89            .broadcast_mul(&(&self.weight + 1.0)?)
90    }
91}
92
93#[derive(Debug, Clone)]
94pub(crate) struct RotaryEmbedding {
95    sin: Tensor,
96    cos: Tensor,
97}
98
99fn rotate_half(xs: &Tensor) -> Result<Tensor> {
100    let last_dim = xs.dim(D::Minus1)?;
101    let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
102    let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
103    Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
104}
105
106impl RotaryEmbedding {
107    pub(crate) fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
108        if cfg.partial_rotary_factor != 0.5 {
109            candle::bail!("partial-rotary-factor {} <> 0.5", cfg.partial_rotary_factor)
110        }
111        let dim = cfg.head_dim / 2;
112        let max_seq_len = cfg.max_seq_len;
113        let inv_freq: Vec<_> = (0..dim)
114            .step_by(2)
115            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
116            .collect();
117        let inv_freq_len = inv_freq.len();
118        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
119        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
120            .to_dtype(dtype)?
121            .reshape((max_seq_len, 1))?;
122        let freqs = t.matmul(&inv_freq)?;
123        let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
124        Ok(Self {
125            sin: freqs.sin()?,
126            cos: freqs.cos()?,
127        })
128    }
129
130    pub(crate) fn apply_rotary_emb_qkv(
131        &self,
132        q: &Tensor,
133        k: &Tensor,
134        seqlen_offset: usize,
135    ) -> Result<(Tensor, Tensor)> {
136        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
137        let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
138        let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
139        let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
140        let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
141        let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
142        let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
143        Ok((q_embed, k_embed))
144    }
145}
146
147#[derive(Debug, Clone)]
148struct Mlp {
149    gate_proj: Linear,
150    up_proj: Linear,
151    down_proj: Linear,
152    act_fn: candle_nn::Activation,
153}
154
155impl Mlp {
156    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
157        let h = cfg.hidden_size;
158        let intermediate_size = cfg.intermediate_size / 2;
159        let gate_proj = linear(h, intermediate_size, true, vb.pp("gate_proj"))?;
160        let up_proj = linear(h, intermediate_size, true, vb.pp("up_proj"))?;
161        let down_proj = linear(intermediate_size, h, true, vb.pp("down_proj"))?;
162        Ok(Self {
163            gate_proj,
164            up_proj,
165            down_proj,
166            act_fn: cfg.hidden_activation,
167        })
168    }
169}
170
171impl Module for Mlp {
172    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
173        let gate = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
174        (gate * xs.apply(&self.up_proj))?.apply(&self.down_proj)
175    }
176}
177
178// Real-Gated Linear Recurrent Unit
179#[derive(Debug, Clone)]
180pub(crate) struct Rglru {
181    pub(crate) recurrent_param: Tensor,
182    pub(crate) input_gate_weight: Tensor,
183    pub(crate) input_gate_bias: Tensor,
184    pub(crate) recurrent_gate_weight: Tensor,
185    pub(crate) recurrent_gate_bias: Tensor,
186    pub(crate) block_width: usize,
187    pub(crate) n_heads: usize,
188    pub(crate) recurrent_states: Option<Tensor>,
189}
190
191fn baddbmm(a: &Tensor, b: &Tensor, c: &Tensor) -> Result<Tensor> {
192    a.broadcast_add(&b.matmul(c)?)
193}
194
195fn softplus(xs: &Tensor) -> Result<Tensor> {
196    (xs.exp()? + 1.0)?.log()
197}
198
199impl Rglru {
200    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
201        let h = cfg.hidden_size;
202        let lru_width = cfg.lru_width.unwrap_or(h);
203        let n_heads = cfg.num_attention_heads;
204        let block_width = lru_width / n_heads;
205        let recurrent_param = vb.get((lru_width,), "recurrent_param")?;
206        let input_gate_weight = vb.get((n_heads, block_width, block_width), "input_gate_weight")?;
207        let input_gate_bias = vb.get((n_heads, block_width), "input_gate_bias")?;
208        let recurrent_gate_weight =
209            vb.get((n_heads, block_width, block_width), "recurrent_gate_weight")?;
210        let recurrent_gate_bias = vb.get((n_heads, block_width), "recurrent_gate_bias")?;
211        Ok(Self {
212            recurrent_param,
213            input_gate_bias,
214            input_gate_weight,
215            recurrent_gate_bias,
216            recurrent_gate_weight,
217            block_width,
218            n_heads,
219            recurrent_states: None,
220        })
221    }
222
223    // https://github.com/huggingface/transformers/blob/0bd58f1ce0573c0e3269de4215a17d318add49b9/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py#L303
224    pub(crate) fn forward(&mut self, xs: &Tensor, pos: usize) -> Result<Tensor> {
225        let (b_sz, seq_len, lru_width) = xs.dims3()?;
226        let pos = Tensor::arange(pos as u32, (pos + seq_len) as u32, xs.device())?;
227        let reset = pos.eq(0u32)?.unsqueeze(1)?.unsqueeze(0)?;
228        let reshape_act = xs
229            .reshape((b_sz * seq_len, self.n_heads, self.block_width))?
230            .permute((1, 0, 2))?
231            .contiguous()?;
232
233        let res = baddbmm(
234            &self.input_gate_bias.unsqueeze(1)?,
235            &reshape_act,
236            &self.input_gate_weight,
237        )?;
238        let input_gate = res.transpose(0, 1)?.reshape((b_sz, seq_len, lru_width))?;
239        let input_gate = candle_nn::ops::sigmoid(&input_gate)?;
240        let res = baddbmm(
241            &self.recurrent_gate_bias.unsqueeze(1)?,
242            &reshape_act,
243            &self.recurrent_gate_weight,
244        )?;
245        let recurrent_gate = res.transpose(0, 1)?.reshape((b_sz, seq_len, lru_width))?;
246        let recurrent_gate = candle_nn::ops::sigmoid(&recurrent_gate)?;
247
248        let log_recurrent_gate =
249            (recurrent_gate * (-8.0))?.broadcast_mul(&softplus(&self.recurrent_param)?)?;
250        let recurrent_gate = log_recurrent_gate.exp()?;
251        let a_square = (log_recurrent_gate * 2.)?.exp()?;
252
253        // Gate the input.
254        let gated_inputs = (xs * input_gate)?;
255
256        let reset = reset.to_dtype(a_square.dtype())?;
257        let multiplier =
258            reset.broadcast_add(&((1.0 - &reset)?.broadcast_mul(&(1.0 - a_square)?.sqrt()?))?)?;
259        let normalized_x = (gated_inputs * multiplier.to_dtype(xs.dtype()))?;
260
261        let (hidden_states, recurrent_states) = rnn_scan(
262            &normalized_x,
263            &recurrent_gate,
264            &reset,
265            self.recurrent_states.as_ref(),
266        )?;
267        self.recurrent_states = Some(recurrent_states);
268        Ok(hidden_states)
269    }
270}
271
272fn rnn_scan(
273    hidden_states: &Tensor,
274    recurrent_gate: &Tensor,
275    reset: &Tensor,
276    recurrent_states: Option<&Tensor>,
277) -> Result<(Tensor, Tensor)> {
278    let acc_dtype = DType::F32;
279    let dev = hidden_states.device();
280    let in_dtype = hidden_states.dtype();
281    let inv_reset = (1.0 - reset)?.to_dtype(recurrent_gate.dtype())?;
282    let recurrent_gate = recurrent_gate.broadcast_mul(&inv_reset)?;
283    let (c, r) = if hidden_states.dim(1)? == 1 {
284        match recurrent_states {
285            None => {
286                let next_state = hidden_states.i((.., 0))?.to_dtype(acc_dtype)?;
287                (hidden_states.clone(), next_state)
288            }
289            Some(recurrent_states) => {
290                let contextualized_states =
291                    recurrent_gate.to_dtype(acc_dtype)? * recurrent_states.unsqueeze(1)?;
292                let contextualized_states =
293                    (contextualized_states + hidden_states.to_dtype(acc_dtype)?)?;
294                let c = contextualized_states.to_dtype(in_dtype)?;
295                let l = contextualized_states.dim(1)?;
296                let r = contextualized_states.i((.., l - 1))?;
297                (c, r)
298            }
299        }
300    } else {
301        let mut recurrent_states = match recurrent_states {
302            None => Tensor::zeros(hidden_states.i((.., 0))?.shape(), acc_dtype, dev)?,
303            Some(r) => r.clone(),
304        };
305        let mut contextualized_states = vec![];
306        for t in 0..hidden_states.dim(1)? {
307            recurrent_states =
308                (recurrent_gate.i((.., t))?.to_dtype(acc_dtype)? * recurrent_states)?;
309            recurrent_states =
310                (recurrent_states + hidden_states.i((.., t))?.to_dtype(acc_dtype)?)?;
311            contextualized_states.push(recurrent_states.to_dtype(in_dtype)?)
312        }
313        let contextualized_states = Tensor::stack(&contextualized_states, 1)?;
314        (contextualized_states, recurrent_states)
315    };
316    Ok((c, r))
317}
318
319#[derive(Debug, Clone)]
320struct RecurrentBlock {
321    linear_y: Linear,
322    linear_x: Linear,
323    linear_out: Linear,
324    conv_1d: candle_nn::Conv1d,
325    conv1d_state: Option<Tensor>,
326    conv1d_width: usize,
327    rg_lru: Rglru,
328    act_fn: candle_nn::Activation,
329}
330
331impl RecurrentBlock {
332    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
333        let h = cfg.hidden_size;
334        let lru_width = cfg.lru_width.unwrap_or(h);
335        let linear_y = linear(h, lru_width, true, vb.pp("linear_y"))?;
336        let linear_x = linear(h, lru_width, true, vb.pp("linear_x"))?;
337        let linear_out = linear(lru_width, h, true, vb.pp("linear_out"))?;
338        let conv_1d = candle_nn::conv1d(
339            lru_width,
340            lru_width,
341            cfg.conv1d_width,
342            candle_nn::Conv1dConfig {
343                groups: lru_width,
344                padding: cfg.conv1d_width - 1,
345                ..Default::default()
346            },
347            vb.pp("conv_1d"),
348        )?;
349        let rg_lru = Rglru::new(cfg, vb.pp("rg_lru"))?;
350        Ok(Self {
351            linear_y,
352            linear_x,
353            linear_out,
354            conv_1d,
355            conv1d_state: None,
356            conv1d_width: cfg.conv1d_width,
357            rg_lru,
358            act_fn: cfg.hidden_activation,
359        })
360    }
361
362    pub fn forward(&mut self, xs: &Tensor, pos: usize) -> Result<Tensor> {
363        let (_b_sz, seq_len, _) = xs.dims3()?;
364
365        let y_branch = xs.apply(&self.linear_y)?.apply(&self.act_fn)?;
366        let x_branch = xs.apply(&self.linear_x)?.transpose(1, 2)?;
367        let x_branch = if pos == 0 {
368            let x_len = x_branch.dim(D::Minus1)?;
369            let pad = self.conv1d_width as i64 - x_len as i64 - 1;
370            let padded = match pad.cmp(&0) {
371                std::cmp::Ordering::Equal => x_branch.clone(),
372                std::cmp::Ordering::Less => {
373                    let rev_pad = (-pad) as usize;
374                    x_branch.narrow(D::Minus1, rev_pad, x_len - rev_pad)?
375                }
376                std::cmp::Ordering::Greater => {
377                    x_branch.pad_with_zeros(D::Minus1, pad as usize, 0)?
378                }
379            };
380            self.conv1d_state = Some(padded);
381            x_branch
382                .apply(&self.conv_1d)?
383                .narrow(D::Minus1, 0, seq_len)?
384        } else {
385            let conv_state = match self.conv1d_state.as_ref() {
386                None => candle::bail!("empty cache despite pos > 0"),
387                Some(s) => Tensor::cat(&[s, &x_branch], D::Minus1)?,
388            };
389            let w = self.conv_1d.weight().i((.., 0, ..))?;
390            let x_branch = conv_state.broadcast_mul(&w)?.sum(D::Minus1)?;
391            let x_branch = match self.conv_1d.bias() {
392                None => x_branch,
393                Some(b) => x_branch.broadcast_add(b)?,
394            };
395            let x_branch = x_branch.unsqueeze(D::Minus1)?;
396            self.conv1d_state = Some(conv_state.i((.., .., 1..))?);
397            x_branch
398        };
399        let x_branch = x_branch.transpose(1, 2)?;
400        let x_branch = self.rg_lru.forward(&x_branch, pos)?;
401        (x_branch * y_branch)?.apply(&self.linear_out)
402    }
403}
404
405#[derive(Debug, Clone)]
406struct SdpaAttention {
407    q_proj: Linear,
408    k_proj: Linear,
409    v_proj: Linear,
410    o_proj: Linear,
411    n_heads: usize,
412    n_kv_heads: usize,
413    head_dim: usize,
414    hidden_size: usize,
415    kv_cache: Option<(Tensor, Tensor)>,
416    rotary_emb: Arc<RotaryEmbedding>,
417}
418
419impl SdpaAttention {
420    fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
421        let h = cfg.hidden_size;
422        let n_heads = cfg.num_attention_heads;
423        let n_kv_heads = cfg.num_key_value_heads;
424        let hd = cfg.head_dim;
425        let q_proj = linear(h, n_heads * hd, cfg.attention_bias, vb.pp("q_proj"))?;
426        let k_proj = linear(h, n_kv_heads * hd, cfg.attention_bias, vb.pp("k_proj"))?;
427        let v_proj = linear(h, n_kv_heads * hd, cfg.attention_bias, vb.pp("v_proj"))?;
428        let o_proj = linear(n_heads * hd, h, true, vb.pp("o_proj"))?;
429        Ok(Self {
430            q_proj,
431            k_proj,
432            v_proj,
433            o_proj,
434            n_heads,
435            n_kv_heads,
436            head_dim: hd,
437            hidden_size: h,
438            kv_cache: None,
439            rotary_emb,
440        })
441    }
442
443    fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
444        let n_rep = self.n_heads / self.n_kv_heads;
445        crate::utils::repeat_kv(x, n_rep)
446    }
447
448    fn forward(
449        &mut self,
450        xs: &Tensor,
451        attention_mask: Option<&Tensor>,
452        pos: usize,
453    ) -> Result<Tensor> {
454        let (bsz, q_len, _) = xs.dims3()?;
455
456        let query_states = xs.apply(&self.q_proj)?;
457        let key_states = xs.apply(&self.k_proj)?;
458        let value_states = xs.apply(&self.v_proj)?;
459
460        let query_states = query_states
461            .reshape((bsz, q_len, self.n_heads, self.head_dim))?
462            .transpose(1, 2)?;
463        let key_states = key_states
464            .reshape((bsz, q_len, self.n_kv_heads, self.head_dim))?
465            .transpose(1, 2)?;
466        let value_states = value_states
467            .reshape((bsz, q_len, self.n_kv_heads, self.head_dim))?
468            .transpose(1, 2)?;
469        let query_states = query_states.chunk(2, D::Minus1)?;
470        let key_states = key_states.chunk(2, D::Minus1)?;
471        let (query_rot, key_rot) =
472            self.rotary_emb
473                .apply_rotary_emb_qkv(&query_states[0], &key_states[0], pos)?;
474        let query_states = Tensor::cat(&[&query_rot, &query_states[1]], D::Minus1)?.contiguous()?;
475        let key_states = Tensor::cat(&[&key_rot, &key_states[1]], D::Minus1)?.contiguous()?;
476
477        let (key_states, value_states) = match &self.kv_cache {
478            None => (key_states, value_states),
479            Some((prev_k, prev_v)) => {
480                let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
481                let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
482                (key_states, value_states)
483            }
484        };
485        self.kv_cache = Some((key_states.clone(), value_states.clone()));
486
487        let key_states = self.repeat_kv(key_states)?;
488        let value_states = self.repeat_kv(value_states)?;
489        let xs = {
490            let att = (query_states.matmul(&key_states.t()?)? / (self.head_dim as f64).sqrt())?;
491            let att = if q_len == 1 {
492                att
493            } else {
494                match attention_mask {
495                    None => att,
496                    Some(mask) => att.broadcast_add(mask)?,
497                }
498            };
499            let att = candle_nn::ops::softmax_last_dim(&att)?;
500            att.matmul(&value_states.contiguous()?)?
501        };
502
503        let xs = xs
504            .transpose(1, 2)?
505            .reshape((bsz, q_len, self.hidden_size))?;
506        self.o_proj.forward(&xs)
507    }
508}
509
510#[derive(Debug, Clone)]
511enum TemporalBlock {
512    Recurrent(RecurrentBlock),
513    Attention(SdpaAttention),
514}
515
516impl TemporalBlock {
517    fn forward(
518        &mut self,
519        xs: &Tensor,
520        attention_mask: Option<&Tensor>,
521        pos: usize,
522    ) -> Result<Tensor> {
523        match self {
524            Self::Recurrent(b) => b.forward(xs, pos),
525            Self::Attention(b) => b.forward(xs, attention_mask, pos),
526        }
527    }
528}
529
530#[derive(Debug, Clone)]
531struct DecoderLayer {
532    temporal_pre_norm: RmsNorm,
533    channel_pre_norm: RmsNorm,
534    temporal_block: TemporalBlock,
535    mlp_block: Mlp,
536}
537
538impl DecoderLayer {
539    fn new(
540        block_idx: usize,
541        rotary_emb: Arc<RotaryEmbedding>,
542        cfg: &Config,
543        vb: VarBuilder,
544    ) -> Result<Self> {
545        let h = cfg.hidden_size;
546        let temporal_pre_norm = RmsNorm::new(h, cfg.rms_norm_eps, vb.pp("temporal_pre_norm"))?;
547        let channel_pre_norm = RmsNorm::new(h, cfg.rms_norm_eps, vb.pp("channel_pre_norm"))?;
548        let temporal_block = match cfg.block_types[block_idx % cfg.block_types.len()] {
549            TemporalBlockType::Recurrent => {
550                let block = RecurrentBlock::new(cfg, vb.pp("temporal_block"))?;
551                TemporalBlock::Recurrent(block)
552            }
553            TemporalBlockType::Attention => {
554                let block = SdpaAttention::new(rotary_emb, cfg, vb.pp("temporal_block"))?;
555                TemporalBlock::Attention(block)
556            }
557        };
558        let mlp_block = Mlp::new(cfg, vb.pp("mlp_block"))?;
559        Ok(Self {
560            temporal_pre_norm,
561            channel_pre_norm,
562            temporal_block,
563            mlp_block,
564        })
565    }
566
567    fn forward(
568        &mut self,
569        xs: &Tensor,
570        attention_mask: Option<&Tensor>,
571        pos: usize,
572    ) -> Result<Tensor> {
573        let residual = xs;
574        let xs = xs.apply(&self.temporal_pre_norm)?;
575        let xs = self.temporal_block.forward(&xs, attention_mask, pos)?;
576        let xs = (xs + residual)?;
577        let residual = &xs;
578        let xs = xs.apply(&self.channel_pre_norm)?.apply(&self.mlp_block)?;
579        xs + residual
580    }
581}
582
583#[derive(Debug, Clone)]
584pub struct Model {
585    embed_tokens: candle_nn::Embedding,
586    layers: Vec<DecoderLayer>,
587    final_norm: RmsNorm,
588    lm_head: Linear,
589    hidden_size: usize,
590    logits_soft_cap: f64,
591    dtype: DType,
592    device: Device,
593}
594
595impl Model {
596    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
597        let embed_tokens =
598            candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?;
599        let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?);
600        let vb_b = vb.pp("layers");
601        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
602        for idx in 0..cfg.num_hidden_layers {
603            let layer = DecoderLayer::new(idx, rotary_emb.clone(), cfg, vb_b.pp(idx))?;
604            layers.push(layer)
605        }
606        let final_norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("final_norm"))?;
607        let lm_head = Linear::new(embed_tokens.embeddings().clone(), None);
608        Ok(Self {
609            embed_tokens,
610            layers,
611            final_norm,
612            lm_head,
613            hidden_size: cfg.hidden_size,
614            logits_soft_cap: cfg.logits_soft_cap,
615            dtype: vb.dtype(),
616            device: vb.device().clone(),
617        })
618    }
619
620    fn prepare_decoder_attention_mask(
621        &self,
622        b_size: usize,
623        tgt_len: usize,
624        seqlen_offset: usize,
625    ) -> Result<Tensor> {
626        let mask: Vec<_> = (0..tgt_len)
627            .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
628            .collect();
629        let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
630        let mask = if seqlen_offset > 0 {
631            let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
632            Tensor::cat(&[&mask0, &mask], D::Minus1)?
633        } else {
634            mask
635        };
636        mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
637            .to_dtype(self.dtype)
638    }
639
640    pub fn forward(&mut self, xs: &Tensor, pos: usize) -> Result<Tensor> {
641        let (b_size, seq_len) = xs.dims2()?;
642        let attention_mask = if seq_len <= 1 {
643            None
644        } else {
645            let mask = self.prepare_decoder_attention_mask(b_size, seq_len, pos)?;
646            Some(mask)
647        };
648        let xs = xs.apply(&self.embed_tokens)?;
649        let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
650        for layer in self.layers.iter_mut() {
651            xs = layer.forward(&xs, attention_mask.as_ref(), pos)?;
652        }
653        let logits = xs
654            .narrow(1, seq_len - 1, 1)?
655            .apply(&self.final_norm)?
656            .apply(&self.lm_head)?;
657        let logits = ((logits / self.logits_soft_cap)?.tanh()? * self.logits_soft_cap)?;
658        Ok(logits)
659    }
660}