candle_transformers/models/
llama.rs

1//! Llama inference implementation.
2//!
3//! See ["LLaMA: Open and Efficient Foundation Language Models"](https://arxiv.org/abs/2302.13971)
4//!
5//! Implementation based on Hugging Face's [transformers](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py)
6
7use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
8use candle::{DType, Device, IndexOp, Result, Tensor, D};
9use candle_nn::{embedding, Embedding, Module, VarBuilder};
10use std::{collections::HashMap, f32::consts::PI};
11
12pub const DEFAULT_MAX_SEQ_LEN: usize = 4096;
13
14#[derive(Debug, Clone, serde::Deserialize, Default)]
15pub enum Llama3RopeType {
16    #[serde(rename = "llama3")]
17    Llama3,
18    #[default]
19    #[serde(rename = "default")]
20    Default,
21}
22
23#[derive(Debug, Clone, serde::Deserialize, Default)]
24pub struct Llama3RopeConfig {
25    pub factor: f32,
26    pub low_freq_factor: f32,
27    pub high_freq_factor: f32,
28    pub original_max_position_embeddings: usize,
29    pub rope_type: Llama3RopeType,
30}
31#[derive(Debug, Clone, serde::Deserialize)]
32#[serde(untagged)]
33pub enum LlamaEosToks {
34    Single(u32),
35    Multiple(Vec<u32>),
36}
37
38#[derive(Debug, Clone, serde::Deserialize)]
39pub struct LlamaConfig {
40    pub hidden_size: usize,
41    pub intermediate_size: usize,
42    pub vocab_size: usize,
43    pub num_hidden_layers: usize,
44    pub num_attention_heads: usize,
45    pub num_key_value_heads: Option<usize>,
46    pub rms_norm_eps: f64,
47    #[serde(default = "default_rope")]
48    pub rope_theta: f32,
49    pub bos_token_id: Option<u32>,
50    pub eos_token_id: Option<LlamaEosToks>,
51    pub rope_scaling: Option<Llama3RopeConfig>,
52    pub max_position_embeddings: usize,
53    pub tie_word_embeddings: Option<bool>,
54}
55
56impl LlamaConfig {
57    pub fn num_key_value_heads(&self) -> usize {
58        self.num_key_value_heads.unwrap_or(self.num_attention_heads)
59    }
60}
61
62fn default_rope() -> f32 {
63    10_000.0
64}
65
66impl LlamaConfig {
67    pub fn into_config(self, use_flash_attn: bool) -> Config {
68        Config {
69            hidden_size: self.hidden_size,
70            intermediate_size: self.intermediate_size,
71            vocab_size: self.vocab_size,
72            num_hidden_layers: self.num_hidden_layers,
73            num_attention_heads: self.num_attention_heads,
74            num_key_value_heads: self.num_key_value_heads(),
75            rms_norm_eps: self.rms_norm_eps,
76            rope_theta: self.rope_theta,
77            use_flash_attn,
78            bos_token_id: self.bos_token_id,
79            eos_token_id: self.eos_token_id,
80            rope_scaling: self.rope_scaling,
81            max_position_embeddings: self.max_position_embeddings,
82            tie_word_embeddings: self.tie_word_embeddings.unwrap_or(false),
83        }
84    }
85}
86
87#[derive(Debug, Clone)]
88pub struct Config {
89    pub hidden_size: usize,
90    pub intermediate_size: usize,
91    pub vocab_size: usize,
92    pub num_hidden_layers: usize,
93    pub num_attention_heads: usize,
94    pub num_key_value_heads: usize,
95    pub use_flash_attn: bool,
96    pub rms_norm_eps: f64,
97    pub rope_theta: f32,
98    pub bos_token_id: Option<u32>,
99    pub eos_token_id: Option<LlamaEosToks>,
100    pub rope_scaling: Option<Llama3RopeConfig>,
101    pub max_position_embeddings: usize,
102    pub tie_word_embeddings: bool,
103}
104
105impl Config {
106    pub fn config_7b_v1(use_flash_attn: bool) -> Self {
107        Self {
108            hidden_size: 4096,
109            intermediate_size: 11008,
110            vocab_size: 32000,
111            num_hidden_layers: 32,
112            num_attention_heads: 32,
113            num_key_value_heads: 32,
114            use_flash_attn,
115            rms_norm_eps: 1e-6,
116            rope_theta: 10_000.0,
117            bos_token_id: None,
118            eos_token_id: None,
119            rope_scaling: None,
120            max_position_embeddings: DEFAULT_MAX_SEQ_LEN,
121            tie_word_embeddings: false,
122        }
123    }
124
125    pub fn config_7b_v2(use_flash_attn: bool) -> Self {
126        Self {
127            hidden_size: 4096,
128            intermediate_size: 11008,
129            vocab_size: 32000,
130            num_hidden_layers: 32,
131            num_attention_heads: 32,
132            num_key_value_heads: 32,
133            use_flash_attn,
134            rms_norm_eps: 1e-5,
135            rope_theta: 10_000.0,
136            bos_token_id: None,
137            eos_token_id: None,
138            rope_scaling: None,
139            max_position_embeddings: DEFAULT_MAX_SEQ_LEN,
140            tie_word_embeddings: false,
141        }
142    }
143}
144
145#[derive(Debug, Clone)]
146pub struct Cache {
147    masks: HashMap<usize, Tensor>,
148    pub use_kv_cache: bool,
149    kvs: Vec<Option<(Tensor, Tensor)>>,
150    cos: Tensor,
151    sin: Tensor,
152    device: Device,
153}
154
155fn calculate_default_inv_freq(cfg: &Config) -> Vec<f32> {
156    let head_dim = cfg.hidden_size / cfg.num_attention_heads;
157    (0..head_dim)
158        .step_by(2)
159        .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
160        .collect()
161}
162
163impl Cache {
164    pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> {
165        // precompute freqs_cis
166        let theta = match &config.rope_scaling {
167            None
168            | Some(Llama3RopeConfig {
169                rope_type: Llama3RopeType::Default,
170                ..
171            }) => calculate_default_inv_freq(config),
172            Some(rope_scaling) => {
173                let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
174                    / rope_scaling.low_freq_factor;
175                let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
176                    / rope_scaling.high_freq_factor;
177
178                calculate_default_inv_freq(config)
179                    .into_iter()
180                    .map(|freq| {
181                        let wavelen = 2. * PI / freq;
182                        if wavelen < high_freq_wavelen {
183                            freq
184                        } else if wavelen > low_freq_wavelen {
185                            freq / rope_scaling.factor
186                        } else {
187                            let smooth = (rope_scaling.original_max_position_embeddings as f32
188                                / wavelen
189                                - rope_scaling.low_freq_factor)
190                                / (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor);
191                            (1. - smooth) * freq / rope_scaling.factor + smooth * freq
192                        }
193                    })
194                    .collect::<Vec<_>>()
195            }
196        };
197
198        let theta = Tensor::new(theta, device)?;
199
200        let idx_theta = Tensor::arange(0, config.max_position_embeddings as u32, device)?
201            .to_dtype(DType::F32)?
202            .reshape((config.max_position_embeddings, 1))?
203            .matmul(&theta.reshape((1, theta.elem_count()))?)?;
204        // This is different from the paper, see:
205        // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
206        let cos = idx_theta.cos()?.to_dtype(dtype)?;
207        let sin = idx_theta.sin()?.to_dtype(dtype)?;
208        Ok(Self {
209            masks: HashMap::new(),
210            use_kv_cache,
211            kvs: vec![None; config.num_hidden_layers],
212            device: device.clone(),
213            cos,
214            sin,
215        })
216    }
217
218    fn mask(&mut self, t: usize) -> Result<Tensor> {
219        if let Some(mask) = self.masks.get(&t) {
220            Ok(mask.clone())
221        } else {
222            let mask: Vec<_> = (0..t)
223                .flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
224                .collect();
225            let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
226            self.masks.insert(t, mask.clone());
227            Ok(mask)
228        }
229    }
230}
231
232#[derive(Debug, Clone)]
233struct CausalSelfAttention {
234    q_proj: Linear,
235    k_proj: Linear,
236    v_proj: Linear,
237    o_proj: Linear,
238    num_attention_heads: usize,
239    num_key_value_heads: usize,
240    head_dim: usize,
241    use_flash_attn: bool,
242    span: tracing::Span,
243    span_rot: tracing::Span,
244    max_position_embeddings: usize,
245}
246
247#[cfg(feature = "flash-attn")]
248fn flash_attn(
249    q: &Tensor,
250    k: &Tensor,
251    v: &Tensor,
252    softmax_scale: f32,
253    causal: bool,
254) -> Result<Tensor> {
255    candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
256}
257
258#[cfg(not(feature = "flash-attn"))]
259fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
260    unimplemented!("compile with '--features flash-attn'")
261}
262
263impl CausalSelfAttention {
264    fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
265        let _enter = self.span_rot.enter();
266        let (_b_sz, _, seq_len, _hidden_size) = x.dims4()?;
267        let cos = cache.cos.narrow(0, index_pos, seq_len)?;
268        let sin = cache.sin.narrow(0, index_pos, seq_len)?;
269        candle_nn::rotary_emb::rope(x, &cos, &sin)
270    }
271
272    fn forward(
273        &self,
274        x: &Tensor,
275        index_pos: usize,
276        block_idx: usize,
277        cache: &mut Cache,
278    ) -> Result<Tensor> {
279        let _enter = self.span.enter();
280        let (b_sz, seq_len, hidden_size) = x.dims3()?;
281        let q = self.q_proj.forward(x)?;
282        let k = self.k_proj.forward(x)?;
283        let v = self.v_proj.forward(x)?;
284
285        let q = q
286            .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
287            .transpose(1, 2)?
288            .contiguous()?;
289        let k = k
290            .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
291            .transpose(1, 2)?
292            .contiguous()?;
293        let mut v = v
294            .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
295            .transpose(1, 2)?;
296
297        let q = self.apply_rotary_emb(&q, index_pos, cache)?;
298        let mut k = self.apply_rotary_emb(&k, index_pos, cache)?;
299
300        if cache.use_kv_cache {
301            if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] {
302                k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
303                v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
304                let k_seq_len = k.dims()[1];
305                if k_seq_len > self.max_position_embeddings {
306                    k = k
307                        .narrow(
308                            D::Minus1,
309                            k_seq_len - self.max_position_embeddings,
310                            self.max_position_embeddings,
311                        )?
312                        .contiguous()?
313                }
314                let v_seq_len = v.dims()[1];
315                if v_seq_len > 2 * self.max_position_embeddings {
316                    v = v
317                        .narrow(
318                            D::Minus1,
319                            v_seq_len - self.max_position_embeddings,
320                            self.max_position_embeddings,
321                        )?
322                        .contiguous()?
323                }
324            }
325            cache.kvs[block_idx] = Some((k.clone(), v.clone()))
326        }
327
328        let k = self.repeat_kv(k)?;
329        let v = self.repeat_kv(v)?;
330
331        let y = if self.use_flash_attn {
332            // flash-attn expects (b_sz, seq_len, nheads, head_dim)
333            let q = q.transpose(1, 2)?;
334            let k = k.transpose(1, 2)?;
335            let v = v.transpose(1, 2)?;
336            let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
337            flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2)?
338        } else {
339            let in_dtype = q.dtype();
340            let q = q.to_dtype(DType::F32)?;
341            let k = k.to_dtype(DType::F32)?;
342            let v = v.to_dtype(DType::F32)?;
343            let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
344            let att = if seq_len == 1 {
345                att
346            } else {
347                let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
348                masked_fill(&att, &mask, f32::NEG_INFINITY)?
349            };
350
351            let att = candle_nn::ops::softmax_last_dim(&att)?;
352            // Convert to contiguous as matmul doesn't support strided vs for now.
353            att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
354        };
355        let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
356        let y = self.o_proj.forward(&y)?;
357        Ok(y)
358    }
359
360    fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
361        crate::utils::repeat_kv(x, self.num_attention_heads / self.num_key_value_heads)
362    }
363
364    fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
365        let span = tracing::span!(tracing::Level::TRACE, "attn");
366        let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
367        let size_in = cfg.hidden_size;
368        let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
369        let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
370        let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?;
371        let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?;
372        let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?;
373        let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?;
374        Ok(Self {
375            q_proj,
376            k_proj,
377            v_proj,
378            o_proj,
379            num_attention_heads: cfg.num_attention_heads,
380            num_key_value_heads: cfg.num_key_value_heads,
381            head_dim: cfg.hidden_size / cfg.num_attention_heads,
382            use_flash_attn: cfg.use_flash_attn,
383            span,
384            span_rot,
385            max_position_embeddings: cfg.max_position_embeddings,
386        })
387    }
388}
389
390fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
391    let shape = mask.shape();
392    let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
393    let m = mask.where_cond(&on_true, on_false)?;
394    Ok(m)
395}
396
397#[derive(Debug, Clone)]
398struct Mlp {
399    c_fc1: Linear,
400    c_fc2: Linear,
401    c_proj: Linear,
402    span: tracing::Span,
403}
404
405impl Mlp {
406    fn forward(&self, x: &Tensor) -> Result<Tensor> {
407        let _enter = self.span.enter();
408        let x = (candle_nn::ops::silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
409        self.c_proj.forward(&x)
410    }
411
412    fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
413        let span = tracing::span!(tracing::Level::TRACE, "mlp");
414        let h_size = cfg.hidden_size;
415        let i_size = cfg.intermediate_size;
416        let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
417        let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?;
418        let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?;
419        Ok(Self {
420            c_fc1,
421            c_fc2,
422            c_proj,
423            span,
424        })
425    }
426}
427
428#[derive(Debug, Clone)]
429struct Block {
430    rms_1: RmsNorm,
431    attn: CausalSelfAttention,
432    rms_2: RmsNorm,
433    mlp: Mlp,
434    span: tracing::Span,
435}
436
437impl Block {
438    fn forward(
439        &self,
440        x: &Tensor,
441        index_pos: usize,
442        block_idx: usize,
443        cache: &mut Cache,
444    ) -> Result<Tensor> {
445        let _enter = self.span.enter();
446        let residual = x;
447        let x = self.rms_1.forward(x)?;
448        let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?;
449        let residual = &x;
450        let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
451        Ok(x)
452    }
453
454    fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
455        let span = tracing::span!(tracing::Level::TRACE, "block");
456        let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?;
457        let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
458        let rms_1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
459        let rms_2 = RmsNorm::new(
460            cfg.hidden_size,
461            cfg.rms_norm_eps,
462            vb.pp("post_attention_layernorm"),
463        )?;
464        Ok(Self {
465            rms_1,
466            attn,
467            rms_2,
468            mlp,
469            span,
470        })
471    }
472}
473
474#[derive(Debug, Clone)]
475pub struct Llama {
476    wte: Embedding,
477    blocks: Vec<Block>,
478    ln_f: RmsNorm,
479    lm_head: Linear,
480}
481
482impl Llama {
483    // required by LLaVA
484    pub fn embed(&self, x: &Tensor) -> Result<Tensor> {
485        self.wte.forward(x)
486    }
487    // required by LLaVA
488    pub fn forward_input_embed(
489        &self,
490        input_embed: &Tensor,
491        index_pos: usize,
492        cache: &mut Cache,
493    ) -> Result<Tensor> {
494        let (_, seq_len, _) = input_embed.dims3()?;
495        let mut x = input_embed.clone();
496        for (block_idx, block) in self.blocks.iter().enumerate() {
497            x = block.forward(&x, index_pos, block_idx, cache)?;
498        }
499        let x = self.ln_f.forward(&x)?;
500        let x = x.i((.., seq_len - 1, ..))?.contiguous()?;
501        let logits = self.lm_head.forward(&x)?;
502        logits.to_dtype(DType::F32)
503    }
504
505    pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
506        let (_b_sz, seq_len) = x.dims2()?;
507        let mut x = self.wte.forward(x)?;
508        for (block_idx, block) in self.blocks.iter().enumerate() {
509            x = block.forward(&x, index_pos, block_idx, cache)?;
510        }
511        let x = self.ln_f.forward(&x)?;
512        let x = x.i((.., seq_len - 1, ..))?.contiguous()?;
513        let logits = self.lm_head.forward(&x)?;
514        logits.to_dtype(DType::F32)
515    }
516
517    pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
518        let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
519        let lm_head = if cfg.tie_word_embeddings {
520            Linear::from_weights(wte.embeddings().clone(), None)
521        } else {
522            linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
523        };
524        let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
525        let blocks: Vec<_> = (0..cfg.num_hidden_layers)
526            .map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cfg).unwrap())
527            .collect();
528
529        Ok(Self {
530            wte,
531            blocks,
532            ln_f,
533            lm_head,
534        })
535    }
536}