candle_transformers/models/
gemma2.rs

1//! Gemma LLM architecture (Google) inference implementation.
2//!
3//! See ["Gemma: Open Models Based on Gemini Technology"](https://blog.google/technology/developers/gemma-open-models/)
4//!
5//! Based on implementations from Google and OpenLLM
6
7use std::sync::Arc;
8
9use candle::{DType, Device, Module, Result, Tensor, D};
10use candle_nn::{linear_b as linear, Activation, Linear, VarBuilder};
11
12fn default_max_position_embeddings() -> usize {
13    4096
14}
15
16#[derive(serde::Deserialize, Debug, Clone)]
17pub struct Config {
18    pub attention_bias: bool,
19    pub head_dim: usize,
20    pub hidden_activation: Activation,
21    pub hidden_size: usize,
22    pub intermediate_size: usize,
23    pub num_attention_heads: usize,
24    pub num_hidden_layers: usize,
25    pub num_key_value_heads: usize,
26    pub rms_norm_eps: f64,
27    pub rope_theta: f64,
28    pub vocab_size: usize,
29    pub final_logit_softcapping: Option<f64>,
30    pub attn_logit_softcapping: Option<f64>,
31    pub query_pre_attn_scalar: usize,
32    // TODO: Handle the sliding window in the attention mask.
33    pub sliding_window: Option<usize>,
34
35    #[serde(default = "default_max_position_embeddings")]
36    pub max_position_embeddings: usize,
37}
38
39#[derive(Debug, Clone)]
40struct RmsNorm {
41    weight: Tensor,
42    eps: f64,
43}
44
45impl RmsNorm {
46    fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
47        let weight = vb.get(dim, "weight")?;
48        Ok(Self { weight, eps })
49    }
50}
51
52impl Module for RmsNorm {
53    fn forward(&self, x: &Tensor) -> Result<Tensor> {
54        let x_dtype = x.dtype();
55        let internal_dtype = match x_dtype {
56            DType::F16 | DType::BF16 => DType::F32,
57            d => d,
58        };
59        let hidden_size = x.dim(D::Minus1)?;
60        let x = x.to_dtype(internal_dtype)?;
61        let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
62        let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
63        x_normed
64            .to_dtype(x_dtype)?
65            .broadcast_mul(&(&self.weight + 1.0)?)
66    }
67}
68
69#[derive(Debug, Clone)]
70struct RotaryEmbedding {
71    sin: Tensor,
72    cos: Tensor,
73}
74
75impl RotaryEmbedding {
76    fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
77        let dim = cfg.head_dim;
78        let max_seq_len = cfg.max_position_embeddings;
79        let inv_freq: Vec<_> = (0..dim)
80            .step_by(2)
81            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
82            .collect();
83        let inv_freq_len = inv_freq.len();
84        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
85        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
86            .to_dtype(dtype)?
87            .reshape((max_seq_len, 1))?;
88        let freqs = t.matmul(&inv_freq)?;
89        Ok(Self {
90            sin: freqs.sin()?,
91            cos: freqs.cos()?,
92        })
93    }
94
95    fn apply_rotary_emb_qkv(
96        &self,
97        q: &Tensor,
98        k: &Tensor,
99        seqlen_offset: usize,
100    ) -> Result<(Tensor, Tensor)> {
101        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
102        let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
103        let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
104        let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
105        let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
106        Ok((q_embed, k_embed))
107    }
108}
109
110#[derive(Debug, Clone)]
111#[allow(clippy::upper_case_acronyms)]
112struct MLP {
113    gate_proj: Linear,
114    up_proj: Linear,
115    down_proj: Linear,
116    act_fn: candle_nn::Activation,
117}
118
119impl MLP {
120    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
121        let hidden_sz = cfg.hidden_size;
122        let intermediate_sz = cfg.intermediate_size;
123        let gate_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("gate_proj"))?;
124        let up_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("up_proj"))?;
125        let down_proj = linear(intermediate_sz, hidden_sz, false, vb.pp("down_proj"))?;
126        Ok(Self {
127            gate_proj,
128            up_proj,
129            down_proj,
130            act_fn: cfg.hidden_activation,
131        })
132    }
133}
134
135impl Module for MLP {
136    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
137        let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
138        let rhs = xs.apply(&self.up_proj)?;
139        (lhs * rhs)?.apply(&self.down_proj)
140    }
141}
142
143#[derive(Debug, Clone)]
144struct Attention {
145    q_proj: Linear,
146    k_proj: Linear,
147    v_proj: Linear,
148    o_proj: Linear,
149    num_heads: usize,
150    num_kv_heads: usize,
151    num_kv_groups: usize,
152    head_dim: usize,
153    attn_logit_softcapping: Option<f64>,
154    rotary_emb: Arc<RotaryEmbedding>,
155    kv_cache: Option<(Tensor, Tensor)>,
156    use_flash_attn: bool,
157}
158
159impl Attention {
160    fn new(
161        rotary_emb: Arc<RotaryEmbedding>,
162        use_flash_attn: bool,
163        cfg: &Config,
164        vb: VarBuilder,
165    ) -> Result<Self> {
166        let hidden_sz = cfg.hidden_size;
167        let num_heads = cfg.num_attention_heads;
168        let num_kv_heads = cfg.num_key_value_heads;
169        let num_kv_groups = num_heads / num_kv_heads;
170        let head_dim = cfg.head_dim;
171        let bias = cfg.attention_bias;
172        let q_proj = linear(hidden_sz, num_heads * head_dim, bias, vb.pp("q_proj"))?;
173        let k_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("k_proj"))?;
174        let v_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("v_proj"))?;
175        let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?;
176        Ok(Self {
177            q_proj,
178            k_proj,
179            v_proj,
180            o_proj,
181            num_heads,
182            num_kv_heads,
183            num_kv_groups,
184            head_dim,
185            attn_logit_softcapping: cfg.attn_logit_softcapping,
186            rotary_emb,
187            kv_cache: None,
188            use_flash_attn,
189        })
190    }
191
192    fn forward(
193        &mut self,
194        xs: &Tensor,
195        attention_mask: Option<&Tensor>,
196        seqlen_offset: usize,
197    ) -> Result<Tensor> {
198        let (b_sz, q_len, _) = xs.dims3()?;
199
200        let query_states = self.q_proj.forward(xs)?;
201        let key_states = self.k_proj.forward(xs)?;
202        let value_states = self.v_proj.forward(xs)?;
203
204        let query_states = query_states
205            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
206            .transpose(1, 2)?;
207        let key_states = key_states
208            .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
209            .transpose(1, 2)?;
210        let value_states = value_states
211            .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
212            .transpose(1, 2)?;
213
214        let (query_states, key_states) =
215            self.rotary_emb
216                .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
217
218        let (key_states, value_states) = match &self.kv_cache {
219            None => (key_states, value_states),
220            Some((prev_k, prev_v)) => {
221                let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
222                let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
223                (key_states, value_states)
224            }
225        };
226        self.kv_cache = Some((key_states.clone(), value_states.clone()));
227
228        let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
229        let value_states =
230            crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
231
232        let attn_output = if self.use_flash_attn {
233            // flash-attn expects (b_sz, seq_len, nheads, head_dim)
234            let q = query_states.transpose(1, 2)?;
235            let k = key_states.transpose(1, 2)?;
236            let v = value_states.transpose(1, 2)?;
237            let scale = 1f32 / (self.head_dim as f32).sqrt();
238            flash_attn(&q, &k, &v, scale, attention_mask.is_some())?.transpose(1, 2)?
239        } else {
240            let scale = 1f64 / f64::sqrt(self.head_dim as f64);
241            let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
242
243            let attn_weights = match self.attn_logit_softcapping {
244                None => attn_weights,
245                Some(sc) => ((attn_weights / sc)?.tanh()? * sc)?,
246            };
247
248            let attn_weights = match attention_mask {
249                None => attn_weights,
250                Some(mask) => attn_weights.broadcast_add(mask)?,
251            };
252            let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
253            attn_weights.matmul(&value_states)?
254        };
255        attn_output
256            .transpose(1, 2)?
257            .reshape((b_sz, q_len, ()))?
258            .apply(&self.o_proj)
259    }
260
261    fn clear_kv_cache(&mut self) {
262        self.kv_cache = None
263    }
264}
265
266#[cfg(feature = "flash-attn")]
267fn flash_attn(
268    q: &Tensor,
269    k: &Tensor,
270    v: &Tensor,
271    softmax_scale: f32,
272    causal: bool,
273) -> Result<Tensor> {
274    candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
275}
276
277#[cfg(not(feature = "flash-attn"))]
278fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
279    unimplemented!("compile with '--features flash-attn'")
280}
281
282#[derive(Debug, Clone)]
283struct DecoderLayer {
284    self_attn: Attention,
285    mlp: MLP,
286    input_layernorm: RmsNorm,
287    pre_feedforward_layernorm: RmsNorm,
288    post_feedforward_layernorm: RmsNorm,
289    post_attention_layernorm: RmsNorm,
290}
291
292impl DecoderLayer {
293    fn new(
294        rotary_emb: Arc<RotaryEmbedding>,
295        use_flash_attn: bool,
296        cfg: &Config,
297        vb: VarBuilder,
298    ) -> Result<Self> {
299        let self_attn = Attention::new(rotary_emb, use_flash_attn, cfg, vb.pp("self_attn"))?;
300        let mlp = MLP::new(cfg, vb.pp("mlp"))?;
301        let input_layernorm =
302            RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
303        let pre_feedforward_layernorm = RmsNorm::new(
304            cfg.hidden_size,
305            cfg.rms_norm_eps,
306            vb.pp("pre_feedforward_layernorm"),
307        )?;
308        let post_feedforward_layernorm = RmsNorm::new(
309            cfg.hidden_size,
310            cfg.rms_norm_eps,
311            vb.pp("post_feedforward_layernorm"),
312        )?;
313        let post_attention_layernorm = RmsNorm::new(
314            cfg.hidden_size,
315            cfg.rms_norm_eps,
316            vb.pp("post_attention_layernorm"),
317        )?;
318        Ok(Self {
319            self_attn,
320            mlp,
321            input_layernorm,
322            pre_feedforward_layernorm,
323            post_feedforward_layernorm,
324            post_attention_layernorm,
325        })
326    }
327
328    fn forward(
329        &mut self,
330        xs: &Tensor,
331        attention_mask: Option<&Tensor>,
332        seqlen_offset: usize,
333    ) -> Result<Tensor> {
334        let residual = xs;
335        let xs = self.input_layernorm.forward(xs)?;
336        let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
337        let xs = xs.apply(&self.post_attention_layernorm)?;
338        let xs = (xs + residual)?;
339        let residual = &xs;
340        let xs = xs.apply(&self.pre_feedforward_layernorm)?;
341        let xs = xs.apply(&self.mlp)?;
342        let xs = xs.apply(&self.post_feedforward_layernorm)?;
343        residual + xs
344    }
345
346    fn clear_kv_cache(&mut self) {
347        self.self_attn.clear_kv_cache()
348    }
349}
350
351#[derive(Debug, Clone)]
352pub struct Model {
353    embed_tokens: candle_nn::Embedding,
354    layers: Vec<DecoderLayer>,
355    norm: RmsNorm,
356    lm_head: Linear,
357    final_logit_softcapping: Option<f64>,
358    device: Device,
359    dtype: DType,
360    hidden_size: usize,
361    sliding_window: Option<usize>,
362}
363
364impl Model {
365    pub fn new(use_flash_attn: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> {
366        let vb_m = vb.pp("model");
367        let embed_tokens =
368            candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
369        let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
370        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
371        let vb_l = vb_m.pp("layers");
372        for layer_idx in 0..cfg.num_hidden_layers {
373            let layer =
374                DecoderLayer::new(rotary_emb.clone(), use_flash_attn, cfg, vb_l.pp(layer_idx))?;
375            layers.push(layer)
376        }
377        let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
378        let lm_head = Linear::new(embed_tokens.embeddings().clone(), None);
379        Ok(Self {
380            embed_tokens,
381            layers,
382            norm,
383            lm_head,
384            final_logit_softcapping: cfg.final_logit_softcapping,
385            device: vb.device().clone(),
386            dtype: vb.dtype(),
387            hidden_size: cfg.hidden_size,
388            sliding_window: cfg.sliding_window,
389        })
390    }
391
392    fn prepare_decoder_attention_mask(
393        &self,
394        b_size: usize,
395        tgt_len: usize,
396        seqlen_offset: usize,
397    ) -> Result<Tensor> {
398        let mask: Vec<_> = match self.sliding_window {
399            None => (0..tgt_len)
400                .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
401                .collect(),
402            Some(sliding_window) => (0..tgt_len)
403                .flat_map(|i| {
404                    (0..tgt_len).map(move |j| {
405                        if i < j || j + sliding_window < i {
406                            f32::NEG_INFINITY
407                        } else {
408                            0.
409                        }
410                    })
411                })
412                .collect(),
413        };
414        let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
415        let mask = if seqlen_offset > 0 {
416            let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
417            Tensor::cat(&[&mask0, &mask], D::Minus1)?
418        } else {
419            mask
420        };
421        mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
422            .to_dtype(self.dtype)
423    }
424
425    pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
426        let (b_size, seq_len) = input_ids.dims2()?;
427        let attention_mask = if seq_len <= 1 {
428            None
429        } else {
430            let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
431            Some(mask)
432        };
433        let xs = self.embed_tokens.forward(input_ids)?;
434        let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
435        for layer in self.layers.iter_mut() {
436            xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
437        }
438        let logits = xs
439            .narrow(1, seq_len - 1, 1)?
440            .apply(&self.norm)?
441            .apply(&self.lm_head)?;
442        let logits = match self.final_logit_softcapping {
443            None => logits,
444            Some(sc) => ((logits / sc)?.tanh()? * sc)?,
445        };
446
447        Ok(logits)
448    }
449
450    pub fn clear_kv_cache(&mut self) {
451        for layer in self.layers.iter_mut() {
452            layer.clear_kv_cache()
453        }
454    }
455}