candle_transformers/models/
granite.rs

1//! Granite is a Long Context Transformer Language Model.
2//!
3//! A high performance transformer model optimized for efficient processing
4//! of very long context sequences
5//!
6//! Based on implementation from [Nod.ai](https://github.com/nod-ai/granite)
7
8use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
9use candle::{DType, Device, IndexOp, Result, Tensor, D};
10use candle_nn::{embedding, Embedding, Module, VarBuilder};
11use std::{collections::HashMap, f32::consts::PI};
12
13pub const DEFAULT_MAX_SEQ_LEN: usize = 4096;
14
15#[derive(Debug, Clone, serde::Deserialize, Default)]
16pub enum GraniteRopeType {
17    #[serde(rename = "granite")]
18    Granite,
19    #[default]
20    #[serde(rename = "default")]
21    Default,
22}
23
24#[derive(Debug, Clone, serde::Deserialize, Default)]
25pub struct GraniteRopeConfig {
26    pub factor: f32,
27    pub low_freq_factor: f32,
28    pub high_freq_factor: f32,
29    pub original_max_position_embeddings: usize,
30    pub rope_type: GraniteRopeType,
31}
32#[derive(Debug, Clone, serde::Deserialize)]
33#[serde(untagged)]
34pub enum GraniteEosToks {
35    Single(u32),
36    Multiple(Vec<u32>),
37}
38
39#[derive(Debug, Clone, serde::Deserialize)]
40pub struct GraniteConfig {
41    pub hidden_size: usize,
42    pub intermediate_size: usize,
43    pub vocab_size: usize,
44    pub num_hidden_layers: usize,
45    pub num_attention_heads: usize,
46    pub num_key_value_heads: Option<usize>,
47    pub rms_norm_eps: f64,
48    #[serde(default = "default_rope")]
49    pub rope_theta: f32,
50    pub bos_token_id: Option<u32>,
51    pub eos_token_id: Option<GraniteEosToks>,
52    pub rope_scaling: Option<GraniteRopeConfig>,
53    pub max_position_embeddings: usize,
54}
55
56impl GraniteConfig {
57    pub fn num_key_value_heads(&self) -> usize {
58        self.num_key_value_heads.unwrap_or(self.num_attention_heads)
59    }
60}
61
62fn default_rope() -> f32 {
63    10_000.0
64}
65
66impl GraniteConfig {
67    pub fn into_config(self, use_flash_attn: bool) -> Config {
68        Config {
69            hidden_size: self.hidden_size,
70            intermediate_size: self.intermediate_size,
71            vocab_size: self.vocab_size,
72            num_hidden_layers: self.num_hidden_layers,
73            num_attention_heads: self.num_attention_heads,
74            num_key_value_heads: self.num_key_value_heads(),
75            rms_norm_eps: self.rms_norm_eps,
76            rope_theta: self.rope_theta,
77            use_flash_attn,
78            bos_token_id: self.bos_token_id,
79            eos_token_id: self.eos_token_id,
80            rope_scaling: self.rope_scaling,
81            max_position_embeddings: self.max_position_embeddings,
82        }
83    }
84}
85
86#[derive(Debug, Clone)]
87pub struct Config {
88    pub hidden_size: usize,
89    pub intermediate_size: usize,
90    pub vocab_size: usize,
91    pub num_hidden_layers: usize,
92    pub num_attention_heads: usize,
93    pub num_key_value_heads: usize,
94    pub use_flash_attn: bool,
95    pub rms_norm_eps: f64,
96    pub rope_theta: f32,
97    pub bos_token_id: Option<u32>,
98    pub eos_token_id: Option<GraniteEosToks>,
99    pub rope_scaling: Option<GraniteRopeConfig>,
100    pub max_position_embeddings: usize,
101}
102
103#[derive(Debug, Clone)]
104pub struct Cache {
105    masks: HashMap<usize, Tensor>,
106    pub use_kv_cache: bool,
107    kvs: Vec<Option<(Tensor, Tensor)>>,
108    cos: Tensor,
109    sin: Tensor,
110    device: Device,
111}
112
113fn calculate_default_inv_freq(cfg: &Config) -> Vec<f32> {
114    let head_dim = cfg.hidden_size / cfg.num_attention_heads;
115    (0..head_dim)
116        .step_by(2)
117        .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
118        .collect()
119}
120
121impl Cache {
122    pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> {
123        // precompute freqs_cis
124        let theta = match &config.rope_scaling {
125            None
126            | Some(GraniteRopeConfig {
127                rope_type: GraniteRopeType::Default,
128                ..
129            }) => calculate_default_inv_freq(config),
130            Some(rope_scaling) => {
131                let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
132                    / rope_scaling.low_freq_factor;
133                let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
134                    / rope_scaling.high_freq_factor;
135
136                calculate_default_inv_freq(config)
137                    .into_iter()
138                    .map(|freq| {
139                        let wavelen = 2. * PI / freq;
140                        if wavelen < high_freq_wavelen {
141                            freq
142                        } else if wavelen > low_freq_wavelen {
143                            freq / rope_scaling.factor
144                        } else {
145                            let smooth = (rope_scaling.original_max_position_embeddings as f32
146                                / wavelen
147                                - rope_scaling.low_freq_factor)
148                                / (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor);
149                            (1. - smooth) * freq / rope_scaling.factor + smooth * freq
150                        }
151                    })
152                    .collect::<Vec<_>>()
153            }
154        };
155
156        let theta = Tensor::new(theta, device)?;
157
158        let idx_theta = Tensor::arange(0, config.max_position_embeddings as u32, device)?
159            .to_dtype(DType::F32)?
160            .reshape((config.max_position_embeddings, 1))?
161            .matmul(&theta.reshape((1, theta.elem_count()))?)?;
162        let cos = idx_theta.cos()?.to_dtype(dtype)?;
163        let sin = idx_theta.sin()?.to_dtype(dtype)?;
164        Ok(Self {
165            masks: HashMap::new(),
166            use_kv_cache,
167            kvs: vec![None; config.num_hidden_layers],
168            device: device.clone(),
169            cos,
170            sin,
171        })
172    }
173
174    fn mask(&mut self, t: usize) -> Result<Tensor> {
175        if let Some(mask) = self.masks.get(&t) {
176            Ok(mask.clone())
177        } else {
178            let mask: Vec<_> = (0..t)
179                .flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
180                .collect();
181            let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
182            self.masks.insert(t, mask.clone());
183            Ok(mask)
184        }
185    }
186}
187
188#[derive(Debug, Clone)]
189struct CausalSelfAttention {
190    q_proj: Linear,
191    k_proj: Linear,
192    v_proj: Linear,
193    o_proj: Linear,
194    num_attention_heads: usize,
195    num_key_value_heads: usize,
196    head_dim: usize,
197    use_flash_attn: bool,
198    span: tracing::Span,
199    span_rot: tracing::Span,
200    max_position_embeddings: usize,
201}
202
203#[cfg(feature = "flash-attn")]
204fn flash_attn(
205    q: &Tensor,
206    k: &Tensor,
207    v: &Tensor,
208    softmax_scale: f32,
209    causal: bool,
210) -> Result<Tensor> {
211    candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
212}
213
214#[cfg(not(feature = "flash-attn"))]
215fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
216    unimplemented!("compile with '--features flash-attn'")
217}
218
219impl CausalSelfAttention {
220    fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
221        let _enter = self.span_rot.enter();
222        let (_b_sz, _, seq_len, _hidden_size) = x.dims4()?;
223        let cos = cache.cos.narrow(0, index_pos, seq_len)?;
224        let sin = cache.sin.narrow(0, index_pos, seq_len)?;
225        candle_nn::rotary_emb::rope(x, &cos, &sin)
226    }
227
228    fn forward(
229        &self,
230        x: &Tensor,
231        index_pos: usize,
232        block_idx: usize,
233        cache: &mut Cache,
234    ) -> Result<Tensor> {
235        let _enter = self.span.enter();
236        let (b_sz, seq_len, hidden_size) = x.dims3()?;
237        let q = self.q_proj.forward(x)?;
238        let k = self.k_proj.forward(x)?;
239        let v = self.v_proj.forward(x)?;
240
241        let q = q
242            .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
243            .transpose(1, 2)?
244            .contiguous()?;
245        let k = k
246            .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
247            .transpose(1, 2)?
248            .contiguous()?;
249        let mut v = v
250            .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
251            .transpose(1, 2)?;
252
253        let q = self.apply_rotary_emb(&q, index_pos, cache)?;
254        let mut k = self.apply_rotary_emb(&k, index_pos, cache)?;
255
256        if cache.use_kv_cache {
257            if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] {
258                k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
259                v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
260                let k_seq_len = k.dims()[1];
261                if k_seq_len > self.max_position_embeddings {
262                    k = k
263                        .narrow(
264                            D::Minus1,
265                            k_seq_len - self.max_position_embeddings,
266                            self.max_position_embeddings,
267                        )?
268                        .contiguous()?
269                }
270                let v_seq_len = v.dims()[1];
271                if v_seq_len > 2 * self.max_position_embeddings {
272                    v = v
273                        .narrow(
274                            D::Minus1,
275                            v_seq_len - self.max_position_embeddings,
276                            self.max_position_embeddings,
277                        )?
278                        .contiguous()?
279                }
280            }
281            cache.kvs[block_idx] = Some((k.clone(), v.clone()))
282        }
283
284        let k = self.repeat_kv(k)?;
285        let v = self.repeat_kv(v)?;
286
287        let y = if self.use_flash_attn {
288            // flash-attn expects (b_sz, seq_len, nheads, head_dim)
289            let q = q.transpose(1, 2)?;
290            let k = k.transpose(1, 2)?;
291            let v = v.transpose(1, 2)?;
292            let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
293            flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2)?
294        } else {
295            let in_dtype = q.dtype();
296            let q = q.to_dtype(DType::F32)?;
297            let k = k.to_dtype(DType::F32)?;
298            let v = v.to_dtype(DType::F32)?;
299            let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
300            let att = if seq_len == 1 {
301                att
302            } else {
303                let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
304                masked_fill(&att, &mask, f32::NEG_INFINITY)?
305            };
306            let att = candle_nn::ops::softmax(&att, D::Minus1)?;
307            // Convert to contiguous as matmul doesn't support strided vs for now.
308            att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
309        };
310        let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
311        let y = self.o_proj.forward(&y)?;
312        Ok(y)
313    }
314
315    fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
316        crate::utils::repeat_kv(x, self.num_attention_heads / self.num_key_value_heads)
317    }
318
319    fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
320        let span = tracing::span!(tracing::Level::TRACE, "attn");
321        let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
322        let size_in = cfg.hidden_size;
323        let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
324        let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
325        let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?;
326        let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?;
327        let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?;
328        let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?;
329        Ok(Self {
330            q_proj,
331            k_proj,
332            v_proj,
333            o_proj,
334            num_attention_heads: cfg.num_attention_heads,
335            num_key_value_heads: cfg.num_key_value_heads,
336            head_dim: cfg.hidden_size / cfg.num_attention_heads,
337            use_flash_attn: cfg.use_flash_attn,
338            span,
339            span_rot,
340            max_position_embeddings: cfg.max_position_embeddings,
341        })
342    }
343}
344
345fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
346    let shape = mask.shape();
347    let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
348    let m = mask.where_cond(&on_true, on_false)?;
349    Ok(m)
350}
351
352#[derive(Debug, Clone)]
353struct Mlp {
354    c_fc1: Linear,
355    c_fc2: Linear,
356    c_proj: Linear,
357    span: tracing::Span,
358}
359
360impl Mlp {
361    fn forward(&self, x: &Tensor) -> Result<Tensor> {
362        let _enter = self.span.enter();
363        let x = (candle_nn::ops::silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
364        self.c_proj.forward(&x)
365    }
366
367    fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
368        let span = tracing::span!(tracing::Level::TRACE, "mlp");
369        let h_size = cfg.hidden_size;
370        let i_size = cfg.intermediate_size;
371        let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
372        let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?;
373        let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?;
374        Ok(Self {
375            c_fc1,
376            c_fc2,
377            c_proj,
378            span,
379        })
380    }
381}
382
383#[derive(Debug, Clone)]
384struct Block {
385    rms_1: RmsNorm,
386    attn: CausalSelfAttention,
387    rms_2: RmsNorm,
388    mlp: Mlp,
389    span: tracing::Span,
390}
391
392impl Block {
393    fn forward(
394        &self,
395        x: &Tensor,
396        index_pos: usize,
397        block_idx: usize,
398        cache: &mut Cache,
399    ) -> Result<Tensor> {
400        let _enter = self.span.enter();
401        let residual = x;
402        let x = self.rms_1.forward(x)?;
403        let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?;
404        let residual = &x;
405        let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
406        Ok(x)
407    }
408
409    fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
410        let span = tracing::span!(tracing::Level::TRACE, "block");
411        let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?;
412        let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
413        let rms_1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
414        let rms_2 = RmsNorm::new(
415            cfg.hidden_size,
416            cfg.rms_norm_eps,
417            vb.pp("post_attention_layernorm"),
418        )?;
419        Ok(Self {
420            rms_1,
421            attn,
422            rms_2,
423            mlp,
424            span,
425        })
426    }
427}
428
429#[derive(Debug, Clone)]
430pub struct Granite {
431    wte: Embedding,
432    blocks: Vec<Block>,
433    ln_f: RmsNorm,
434    lm_head: Linear,
435}
436
437impl Granite {
438    pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
439        let (_b_sz, seq_len) = x.dims2()?;
440        let mut x = self.wte.forward(x)?;
441        for (block_idx, block) in self.blocks.iter().enumerate() {
442            x = block.forward(&x, index_pos, block_idx, cache)?;
443        }
444        let x = self.ln_f.forward(&x)?;
445        let x = x.i((.., seq_len - 1, ..))?.contiguous()?;
446        let logits = self.lm_head.forward(&x)?;
447        logits.to_dtype(DType::F32)
448    }
449
450    pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
451        let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
452        let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
453        let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
454        let blocks: Vec<_> = (0..cfg.num_hidden_layers)
455            .map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cfg).unwrap())
456            .collect();
457
458        Ok(Self {
459            wte,
460            blocks,
461            ln_f,
462            lm_head,
463        })
464    }
465}