1use crate::{quantized_nn::RmsNorm, utils::repeat_kv};
17use candle::{
18 quantized::{gguf_file, QMatMul},
19 DType, Device, IndexOp, Result, Tensor,
20};
21use candle_nn::{Embedding, Module};
22use std::collections::HashMap;
23
24#[derive(Debug, Clone)]
25struct Mlp {
26 feed_forward_w1: QMatMul,
27 feed_forward_w2: QMatMul,
28 feed_forward_w3: QMatMul,
29}
30
31impl Module for Mlp {
32 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
33 let w1 = self.feed_forward_w1.forward(xs)?;
34 let w3 = self.feed_forward_w3.forward(xs)?;
35 self.feed_forward_w2
36 .forward(&(candle_nn::ops::silu(&w1)? * w3)?)
37 }
38}
39
40#[derive(Debug, Clone)]
41struct LayerWeights {
42 attention_wq: QMatMul,
43 attention_wk: QMatMul,
44 attention_wv: QMatMul,
45 attention_bq: Tensor,
46 attention_bk: Tensor,
47 attention_bv: Tensor,
48 attention_wo: QMatMul,
49 attention_norm: RmsNorm,
50 mlp: Mlp,
51 ffn_norm: RmsNorm,
52 n_head: usize,
53 n_kv_head: usize,
54 head_dim: usize,
55 cos: Tensor,
56 sin: Tensor,
57 neg_inf: Tensor,
58 kv_cache: Option<(Tensor, Tensor)>,
59 span_attn: tracing::Span,
60 span_rot: tracing::Span,
61 span_mlp: tracing::Span,
62}
63
64fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> {
65 let shape = mask.shape();
66 let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?;
67 Ok(m)
68}
69
70impl LayerWeights {
71 fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
72 let _enter = self.span_rot.enter();
73 let (_b_sz, _n_head, seq_len, _n_embd) = x.dims4()?;
74 let cos = self.cos.narrow(0, index_pos, seq_len)?;
75 let sin = self.sin.narrow(0, index_pos, seq_len)?;
76 candle_nn::rotary_emb::rope(&x.contiguous()?, &cos, &sin)
77 }
78
79 fn forward_attn(
80 &mut self,
81 x: &Tensor,
82 mask: Option<&Tensor>,
83 index_pos: usize,
84 ) -> Result<Tensor> {
85 let _enter = self.span_attn.enter();
86 let (b_sz, seq_len, n_embd) = x.dims3()?;
87
88 let q = self.attention_wq.forward(x)?;
89 let k = self.attention_wk.forward(x)?;
90 let v = self.attention_wv.forward(x)?;
91
92 let q = q.broadcast_add(&self.attention_bq)?;
93 let k = k.broadcast_add(&self.attention_bk)?;
94 let v = v.broadcast_add(&self.attention_bv)?;
95
96 let q = q
97 .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
98 .transpose(1, 2)?
99 .contiguous()?;
100 let k = k
101 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
102 .transpose(1, 2)?
103 .contiguous()?;
104 let v = v
105 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
106 .transpose(1, 2)?
107 .contiguous()?;
108
109 let q = self.apply_rotary_emb(&q, index_pos)?;
113 let k = self.apply_rotary_emb(&k, index_pos)?;
114
115 let (k, v) = match &self.kv_cache {
116 None => (k, v),
117 Some((k_cache, v_cache)) => {
118 if index_pos == 0 {
119 (k, v)
120 } else {
121 let k = Tensor::cat(&[k_cache, &k], 2)?;
122 let v = Tensor::cat(&[v_cache, &v], 2)?;
123 (k, v)
124 }
125 }
126 };
127 self.kv_cache = Some((k.clone(), v.clone()));
128
129 let k = repeat_kv(k, self.n_head / self.n_kv_head)?;
131 let v = repeat_kv(v, self.n_head / self.n_kv_head)?;
132
133 let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
134 let att = match mask {
135 None => att,
136 Some(mask) => {
137 let mask = mask.broadcast_as(att.shape())?;
138 masked_fill(&att, &mask, &self.neg_inf)?
139 }
140 };
141 let att = candle_nn::ops::softmax_last_dim(&att)?;
142 let y = att.matmul(&v.contiguous()?)?;
144 let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
145 let y = self.attention_wo.forward(&y)?;
146 Ok(y)
147 }
148}
149
150pub struct ModelWeights {
151 tok_embeddings: Embedding,
152 layers: Vec<LayerWeights>,
153 norm: RmsNorm,
154 output: QMatMul,
155 masks: HashMap<usize, Tensor>,
156 span: tracing::Span,
157 span_output: tracing::Span,
158}
159
160fn precomput_freqs_cis(
161 head_dim: usize,
162 freq_base: f32,
163 context_length: usize,
164 device: &Device,
165) -> Result<(Tensor, Tensor)> {
166 let theta: Vec<_> = (0..head_dim)
167 .step_by(2)
168 .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
169 .collect();
170 let theta = Tensor::new(theta.as_slice(), device)?;
171 let idx_theta = Tensor::arange(0, context_length as u32, device)?
172 .to_dtype(DType::F32)?
173 .reshape((context_length, 1))?
174 .matmul(&theta.reshape((1, theta.elem_count()))?)?;
175 let cos = idx_theta.cos()?;
176 let sin = idx_theta.sin()?;
177 Ok((cos, sin))
178}
179
180impl ModelWeights {
181 pub fn from_gguf<R: std::io::Seek + std::io::Read>(
182 ct: gguf_file::Content,
183 reader: &mut R,
184 device: &Device,
185 ) -> Result<Self> {
186 let md_get = |s: &str| match ct.metadata.get(s) {
187 None => candle::bail!("cannot find {s} in metadata"),
188 Some(v) => Ok(v),
189 };
190
191 let head_count = md_get("qwen2.attention.head_count")?.to_u32()? as usize;
192 let head_count_kv = md_get("qwen2.attention.head_count_kv")?.to_u32()? as usize;
193 let embedding_length = md_get("qwen2.embedding_length")?.to_u32()? as usize;
194 let context_length = md_get("qwen2.context_length")?.to_u32()? as usize;
195 let block_count = md_get("qwen2.block_count")?.to_u32()? as usize;
196 let rms_norm_eps = md_get("qwen2.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
197 let rope_freq_base = md_get("qwen2.rope.freq_base")
198 .and_then(|m| m.to_f32())
199 .unwrap_or(10000f32);
200
201 let head_dim = embedding_length / head_count;
202
203 let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
204
205 let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
206 let tok_embeddings = tok_embeddings.dequantize(device)?;
207 let norm = RmsNorm::from_qtensor(
208 ct.tensor(reader, "output_norm.weight", device)?,
209 rms_norm_eps,
210 )?;
211 let output = match ct.tensor(reader, "output.weight", device) {
212 Ok(v) => QMatMul::from_qtensor(v)?,
213 _ => {
214 QMatMul::from_qtensor(ct.tensor(reader, "token_embd.weight", device)?)?
216 }
217 };
218
219 let (cos, sin) = precomput_freqs_cis(head_dim, rope_freq_base, context_length, device)?;
220
221 let mut layers = Vec::with_capacity(block_count);
222
223 for layer_idx in 0..block_count {
224 let prefix = format!("blk.{layer_idx}");
225 let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
226 let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
227 let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
228
229 let attention_bq = ct.tensor(reader, &format!("{prefix}.attn_q.bias"), device)?;
230 let attention_bk = ct.tensor(reader, &format!("{prefix}.attn_k.bias"), device)?;
231 let attention_bv = ct.tensor(reader, &format!("{prefix}.attn_v.bias"), device)?;
232
233 let attention_wo =
234 ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
235
236 let mlp = {
237 let feed_forward_w1 =
238 ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
239 let feed_forward_w2 =
240 ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
241 let feed_forward_w3 =
242 ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
243 Mlp {
244 feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
245 feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
246 feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
247 }
248 };
249
250 let attention_norm =
251 ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?;
252 let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?;
253
254 let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
255 let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
256 let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
257
258 layers.push(LayerWeights {
259 attention_wq: QMatMul::from_qtensor(attention_wq)?,
260 attention_wk: QMatMul::from_qtensor(attention_wk)?,
261 attention_wv: QMatMul::from_qtensor(attention_wv)?,
262 attention_bq: attention_bq.dequantize(device)?,
263 attention_bk: attention_bk.dequantize(device)?,
264 attention_bv: attention_bv.dequantize(device)?,
265 attention_wo: QMatMul::from_qtensor(attention_wo)?,
266 attention_norm: RmsNorm::from_qtensor(attention_norm, rms_norm_eps)?,
267 cos: cos.clone(),
268 sin: sin.clone(),
269 mlp,
270 ffn_norm: RmsNorm::from_qtensor(ffn_norm, rms_norm_eps)?,
271 n_head: head_count,
272 n_kv_head: head_count_kv,
273 head_dim,
274 neg_inf: neg_inf.clone(),
275 kv_cache: None,
276 span_attn,
277 span_rot,
278 span_mlp,
279 });
280 }
281
282 let span = tracing::span!(tracing::Level::TRACE, "model");
283 let span_output = tracing::span!(tracing::Level::TRACE, "output");
284
285 Ok(Self {
286 tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
287 layers,
288 norm,
289 output,
290 masks: HashMap::new(),
291 span,
292 span_output,
293 })
294 }
295
296 fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> {
297 if let Some(mask) = self.masks.get(&t) {
298 Ok(mask.clone())
299 } else {
300 let mask: Vec<_> = (0..t)
301 .flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
302 .collect();
303 let mask = Tensor::from_slice(&mask, (t, t), device)?;
304 self.masks.insert(t, mask.clone());
305 Ok(mask)
306 }
307 }
308
309 pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
310 let (_b_sz, seq_len) = x.dims2()?;
311 let mask = if seq_len == 1 {
312 None
313 } else {
314 Some(self.mask(seq_len, x.device())?)
315 };
316 let _enter = self.span.enter();
317 let mut layer_in = self.tok_embeddings.forward(x)?;
318 for layer in self.layers.iter_mut() {
319 let x = layer_in;
320 let residual = &x;
321 let x = layer.attention_norm.forward(&x)?;
322 let attn = layer.forward_attn(&x, mask.as_ref(), index_pos)?;
323 let x = (attn + residual)?;
324
325 let _enter = layer.span_mlp.enter();
327 let residual = &x;
328 let x = layer.ffn_norm.forward(&x)?;
329 let x = layer.mlp.forward(&x)?;
330 let x = (x + residual)?;
331 layer_in = x
332 }
333 let x = self.norm.forward(&layer_in)?;
334 let x = x.i((.., seq_len - 1, ..))?;
335 let _enter = self.span_output.enter();
336 self.output.forward(&x)
337 }
338}