candle_transformers/models/
quantized_qwen2.rs

1//! Qwen2 model implementation with quantization support.
2//!
3//! Qwen2 is a chat-optimized language model that supports 8-bit quantization
4//! for reduced memory usage and faster inference.
5//!
6//! Key characteristics:
7//! - Group Query Attention (GQA)
8//! - RMSNorm for layer normalization
9//! - Rotary positional embeddings (RoPE)
10//! - Support for 8-bit quantization
11//!
12//! References:
13//! - [Model Card](https://huggingface.co/Qwen/Qwen2)
14//!
15
16use crate::{quantized_nn::RmsNorm, utils::repeat_kv};
17use candle::{
18    quantized::{gguf_file, QMatMul},
19    DType, Device, IndexOp, Result, Tensor,
20};
21use candle_nn::{Embedding, Module};
22use std::collections::HashMap;
23
24#[derive(Debug, Clone)]
25struct Mlp {
26    feed_forward_w1: QMatMul,
27    feed_forward_w2: QMatMul,
28    feed_forward_w3: QMatMul,
29}
30
31impl Module for Mlp {
32    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
33        let w1 = self.feed_forward_w1.forward(xs)?;
34        let w3 = self.feed_forward_w3.forward(xs)?;
35        self.feed_forward_w2
36            .forward(&(candle_nn::ops::silu(&w1)? * w3)?)
37    }
38}
39
40#[derive(Debug, Clone)]
41struct LayerWeights {
42    attention_wq: QMatMul,
43    attention_wk: QMatMul,
44    attention_wv: QMatMul,
45    attention_bq: Tensor,
46    attention_bk: Tensor,
47    attention_bv: Tensor,
48    attention_wo: QMatMul,
49    attention_norm: RmsNorm,
50    mlp: Mlp,
51    ffn_norm: RmsNorm,
52    n_head: usize,
53    n_kv_head: usize,
54    head_dim: usize,
55    cos: Tensor,
56    sin: Tensor,
57    neg_inf: Tensor,
58    kv_cache: Option<(Tensor, Tensor)>,
59    span_attn: tracing::Span,
60    span_rot: tracing::Span,
61    span_mlp: tracing::Span,
62}
63
64fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> {
65    let shape = mask.shape();
66    let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?;
67    Ok(m)
68}
69
70impl LayerWeights {
71    fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
72        let _enter = self.span_rot.enter();
73        let (_b_sz, _n_head, seq_len, _n_embd) = x.dims4()?;
74        let cos = self.cos.narrow(0, index_pos, seq_len)?;
75        let sin = self.sin.narrow(0, index_pos, seq_len)?;
76        candle_nn::rotary_emb::rope(&x.contiguous()?, &cos, &sin)
77    }
78
79    fn forward_attn(
80        &mut self,
81        x: &Tensor,
82        mask: Option<&Tensor>,
83        index_pos: usize,
84    ) -> Result<Tensor> {
85        let _enter = self.span_attn.enter();
86        let (b_sz, seq_len, n_embd) = x.dims3()?;
87
88        let q = self.attention_wq.forward(x)?;
89        let k = self.attention_wk.forward(x)?;
90        let v = self.attention_wv.forward(x)?;
91
92        let q = q.broadcast_add(&self.attention_bq)?;
93        let k = k.broadcast_add(&self.attention_bk)?;
94        let v = v.broadcast_add(&self.attention_bv)?;
95
96        let q = q
97            .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
98            .transpose(1, 2)?
99            .contiguous()?;
100        let k = k
101            .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
102            .transpose(1, 2)?
103            .contiguous()?;
104        let v = v
105            .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
106            .transpose(1, 2)?
107            .contiguous()?;
108
109        // let (q, k) = self
110        //     .rotary_embedding
111        //     .apply_rotary_emb_qkv(&q, &k, index_pos)?;
112        let q = self.apply_rotary_emb(&q, index_pos)?;
113        let k = self.apply_rotary_emb(&k, index_pos)?;
114
115        let (k, v) = match &self.kv_cache {
116            None => (k, v),
117            Some((k_cache, v_cache)) => {
118                if index_pos == 0 {
119                    (k, v)
120                } else {
121                    let k = Tensor::cat(&[k_cache, &k], 2)?;
122                    let v = Tensor::cat(&[v_cache, &v], 2)?;
123                    (k, v)
124                }
125            }
126        };
127        self.kv_cache = Some((k.clone(), v.clone()));
128
129        // Support for MQA, useful for 70B models and mistral.
130        let k = repeat_kv(k, self.n_head / self.n_kv_head)?;
131        let v = repeat_kv(v, self.n_head / self.n_kv_head)?;
132
133        let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
134        let att = match mask {
135            None => att,
136            Some(mask) => {
137                let mask = mask.broadcast_as(att.shape())?;
138                masked_fill(&att, &mask, &self.neg_inf)?
139            }
140        };
141        let att = candle_nn::ops::softmax_last_dim(&att)?;
142        // Convert to contiguous as matmul doesn't support strided vs for now.
143        let y = att.matmul(&v.contiguous()?)?;
144        let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
145        let y = self.attention_wo.forward(&y)?;
146        Ok(y)
147    }
148}
149
150pub struct ModelWeights {
151    tok_embeddings: Embedding,
152    layers: Vec<LayerWeights>,
153    norm: RmsNorm,
154    output: QMatMul,
155    masks: HashMap<usize, Tensor>,
156    span: tracing::Span,
157    span_output: tracing::Span,
158}
159
160fn precomput_freqs_cis(
161    head_dim: usize,
162    freq_base: f32,
163    context_length: usize,
164    device: &Device,
165) -> Result<(Tensor, Tensor)> {
166    let theta: Vec<_> = (0..head_dim)
167        .step_by(2)
168        .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
169        .collect();
170    let theta = Tensor::new(theta.as_slice(), device)?;
171    let idx_theta = Tensor::arange(0, context_length as u32, device)?
172        .to_dtype(DType::F32)?
173        .reshape((context_length, 1))?
174        .matmul(&theta.reshape((1, theta.elem_count()))?)?;
175    let cos = idx_theta.cos()?;
176    let sin = idx_theta.sin()?;
177    Ok((cos, sin))
178}
179
180impl ModelWeights {
181    pub fn from_gguf<R: std::io::Seek + std::io::Read>(
182        ct: gguf_file::Content,
183        reader: &mut R,
184        device: &Device,
185    ) -> Result<Self> {
186        let md_get = |s: &str| match ct.metadata.get(s) {
187            None => candle::bail!("cannot find {s} in metadata"),
188            Some(v) => Ok(v),
189        };
190
191        let head_count = md_get("qwen2.attention.head_count")?.to_u32()? as usize;
192        let head_count_kv = md_get("qwen2.attention.head_count_kv")?.to_u32()? as usize;
193        let embedding_length = md_get("qwen2.embedding_length")?.to_u32()? as usize;
194        let context_length = md_get("qwen2.context_length")?.to_u32()? as usize;
195        let block_count = md_get("qwen2.block_count")?.to_u32()? as usize;
196        let rms_norm_eps = md_get("qwen2.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
197        let rope_freq_base = md_get("qwen2.rope.freq_base")
198            .and_then(|m| m.to_f32())
199            .unwrap_or(10000f32);
200
201        let head_dim = embedding_length / head_count;
202
203        let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
204
205        let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
206        let tok_embeddings = tok_embeddings.dequantize(device)?;
207        let norm = RmsNorm::from_qtensor(
208            ct.tensor(reader, "output_norm.weight", device)?,
209            rms_norm_eps,
210        )?;
211        let output = match ct.tensor(reader, "output.weight", device) {
212            Ok(v) => QMatMul::from_qtensor(v)?,
213            _ => {
214                // use tie_word_embeddings
215                QMatMul::from_qtensor(ct.tensor(reader, "token_embd.weight", device)?)?
216            }
217        };
218
219        let (cos, sin) = precomput_freqs_cis(head_dim, rope_freq_base, context_length, device)?;
220
221        let mut layers = Vec::with_capacity(block_count);
222
223        for layer_idx in 0..block_count {
224            let prefix = format!("blk.{layer_idx}");
225            let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
226            let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
227            let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
228
229            let attention_bq = ct.tensor(reader, &format!("{prefix}.attn_q.bias"), device)?;
230            let attention_bk = ct.tensor(reader, &format!("{prefix}.attn_k.bias"), device)?;
231            let attention_bv = ct.tensor(reader, &format!("{prefix}.attn_v.bias"), device)?;
232
233            let attention_wo =
234                ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
235
236            let mlp = {
237                let feed_forward_w1 =
238                    ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
239                let feed_forward_w2 =
240                    ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
241                let feed_forward_w3 =
242                    ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
243                Mlp {
244                    feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
245                    feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
246                    feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
247                }
248            };
249
250            let attention_norm =
251                ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?;
252            let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?;
253
254            let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
255            let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
256            let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
257
258            layers.push(LayerWeights {
259                attention_wq: QMatMul::from_qtensor(attention_wq)?,
260                attention_wk: QMatMul::from_qtensor(attention_wk)?,
261                attention_wv: QMatMul::from_qtensor(attention_wv)?,
262                attention_bq: attention_bq.dequantize(device)?,
263                attention_bk: attention_bk.dequantize(device)?,
264                attention_bv: attention_bv.dequantize(device)?,
265                attention_wo: QMatMul::from_qtensor(attention_wo)?,
266                attention_norm: RmsNorm::from_qtensor(attention_norm, rms_norm_eps)?,
267                cos: cos.clone(),
268                sin: sin.clone(),
269                mlp,
270                ffn_norm: RmsNorm::from_qtensor(ffn_norm, rms_norm_eps)?,
271                n_head: head_count,
272                n_kv_head: head_count_kv,
273                head_dim,
274                neg_inf: neg_inf.clone(),
275                kv_cache: None,
276                span_attn,
277                span_rot,
278                span_mlp,
279            });
280        }
281
282        let span = tracing::span!(tracing::Level::TRACE, "model");
283        let span_output = tracing::span!(tracing::Level::TRACE, "output");
284
285        Ok(Self {
286            tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
287            layers,
288            norm,
289            output,
290            masks: HashMap::new(),
291            span,
292            span_output,
293        })
294    }
295
296    fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> {
297        if let Some(mask) = self.masks.get(&t) {
298            Ok(mask.clone())
299        } else {
300            let mask: Vec<_> = (0..t)
301                .flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
302                .collect();
303            let mask = Tensor::from_slice(&mask, (t, t), device)?;
304            self.masks.insert(t, mask.clone());
305            Ok(mask)
306        }
307    }
308
309    pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
310        let (_b_sz, seq_len) = x.dims2()?;
311        let mask = if seq_len == 1 {
312            None
313        } else {
314            Some(self.mask(seq_len, x.device())?)
315        };
316        let _enter = self.span.enter();
317        let mut layer_in = self.tok_embeddings.forward(x)?;
318        for layer in self.layers.iter_mut() {
319            let x = layer_in;
320            let residual = &x;
321            let x = layer.attention_norm.forward(&x)?;
322            let attn = layer.forward_attn(&x, mask.as_ref(), index_pos)?;
323            let x = (attn + residual)?;
324
325            // MLP
326            let _enter = layer.span_mlp.enter();
327            let residual = &x;
328            let x = layer.ffn_norm.forward(&x)?;
329            let x = layer.mlp.forward(&x)?;
330            let x = (x + residual)?;
331            layer_in = x
332        }
333        let x = self.norm.forward(&layer_in)?;
334        let x = x.i((.., seq_len - 1, ..))?;
335        let _enter = self.span_output.enter();
336        self.output.forward(&x)
337    }
338}