1use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm};
20use candle::{DType, Device, Module, Result, Tensor, D};
21use candle_nn::{Activation, VarBuilder};
22use std::sync::Arc;
23
24#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
25pub struct Config {
26 pub vocab_size: usize,
27 pub hidden_size: usize,
28 pub intermediate_size: usize,
29 pub num_hidden_layers: usize,
30 pub num_attention_heads: usize,
31 pub num_key_value_heads: usize,
32 pub max_position_embeddings: usize,
33 pub sliding_window: usize,
34 pub max_window_layers: usize,
35 pub tie_word_embeddings: bool,
36 pub rope_theta: f64,
37 pub rms_norm_eps: f64,
38 pub use_sliding_window: bool,
39 pub hidden_act: Activation,
40 pub decoder_sparse_step: usize,
41 pub moe_intermediate_size: usize,
42 pub shared_expert_intermediate_size: usize,
43 pub num_experts_per_tok: usize,
44 pub num_experts: usize,
45 pub norm_topk_prob: bool,
46}
47
48#[derive(Debug, Clone)]
49struct RotaryEmbedding {
50 sin: Tensor,
51 cos: Tensor,
52}
53
54impl RotaryEmbedding {
55 fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
56 let dim = cfg.hidden_size / cfg.num_attention_heads;
57 let max_seq_len = cfg.max_position_embeddings;
58 let inv_freq: Vec<_> = (0..dim)
59 .step_by(2)
60 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
61 .collect();
62 let inv_freq_len = inv_freq.len();
63 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
64 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
65 .to_dtype(dtype)?
66 .reshape((max_seq_len, 1))?;
67 let freqs = t.matmul(&inv_freq)?;
68 Ok(Self {
69 sin: freqs.sin()?,
70 cos: freqs.cos()?,
71 })
72 }
73
74 fn apply_rotary_emb_qkv(
75 &self,
76 q: &Tensor,
77 k: &Tensor,
78 seqlen_offset: usize,
79 ) -> Result<(Tensor, Tensor)> {
80 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
81 let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
82 let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
83 let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
84 let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
85 Ok((q_embed, k_embed))
86 }
87}
88
89#[derive(Debug, Clone)]
90#[allow(clippy::upper_case_acronyms)]
91struct MLP {
92 gate_proj: Linear,
93 up_proj: Linear,
94 down_proj: Linear,
95 act_fn: Activation,
96}
97
98impl MLP {
99 fn new(intermediate_sz: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
100 let hidden_sz = cfg.hidden_size;
101 let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
102 let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
103 let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
104 Ok(Self {
105 gate_proj,
106 up_proj,
107 down_proj,
108 act_fn: cfg.hidden_act,
109 })
110 }
111}
112
113impl Module for MLP {
114 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
115 let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
116 let rhs = xs.apply(&self.up_proj)?;
117 (lhs * rhs)?.apply(&self.down_proj)
118 }
119}
120
121#[derive(Debug, Clone)]
122struct Attention {
123 q_proj: Linear,
124 k_proj: Linear,
125 v_proj: Linear,
126 o_proj: Linear,
127 num_heads: usize,
128 num_kv_heads: usize,
129 num_kv_groups: usize,
130 head_dim: usize,
131 hidden_size: usize,
132 rotary_emb: Arc<RotaryEmbedding>,
133 kv_cache: Option<(Tensor, Tensor)>,
134}
135
136impl Attention {
137 fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
138 let hidden_sz = cfg.hidden_size;
139 let num_heads = cfg.num_attention_heads;
140 let num_kv_heads = cfg.num_key_value_heads;
141 let num_kv_groups = num_heads / num_kv_heads;
142 let head_dim = hidden_sz / num_heads;
143 let q_proj = linear(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
144 let k_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
145 let v_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
146 let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
147 Ok(Self {
148 q_proj,
149 k_proj,
150 v_proj,
151 o_proj,
152 num_heads,
153 num_kv_heads,
154 num_kv_groups,
155 head_dim,
156 hidden_size: hidden_sz,
157 rotary_emb,
158 kv_cache: None,
159 })
160 }
161
162 fn forward(
163 &mut self,
164 xs: &Tensor,
165 attention_mask: Option<&Tensor>,
166 seqlen_offset: usize,
167 ) -> Result<Tensor> {
168 let (b_sz, q_len, _) = xs.dims3()?;
169
170 let query_states = self.q_proj.forward(xs)?;
171 let key_states = self.k_proj.forward(xs)?;
172 let value_states = self.v_proj.forward(xs)?;
173
174 let query_states = query_states
175 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
176 .transpose(1, 2)?;
177 let key_states = key_states
178 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
179 .transpose(1, 2)?;
180 let value_states = value_states
181 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
182 .transpose(1, 2)?;
183
184 let (query_states, key_states) =
185 self.rotary_emb
186 .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
187
188 let (key_states, value_states) = match &self.kv_cache {
189 None => (key_states, value_states),
190 Some((prev_k, prev_v)) => {
191 let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
192 let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
193 (key_states, value_states)
194 }
195 };
196 self.kv_cache = Some((key_states.clone(), value_states.clone()));
197
198 let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
199 let value_states =
200 crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
201
202 let attn_output = {
203 let scale = 1f64 / f64::sqrt(self.head_dim as f64);
204 let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
205
206 let attn_weights = match attention_mask {
207 None => attn_weights,
208 Some(mask) => attn_weights.broadcast_add(mask)?,
209 };
210 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
211 attn_weights.matmul(&value_states)?
212 };
213 attn_output
214 .transpose(1, 2)?
215 .reshape((b_sz, q_len, self.hidden_size))?
216 .apply(&self.o_proj)
217 }
218
219 fn clear_kv_cache(&mut self) {
220 self.kv_cache = None
221 }
222}
223
224#[derive(Debug, Clone)]
226struct SparseMoeBlock {
227 gate: Linear,
228 experts: Vec<MLP>,
229 shared_expert: MLP,
230 shared_expert_gate: Linear,
231 norm_topk_prob: bool,
232 num_experts_per_tok: usize,
233}
234
235impl SparseMoeBlock {
236 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
237 let gate = linear_no_bias(cfg.hidden_size, cfg.num_experts, vb.pp("gate"))?;
238 let mut experts = Vec::with_capacity(cfg.num_experts);
239 let vb_e = vb.pp("experts");
240 for idx in 0..cfg.num_experts {
241 let expert = MLP::new(cfg.moe_intermediate_size, cfg, vb_e.pp(idx))?;
242 experts.push(expert)
243 }
244 let shared_expert = MLP::new(
245 cfg.shared_expert_intermediate_size,
246 cfg,
247 vb.pp("shared_expert"),
248 )?;
249 let shared_expert_gate = linear_no_bias(cfg.hidden_size, 1, vb.pp("shared_expert_gate"))?;
250 Ok(Self {
251 gate,
252 experts,
253 shared_expert,
254 shared_expert_gate,
255 norm_topk_prob: cfg.norm_topk_prob,
256 num_experts_per_tok: cfg.num_experts_per_tok,
257 })
258 }
259}
260
261impl Module for SparseMoeBlock {
262 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
263 let (b_size, seq_len, hidden_dim) = xs.dims3()?;
264 let xs = xs.reshape(((), hidden_dim))?;
265 let router_logits = xs.apply(&self.gate)?;
266 let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
267
268 let experts_per_tok = routing_weights
271 .arg_sort_last_dim(false)?
272 .narrow(D::Minus1, 0, self.num_experts_per_tok)?
273 .contiguous()?;
274 let routing_weights = routing_weights.gather(&experts_per_tok, D::Minus1)?;
275
276 let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
279 let experts_per_tok = experts_per_tok.to_vec2::<u32>()?;
280 let mut top_x = vec![vec![]; self.experts.len()];
281 let mut selected_experts = vec![vec![]; self.experts.len()];
282 for (row_idx, (rw, expert_idxs)) in routing_weights
283 .iter()
284 .zip(experts_per_tok.iter())
285 .enumerate()
286 {
287 let sum_rw = rw.iter().sum::<f32>();
288 for (&rw, &expert_idx) in rw.iter().zip(expert_idxs.iter()) {
289 top_x[expert_idx as usize].push(row_idx as u32);
290 let rw = if self.norm_topk_prob { rw / sum_rw } else { rw };
291 selected_experts[expert_idx as usize].push(rw)
292 }
293 }
294
295 let mut ys = xs.zeros_like()?;
296 for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
297 let top_x = &top_x[expert_idx];
298 if top_x.is_empty() {
299 continue;
300 }
301 let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
302 let selected_experts =
303 Tensor::new(selected_experts[expert_idx].as_slice(), xs.device())?
304 .reshape(((), 1))?
305 .to_dtype(xs.dtype())?;
306 let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
310 let current_hidden_states = expert_layer.forward(¤t_state)?;
312 let current_hidden_states = current_hidden_states.broadcast_mul(&selected_experts)?;
313 ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?;
314 }
315 let shared_expert_output = xs.apply(&self.shared_expert)?;
316 let shared_expert_output = shared_expert_output.broadcast_mul(&candle_nn::ops::sigmoid(
317 &xs.apply(&self.shared_expert_gate)?,
318 )?)?;
319 let ys = (ys + shared_expert_output)?;
320 let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
321 Ok(ys)
322 }
323}
324
325#[derive(Debug, Clone)]
326enum MlpOrMoeBlock {
327 Mlp(MLP),
328 MoeBlock(SparseMoeBlock),
329}
330
331impl Module for MlpOrMoeBlock {
332 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
333 match self {
334 Self::MoeBlock(m) => m.forward(xs),
335 Self::Mlp(m) => m.forward(xs),
336 }
337 }
338}
339
340#[derive(Debug, Clone)]
341struct DecoderLayer {
342 self_attn: Attention,
343 mlp: MlpOrMoeBlock,
344 input_layernorm: RmsNorm,
345 post_attention_layernorm: RmsNorm,
346}
347
348impl DecoderLayer {
349 fn new(
350 layer_idx: usize,
351 rotary_emb: Arc<RotaryEmbedding>,
352 cfg: &Config,
353 vb: VarBuilder,
354 ) -> Result<Self> {
355 let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
356 let mlp = if cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0 {
357 MlpOrMoeBlock::MoeBlock(SparseMoeBlock::new(cfg, vb.pp("mlp"))?)
358 } else {
359 MlpOrMoeBlock::Mlp(MLP::new(cfg.intermediate_size, cfg, vb.pp("mlp"))?)
360 };
361 let input_layernorm =
362 RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
363 let post_attention_layernorm = RmsNorm::new(
364 cfg.hidden_size,
365 cfg.rms_norm_eps,
366 vb.pp("post_attention_layernorm"),
367 )?;
368 Ok(Self {
369 self_attn,
370 mlp,
371 input_layernorm,
372 post_attention_layernorm,
373 })
374 }
375
376 fn forward(
377 &mut self,
378 xs: &Tensor,
379 attention_mask: Option<&Tensor>,
380 seqlen_offset: usize,
381 ) -> Result<Tensor> {
382 let residual = xs;
383 let xs = self.input_layernorm.forward(xs)?;
384 let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
385 let xs = (xs + residual)?;
386 let residual = &xs;
387 let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
388 residual + xs
389 }
390
391 fn clear_kv_cache(&mut self) {
392 self.self_attn.clear_kv_cache()
393 }
394}
395
396#[derive(Debug, Clone)]
397pub struct Model {
398 embed_tokens: candle_nn::Embedding,
399 layers: Vec<DecoderLayer>,
400 norm: RmsNorm,
401 lm_head: Linear,
402 sliding_window: usize,
403 device: Device,
404 dtype: DType,
405}
406
407impl Model {
408 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
409 let vb_m = vb.pp("model");
410 let embed_tokens =
411 candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
412 let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
413 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
414 let vb_l = vb_m.pp("layers");
415 for layer_idx in 0..cfg.num_hidden_layers {
416 let layer = DecoderLayer::new(layer_idx, rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
417 layers.push(layer)
418 }
419 let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
420 let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
421 Ok(Self {
422 embed_tokens,
423 layers,
424 norm,
425 lm_head,
426 sliding_window: cfg.sliding_window,
427 device: vb.device().clone(),
428 dtype: vb.dtype(),
429 })
430 }
431
432 fn prepare_decoder_attention_mask(
433 &self,
434 b_size: usize,
435 tgt_len: usize,
436 seqlen_offset: usize,
437 ) -> Result<Tensor> {
438 let mask: Vec<_> = (0..tgt_len)
440 .flat_map(|i| {
441 (0..tgt_len).map(move |j| {
442 if i < j || j + self.sliding_window < i {
443 f32::NEG_INFINITY
444 } else {
445 0.
446 }
447 })
448 })
449 .collect();
450 let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
451 let mask = if seqlen_offset > 0 {
452 let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
453 Tensor::cat(&[&mask0, &mask], D::Minus1)?
454 } else {
455 mask
456 };
457 mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
458 .to_dtype(self.dtype)
459 }
460
461 pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
462 let (b_size, seq_len) = input_ids.dims2()?;
463 let attention_mask = if seq_len <= 1 {
464 None
465 } else {
466 let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
467 Some(mask)
468 };
469 let mut xs = self.embed_tokens.forward(input_ids)?;
470 for layer in self.layers.iter_mut() {
471 xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
472 }
473 xs.narrow(1, seq_len - 1, 1)?
474 .apply(&self.norm)?
475 .apply(&self.lm_head)
476 }
477
478 pub fn clear_kv_cache(&mut self) {
479 for layer in self.layers.iter_mut() {
480 layer.clear_kv_cache()
481 }
482 }
483}