candle_transformers/models/
gemma3.rs

1//! Gemma LLM architecture (Google) inference implementation.
2//!
3//! See ["Introducing Gemma 3: The most capable model you can run on a single GPU or TPU"](https://blog.google/technology/developers/gemma-3/)
4//!
5//! Based on implementations from HuggingFace transformers.
6
7use std::sync::Arc;
8
9use candle::{DType, Device, Module, Result, Tensor, D};
10use candle_nn::{linear_b as linear, Activation, Linear, VarBuilder};
11
12#[derive(serde::Deserialize, Debug, Clone)]
13pub struct Config {
14    pub attention_bias: bool,
15    pub head_dim: usize,
16    pub hidden_activation: Activation,
17    pub hidden_size: usize,
18    pub intermediate_size: usize,
19    pub num_attention_heads: usize,
20    pub num_hidden_layers: usize,
21    pub num_key_value_heads: usize,
22    pub rms_norm_eps: f64,
23    pub rope_theta: f64,
24    pub vocab_size: usize,
25    pub final_logit_softcapping: Option<f64>,
26    pub attn_logit_softcapping: Option<f64>,
27    pub query_pre_attn_scalar: usize,
28    pub sliding_window: usize,
29    pub sliding_window_pattern: usize,
30    pub max_position_embeddings: usize,
31}
32
33#[derive(Debug, Clone)]
34struct RmsNorm {
35    weight: Tensor,
36    eps: f64,
37}
38
39impl RmsNorm {
40    fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
41        let weight = vb.get(dim, "weight")?;
42        Ok(Self { weight, eps })
43    }
44}
45
46impl Module for RmsNorm {
47    fn forward(&self, x: &Tensor) -> Result<Tensor> {
48        let x_dtype = x.dtype();
49        let internal_dtype = match x_dtype {
50            DType::F16 | DType::BF16 => DType::F32,
51            d => d,
52        };
53        let hidden_size = x.dim(D::Minus1)?;
54        let x = x.to_dtype(internal_dtype)?;
55        let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
56        let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
57        x_normed
58            .to_dtype(x_dtype)?
59            .broadcast_mul(&(&self.weight + 1.0)?)
60    }
61}
62
63#[derive(Debug, Clone)]
64struct RotaryEmbedding {
65    sin: Tensor,
66    cos: Tensor,
67}
68
69impl RotaryEmbedding {
70    fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
71        let dim = cfg.head_dim;
72        let max_seq_len = cfg.max_position_embeddings;
73        let inv_freq: Vec<_> = (0..dim)
74            .step_by(2)
75            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
76            .collect();
77        let inv_freq_len = inv_freq.len();
78        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
79        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
80            .to_dtype(dtype)?
81            .reshape((max_seq_len, 1))?;
82        let freqs = t.matmul(&inv_freq)?;
83        Ok(Self {
84            sin: freqs.sin()?,
85            cos: freqs.cos()?,
86        })
87    }
88
89    fn apply_rotary_emb_qkv(
90        &self,
91        q: &Tensor,
92        k: &Tensor,
93        seqlen_offset: usize,
94    ) -> Result<(Tensor, Tensor)> {
95        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
96        let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
97        let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
98        let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
99        let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
100        Ok((q_embed, k_embed))
101    }
102}
103
104#[derive(Debug, Clone)]
105#[allow(clippy::upper_case_acronyms)]
106struct MLP {
107    gate_proj: Linear,
108    up_proj: Linear,
109    down_proj: Linear,
110    act_fn: candle_nn::Activation,
111}
112
113impl MLP {
114    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
115        let hidden_sz = cfg.hidden_size;
116        let intermediate_sz = cfg.intermediate_size;
117        let gate_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("gate_proj"))?;
118        let up_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("up_proj"))?;
119        let down_proj = linear(intermediate_sz, hidden_sz, false, vb.pp("down_proj"))?;
120        Ok(Self {
121            gate_proj,
122            up_proj,
123            down_proj,
124            act_fn: cfg.hidden_activation,
125        })
126    }
127}
128
129impl Module for MLP {
130    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
131        let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
132        let rhs = xs.apply(&self.up_proj)?;
133        (lhs * rhs)?.apply(&self.down_proj)
134    }
135}
136
137#[derive(Debug, Clone)]
138enum KvCache {
139    Normal(candle_nn::kv_cache::KvCache),
140    Rotating(candle_nn::kv_cache::RotatingKvCache),
141}
142
143#[derive(Debug, Clone)]
144struct Attention {
145    q_proj: Linear,
146    k_proj: Linear,
147    v_proj: Linear,
148    o_proj: Linear,
149    q_norm: RmsNorm,
150    k_norm: RmsNorm,
151    num_heads: usize,
152    num_kv_heads: usize,
153    num_kv_groups: usize,
154    head_dim: usize,
155    attn_logit_softcapping: Option<f64>,
156    rotary_emb: Arc<RotaryEmbedding>,
157    kv_cache: KvCache,
158    use_flash_attn: bool,
159}
160
161impl Attention {
162    fn new(
163        rotary_emb: Arc<RotaryEmbedding>,
164        use_flash_attn: bool,
165        is_sliding: bool,
166        cfg: &Config,
167        vb: VarBuilder,
168    ) -> Result<Self> {
169        let hidden_sz = cfg.hidden_size;
170        let num_heads = cfg.num_attention_heads;
171        let num_kv_heads = cfg.num_key_value_heads;
172        let num_kv_groups = num_heads / num_kv_heads;
173        let head_dim = cfg.head_dim;
174        let bias = cfg.attention_bias;
175        let q_proj = linear(hidden_sz, num_heads * head_dim, bias, vb.pp("q_proj"))?;
176        let k_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("k_proj"))?;
177        let v_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("v_proj"))?;
178        let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?;
179        let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?;
180        let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?;
181        let kv_cache = if is_sliding {
182            KvCache::Rotating(candle_nn::kv_cache::RotatingKvCache::new(
183                2,
184                cfg.sliding_window,
185            ))
186        } else {
187            KvCache::Normal(candle_nn::kv_cache::KvCache::new(2, cfg.sliding_window))
188        };
189        Ok(Self {
190            q_proj,
191            k_proj,
192            v_proj,
193            o_proj,
194            q_norm,
195            k_norm,
196            num_heads,
197            num_kv_heads,
198            num_kv_groups,
199            head_dim,
200            attn_logit_softcapping: cfg.attn_logit_softcapping,
201            rotary_emb,
202            kv_cache,
203            use_flash_attn,
204        })
205    }
206
207    fn forward(
208        &mut self,
209        xs: &Tensor,
210        attention_mask: Option<&Tensor>,
211        seqlen_offset: usize,
212    ) -> Result<Tensor> {
213        let (b_sz, q_len, _) = xs.dims3()?;
214
215        let query_states = self.q_proj.forward(xs)?;
216        let key_states = self.k_proj.forward(xs)?;
217        let value_states = self.v_proj.forward(xs)?;
218
219        let query_states = query_states
220            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
221            .transpose(1, 2)?;
222        let key_states = key_states
223            .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
224            .transpose(1, 2)?;
225        let value_states = value_states
226            .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
227            .transpose(1, 2)?;
228        let query_states = self.q_norm.forward(&query_states)?;
229        let key_states = self.k_norm.forward(&key_states)?;
230
231        let (query_states, key_states) =
232            self.rotary_emb
233                .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
234
235        let (key_states, value_states) = match &mut self.kv_cache {
236            KvCache::Normal(cache) => cache.append(&key_states, &value_states)?,
237            KvCache::Rotating(cache) => cache.append(&key_states, &value_states)?,
238        };
239
240        let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
241        let value_states =
242            crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
243
244        let attn_output = if self.use_flash_attn {
245            // flash-attn expects (b_sz, seq_len, nheads, head_dim)
246            let q = query_states.transpose(1, 2)?;
247            let k = key_states.transpose(1, 2)?;
248            let v = value_states.transpose(1, 2)?;
249            let scale = 1f32 / (self.head_dim as f32).sqrt();
250            flash_attn(&q, &k, &v, scale, attention_mask.is_some())?.transpose(1, 2)?
251        } else {
252            let scale = 1f64 / f64::sqrt(self.head_dim as f64);
253            let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
254
255            let attn_weights = match self.attn_logit_softcapping {
256                None => attn_weights,
257                Some(sc) => ((attn_weights / sc)?.tanh()? * sc)?,
258            };
259
260            let attn_weights = match attention_mask {
261                None => attn_weights,
262                Some(mask) => attn_weights.broadcast_add(mask)?,
263            };
264            let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
265            attn_weights.matmul(&value_states)?
266        };
267        attn_output
268            .transpose(1, 2)?
269            .reshape((b_sz, q_len, ()))?
270            .apply(&self.o_proj)
271    }
272
273    fn clear_kv_cache(&mut self) {
274        match &mut self.kv_cache {
275            KvCache::Normal(c) => c.reset(),
276            KvCache::Rotating(c) => c.reset(),
277        }
278    }
279}
280
281#[cfg(feature = "flash-attn")]
282fn flash_attn(
283    q: &Tensor,
284    k: &Tensor,
285    v: &Tensor,
286    softmax_scale: f32,
287    causal: bool,
288) -> Result<Tensor> {
289    candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
290}
291
292#[cfg(not(feature = "flash-attn"))]
293fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
294    unimplemented!("compile with '--features flash-attn'")
295}
296
297#[derive(Debug, Clone)]
298struct DecoderLayer {
299    self_attn: Attention,
300    mlp: MLP,
301    input_layernorm: RmsNorm,
302    pre_feedforward_layernorm: RmsNorm,
303    post_feedforward_layernorm: RmsNorm,
304    post_attention_layernorm: RmsNorm,
305}
306
307impl DecoderLayer {
308    fn new(
309        rotary_emb: Arc<RotaryEmbedding>,
310        use_flash_attn: bool,
311        is_sliding: bool,
312        cfg: &Config,
313        vb: VarBuilder,
314    ) -> Result<Self> {
315        let self_attn = Attention::new(
316            rotary_emb,
317            use_flash_attn,
318            is_sliding,
319            cfg,
320            vb.pp("self_attn"),
321        )?;
322        let mlp = MLP::new(cfg, vb.pp("mlp"))?;
323        let input_layernorm =
324            RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
325        let pre_feedforward_layernorm = RmsNorm::new(
326            cfg.hidden_size,
327            cfg.rms_norm_eps,
328            vb.pp("pre_feedforward_layernorm"),
329        )?;
330        let post_feedforward_layernorm = RmsNorm::new(
331            cfg.hidden_size,
332            cfg.rms_norm_eps,
333            vb.pp("post_feedforward_layernorm"),
334        )?;
335        let post_attention_layernorm = RmsNorm::new(
336            cfg.hidden_size,
337            cfg.rms_norm_eps,
338            vb.pp("post_attention_layernorm"),
339        )?;
340        Ok(Self {
341            self_attn,
342            mlp,
343            input_layernorm,
344            pre_feedforward_layernorm,
345            post_feedforward_layernorm,
346            post_attention_layernorm,
347        })
348    }
349
350    fn forward(
351        &mut self,
352        xs: &Tensor,
353        attention_mask: Option<&Tensor>,
354        seqlen_offset: usize,
355    ) -> Result<Tensor> {
356        let residual = xs;
357        let xs = self.input_layernorm.forward(xs)?;
358        let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
359        let xs = xs.apply(&self.post_attention_layernorm)?;
360        let xs = (xs + residual)?;
361        let residual = &xs;
362        let xs = xs.apply(&self.pre_feedforward_layernorm)?;
363        let xs = xs.apply(&self.mlp)?;
364        let xs = xs.apply(&self.post_feedforward_layernorm)?;
365        residual + xs
366    }
367
368    fn clear_kv_cache(&mut self) {
369        self.self_attn.clear_kv_cache()
370    }
371}
372
373#[derive(Debug, Clone)]
374pub struct Model {
375    embed_tokens: candle_nn::Embedding,
376    layers: Vec<DecoderLayer>,
377    norm: RmsNorm,
378    lm_head: Linear,
379    final_logit_softcapping: Option<f64>,
380    device: Device,
381    dtype: DType,
382    hidden_size: usize,
383    sliding_window: usize,
384}
385
386impl Model {
387    pub fn new(use_flash_attn: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> {
388        let vb_m = vb.pp("model");
389        let embed_tokens =
390            candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
391        let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
392        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
393        let vb_l = vb_m.pp("layers");
394        for layer_idx in 0..cfg.num_hidden_layers {
395            let is_sliding = (layer_idx + 1) % cfg.sliding_window_pattern > 0;
396            let layer = DecoderLayer::new(
397                rotary_emb.clone(),
398                use_flash_attn,
399                is_sliding,
400                cfg,
401                vb_l.pp(layer_idx),
402            )?;
403            layers.push(layer)
404        }
405        let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
406        let lm_head = Linear::new(embed_tokens.embeddings().clone(), None);
407        Ok(Self {
408            embed_tokens,
409            layers,
410            norm,
411            lm_head,
412            final_logit_softcapping: cfg.final_logit_softcapping,
413            device: vb.device().clone(),
414            dtype: vb.dtype(),
415            hidden_size: cfg.hidden_size,
416            sliding_window: cfg.sliding_window,
417        })
418    }
419
420    fn prepare_decoder_attention_mask(
421        &self,
422        b_size: usize,
423        tgt_len: usize,
424        seqlen_offset: usize,
425    ) -> Result<Tensor> {
426        let mask: Vec<_> = match Some(self.sliding_window) {
427            None => (0..tgt_len)
428                .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
429                .collect(),
430            Some(sliding_window) => (0..tgt_len)
431                .flat_map(|i| {
432                    (0..tgt_len).map(move |j| {
433                        if i < j || j + sliding_window < i {
434                            f32::NEG_INFINITY
435                        } else {
436                            0.
437                        }
438                    })
439                })
440                .collect(),
441        };
442        let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
443        let mask = if seqlen_offset > 0 {
444            let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
445            Tensor::cat(&[&mask0, &mask], D::Minus1)?
446        } else {
447            mask
448        };
449        mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
450            .to_dtype(self.dtype)
451    }
452
453    pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
454        let (b_size, seq_len) = input_ids.dims2()?;
455        let attention_mask = if seq_len <= 1 {
456            None
457        } else {
458            let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
459            Some(mask)
460        };
461        let xs = self.embed_tokens.forward(input_ids)?;
462        let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
463        for layer in self.layers.iter_mut() {
464            xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
465        }
466        let logits = xs
467            .narrow(1, seq_len - 1, 1)?
468            .apply(&self.norm)?
469            .apply(&self.lm_head)?;
470        let logits = match self.final_logit_softcapping {
471            None => logits,
472            Some(sc) => ((logits / sc)?.tanh()? * sc)?,
473        };
474
475        Ok(logits)
476    }
477
478    pub fn clear_kv_cache(&mut self) {
479        for layer in self.layers.iter_mut() {
480            layer.clear_kv_cache()
481        }
482    }
483}