1use std::collections::HashMap;
20
21use crate::quantized_nn::RmsNorm;
22use candle::quantized::QTensor;
23use candle::quantized::{ggml_file, gguf_file};
24use candle::{DType, Device, IndexOp, Result, Tensor};
25use candle_nn::{Embedding, Module};
26
27pub const MAX_SEQ_LEN: usize = 4096;
28
29#[derive(Debug, Clone)]
31struct QMatMul {
32 inner: candle::quantized::QMatMul,
33 span: tracing::Span,
34}
35
36impl QMatMul {
37 fn from_qtensor(qtensor: QTensor) -> Result<Self> {
38 let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?;
39 let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
40 Ok(Self { inner, span })
41 }
42
43 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
44 let _enter = self.span.enter();
45 self.inner.forward(xs)
46 }
47}
48
49#[derive(Debug, Clone)]
50struct Mlp {
51 feed_forward_w1: QMatMul,
52 feed_forward_w2: QMatMul,
53 feed_forward_w3: QMatMul,
54}
55
56impl Module for Mlp {
57 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
58 let w1 = self.feed_forward_w1.forward(xs)?;
59 let w3 = self.feed_forward_w3.forward(xs)?;
60 self.feed_forward_w2
61 .forward(&(candle_nn::ops::silu(&w1)? * w3)?)
62 }
63}
64
65#[derive(Debug, Clone)]
66enum MlpOrMoe {
67 Mlp(Mlp),
68 MoE {
69 n_expert_used: usize,
70 feed_forward_gate_inp: QMatMul,
71 experts: Vec<Mlp>,
72 },
73}
74
75impl Module for MlpOrMoe {
76 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
77 match self {
78 Self::MoE {
79 feed_forward_gate_inp,
80 experts,
81 n_expert_used,
82 } => {
83 let (b_size, seq_len, hidden_dim) = xs.dims3()?;
84 let xs = xs.reshape(((), hidden_dim))?;
85 let router_logits = feed_forward_gate_inp.forward(&xs)?;
86 let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
87
88 let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
91
92 let mut top_x = vec![vec![]; experts.len()];
95 let mut selected_rws = vec![vec![]; experts.len()];
96 for (row_idx, rw) in routing_weights.iter().enumerate() {
97 let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
98 dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
99 let mut sum_routing_weights = 0f32;
100 for &expert_idx in dst.iter().take(*n_expert_used) {
101 let expert_idx = expert_idx as usize;
102 let routing_weight = rw[expert_idx];
103 sum_routing_weights += routing_weight;
104 top_x[expert_idx].push(row_idx as u32);
105 }
106 for &expert_idx in dst.iter().take(*n_expert_used) {
107 let expert_idx = expert_idx as usize;
108 let routing_weight = rw[expert_idx];
109 selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
110 }
111 }
112
113 let mut ys = xs.zeros_like()?;
117 for (expert_idx, expert_layer) in experts.iter().enumerate() {
118 let top_x = &top_x[expert_idx];
119 if top_x.is_empty() {
120 continue;
121 }
122 let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
123 let selected_rws =
124 Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?
125 .reshape(((), 1))?;
126 let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
130 let current_hidden_states = expert_layer.forward(¤t_state)?;
132 let current_hidden_states =
133 current_hidden_states.broadcast_mul(&selected_rws)?;
134 ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?;
135 }
136
137 let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
138 Ok(ys)
139 }
140 Self::Mlp(mlp) => mlp.forward(xs),
141 }
142 }
143}
144
145#[derive(Debug, Clone)]
146struct LayerWeights {
147 attention_wq: QMatMul,
148 attention_wk: QMatMul,
149 attention_wv: QMatMul,
150 attention_wo: QMatMul,
151 attention_norm: RmsNorm,
152 mlp_or_moe: MlpOrMoe,
153 ffn_norm: RmsNorm,
154 n_head: usize,
155 n_kv_head: usize,
156 head_dim: usize,
157 cos: Tensor,
158 sin: Tensor,
159 neg_inf: Tensor,
160 kv_cache: Option<(Tensor, Tensor)>,
161 span_attn: tracing::Span,
162 span_rot: tracing::Span,
163 span_mlp: tracing::Span,
164}
165
166fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> {
167 let shape = mask.shape();
168 let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?;
169 Ok(m)
170}
171
172impl LayerWeights {
173 fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
174 let _enter = self.span_rot.enter();
175 let (_b_sz, _n_head, seq_len, _n_embd) = x.dims4()?;
176 let cos = self.cos.narrow(0, index_pos, seq_len)?;
177 let sin = self.sin.narrow(0, index_pos, seq_len)?;
178 candle_nn::rotary_emb::rope_i(&x.contiguous()?, &cos, &sin)
181 }
182
183 fn forward_attn(
184 &mut self,
185 x: &Tensor,
186 mask: Option<&Tensor>,
187 index_pos: usize,
188 ) -> Result<Tensor> {
189 let _enter = self.span_attn.enter();
190 let (b_sz, seq_len, n_embd) = x.dims3()?;
191 let q = self.attention_wq.forward(x)?;
192 let k = self.attention_wk.forward(x)?;
193 let v = self.attention_wv.forward(x)?;
194
195 let q = q
196 .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
197 .transpose(1, 2)?;
198 let k = k
199 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
200 .transpose(1, 2)?;
201 let v = v
202 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
203 .transpose(1, 2)?
204 .contiguous()?;
208
209 let q = self.apply_rotary_emb(&q, index_pos)?;
210 let k = self.apply_rotary_emb(&k, index_pos)?;
211
212 let (k, v) = match &self.kv_cache {
213 None => (k, v),
214 Some((k_cache, v_cache)) => {
215 if index_pos == 0 {
216 (k, v)
217 } else {
218 let k = Tensor::cat(&[k_cache, &k], 2)?;
219 let v = Tensor::cat(&[v_cache, &v], 2)?;
220 (k, v)
221 }
222 }
223 };
224 self.kv_cache = Some((k.clone(), v.clone()));
225
226 let y = if q.device().is_metal() && seq_len == 1 {
227 candle_nn::ops::sdpa(&q, &k, &v, 1. / (self.head_dim as f32).sqrt(), 1.)?
229 } else {
230 let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
232 let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
233
234 let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
235 let att = match mask {
236 None => att,
237 Some(mask) => {
238 let mask = mask.broadcast_as(att.shape())?;
239 masked_fill(&att, &mask, &self.neg_inf)?
240 }
241 };
242 let att = candle_nn::ops::softmax_last_dim(&att)?;
243 att.matmul(&v.contiguous()?)?
245 };
246
247 let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
248 let y = self.attention_wo.forward(&y)?;
249 Ok(y)
250 }
251}
252
253#[derive(Debug, Clone)]
254pub struct ModelWeights {
255 tok_embeddings: Embedding,
256 layers: Vec<LayerWeights>,
257 norm: RmsNorm,
258 output: QMatMul,
259 masks: HashMap<usize, Tensor>,
260 span: tracing::Span,
261 span_output: tracing::Span,
262}
263
264fn precomput_freqs_cis(
265 head_dim: usize,
266 freq_base: f32,
267 device: &Device,
268) -> Result<(Tensor, Tensor)> {
269 let theta: Vec<_> = (0..head_dim)
270 .step_by(2)
271 .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
272 .collect();
273 let theta = Tensor::new(theta.as_slice(), device)?;
274 let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
275 .to_dtype(DType::F32)?
276 .reshape((MAX_SEQ_LEN, 1))?
277 .matmul(&theta.reshape((1, theta.elem_count()))?)?;
278 let cos = idx_theta.cos()?;
279 let sin = idx_theta.sin()?;
280 Ok((cos, sin))
281}
282
283impl ModelWeights {
284 pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
285 let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
286 let (cos, sin) = precomput_freqs_cis(head_dim, 10000., &ct.device)?;
287 let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?;
288 let tok_embeddings = ct.remove("tok_embeddings.weight")?;
289 let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
290 let norm = RmsNorm::from_qtensor(ct.remove("norm.weight")?, 1e-5)?;
291 let output = ct.remove("output.weight")?;
292 let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
293 for layer_idx in 0..ct.hparams.n_layer {
294 let prefix = format!("layers.{layer_idx}");
295 let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?;
296 let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
297 let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
298 let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
299 let mlp_or_moe = {
300 let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
301 let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
302 let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
303 MlpOrMoe::Mlp(Mlp {
304 feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
305 feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
306 feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
307 })
308 };
309 let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?;
310 let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?;
311 let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
312 let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
313 let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
314 layers.push(LayerWeights {
315 attention_wq: QMatMul::from_qtensor(attention_wq)?,
316 attention_wk: QMatMul::from_qtensor(attention_wk)?,
317 attention_wv: QMatMul::from_qtensor(attention_wv)?,
318 attention_wo: QMatMul::from_qtensor(attention_wo)?,
319 attention_norm: RmsNorm::from_qtensor(attention_norm, 1e-5)?,
320 mlp_or_moe,
321 ffn_norm: RmsNorm::from_qtensor(ffn_norm, 1e-5)?,
322 n_head: ct.hparams.n_head as usize,
323 n_kv_head: ct.hparams.n_head as usize / gqa,
324 head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
325 cos: cos.clone(),
326 sin: sin.clone(),
327 neg_inf: neg_inf.clone(),
328 kv_cache: None,
329 span_attn,
330 span_rot,
331 span_mlp,
332 })
333 }
334 let span = tracing::span!(tracing::Level::TRACE, "model");
335 let span_output = tracing::span!(tracing::Level::TRACE, "output");
336 Ok(Self {
337 tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
338 layers,
339 norm,
340 output: QMatMul::from_qtensor(output)?,
341 masks: HashMap::new(),
342 span,
343 span_output,
344 })
345 }
346
347 pub fn from_gguf<R: std::io::Seek + std::io::Read>(
348 ct: gguf_file::Content,
349 reader: &mut R,
350 device: &Device,
351 ) -> Result<Self> {
352 let md_get = |s: &str| match ct.metadata.get(s) {
353 None => candle::bail!("cannot find {s} in metadata"),
354 Some(v) => Ok(v),
355 };
356
357 let n_expert = md_get("llama.expert_count")
359 .and_then(|v| v.to_u32())
360 .unwrap_or(0) as usize;
361 let n_expert_used = md_get("llama.expert_used_count")
362 .and_then(|v| v.to_u32())
363 .unwrap_or(0) as usize;
364 let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
365 let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
366 let block_count = md_get("llama.block_count")?.to_u32()? as usize;
367 let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
368 let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
369 let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
371
372 let rope_freq_base = md_get("llama.rope.freq_base")
373 .and_then(|m| m.to_f32())
374 .unwrap_or(10000f32);
375 let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
376 let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
377
378 let tok_embeddings_q = ct.tensor(reader, "token_embd.weight", device)?;
379 let tok_embeddings = tok_embeddings_q.dequantize(device)?;
380 let norm = RmsNorm::from_qtensor(
381 ct.tensor(reader, "output_norm.weight", device)?,
382 rms_norm_eps,
383 )?;
384 let output = match ct.tensor(reader, "output.weight", device) {
385 Ok(tensor) => tensor,
386 Err(_) => tok_embeddings_q,
387 };
388 let mut layers = Vec::with_capacity(block_count);
389 for layer_idx in 0..block_count {
390 let prefix = format!("blk.{layer_idx}");
391 let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
392 let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
393 let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
394 let attention_wo =
395 ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
396 let mlp_or_moe = if n_expert <= 1 {
397 let feed_forward_w1 =
398 ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
399 let feed_forward_w2 =
400 ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
401 let feed_forward_w3 =
402 ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
403 MlpOrMoe::Mlp(Mlp {
404 feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
405 feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
406 feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
407 })
408 } else {
409 let feed_forward_gate_inp =
410 ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?;
411 let mut experts = Vec::with_capacity(n_expert);
412 for i in 0..n_expert {
413 let feed_forward_w1 =
414 ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?;
415 let feed_forward_w2 =
416 ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?;
417 let feed_forward_w3 =
418 ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?;
419 experts.push(Mlp {
420 feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
421 feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
422 feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
423 })
424 }
425 MlpOrMoe::MoE {
426 n_expert_used,
427 feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?,
428 experts,
429 }
430 };
431 let attention_norm =
432 ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?;
433 let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?;
434 let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
435 let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
436 let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
437 layers.push(LayerWeights {
438 attention_wq: QMatMul::from_qtensor(attention_wq)?,
439 attention_wk: QMatMul::from_qtensor(attention_wk)?,
440 attention_wv: QMatMul::from_qtensor(attention_wv)?,
441 attention_wo: QMatMul::from_qtensor(attention_wo)?,
442 attention_norm: RmsNorm::from_qtensor(attention_norm, rms_norm_eps)?,
443 mlp_or_moe,
444 ffn_norm: RmsNorm::from_qtensor(ffn_norm, rms_norm_eps)?,
445 n_head: head_count,
446 n_kv_head: head_count_kv,
447 head_dim: embedding_length / head_count,
448 cos: cos.clone(),
449 sin: sin.clone(),
450 neg_inf: neg_inf.clone(),
451 kv_cache: None,
452 span_attn,
453 span_rot,
454 span_mlp,
455 })
456 }
457 let span = tracing::span!(tracing::Level::TRACE, "model");
458 let span_output = tracing::span!(tracing::Level::TRACE, "output");
459 Ok(Self {
460 tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
461 layers,
462 norm,
463 output: QMatMul::from_qtensor(output)?,
464 masks: HashMap::new(),
465 span,
466 span_output,
467 })
468 }
469
470 fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> {
471 if let Some(mask) = self.masks.get(&t) {
472 Ok(mask.clone())
473 } else {
474 let mask: Vec<_> = (0..t)
475 .flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
476 .collect();
477 let mask = Tensor::from_slice(&mask, (t, t), device)?;
478 self.masks.insert(t, mask.clone());
479 Ok(mask)
480 }
481 }
482
483 pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
484 let (_b_sz, seq_len) = x.dims2()?;
485 let mask = if seq_len == 1 {
486 None
487 } else {
488 Some(self.mask(seq_len, x.device())?)
489 };
490 let _enter = self.span.enter();
491 let mut layer_in = self.tok_embeddings.forward(x)?;
492 for layer in self.layers.iter_mut() {
493 let x = layer_in;
494 let residual = &x;
495 let x = layer.attention_norm.forward(&x)?;
496 let attn = layer.forward_attn(&x, mask.as_ref(), index_pos)?;
497 let x = (attn + residual)?;
498
499 let _enter = layer.span_mlp.enter();
501 let residual = &x;
502 let x = layer.ffn_norm.forward(&x)?;
503 let x = layer.mlp_or_moe.forward(&x)?;
504 let x = (x + residual)?;
505 layer_in = x
506 }
507 let x = self.norm.forward(&layer_in)?;
508 let x = x.i((.., seq_len - 1, ..))?;
509 let _enter = self.span_output.enter();
510 self.output.forward(&x)
511 }
512}