1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{f32::consts::PI, sync::Arc};
4
5use candle::{
6 shape::Dim, CpuStorage, CustomOp1, DType, Device, Error, IndexOp, Layout, Result, Shape,
7 Tensor, WithDType, D,
8};
9use candle_nn::{embedding, rms_norm, Activation, Embedding, Linear, Module, RmsNorm, VarBuilder};
10use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
11use serde::Deserialize;
12
13struct NonZero {}
14
15impl NonZero {
16 fn nonzero<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Vec<u32> {
18 let n = layout.dims().len();
19 let mut result = Vec::new();
20 let mut indices = vec![0u32; n];
21 for (i, v) in vs.iter().enumerate() {
22 if !v.is_zero() {
23 let mut idx = i;
24 for (dim_index, dim) in layout.dims().iter().enumerate().rev() {
25 let d = idx % dim;
26 indices[dim_index] = u32::try_from(d).unwrap();
27 idx /= dim;
28 }
29 result.extend_from_slice(&indices);
30 }
31 }
32 result
33 }
34}
35
36impl CustomOp1 for NonZero {
37 fn name(&self) -> &'static str {
38 "nonzero"
39 }
40
41 fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
42 if !layout.is_contiguous() {
43 return Err(Error::RequiresContiguous { op: "nonzero" });
44 }
45 let result = match storage {
46 candle::CpuStorage::U8(vs) => self.nonzero(vs, layout),
47 candle::CpuStorage::U32(vs) => self.nonzero(vs, layout),
48 candle::CpuStorage::I64(vs) => self.nonzero(vs, layout),
49 candle::CpuStorage::BF16(vs) => self.nonzero(vs, layout),
50 candle::CpuStorage::F16(vs) => self.nonzero(vs, layout),
51 candle::CpuStorage::F32(vs) => self.nonzero(vs, layout),
52 candle::CpuStorage::F64(vs) => self.nonzero(vs, layout),
53 };
54 let index_len = layout.dims().len();
55 let result_len = result.len() / index_len;
56 let result = CpuStorage::U32(result);
57 let shape = Shape::from_dims(&[result_len, index_len]);
58 Ok((result, shape))
59 }
60}
61
62pub trait NonZeroOp {
63 fn nonzero(&self) -> Result<Tensor>;
64}
65
66impl NonZeroOp for Tensor {
67 fn nonzero(&self) -> Result<Tensor> {
68 if !self.is_contiguous() {
69 return Err(candle::Error::RequiresContiguous { op: "nonzero" });
70 }
71 let original_device = self.device();
72 self.to_device(&candle::Device::Cpu)?
73 .apply_op1_no_bwd(&NonZero {})?
74 .to_device(original_device)
75 }
76}
77
78pub struct TopKOutput {
79 pub values: Tensor,
80 pub indices: Tensor,
81}
82
83pub trait TopKLastDimOp {
84 fn topk(&self, topk: usize) -> Result<TopKOutput>;
88
89 fn topk_unsorted(&self, topk: usize) -> Result<TopKOutput>;
93}
94
95impl TopKLastDimOp for Tensor {
96 fn topk(&self, topk: usize) -> Result<TopKOutput> {
97 let sorted_indices = self.arg_sort_last_dim(false)?;
99 let topk_indices = sorted_indices.narrow(D::Minus1, 0, topk)?.contiguous()?;
100 Ok(TopKOutput {
101 values: self.gather(&topk_indices, D::Minus1)?,
102 indices: topk_indices,
103 })
104 }
105
106 fn topk_unsorted(&self, topk: usize) -> Result<TopKOutput> {
107 let sorted_indices_all = self.arg_sort_last_dim(false)?;
109 let topk_indices_sorted = sorted_indices_all
110 .narrow(D::Minus1, 0, topk)?
111 .contiguous()?;
112 let topk_values_sorted = self.gather(&topk_indices_sorted, D::Minus1)?;
113
114 let reorder_indices = topk_indices_sorted.arg_sort_last_dim(true)?;
116 let topk_indices_unsorted = topk_indices_sorted.gather(&reorder_indices, D::Minus1)?;
117 let topk_values_unsorted = topk_values_sorted.gather(&reorder_indices, D::Minus1)?;
118 Ok(TopKOutput {
119 values: topk_values_unsorted,
120 indices: topk_indices_unsorted,
121 })
122 }
123}
124
125pub trait SplitOp {
126 fn split<D: Dim>(&self, splits: &[usize], dim: D) -> Result<Vec<Tensor>>;
127}
128
129impl SplitOp for Tensor {
130 fn split<D: Dim>(&self, splits: &[usize], dim: D) -> Result<Vec<Tensor>> {
131 let dim = dim.to_index(self.shape(), "split")?;
132 let mut split_res = Vec::new();
133 let mut index = 0;
134 for split in splits {
135 split_res.push(self.narrow(dim, index, *split)?);
136 index += *split;
137 }
138 Ok(split_res)
139 }
140}
141
142pub trait BincountOp {
143 fn bincount(&self, minlength: u32) -> Result<Vec<u32>>;
144}
145
146fn bincount(values: &[u32], minlength: u32) -> Vec<u32> {
147 let max_val = values.par_iter().max().copied().unwrap_or(0);
149
150 let result_len = (max_val + 1).max(minlength);
153
154 values
157 .par_iter()
158 .fold(
159 || vec![0u32; result_len as usize],
161 |mut local_counts, &val| {
163 local_counts[val as usize] += 1;
164 local_counts
165 },
166 )
167 .reduce(
169 || vec![0u32; result_len as usize],
171 |mut global_counts, local_counts| {
173 for (g, l) in global_counts.iter_mut().zip(local_counts) {
174 *g += l;
175 }
176 global_counts
177 },
178 )
179}
180
181impl BincountOp for Tensor {
182 fn bincount(&self, minlength: u32) -> Result<Vec<u32>> {
183 let values = self.to_vec1::<u32>()?;
184
185 Ok(bincount(&values, minlength))
186 }
187}
188
189fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
190 let shape = mask.shape();
191 let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
192 let m = mask.where_cond(&on_true, on_false)?;
193 Ok(m)
194}
195
196#[doc(hidden)]
197#[macro_export]
198macro_rules! serde_default_fn {
199 ($t:ty, $name:ident, $v:expr) => {
200 fn $name() -> $t {
201 $v
202 }
203 };
204}
205
206serde_default_fn!(f64, routed_scaling_factor, 1.0);
207serde_default_fn!(TopkMethod, topk_method, TopkMethod::Greedy);
208serde_default_fn!(usize, moe_layer_freq, 1);
209serde_default_fn!(usize, first_k_dense_replace, 0);
210serde_default_fn!(bool, norm_topk_prob, false);
211serde_default_fn!(ScoringFunc, scoring_func, ScoringFunc::Softmax);
212serde_default_fn!(Activation, hidden_act, Activation::Silu);
213serde_default_fn!(bool, tie_word_embeddings, false);
214
215#[derive(Deserialize, Clone, Debug)]
216enum TopkMethod {
217 #[serde(rename = "greedy")]
218 Greedy,
219 #[serde(rename = "group_limited_greedy")]
220 GroupLimitedGreedy,
221}
222
223#[derive(Deserialize, Clone, Debug)]
224enum ScoringFunc {
225 #[serde(rename = "softmax")]
226 Softmax,
227}
228
229#[derive(Deserialize, Clone, Debug)]
230pub struct DeepSeekV2Config {
231 pub(crate) vocab_size: usize,
232 pub(crate) hidden_size: usize,
233 pub(crate) intermediate_size: usize,
234 pub(crate) moe_intermediate_size: usize,
235 pub(crate) num_hidden_layers: usize,
236 pub(crate) num_attention_heads: usize,
237 pub(crate) n_shared_experts: Option<usize>,
238 pub(crate) n_routed_experts: Option<usize>,
239 #[serde(default = "routed_scaling_factor")]
240 pub(crate) routed_scaling_factor: f64,
241 #[serde(default = "topk_method")]
242 topk_method: TopkMethod,
243 pub(crate) num_experts_per_tok: Option<usize>,
244 #[serde(default = "moe_layer_freq")]
245 pub(crate) moe_layer_freq: usize,
246 #[serde(default = "first_k_dense_replace")]
247 pub(crate) first_k_dense_replace: usize,
248 #[serde(default = "norm_topk_prob")]
250 pub(crate) norm_topk_prob: bool,
251 #[serde(default = "scoring_func")]
252 scoring_func: ScoringFunc,
253 #[serde(default = "hidden_act")]
254 pub(crate) hidden_act: Activation,
255 pub(crate) max_position_embeddings: usize,
256 pub(crate) rms_norm_eps: f64,
257 #[serde(default = "tie_word_embeddings")]
258 pub(crate) tie_word_embeddings: bool,
259 pub(crate) rope_theta: f32,
260 pub(crate) rope_scaling: Option<DeepSeekV2RopeScaling>,
261 pub(crate) attention_bias: bool,
262 pub(crate) q_lora_rank: Option<usize>,
263 pub(crate) qk_rope_head_dim: usize,
264 pub(crate) kv_lora_rank: usize,
265 pub(crate) v_head_dim: usize,
266 pub(crate) qk_nope_head_dim: usize,
267 pub(crate) n_group: usize,
268 pub(crate) topk_group: usize,
269}
270
271#[derive(Debug, Clone, Deserialize)]
272#[serde(rename_all = "lowercase")]
273pub enum ScaledRopeType {
274 #[serde(alias = "su")]
275 #[serde(alias = "longrope")]
276 Su,
277 #[serde(alias = "yarn")]
278 Yarn,
279 #[serde(alias = "dynamic")]
280 Dynamic,
281 #[serde(alias = "linear")]
282 Linear,
283}
284
285#[derive(Debug, Clone)]
286pub struct DeepSeekV2RotaryEmbedding {
287 sin: Tensor,
288 cos: Tensor,
289}
290
291#[derive(Debug, Clone, Deserialize)]
292#[serde(untagged)]
293pub enum DeepSeekV2RopeScaling {
294 Yarn {
295 original_max_position_embeddings: usize,
296 beta_fast: f32,
297 beta_slow: f32,
298 mscale: f32,
299 mscale_all_dim: f32,
300 factor: f32,
301 #[serde(rename = "type")]
302 scaling_type: ScaledRopeType,
303 },
304 LinearOrDynamic {
305 #[serde(rename = "type")]
306 scaling_type: ScaledRopeType,
307 factor: f64,
308 },
309}
310
311pub struct DeepSeekV2RopeConfig {
312 pub rope_scaling: Option<DeepSeekV2RopeScaling>,
313 pub max_position_embeddings: usize,
314 pub rope_theta: f32,
315 pub qk_rope_head_dim: usize,
316}
317
318impl DeepSeekV2RotaryEmbedding {
319 fn new_unscaled(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
320 let max_seq_len = cfg.max_position_embeddings;
321 let dim = cfg.qk_rope_head_dim;
322
323 let inv_freq: Vec<_> = (0..dim)
324 .step_by(2)
325 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32))
326 .collect();
327 let inv_freq_len = inv_freq.len();
328 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
329 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
330 .to_dtype(DType::F32)?
331 .reshape((max_seq_len, 1))?;
332 let freqs = t.matmul(&inv_freq)?;
333
334 let sin = freqs.sin()?.to_dtype(dtype)?;
335 let cos = freqs.cos()?.to_dtype(dtype)?;
336
337 Ok(Self { sin, cos })
338 }
339
340 fn yarn_find_correction_dim(
341 num_rot: f32,
342 dim: usize,
343 base: f32,
344 max_position_embeddings: usize,
345 ) -> f32 {
346 (dim as f32 * (max_position_embeddings as f32 / (num_rot * 2. * PI)).ln())
347 / (2. * base.ln())
348 }
349
350 fn yarn_find_correction_range(
351 low_rot: f32,
352 high_rot: f32,
353 dim: usize,
354 base: f32,
355 max_position_embeddings: usize,
356 ) -> (f32, f32) {
357 let low =
358 Self::yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings).floor();
359 let high =
360 Self::yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings).ceil();
361 (low.max(0.), high.min(dim as f32 - 1.))
362 }
363
364 fn yarn_linear_ramp_mask(min: f32, mut max: f32, dim: usize, dev: &Device) -> Result<Tensor> {
365 if min == max {
366 max += 0.001;
368 }
369 let linear_func =
370 ((Tensor::arange(0f32, dim as f32, dev)? - min as f64)? / (max as f64 - min as f64))?;
371 linear_func.clamp(0., 1.)
372 }
373
374 pub(crate) fn yarn_get_mscale(scale: f32, mscale: f32) -> f32 {
375 if scale <= 1. {
376 return 1.;
377 }
378 0.1 * mscale * scale.ln() + 1.
379 }
380
381 #[allow(clippy::too_many_arguments)]
382 fn new_yarn(
383 cfg: &DeepSeekV2RopeConfig,
384 dtype: DType,
385 dev: &Device,
386 original_max_position_embeddings: usize,
387 beta_fast: f32,
388 beta_slow: f32,
389 factor: f32,
390 mscale: f32,
391 mscale_all_dim: f32,
392 ) -> Result<Self> {
393 let freq_extra: Vec<_> = (0..cfg.qk_rope_head_dim)
394 .step_by(2)
395 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32))
396 .collect();
397 let freq_extra_len = freq_extra.len();
398 let freq_extra = Tensor::from_vec(freq_extra, freq_extra_len, dev)?;
399 let freq_inter: Vec<_> = (0..cfg.qk_rope_head_dim)
400 .step_by(2)
401 .map(|i| 1f32 / (factor * cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32)))
402 .collect();
403 let freq_inter_len = freq_inter.len();
404 let freq_inter = Tensor::from_vec(freq_inter, (1, freq_inter_len), dev)?;
405
406 let (low, high) = Self::yarn_find_correction_range(
407 beta_fast,
408 beta_slow,
409 cfg.qk_rope_head_dim,
410 cfg.rope_theta,
411 original_max_position_embeddings,
412 );
413 let inv_freq_mask =
414 (1. - Self::yarn_linear_ramp_mask(low, high, cfg.qk_rope_head_dim / 2, dev)?)?;
415 let inv_freq = freq_inter
416 .broadcast_mul(&(1. - &inv_freq_mask)?)?
417 .broadcast_add(&freq_extra.broadcast_mul(&inv_freq_mask)?)?;
418
419 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
420 .to_dtype(DType::F32)?
421 .reshape((cfg.max_position_embeddings, 1))?;
422 let freqs = t.matmul(&inv_freq)?;
423
424 let mscale =
425 Self::yarn_get_mscale(factor, mscale) / Self::yarn_get_mscale(factor, mscale_all_dim);
426 let sin = (freqs.sin()? * mscale as f64)?.to_dtype(dtype)?;
427 let cos = (freqs.cos()? * mscale as f64)?.to_dtype(dtype)?;
428
429 Ok(Self { sin, cos })
430 }
431
432 pub fn new(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
433 match &cfg.rope_scaling {
434 Some(DeepSeekV2RopeScaling::LinearOrDynamic {
435 scaling_type: _,
436 factor: _,
437 }) => candle::bail!("linear and dynamic rope are not implemented yet!"),
438 Some(DeepSeekV2RopeScaling::Yarn {
439 original_max_position_embeddings,
440 beta_fast,
441 beta_slow,
442 factor,
443 mscale,
444 mscale_all_dim,
445 scaling_type: _,
446 }) => Self::new_yarn(
447 cfg,
448 dtype,
449 dev,
450 *original_max_position_embeddings,
451 *beta_fast,
452 *beta_slow,
453 *factor,
454 *mscale,
455 *mscale_all_dim,
456 ),
457 None => Self::new_unscaled(cfg, dtype, dev),
458 }
459 }
460
461 pub fn forward(
462 &self,
463 q: &Tensor,
464 k: &Tensor,
465 seqlen_offset: usize,
466 ) -> Result<(Tensor, Tensor)> {
467 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
468
469 let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
470 let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
471
472 let q_embed = candle_nn::rotary_emb::rope_i(&q.contiguous()?, &cos, &sin)?;
473 let k_embed = candle_nn::rotary_emb::rope_i(&k.contiguous()?, &cos, &sin)?;
474
475 Ok((q_embed, k_embed))
476 }
477}
478
479impl DeepSeekV2Config {
480 pub(crate) fn q_head_dim(&self) -> usize {
481 self.qk_rope_head_dim + self.qk_nope_head_dim
482 }
483
484 fn softmax_scale(&self) -> f32 {
485 let mut softmax_scale = 1.0 / (self.q_head_dim() as f32).sqrt();
486 if let Some(DeepSeekV2RopeScaling::Yarn {
487 mscale_all_dim,
488 factor,
489 ..
490 }) = self.rope_scaling
491 {
492 let mscale = DeepSeekV2RotaryEmbedding::yarn_get_mscale(factor, mscale_all_dim);
493 softmax_scale = softmax_scale * mscale * mscale;
494 }
495 softmax_scale
496 }
497}
498
499enum QProj {
500 Plain(Linear),
501 Lora { a: Linear, norm: RmsNorm, b: Linear },
502}
503
504impl QProj {
505 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
506 match self {
507 Self::Lora { a, norm, b } => b.forward(&norm.forward(&a.forward(xs)?)?),
508 Self::Plain(lin) => lin.forward(xs),
509 }
510 }
511}
512
513struct Attention {
514 q: QProj,
515 kv_a_proj_with_mqa: Linear,
516 kv_a_layernorm: RmsNorm,
517 kv_b_proj: Linear,
518 o_proj: Linear,
519 rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
520 cfg: DeepSeekV2Config,
521 q_head_dim: usize,
522 softmax_scale: f64,
523 kv_cache: Option<(Tensor, Tensor)>,
524}
525
526impl Attention {
527 fn new(
528 rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
529 cfg: &DeepSeekV2Config,
530 vb: VarBuilder,
531 ) -> Result<Self> {
532 let q_head_dim = cfg.q_head_dim();
533 let q = match cfg.q_lora_rank {
534 Some(lora_rank) => {
535 let a = candle_nn::linear_b(
536 cfg.hidden_size,
537 lora_rank,
538 cfg.attention_bias,
539 vb.pp("q_a_proj"),
540 )?;
541 let norm = rms_norm(lora_rank, cfg.rms_norm_eps, vb.pp("q_a_layernorm"))?;
542 let b = candle_nn::linear_no_bias(
543 lora_rank,
544 cfg.num_attention_heads * q_head_dim,
545 vb.pp("q_b_proj"),
546 )?;
547 QProj::Lora { a, norm, b }
548 }
549 None => QProj::Plain(candle_nn::linear_no_bias(
550 cfg.hidden_size,
551 cfg.num_attention_heads * q_head_dim,
552 vb.pp("q_proj"),
553 )?),
554 };
555
556 let kv_a_proj_with_mqa = candle_nn::linear_b(
557 cfg.hidden_size,
558 cfg.kv_lora_rank + cfg.qk_rope_head_dim,
559 cfg.attention_bias,
560 vb.pp("kv_a_proj_with_mqa"),
561 )?;
562 let kv_a_layernorm = rms_norm(cfg.kv_lora_rank, cfg.rms_norm_eps, vb.pp("kv_a_layernorm"))?;
563 let kv_b_proj = candle_nn::linear_no_bias(
564 cfg.kv_lora_rank,
565 cfg.num_attention_heads * (q_head_dim - cfg.qk_rope_head_dim + cfg.v_head_dim),
566 vb.pp("kv_b_proj"),
567 )?;
568
569 let o_proj = candle_nn::linear_b(
570 cfg.num_attention_heads * cfg.v_head_dim,
571 cfg.hidden_size,
572 cfg.attention_bias,
573 vb.pp("o_proj"),
574 )?;
575
576 Ok(Self {
577 q,
578 kv_a_proj_with_mqa,
579 kv_a_layernorm,
580 kv_b_proj,
581 o_proj,
582 rotary_emb,
583 cfg: cfg.clone(),
584 q_head_dim,
585 softmax_scale: cfg.softmax_scale() as f64,
586 kv_cache: None,
587 })
588 }
589
590 fn forward(
591 &mut self,
592 xs: &Tensor,
593 attention_mask: Option<&Tensor>,
594 seqlen_offset: usize,
595 ) -> Result<Tensor> {
596 let (bs, seq_len, _) = xs.dims3()?;
597
598 let q = {
599 let q = self.q.forward(xs)?;
600 q.reshape((bs, seq_len, self.cfg.num_attention_heads, self.q_head_dim))?
601 .transpose(1, 2)?
602 };
603 let q_split = q.split(
604 &[self.cfg.qk_nope_head_dim, self.cfg.qk_rope_head_dim],
605 D::Minus1,
606 )?;
607 let q_nope = q_split[0].clone();
608 let q_pe = q_split[1].clone();
609
610 let compressed_kv = self.kv_a_proj_with_mqa.forward(xs)?;
611 let ckv_split = compressed_kv.split(
612 &[self.cfg.kv_lora_rank, self.cfg.qk_rope_head_dim],
613 D::Minus1,
614 )?;
615 let compressed_kv = ckv_split[0].clone();
616 let k_pe = {
617 let k_pe = ckv_split[1].clone();
618 k_pe.reshape((bs, seq_len, 1, self.cfg.qk_rope_head_dim))?
619 .transpose(1, 2)?
620 };
621 let kv = {
622 let kv = self
623 .kv_b_proj
624 .forward(&self.kv_a_layernorm.forward(&compressed_kv)?)?;
625 kv.reshape((
626 bs,
627 seq_len,
628 self.cfg.num_attention_heads,
629 self.cfg.qk_nope_head_dim + self.cfg.v_head_dim,
630 ))?
631 .transpose(1, 2)?
632 };
633
634 let kv_split = kv.split(&[self.cfg.qk_nope_head_dim, self.cfg.v_head_dim], D::Minus1)?;
635 let k_nope = kv_split[0].clone();
636 let v = kv_split[1].clone();
637
638 let (q_pe, k_pe) = self.rotary_emb.forward(&q_pe, &k_pe, seqlen_offset)?;
639
640 let q = Tensor::cat(&[q_nope, q_pe], D::Minus1)?;
641 let k = Tensor::cat(&[k_nope, k_pe.repeat((1, q.dim(1)?, 1, 1))?], D::Minus1)?;
642
643 let (k, v) = match &self.kv_cache {
644 None => (k, v),
645 Some((prev_k, prev_v)) => {
646 let key_states = Tensor::cat(&[prev_k, &k], 2)?;
647 let value_states = Tensor::cat(&[prev_v, &v], 2)?;
648 (key_states, value_states)
649 }
650 };
651 self.kv_cache = Some((k.clone(), v.clone()));
652
653 let attn_out = {
654 let att = (q.contiguous()?.matmul(&k.t()?.contiguous()?)? * self.softmax_scale)?;
655 let att = match attention_mask {
656 Some(mask) => att.broadcast_add(mask)?,
657 None => att,
658 };
659
660 let att = candle_nn::ops::softmax_last_dim(&att)?;
661 att.matmul(&v.contiguous()?)?
663 };
664
665 let attn_out = if attention_mask.is_some() {
666 attn_out.transpose(1, 2)?.reshape((bs, seq_len, ()))?
667 } else {
668 attn_out.reshape((bs, seq_len, ()))?
669 };
670
671 self.o_proj.forward(&attn_out)
672 }
673
674 fn clear_kv_cache(&mut self) {
675 self.kv_cache = None
676 }
677}
678
679struct Mlp {
680 gate: Linear,
681 up: Linear,
682 down: Linear,
683 act: Activation,
684}
685
686impl Mlp {
687 fn new(
688 cfg: &DeepSeekV2Config,
689 vb: VarBuilder,
690 hidden_size: Option<usize>,
691 intermediate_size: Option<usize>,
692 ) -> Result<Self> {
693 let hidden_size = hidden_size.unwrap_or(cfg.hidden_size);
694 let intermediate_size = intermediate_size.unwrap_or(cfg.intermediate_size);
695
696 Ok(Self {
697 gate: candle_nn::linear_no_bias(hidden_size, intermediate_size, vb.pp("gate_proj"))?,
698 up: candle_nn::linear_no_bias(hidden_size, intermediate_size, vb.pp("up_proj"))?,
699 down: candle_nn::linear_no_bias(intermediate_size, hidden_size, vb.pp("down_proj"))?,
700 act: cfg.hidden_act,
701 })
702 }
703
704 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
705 let lhs = self.gate.forward(xs)?.apply(&self.act)?;
706 let rhs = self.up.forward(xs)?;
707 self.down.forward(&(&lhs * &rhs)?)
708 }
709}
710
711struct MoeGate {
712 weight: Tensor,
713 cfg: DeepSeekV2Config,
714 top_k: usize,
715 n_routed_experts: usize,
716}
717
718impl MoeGate {
719 fn new(cfg: &DeepSeekV2Config, vb: VarBuilder, n_routed_experts: usize) -> Result<Self> {
720 let weight = vb.get((n_routed_experts, cfg.hidden_size), "weight")?;
721 Ok(Self {
722 weight,
723 cfg: cfg.clone(),
724 top_k: cfg.num_experts_per_tok.unwrap(),
725 n_routed_experts,
726 })
727 }
728
729 fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor)> {
731 let (bs, seq_len, h) = xs.dims3()?;
732 let xs = xs.reshape(((), h))?;
734 let logits = xs
735 .to_dtype(DType::F32)?
736 .broadcast_matmul(&self.weight.t()?.to_dtype(DType::F32)?)?;
737 let scores = match self.cfg.scoring_func {
738 ScoringFunc::Softmax => candle_nn::ops::softmax_last_dim(&logits)?,
739 };
740
741 let (mut topk_weight, topk_idx) = match self.cfg.topk_method {
743 TopkMethod::Greedy => {
744 let TopKOutput { values, indices } = scores.topk_unsorted(self.top_k)?;
745 (values, indices)
746 }
747 TopkMethod::GroupLimitedGreedy => {
748 let group_scores = scores
750 .reshape((bs * seq_len, self.cfg.n_group, ()))?
751 .max(D::Minus1)?;
752 let group_idx = scores.topk_unsorted(self.cfg.topk_group)?.indices;
754 let group_mask = group_scores.zeros_like()?.scatter_add(
756 &group_idx,
757 &group_idx.ones_like()?.to_dtype(group_scores.dtype())?,
758 1,
759 )?;
760 let score_mask = group_mask
762 .unsqueeze(D::Minus1)?
763 .expand((
764 bs * seq_len,
765 self.cfg.n_group,
766 self.n_routed_experts / self.cfg.n_group,
767 ))?
768 .reshape((bs, seq_len, ()))?;
769 let tmp_scores = masked_fill(&score_mask, &(1. - &score_mask.ne(0.)?)?, 0.)?;
772 let TopKOutput { values, indices } = tmp_scores.topk_unsorted(self.top_k)?;
773 (values, indices)
774 }
775 };
776
777 if self.top_k > 1 && self.cfg.norm_topk_prob {
778 let denominator = (topk_weight.sum_keepdim(D::Minus1)? + 1e-20)?;
779 topk_weight = (topk_weight / denominator)?;
780 } else {
781 topk_weight = (topk_weight * self.cfg.routed_scaling_factor)?;
782 }
783 Ok((topk_idx, topk_weight))
784 }
785}
786
787struct Moe {
788 experts: Vec<Mlp>,
789 shared_experts: Option<Mlp>,
790 gate: MoeGate,
791}
792
793impl Moe {
794 fn new(
795 cfg: &DeepSeekV2Config,
796 vb: VarBuilder,
797
798 n_shared_experts: Option<usize>,
799 n_routed_experts: usize,
800 ) -> Result<Self> {
801 let mut experts = Vec::with_capacity(n_routed_experts);
802 for i in 0..n_routed_experts {
803 let vb_e = vb.pp("experts").pp(i);
804 experts.push(Mlp::new(cfg, vb_e, None, Some(cfg.moe_intermediate_size))?);
805 }
806 let shared_experts = if let Some(n_shared_experts) = n_shared_experts {
807 let intermediate_size = cfg.moe_intermediate_size * n_shared_experts;
808 Some(Mlp::new(
809 cfg,
810 vb.pp("shared_experts"),
811 None,
812 Some(intermediate_size),
813 )?)
814 } else {
815 None
816 };
817 let gate = MoeGate::new(cfg, vb.pp("gate"), n_routed_experts)?;
818 Ok(Self {
819 experts,
820 shared_experts,
821 gate,
822 })
823 }
824
825 fn moe_infer(&self, xs: &Tensor, topk_ids: &Tensor, topk_weight: &Tensor) -> Result<Tensor> {
826 let mut y = xs.zeros_like()?;
827 let counts = topk_ids
828 .flatten_all()?
829 .bincount(self.experts.len() as u32)?;
830 for (i, expert) in self.experts.iter().enumerate() {
831 if counts[i] == 0 {
832 continue;
833 }
834 let idx_top = topk_ids.eq(i as f64)?.nonzero()?.t()?;
835 let idx = &idx_top.i(0)?.contiguous()?;
836 let top = &idx_top.i(1)?.contiguous()?;
837
838 y = y.index_add(
839 idx,
840 &expert.forward(&xs.index_select(idx, 0)?)?.broadcast_mul(
841 &topk_weight
842 .index_select(idx, 0)?
843 .gather(&top.unsqueeze(1)?, 1)?
844 .squeeze(1)?
845 .unsqueeze(D::Minus1)?
846 .to_dtype(xs.dtype())?,
847 )?,
848 0,
849 )?;
850 }
851
852 Ok(y)
853 }
854
855 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
856 let identity = xs.clone();
857 let orig_shape = xs.shape();
858 let (topk_idx, topk_weight) = self.gate.forward(xs)?;
859 let xs = xs.reshape(((), xs.dim(D::Minus1)?))?;
860
861 let mut y = self
862 .moe_infer(&xs, &topk_idx, &topk_weight)?
863 .reshape(orig_shape)?;
864 if let Some(ref shared_experts) = self.shared_experts {
865 y = (y + shared_experts.forward(&identity)?)?;
866 }
867 Ok(y)
868 }
869}
870
871enum MoeOrMlp {
872 Moe(Moe),
873 Mlp(Mlp),
874}
875
876impl MoeOrMlp {
877 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
878 match self {
879 Self::Mlp(mlp) => mlp.forward(xs),
880 Self::Moe(moe) => moe.forward(xs),
881 }
882 }
883}
884
885struct DecoderLayer {
886 input_layernorm: RmsNorm,
887 post_attention_layernorm: RmsNorm,
888 attn: Attention,
889 moe_or_mlp: MoeOrMlp,
890}
891
892impl DecoderLayer {
893 fn new(
894 rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
895 cfg: &DeepSeekV2Config,
896 vb: VarBuilder,
897 layer_idx: usize,
898 ) -> Result<Self> {
899 let attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
900 let input_layernorm =
901 rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
902 let post_attention_layernorm = rms_norm(
903 cfg.hidden_size,
904 cfg.rms_norm_eps,
905 vb.pp("post_attention_layernorm"),
906 )?;
907 let moe_or_mlp = if cfg.n_routed_experts.is_some()
908 && layer_idx >= cfg.first_k_dense_replace
909 && layer_idx % cfg.moe_layer_freq == 0
910 {
911 MoeOrMlp::Moe(Moe::new(
912 cfg,
913 vb.pp("mlp"),
914 cfg.n_shared_experts,
915 cfg.n_routed_experts.unwrap(),
916 )?)
917 } else {
918 MoeOrMlp::Mlp(Mlp::new(cfg, vb.pp("mlp"), None, None)?)
919 };
920
921 Ok(Self {
922 input_layernorm,
923 post_attention_layernorm,
924 attn,
925 moe_or_mlp,
926 })
927 }
928
929 fn forward(
930 &mut self,
931 xs: &Tensor,
932 attention_mask: Option<&Tensor>,
933 seqlen_offset: usize,
934 ) -> Result<Tensor> {
935 let residual = xs;
936 let xs = self.input_layernorm.forward(xs)?;
937 let xs = self.attn.forward(&xs, attention_mask, seqlen_offset)?;
938 let xs = (xs + residual)?;
939 let residual = &xs;
940 let xs = self
941 .moe_or_mlp
942 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
943 residual + xs
944 }
945
946 fn clear_kv_cache(&mut self) {
947 self.attn.clear_kv_cache();
948 }
949}
950
951pub struct DeepSeekV2 {
952 lm_head: Linear,
953 embed_tokens: Embedding,
954 norm: RmsNorm,
955 layers: Vec<DecoderLayer>,
956 dtype: DType,
957 device: Device,
958}
959
960impl DeepSeekV2 {
961 pub fn new(cfg: &DeepSeekV2Config, vb: VarBuilder) -> Result<Self> {
962 let vb_m = vb.pp("model");
963
964 let embed_tokens = embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
965 let lm_head = if !cfg.tie_word_embeddings {
966 candle_nn::linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
967 } else {
968 candle_nn::Linear::new(embed_tokens.embeddings().clone(), None)
969 };
970 let norm = rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
971
972 let rope_cfg = DeepSeekV2RopeConfig {
973 rope_scaling: cfg.rope_scaling.clone(),
974 max_position_embeddings: cfg.max_position_embeddings,
975 rope_theta: cfg.rope_theta,
976 qk_rope_head_dim: cfg.qk_rope_head_dim,
977 };
978 let rotary_emb = Arc::new(DeepSeekV2RotaryEmbedding::new(
979 &rope_cfg,
980 vb.dtype(),
981 vb.device(),
982 )?);
983
984 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
985 let vb_l = vb_m.pp("layers");
986 for layer_idx in 0..cfg.num_hidden_layers {
987 let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx), layer_idx)?;
988 layers.push(layer)
989 }
990
991 Ok(Self {
992 lm_head,
993 embed_tokens,
994 norm,
995 layers,
996 dtype: vb.dtype(),
997 device: vb.device().clone(),
998 })
999 }
1000
1001 fn prepare_decoder_attention_mask(
1002 &self,
1003 b_size: usize,
1004 tgt_len: usize,
1005 seqlen_offset: usize,
1006 ) -> Result<Tensor> {
1007 let mask: Vec<_> = (0..tgt_len)
1008 .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
1009 .collect();
1010 let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
1011 let mask = if seqlen_offset > 0 {
1012 let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
1013 Tensor::cat(&[&mask0, &mask], D::Minus1)?
1014 } else {
1015 mask
1016 };
1017 mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
1018 .to_dtype(self.dtype)
1019 }
1020
1021 pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
1022 let (bs, seq_len) = input_ids.dims2()?;
1023 let mut xs = self.embed_tokens.forward(input_ids)?;
1024 let attention_mask = if seq_len == 1 {
1025 None
1026 } else {
1027 let mask = self.prepare_decoder_attention_mask(bs, seq_len, seqlen_offset)?;
1028 Some(mask)
1029 };
1030 for layer in &mut self.layers {
1031 xs = layer.forward(
1032 &xs,
1033 attention_mask
1034 .as_ref()
1035 .map(|m| m.to_device(xs.device()).unwrap())
1036 .as_ref(),
1037 seqlen_offset,
1038 )?;
1039 }
1040 let xs = xs.apply(&self.norm)?;
1041 let xs = xs.i((.., seq_len - 1, ..))?.contiguous()?;
1042 let logits = self.lm_head.forward(&xs)?;
1043 logits.to_dtype(DType::F32)
1044 }
1045
1046 pub fn clear_kv_cache(&mut self) {
1047 for layer in self.layers.iter_mut() {
1048 layer.clear_kv_cache();
1049 }
1050 }
1051}