candle_transformers/models/
quantized_llama.rs

1//! Quantized llama model implementation.
2//!
3//! This provides a quantized implementation of the llama language model architecture.
4//! The model implements parameter efficient quantization for reduced memory usage
5//! while maintaining model quality.
6//!
7//! Key characteristics:
8//! - Transformer decoder architecture
9//! - Support for 2/3/4/8-bit quantization
10//! - Optimized memory usage through quantization
11//! - Configurable model sizes and parameter counts
12//!
13//! - 💻 [GH Link](https://github.com/facebookresearch/llama)
14//! - 📝 [Paper](https://arxiv.org/abs/2302.13971)
15//!
16//! ![](https://raw.githubusercontent.com/huggingface/candle/main/candle-examples/examples/quantized/assets/aoc.gif)
17//!
18
19use std::collections::HashMap;
20
21use crate::quantized_nn::RmsNorm;
22use candle::quantized::QTensor;
23use candle::quantized::{ggml_file, gguf_file};
24use candle::{DType, Device, IndexOp, Result, Tensor};
25use candle_nn::{Embedding, Module};
26
27pub const MAX_SEQ_LEN: usize = 4096;
28
29// QMatMul wrapper adding some tracing.
30#[derive(Debug, Clone)]
31struct QMatMul {
32    inner: candle::quantized::QMatMul,
33    span: tracing::Span,
34}
35
36impl QMatMul {
37    fn from_qtensor(qtensor: QTensor) -> Result<Self> {
38        let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?;
39        let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
40        Ok(Self { inner, span })
41    }
42
43    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
44        let _enter = self.span.enter();
45        self.inner.forward(xs)
46    }
47}
48
49#[derive(Debug, Clone)]
50struct Mlp {
51    feed_forward_w1: QMatMul,
52    feed_forward_w2: QMatMul,
53    feed_forward_w3: QMatMul,
54}
55
56impl Module for Mlp {
57    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
58        let w1 = self.feed_forward_w1.forward(xs)?;
59        let w3 = self.feed_forward_w3.forward(xs)?;
60        self.feed_forward_w2
61            .forward(&(candle_nn::ops::silu(&w1)? * w3)?)
62    }
63}
64
65#[derive(Debug, Clone)]
66enum MlpOrMoe {
67    Mlp(Mlp),
68    MoE {
69        n_expert_used: usize,
70        feed_forward_gate_inp: QMatMul,
71        experts: Vec<Mlp>,
72    },
73}
74
75impl Module for MlpOrMoe {
76    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
77        match self {
78            Self::MoE {
79                feed_forward_gate_inp,
80                experts,
81                n_expert_used,
82            } => {
83                let (b_size, seq_len, hidden_dim) = xs.dims3()?;
84                let xs = xs.reshape(((), hidden_dim))?;
85                let router_logits = feed_forward_gate_inp.forward(&xs)?;
86                let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
87
88                // In order to extract topk, we extract the data from the tensor and manipulate it
89                // directly. Maybe we will want to use some custom ops instead at some point.
90                let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
91
92                // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
93                // top_x contains the row indexes to evaluate for each expert.
94                let mut top_x = vec![vec![]; experts.len()];
95                let mut selected_rws = vec![vec![]; experts.len()];
96                for (row_idx, rw) in routing_weights.iter().enumerate() {
97                    let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
98                    dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
99                    let mut sum_routing_weights = 0f32;
100                    for &expert_idx in dst.iter().take(*n_expert_used) {
101                        let expert_idx = expert_idx as usize;
102                        let routing_weight = rw[expert_idx];
103                        sum_routing_weights += routing_weight;
104                        top_x[expert_idx].push(row_idx as u32);
105                    }
106                    for &expert_idx in dst.iter().take(*n_expert_used) {
107                        let expert_idx = expert_idx as usize;
108                        let routing_weight = rw[expert_idx];
109                        selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
110                    }
111                }
112
113                // routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
114                // expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
115
116                let mut ys = xs.zeros_like()?;
117                for (expert_idx, expert_layer) in experts.iter().enumerate() {
118                    let top_x = &top_x[expert_idx];
119                    if top_x.is_empty() {
120                        continue;
121                    }
122                    let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
123                    let selected_rws =
124                        Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?
125                            .reshape(((), 1))?;
126                    // Index the correct hidden states and compute the expert hidden state for
127                    // the current expert. We need to make sure to multiply the output hidden
128                    // states by `routing_weights` on the corresponding tokens (top-1 and top-2)
129                    let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
130                    // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
131                    let current_hidden_states = expert_layer.forward(&current_state)?;
132                    let current_hidden_states =
133                        current_hidden_states.broadcast_mul(&selected_rws)?;
134                    ys = ys.index_add(&top_x, &current_hidden_states, 0)?;
135                }
136
137                let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
138                Ok(ys)
139            }
140            Self::Mlp(mlp) => mlp.forward(xs),
141        }
142    }
143}
144
145#[derive(Debug, Clone)]
146struct LayerWeights {
147    attention_wq: QMatMul,
148    attention_wk: QMatMul,
149    attention_wv: QMatMul,
150    attention_wo: QMatMul,
151    attention_norm: RmsNorm,
152    mlp_or_moe: MlpOrMoe,
153    ffn_norm: RmsNorm,
154    n_head: usize,
155    n_kv_head: usize,
156    head_dim: usize,
157    cos: Tensor,
158    sin: Tensor,
159    neg_inf: Tensor,
160    kv_cache: Option<(Tensor, Tensor)>,
161    span_attn: tracing::Span,
162    span_rot: tracing::Span,
163    span_mlp: tracing::Span,
164}
165
166fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> {
167    let shape = mask.shape();
168    let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?;
169    Ok(m)
170}
171
172impl LayerWeights {
173    fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
174        let _enter = self.span_rot.enter();
175        let (_b_sz, _n_head, seq_len, _n_embd) = x.dims4()?;
176        let cos = self.cos.narrow(0, index_pos, seq_len)?;
177        let sin = self.sin.narrow(0, index_pos, seq_len)?;
178        // The call to contiguous below is only necessary when processing the prompt.
179        // When the seq_len is 1 in the inference loop, this is a no-op.
180        candle_nn::rotary_emb::rope_i(&x.contiguous()?, &cos, &sin)
181    }
182
183    fn forward_attn(
184        &mut self,
185        x: &Tensor,
186        mask: Option<&Tensor>,
187        index_pos: usize,
188    ) -> Result<Tensor> {
189        let _enter = self.span_attn.enter();
190        let (b_sz, seq_len, n_embd) = x.dims3()?;
191        let q = self.attention_wq.forward(x)?;
192        let k = self.attention_wk.forward(x)?;
193        let v = self.attention_wv.forward(x)?;
194
195        let q = q
196            .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
197            .transpose(1, 2)?;
198        let k = k
199            .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
200            .transpose(1, 2)?;
201        let v = v
202            .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
203            .transpose(1, 2)?
204            // This call to contiguous ensures that the fast kernel can be called below. It's
205            // actually a no-op except when processing the initial prompt so has no significant
206            // impact on performance.
207            .contiguous()?;
208
209        let q = self.apply_rotary_emb(&q, index_pos)?;
210        let k = self.apply_rotary_emb(&k, index_pos)?;
211
212        let (k, v) = match &self.kv_cache {
213            None => (k, v),
214            Some((k_cache, v_cache)) => {
215                if index_pos == 0 {
216                    (k, v)
217                } else {
218                    let k = Tensor::cat(&[k_cache, &k], 2)?;
219                    let v = Tensor::cat(&[v_cache, &v], 2)?;
220                    (k, v)
221                }
222            }
223        };
224        self.kv_cache = Some((k.clone(), v.clone()));
225
226        let y = if q.device().is_metal() && seq_len == 1 {
227            // SDPA will do MQA for us
228            candle_nn::ops::sdpa(&q, &k, &v, 1. / (self.head_dim as f32).sqrt(), 1.)?
229        } else {
230            // Support for MQA, useful for 70B models and mistral.
231            let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
232            let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
233
234            let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
235            let att = match mask {
236                None => att,
237                Some(mask) => {
238                    let mask = mask.broadcast_as(att.shape())?;
239                    masked_fill(&att, &mask, &self.neg_inf)?
240                }
241            };
242            let att = candle_nn::ops::softmax_last_dim(&att)?;
243            // Convert to contiguous as matmul doesn't support strided vs for now.
244            att.matmul(&v.contiguous()?)?
245        };
246
247        let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
248        let y = self.attention_wo.forward(&y)?;
249        Ok(y)
250    }
251}
252
253#[derive(Debug, Clone)]
254pub struct ModelWeights {
255    tok_embeddings: Embedding,
256    layers: Vec<LayerWeights>,
257    norm: RmsNorm,
258    output: QMatMul,
259    masks: HashMap<usize, Tensor>,
260    span: tracing::Span,
261    span_output: tracing::Span,
262}
263
264fn precomput_freqs_cis(
265    head_dim: usize,
266    freq_base: f32,
267    device: &Device,
268) -> Result<(Tensor, Tensor)> {
269    let theta: Vec<_> = (0..head_dim)
270        .step_by(2)
271        .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
272        .collect();
273    let theta = Tensor::new(theta.as_slice(), device)?;
274    let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
275        .to_dtype(DType::F32)?
276        .reshape((MAX_SEQ_LEN, 1))?
277        .matmul(&theta.reshape((1, theta.elem_count()))?)?;
278    let cos = idx_theta.cos()?;
279    let sin = idx_theta.sin()?;
280    Ok((cos, sin))
281}
282
283impl ModelWeights {
284    pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
285        let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
286        let (cos, sin) = precomput_freqs_cis(head_dim, 10000., &ct.device)?;
287        let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?;
288        let tok_embeddings = ct.remove("tok_embeddings.weight")?;
289        let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
290        let norm = RmsNorm::from_qtensor(ct.remove("norm.weight")?, 1e-5)?;
291        let output = ct.remove("output.weight")?;
292        let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
293        for layer_idx in 0..ct.hparams.n_layer {
294            let prefix = format!("layers.{layer_idx}");
295            let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?;
296            let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
297            let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
298            let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
299            let mlp_or_moe = {
300                let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
301                let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
302                let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
303                MlpOrMoe::Mlp(Mlp {
304                    feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
305                    feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
306                    feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
307                })
308            };
309            let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?;
310            let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?;
311            let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
312            let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
313            let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
314            layers.push(LayerWeights {
315                attention_wq: QMatMul::from_qtensor(attention_wq)?,
316                attention_wk: QMatMul::from_qtensor(attention_wk)?,
317                attention_wv: QMatMul::from_qtensor(attention_wv)?,
318                attention_wo: QMatMul::from_qtensor(attention_wo)?,
319                attention_norm: RmsNorm::from_qtensor(attention_norm, 1e-5)?,
320                mlp_or_moe,
321                ffn_norm: RmsNorm::from_qtensor(ffn_norm, 1e-5)?,
322                n_head: ct.hparams.n_head as usize,
323                n_kv_head: ct.hparams.n_head as usize / gqa,
324                head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
325                cos: cos.clone(),
326                sin: sin.clone(),
327                neg_inf: neg_inf.clone(),
328                kv_cache: None,
329                span_attn,
330                span_rot,
331                span_mlp,
332            })
333        }
334        let span = tracing::span!(tracing::Level::TRACE, "model");
335        let span_output = tracing::span!(tracing::Level::TRACE, "output");
336        Ok(Self {
337            tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
338            layers,
339            norm,
340            output: QMatMul::from_qtensor(output)?,
341            masks: HashMap::new(),
342            span,
343            span_output,
344        })
345    }
346
347    pub fn from_gguf<R: std::io::Seek + std::io::Read>(
348        ct: gguf_file::Content,
349        reader: &mut R,
350        device: &Device,
351    ) -> Result<Self> {
352        let md_get = |s: &str| match ct.metadata.get(s) {
353            None => candle::bail!("cannot find {s} in metadata"),
354            Some(v) => Ok(v),
355        };
356
357        // Parameter extraction from metadata.
358        let n_expert = md_get("llama.expert_count")
359            .and_then(|v| v.to_u32())
360            .unwrap_or(0) as usize;
361        let n_expert_used = md_get("llama.expert_used_count")
362            .and_then(|v| v.to_u32())
363            .unwrap_or(0) as usize;
364        let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
365        let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
366        let block_count = md_get("llama.block_count")?.to_u32()? as usize;
367        let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
368        let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
369        // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
370        let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
371
372        let rope_freq_base = md_get("llama.rope.freq_base")
373            .and_then(|m| m.to_f32())
374            .unwrap_or(10000f32);
375        let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
376        let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
377
378        let tok_embeddings_q = ct.tensor(reader, "token_embd.weight", device)?;
379        let tok_embeddings = tok_embeddings_q.dequantize(device)?;
380        let norm = RmsNorm::from_qtensor(
381            ct.tensor(reader, "output_norm.weight", device)?,
382            rms_norm_eps,
383        )?;
384        let output = match ct.tensor(reader, "output.weight", device) {
385            Ok(tensor) => tensor,
386            Err(_) => tok_embeddings_q,
387        };
388        let mut layers = Vec::with_capacity(block_count);
389        for layer_idx in 0..block_count {
390            let prefix = format!("blk.{layer_idx}");
391            let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
392            let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
393            let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
394            let attention_wo =
395                ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
396            let mlp_or_moe = if n_expert <= 1 {
397                let feed_forward_w1 =
398                    ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
399                let feed_forward_w2 =
400                    ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
401                let feed_forward_w3 =
402                    ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
403                MlpOrMoe::Mlp(Mlp {
404                    feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
405                    feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
406                    feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
407                })
408            } else {
409                let feed_forward_gate_inp =
410                    ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?;
411                let mut experts = Vec::with_capacity(n_expert);
412                for i in 0..n_expert {
413                    let feed_forward_w1 =
414                        ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?;
415                    let feed_forward_w2 =
416                        ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?;
417                    let feed_forward_w3 =
418                        ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?;
419                    experts.push(Mlp {
420                        feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
421                        feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
422                        feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
423                    })
424                }
425                MlpOrMoe::MoE {
426                    n_expert_used,
427                    feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?,
428                    experts,
429                }
430            };
431            let attention_norm =
432                ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?;
433            let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?;
434            let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
435            let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
436            let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
437            layers.push(LayerWeights {
438                attention_wq: QMatMul::from_qtensor(attention_wq)?,
439                attention_wk: QMatMul::from_qtensor(attention_wk)?,
440                attention_wv: QMatMul::from_qtensor(attention_wv)?,
441                attention_wo: QMatMul::from_qtensor(attention_wo)?,
442                attention_norm: RmsNorm::from_qtensor(attention_norm, rms_norm_eps)?,
443                mlp_or_moe,
444                ffn_norm: RmsNorm::from_qtensor(ffn_norm, rms_norm_eps)?,
445                n_head: head_count,
446                n_kv_head: head_count_kv,
447                head_dim: embedding_length / head_count,
448                cos: cos.clone(),
449                sin: sin.clone(),
450                neg_inf: neg_inf.clone(),
451                kv_cache: None,
452                span_attn,
453                span_rot,
454                span_mlp,
455            })
456        }
457        let span = tracing::span!(tracing::Level::TRACE, "model");
458        let span_output = tracing::span!(tracing::Level::TRACE, "output");
459        Ok(Self {
460            tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
461            layers,
462            norm,
463            output: QMatMul::from_qtensor(output)?,
464            masks: HashMap::new(),
465            span,
466            span_output,
467        })
468    }
469
470    fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> {
471        if let Some(mask) = self.masks.get(&t) {
472            Ok(mask.clone())
473        } else {
474            let mask: Vec<_> = (0..t)
475                .flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
476                .collect();
477            let mask = Tensor::from_slice(&mask, (t, t), device)?;
478            self.masks.insert(t, mask.clone());
479            Ok(mask)
480        }
481    }
482
483    pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
484        let (_b_sz, seq_len) = x.dims2()?;
485        let mask = if seq_len == 1 {
486            None
487        } else {
488            Some(self.mask(seq_len, x.device())?)
489        };
490        let _enter = self.span.enter();
491        let mut layer_in = self.tok_embeddings.forward(x)?;
492        for layer in self.layers.iter_mut() {
493            let x = layer_in;
494            let residual = &x;
495            let x = layer.attention_norm.forward(&x)?;
496            let attn = layer.forward_attn(&x, mask.as_ref(), index_pos)?;
497            let x = (attn + residual)?;
498
499            // MLP
500            let _enter = layer.span_mlp.enter();
501            let residual = &x;
502            let x = layer.ffn_norm.forward(&x)?;
503            let x = layer.mlp_or_moe.forward(&x)?;
504            let x = (x + residual)?;
505            layer_in = x
506        }
507        let x = self.norm.forward(&layer_in)?;
508        let x = x.i((.., seq_len - 1, ..))?;
509        let _enter = self.span_output.enter();
510        self.output.forward(&x)
511    }
512}