candle_transformers/models/
quantized_recurrent_gemma.rs

1//! Recurrent Gemma model implementation with quantization support.
2//!
3//! Gemma is a large language model optimized for efficiency.
4//! This implementation provides quantization for reduced memory and compute.
5//!
6//! Key characteristics:
7//! - Recurrent blocks with gated recurrent units
8//! - Convolution and attention blocks
9//! - RMSNorm for layer normalization
10//! - Rotary positional embeddings (RoPE)
11//! - Support for 8-bit quantization
12//!
13//! References:
14//! - [Gemma Paper](https://arxiv.org/abs/2401.06751)
15//! - [Model Card](https://ai.google.dev/gemma)
16//!
17
18use crate::quantized_nn::{linear_b as linear, Embedding, Linear};
19pub use crate::quantized_var_builder::VarBuilder;
20use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
21use std::sync::Arc;
22
23use crate::models::recurrent_gemma::{Config, Rglru, RmsNorm, RotaryEmbedding, TemporalBlockType};
24
25fn rms_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<RmsNorm> {
26    let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
27    Ok(RmsNorm::from_weight(weight, eps))
28}
29
30#[derive(Debug, Clone)]
31struct Mlp {
32    gate_proj: Linear,
33    up_proj: Linear,
34    down_proj: Linear,
35    act_fn: candle_nn::Activation,
36}
37
38impl Mlp {
39    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
40        let h = cfg.hidden_size;
41        let intermediate_size = cfg.intermediate_size / 2;
42        let gate_proj = linear(h, intermediate_size, true, vb.pp("gate_proj"))?;
43        let up_proj = linear(h, intermediate_size, true, vb.pp("up_proj"))?;
44        let down_proj = linear(intermediate_size, h, true, vb.pp("down_proj"))?;
45        Ok(Self {
46            gate_proj,
47            up_proj,
48            down_proj,
49            act_fn: cfg.hidden_activation,
50        })
51    }
52}
53
54impl Module for Mlp {
55    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
56        let gate = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
57        (gate * xs.apply(&self.up_proj))?.apply(&self.down_proj)
58    }
59}
60
61fn rglru(cfg: &Config, vb: VarBuilder) -> Result<Rglru> {
62    let h = cfg.hidden_size;
63    let lru_width = cfg.lru_width.unwrap_or(h);
64    let n_heads = cfg.num_attention_heads;
65    let block_width = lru_width / n_heads;
66    let recurrent_param = vb.get((lru_width,), "recurrent_param")?;
67    let input_gate_weight = vb.get((n_heads, block_width, block_width), "input_gate_weight")?;
68    let input_gate_bias = vb.get((n_heads, block_width), "input_gate_bias")?;
69    let recurrent_gate_weight =
70        vb.get((n_heads, block_width, block_width), "recurrent_gate_weight")?;
71    let recurrent_gate_bias = vb.get((n_heads, block_width), "recurrent_gate_bias")?;
72    Ok(Rglru {
73        recurrent_param: recurrent_param.dequantize(vb.device())?,
74        input_gate_bias: input_gate_bias.dequantize(vb.device())?,
75        input_gate_weight: input_gate_weight.dequantize(vb.device())?,
76        recurrent_gate_bias: recurrent_gate_bias.dequantize(vb.device())?,
77        recurrent_gate_weight: recurrent_gate_weight.dequantize(vb.device())?,
78        block_width,
79        n_heads,
80        recurrent_states: None,
81    })
82}
83
84#[derive(Debug, Clone)]
85struct RecurrentBlock {
86    linear_y: Linear,
87    linear_x: Linear,
88    linear_out: Linear,
89    conv_1d: candle_nn::Conv1d,
90    conv1d_state: Option<Tensor>,
91    conv1d_width: usize,
92    rg_lru: Rglru,
93    act_fn: candle_nn::Activation,
94}
95
96impl RecurrentBlock {
97    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
98        let h = cfg.hidden_size;
99        let lru_width = cfg.lru_width.unwrap_or(h);
100        let linear_y = linear(h, lru_width, true, vb.pp("linear_y"))?;
101        let linear_x = linear(h, lru_width, true, vb.pp("linear_x"))?;
102        let linear_out = linear(lru_width, h, true, vb.pp("linear_out"))?;
103
104        let conv_1d = {
105            let ws = vb
106                .get((lru_width, 1, cfg.conv1d_width), "conv_1d.weight")?
107                .dequantize(vb.device())?;
108            let bs = vb.get(lru_width, "conv_1d.bias")?.dequantize(vb.device())?;
109            let config = candle_nn::Conv1dConfig {
110                groups: lru_width,
111                padding: cfg.conv1d_width - 1,
112                ..Default::default()
113            };
114            candle_nn::Conv1d::new(ws, Some(bs), config)
115        };
116        let rg_lru = rglru(cfg, vb.pp("rg_lru"))?;
117        Ok(Self {
118            linear_y,
119            linear_x,
120            linear_out,
121            conv_1d,
122            conv1d_state: None,
123            conv1d_width: cfg.conv1d_width,
124            rg_lru,
125            act_fn: cfg.hidden_activation,
126        })
127    }
128
129    pub fn forward(&mut self, xs: &Tensor, pos: usize) -> Result<Tensor> {
130        let (_b_sz, seq_len, _) = xs.dims3()?;
131
132        let y_branch = xs.apply(&self.linear_y)?.apply(&self.act_fn)?;
133        let x_branch = xs.apply(&self.linear_x)?.transpose(1, 2)?;
134        let x_branch = if pos == 0 {
135            let x_len = x_branch.dim(D::Minus1)?;
136            let pad = self.conv1d_width as i64 - x_len as i64 - 1;
137            let padded = match pad.cmp(&0) {
138                std::cmp::Ordering::Equal => x_branch.clone(),
139                std::cmp::Ordering::Less => {
140                    let rev_pad = (-pad) as usize;
141                    x_branch.narrow(D::Minus1, rev_pad, x_len - rev_pad)?
142                }
143                std::cmp::Ordering::Greater => {
144                    x_branch.pad_with_zeros(D::Minus1, pad as usize, 0)?
145                }
146            };
147            self.conv1d_state = Some(padded);
148            x_branch
149                .apply(&self.conv_1d)?
150                .narrow(D::Minus1, 0, seq_len)?
151        } else {
152            let conv_state = match self.conv1d_state.as_ref() {
153                None => candle::bail!("empty cache despite pos > 0"),
154                Some(s) => Tensor::cat(&[s, &x_branch], D::Minus1)?,
155            };
156            let w = self.conv_1d.weight().i((.., 0, ..))?;
157            let x_branch = conv_state.broadcast_mul(&w)?.sum(D::Minus1)?;
158            let x_branch = match self.conv_1d.bias() {
159                None => x_branch,
160                Some(b) => x_branch.broadcast_add(b)?,
161            };
162            let x_branch = x_branch.unsqueeze(D::Minus1)?;
163            self.conv1d_state = Some(conv_state.i((.., .., 1..))?);
164            x_branch
165        };
166        let x_branch = x_branch.transpose(1, 2)?;
167        let x_branch = self.rg_lru.forward(&x_branch, pos)?;
168        (x_branch * y_branch)?.apply(&self.linear_out)
169    }
170}
171
172#[derive(Debug, Clone)]
173struct SdpaAttention {
174    q_proj: Linear,
175    k_proj: Linear,
176    v_proj: Linear,
177    o_proj: Linear,
178    n_heads: usize,
179    n_kv_heads: usize,
180    head_dim: usize,
181    hidden_size: usize,
182    kv_cache: Option<(Tensor, Tensor)>,
183    rotary_emb: Arc<RotaryEmbedding>,
184}
185
186impl SdpaAttention {
187    fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
188        let h = cfg.hidden_size;
189        let n_heads = cfg.num_attention_heads;
190        let n_kv_heads = cfg.num_key_value_heads;
191        let hd = cfg.head_dim;
192        let q_proj = linear(h, n_heads * hd, cfg.attention_bias, vb.pp("q_proj"))?;
193        let k_proj = linear(h, n_kv_heads * hd, cfg.attention_bias, vb.pp("k_proj"))?;
194        let v_proj = linear(h, n_kv_heads * hd, cfg.attention_bias, vb.pp("v_proj"))?;
195        let o_proj = linear(n_heads * hd, h, true, vb.pp("o_proj"))?;
196        Ok(Self {
197            q_proj,
198            k_proj,
199            v_proj,
200            o_proj,
201            n_heads,
202            n_kv_heads,
203            head_dim: hd,
204            hidden_size: h,
205            kv_cache: None,
206            rotary_emb,
207        })
208    }
209
210    fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
211        let n_rep = self.n_heads / self.n_kv_heads;
212        crate::utils::repeat_kv(x, n_rep)
213    }
214
215    fn forward(
216        &mut self,
217        xs: &Tensor,
218        attention_mask: Option<&Tensor>,
219        pos: usize,
220    ) -> Result<Tensor> {
221        let (bsz, q_len, _) = xs.dims3()?;
222
223        let query_states = xs.apply(&self.q_proj)?;
224        let key_states = xs.apply(&self.k_proj)?;
225        let value_states = xs.apply(&self.v_proj)?;
226
227        let query_states = query_states
228            .reshape((bsz, q_len, self.n_heads, self.head_dim))?
229            .transpose(1, 2)?;
230        let key_states = key_states
231            .reshape((bsz, q_len, self.n_kv_heads, self.head_dim))?
232            .transpose(1, 2)?;
233        let value_states = value_states
234            .reshape((bsz, q_len, self.n_kv_heads, self.head_dim))?
235            .transpose(1, 2)?;
236        let query_states = query_states.chunk(2, D::Minus1)?;
237        let key_states = key_states.chunk(2, D::Minus1)?;
238        let (query_rot, key_rot) =
239            self.rotary_emb
240                .apply_rotary_emb_qkv(&query_states[0], &key_states[0], pos)?;
241        let query_states = Tensor::cat(&[&query_rot, &query_states[1]], D::Minus1)?.contiguous()?;
242        let key_states = Tensor::cat(&[&key_rot, &key_states[1]], D::Minus1)?.contiguous()?;
243
244        let (key_states, value_states) = match &self.kv_cache {
245            None => (key_states, value_states),
246            Some((prev_k, prev_v)) => {
247                let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
248                let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
249                (key_states, value_states)
250            }
251        };
252        self.kv_cache = Some((key_states.clone(), value_states.clone()));
253
254        let key_states = self.repeat_kv(key_states)?;
255        let value_states = self.repeat_kv(value_states)?;
256        let xs = {
257            let att = (query_states.matmul(&key_states.t()?)? / (self.head_dim as f64).sqrt())?;
258            let att = if q_len == 1 {
259                att
260            } else {
261                match attention_mask {
262                    None => att,
263                    Some(mask) => att.broadcast_add(mask)?,
264                }
265            };
266            let att = candle_nn::ops::softmax_last_dim(&att)?;
267            att.matmul(&value_states.contiguous()?)?
268        };
269
270        let xs = xs
271            .transpose(1, 2)?
272            .reshape((bsz, q_len, self.hidden_size))?;
273        self.o_proj.forward(&xs)
274    }
275}
276
277#[derive(Debug, Clone)]
278enum TemporalBlock {
279    Recurrent(RecurrentBlock),
280    Attention(SdpaAttention),
281}
282
283impl TemporalBlock {
284    fn forward(
285        &mut self,
286        xs: &Tensor,
287        attention_mask: Option<&Tensor>,
288        pos: usize,
289    ) -> Result<Tensor> {
290        match self {
291            Self::Recurrent(b) => b.forward(xs, pos),
292            Self::Attention(b) => b.forward(xs, attention_mask, pos),
293        }
294    }
295}
296
297#[derive(Debug, Clone)]
298struct DecoderLayer {
299    temporal_pre_norm: RmsNorm,
300    channel_pre_norm: RmsNorm,
301    temporal_block: TemporalBlock,
302    mlp_block: Mlp,
303}
304
305impl DecoderLayer {
306    fn new(
307        block_idx: usize,
308        rotary_emb: Arc<RotaryEmbedding>,
309        cfg: &Config,
310        vb: VarBuilder,
311    ) -> Result<Self> {
312        let h = cfg.hidden_size;
313        let temporal_pre_norm = rms_norm(h, cfg.rms_norm_eps, vb.pp("temporal_pre_norm"))?;
314        let channel_pre_norm = rms_norm(h, cfg.rms_norm_eps, vb.pp("channel_pre_norm"))?;
315        let temporal_block = match cfg.block_types[block_idx % cfg.block_types.len()] {
316            TemporalBlockType::Recurrent => {
317                let block = RecurrentBlock::new(cfg, vb.pp("temporal_block"))?;
318                TemporalBlock::Recurrent(block)
319            }
320            TemporalBlockType::Attention => {
321                let block = SdpaAttention::new(rotary_emb, cfg, vb.pp("temporal_block"))?;
322                TemporalBlock::Attention(block)
323            }
324        };
325        let mlp_block = Mlp::new(cfg, vb.pp("mlp_block"))?;
326        Ok(Self {
327            temporal_pre_norm,
328            channel_pre_norm,
329            temporal_block,
330            mlp_block,
331        })
332    }
333
334    fn forward(
335        &mut self,
336        xs: &Tensor,
337        attention_mask: Option<&Tensor>,
338        pos: usize,
339    ) -> Result<Tensor> {
340        let residual = xs;
341        let xs = xs.apply(&self.temporal_pre_norm)?;
342        let xs = self.temporal_block.forward(&xs, attention_mask, pos)?;
343        let xs = (xs + residual)?;
344        let residual = &xs;
345        let xs = xs.apply(&self.channel_pre_norm)?.apply(&self.mlp_block)?;
346        xs + residual
347    }
348}
349
350#[derive(Debug, Clone)]
351pub struct Model {
352    embed_tokens: Embedding,
353    layers: Vec<DecoderLayer>,
354    final_norm: RmsNorm,
355    lm_head: Linear,
356    hidden_size: usize,
357    logits_soft_cap: f64,
358    device: Device,
359}
360
361impl Model {
362    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
363        let embed_tokens = Embedding::new(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?;
364        let rotary_emb = Arc::new(RotaryEmbedding::new(DType::F32, cfg, vb.device())?);
365        let vb_b = vb.pp("layers");
366        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
367        for idx in 0..cfg.num_hidden_layers {
368            let layer = DecoderLayer::new(idx, rotary_emb.clone(), cfg, vb_b.pp(idx))?;
369            layers.push(layer)
370        }
371        let final_norm = rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("final_norm"))?;
372        let lm_head = linear(
373            cfg.hidden_size,
374            cfg.vocab_size,
375            false,
376            vb.pp("embed_tokens"),
377        )?;
378        Ok(Self {
379            embed_tokens,
380            layers,
381            final_norm,
382            lm_head,
383            hidden_size: cfg.hidden_size,
384            logits_soft_cap: cfg.logits_soft_cap,
385            device: vb.device().clone(),
386        })
387    }
388
389    fn prepare_decoder_attention_mask(
390        &self,
391        b_size: usize,
392        tgt_len: usize,
393        seqlen_offset: usize,
394    ) -> Result<Tensor> {
395        let mask: Vec<_> = (0..tgt_len)
396            .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
397            .collect();
398        let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
399        let mask = if seqlen_offset > 0 {
400            let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
401            Tensor::cat(&[&mask0, &mask], D::Minus1)?
402        } else {
403            mask
404        };
405        mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
406            .to_dtype(DType::F32)
407    }
408
409    pub fn forward(&mut self, xs: &Tensor, pos: usize) -> Result<Tensor> {
410        let (b_size, seq_len) = xs.dims2()?;
411        let attention_mask = if seq_len <= 1 {
412            None
413        } else {
414            let mask = self.prepare_decoder_attention_mask(b_size, seq_len, pos)?;
415            Some(mask)
416        };
417        let xs = xs.apply(&self.embed_tokens)?;
418        let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
419        for layer in self.layers.iter_mut() {
420            xs = layer.forward(&xs, attention_mask.as_ref(), pos)?;
421        }
422        let logits = xs
423            .narrow(1, seq_len - 1, 1)?
424            .apply(&self.final_norm)?
425            .apply(&self.lm_head)?;
426        let logits = ((logits / self.logits_soft_cap)?.tanh()? * self.logits_soft_cap)?;
427        Ok(logits)
428    }
429}