1use crate::models::with_tracing::{linear_no_bias, Linear, RmsNorm};
19use candle::{DType, Device, Module, Result, Tensor, D};
23use candle_nn::{Activation, VarBuilder};
24use serde::Deserialize;
25use std::sync::Arc;
26
27#[derive(Debug, Clone, PartialEq, Deserialize)]
29pub struct Config {
30 pub(crate) vocab_size: usize,
31 pub(crate) hidden_size: usize,
32 pub(crate) intermediate_size: usize,
33 pub(crate) num_hidden_layers: usize,
34 pub(crate) num_attention_heads: usize,
35 pub(crate) num_key_value_heads: usize,
36 pub(crate) hidden_act: Activation,
37 pub(crate) max_position_embeddings: usize,
38 pub(crate) rms_norm_eps: f64,
39 pub(crate) rope_theta: f64,
40 pub(crate) sliding_window: usize,
41 pub(crate) num_experts_per_tok: usize,
42 pub(crate) num_local_experts: usize,
43 pub(crate) use_flash_attn: bool,
44}
45
46impl Config {
47 pub fn v0_1_8x7b(use_flash_attn: bool) -> Self {
49 Self {
50 vocab_size: 32000,
51 hidden_size: 4096,
52 intermediate_size: 14336,
53 num_hidden_layers: 32,
54 num_attention_heads: 32,
55 num_key_value_heads: 8,
56 hidden_act: Activation::Silu,
57 max_position_embeddings: 32768,
58 rms_norm_eps: 1e-5,
59 rope_theta: 1e6,
60 sliding_window: 4096,
61 num_experts_per_tok: 2,
62 num_local_experts: 8,
63 use_flash_attn,
64 }
65 }
66}
67
68#[derive(Debug, Clone)]
69struct RotaryEmbedding {
70 sin: Tensor,
71 cos: Tensor,
72}
73
74fn rotate_half(xs: &Tensor) -> Result<Tensor> {
75 let last_dim = xs.dim(D::Minus1)?;
76 let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
77 let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
78 Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
79}
80
81impl RotaryEmbedding {
82 fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
83 let dim = cfg.hidden_size / cfg.num_attention_heads;
84 let max_seq_len = cfg.max_position_embeddings;
85 let inv_freq: Vec<_> = (0..dim)
86 .step_by(2)
87 .map(|i| 1f32 / (cfg.rope_theta as f32).powf(i as f32 / dim as f32))
88 .collect();
89 let inv_freq_len = inv_freq.len();
90 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
91 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
92 .to_dtype(dtype)?
93 .reshape((max_seq_len, 1))?;
94 let freqs = t.matmul(&inv_freq)?;
95 let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
96 Ok(Self {
97 sin: freqs.sin()?,
98 cos: freqs.cos()?,
99 })
100 }
101
102 fn apply_rotary_emb_qkv(
103 &self,
104 q: &Tensor,
105 k: &Tensor,
106 seqlen_offset: usize,
107 ) -> Result<(Tensor, Tensor)> {
108 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
109 let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
110 let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
111 let cos = cos.unsqueeze(0)?.unsqueeze(0)?; let sin = sin.unsqueeze(0)?.unsqueeze(0)?; let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
114 let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
115 Ok((q_embed, k_embed))
116 }
117}
118
119#[cfg(feature = "flash-attn")]
120fn flash_attn(
121 q: &Tensor,
122 k: &Tensor,
123 v: &Tensor,
124 softmax_scale: f32,
125 causal: bool,
126) -> Result<Tensor> {
127 candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
128}
129
130#[cfg(not(feature = "flash-attn"))]
131fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
132 unimplemented!("compile with '--features flash-attn'")
133}
134
135#[derive(Debug, Clone)]
136struct Attention {
137 q_proj: Linear,
138 k_proj: Linear,
139 v_proj: Linear,
140 o_proj: Linear,
141 num_heads: usize,
142 num_kv_heads: usize,
143 num_kv_groups: usize,
144 head_dim: usize,
145 hidden_size: usize,
146 rotary_emb: Arc<RotaryEmbedding>,
147 kv_cache: Option<(Tensor, Tensor)>,
148 use_flash_attn: bool,
149}
150
151impl Attention {
152 fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
153 let hidden_sz = cfg.hidden_size;
154 let num_heads = cfg.num_attention_heads;
155 let num_kv_heads = cfg.num_key_value_heads;
156 let num_kv_groups = num_heads / num_kv_heads;
157 let head_dim = hidden_sz / num_heads;
158 let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
159 let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
160 let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
161 let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
162 Ok(Self {
163 q_proj,
164 k_proj,
165 v_proj,
166 o_proj,
167 num_heads,
168 num_kv_heads,
169 num_kv_groups,
170 head_dim,
171 hidden_size: hidden_sz,
172 rotary_emb,
173 kv_cache: None,
174 use_flash_attn: cfg.use_flash_attn,
175 })
176 }
177
178 fn forward(
179 &mut self,
180 xs: &Tensor,
181 attention_mask: Option<&Tensor>,
182 seqlen_offset: usize,
183 ) -> Result<Tensor> {
184 let (b_sz, q_len, _) = xs.dims3()?;
185
186 let query_states = self.q_proj.forward(xs)?;
187 let key_states = self.k_proj.forward(xs)?;
188 let value_states = self.v_proj.forward(xs)?;
189
190 let query_states = query_states
191 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
192 .transpose(1, 2)?;
193 let key_states = key_states
194 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
195 .transpose(1, 2)?;
196 let value_states = value_states
197 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
198 .transpose(1, 2)?;
199
200 let (query_states, key_states) =
201 self.rotary_emb
202 .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
203
204 let (key_states, value_states) = match &self.kv_cache {
205 None => (key_states, value_states),
206 Some((prev_k, prev_v)) => {
207 let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
208 let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
209 (key_states, value_states)
210 }
211 };
212 self.kv_cache = Some((key_states.clone(), value_states.clone()));
213
214 let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
215 let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
216
217 let attn_output = if self.use_flash_attn {
218 let q = query_states.transpose(1, 2)?;
220 let k = key_states.transpose(1, 2)?;
221 let v = value_states.transpose(1, 2)?;
222 let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
223 flash_attn(&q, &k, &v, softmax_scale, q_len > 1)?.transpose(1, 2)?
224 } else {
225 let scale = 1f64 / f64::sqrt(self.head_dim as f64);
226 let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
227
228 let attn_weights = match attention_mask {
229 None => attn_weights,
230 Some(mask) => attn_weights.broadcast_add(mask)?,
231 };
232 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
233 attn_weights.matmul(&value_states)?
234 };
235 attn_output
236 .transpose(1, 2)?
237 .reshape((b_sz, q_len, self.hidden_size))?
238 .apply(&self.o_proj)
239 }
240}
241
242#[derive(Debug, Clone)]
243struct BlockSparseTop2MLP {
244 w1: Linear,
245 w2: Linear,
246 w3: Linear,
247 act_fn: Activation,
248}
249
250impl BlockSparseTop2MLP {
251 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
252 let hidden_sz = cfg.hidden_size;
253 let intermediate_sz = cfg.intermediate_size;
254 let w1 = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("w1"))?;
255 let w2 = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("w2"))?;
256 let w3 = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("w3"))?;
257 Ok(Self {
258 w1,
259 w2,
260 w3,
261 act_fn: cfg.hidden_act,
262 })
263 }
264}
265
266impl Module for BlockSparseTop2MLP {
267 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
268 let lhs = xs.apply(&self.w1)?.apply(&self.act_fn)?;
269 let rhs = xs.apply(&self.w3)?;
270 (lhs * rhs)?.apply(&self.w2)
271 }
272}
273
274#[derive(Debug, Clone)]
275struct SparseMoeBlock {
276 gate: Linear,
277 experts: Vec<BlockSparseTop2MLP>,
278 num_experts_per_tok: usize,
279}
280
281impl SparseMoeBlock {
282 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
283 let gate = linear_no_bias(cfg.hidden_size, cfg.num_local_experts, vb.pp("gate"))?;
284 let mut experts = Vec::with_capacity(cfg.num_local_experts);
285 let vb = vb.pp("experts");
286 for idx in 0..cfg.num_local_experts {
287 let expert = BlockSparseTop2MLP::new(cfg, vb.pp(idx))?;
288 experts.push(expert)
289 }
290 Ok(SparseMoeBlock {
291 gate,
292 experts,
293 num_experts_per_tok: cfg.num_experts_per_tok,
294 })
295 }
296}
297
298impl Module for SparseMoeBlock {
299 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
300 let (b_size, seq_len, hidden_dim) = xs.dims3()?;
301 let xs = xs.reshape(((), hidden_dim))?;
302 let router_logits = xs.apply(&self.gate)?;
303 let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
304
305 let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
308
309 let mut top_x = vec![vec![]; self.experts.len()];
312 let mut selected_rws = vec![vec![]; self.experts.len()];
313 for (row_idx, rw) in routing_weights.iter().enumerate() {
314 let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
315 dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
316 let mut sum_routing_weights = 0f32;
317 for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
318 let expert_idx = expert_idx as usize;
319 let routing_weight = rw[expert_idx];
320 sum_routing_weights += routing_weight;
321 top_x[expert_idx].push(row_idx as u32);
322 }
323 for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
324 let expert_idx = expert_idx as usize;
325 let routing_weight = rw[expert_idx];
326 selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
327 }
328 }
329
330 let mut ys = xs.zeros_like()?;
334 for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
335 let top_x = &top_x[expert_idx];
336 if top_x.is_empty() {
337 continue;
338 }
339 let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
340 let selected_rws =
341 Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?.reshape(((), 1))?;
342 let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
346 let current_hidden_states = expert_layer.forward(¤t_state)?;
348 let current_hidden_states = current_hidden_states.broadcast_mul(&selected_rws)?;
349 ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?;
350 }
351
352 let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
353 Ok(ys)
354 }
355}
356
357#[derive(Debug, Clone)]
358struct DecoderLayer {
359 self_attn: Attention,
360 block_sparse_moe: SparseMoeBlock,
361 input_layernorm: RmsNorm,
362 post_attention_layernorm: RmsNorm,
363}
364
365impl DecoderLayer {
366 fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
367 let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
368 let block_sparse_moe = SparseMoeBlock::new(cfg, vb.pp("block_sparse_moe"))?;
369 let input_layernorm =
370 RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
371 let post_attention_layernorm = RmsNorm::new(
372 cfg.hidden_size,
373 cfg.rms_norm_eps,
374 vb.pp("post_attention_layernorm"),
375 )?;
376 Ok(Self {
377 self_attn,
378 block_sparse_moe,
379 input_layernorm,
380 post_attention_layernorm,
381 })
382 }
383
384 fn forward(
385 &mut self,
386 xs: &Tensor,
387 attention_mask: Option<&Tensor>,
388 seqlen_offset: usize,
389 ) -> Result<Tensor> {
390 let residual = xs;
391 let xs = self.input_layernorm.forward(xs)?;
392 let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
393 let xs = (xs + residual)?;
394 let residual = &xs;
395 let xs = xs
396 .apply(&self.post_attention_layernorm)?
397 .apply(&self.block_sparse_moe)?;
398 residual + xs
399 }
400}
401
402#[derive(Debug, Clone)]
403pub struct Model {
404 embed_tokens: candle_nn::Embedding,
405 layers: Vec<DecoderLayer>,
406 norm: RmsNorm,
407 lm_head: Linear,
408 sliding_window: usize,
409 device: Device,
410 dtype: DType,
411}
412
413impl Model {
414 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
415 let vb_m = vb.pp("model");
416 let embed_tokens =
417 candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
418 let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
419 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
420 let vb_l = vb_m.pp("layers");
421 for layer_idx in 0..cfg.num_hidden_layers {
422 let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
423 layers.push(layer)
424 }
425 let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
426 let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
427 Ok(Self {
428 embed_tokens,
429 layers,
430 norm,
431 lm_head,
432 sliding_window: cfg.sliding_window,
433 device: vb.device().clone(),
434 dtype: vb.dtype(),
435 })
436 }
437
438 fn prepare_decoder_attention_mask(
439 &self,
440 b_size: usize,
441 tgt_len: usize,
442 seqlen_offset: usize,
443 ) -> Result<Tensor> {
444 let mask: Vec<_> = (0..tgt_len)
446 .flat_map(|i| {
447 (0..tgt_len).map(move |j| {
448 if i < j || j + self.sliding_window < i {
449 f32::NEG_INFINITY
450 } else {
451 0.
452 }
453 })
454 })
455 .collect();
456 let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
457 let mask = if seqlen_offset > 0 {
458 let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
459 Tensor::cat(&[&mask0, &mask], D::Minus1)?
460 } else {
461 mask
462 };
463 mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
464 .to_dtype(self.dtype)
465 }
466
467 pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
468 let (b_size, seq_len) = input_ids.dims2()?;
469 let attention_mask = if seq_len <= 1 {
470 None
471 } else {
472 let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
473 Some(mask)
474 };
475 let mut xs = self.embed_tokens.forward(input_ids)?;
476 for layer in self.layers.iter_mut() {
477 xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
478 }
479 xs.narrow(1, seq_len - 1, 1)?
480 .apply(&self.norm)?
481 .apply(&self.lm_head)
482 }
483}