candle_transformers/models/
deepseek2.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{f32::consts::PI, sync::Arc};
4
5use candle::{
6    shape::Dim, CpuStorage, CustomOp1, DType, Device, Error, IndexOp, Layout, Result, Shape,
7    Tensor, WithDType, D,
8};
9use candle_nn::{embedding, rms_norm, Activation, Embedding, Linear, Module, RmsNorm, VarBuilder};
10use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
11use serde::Deserialize;
12
13struct NonZero {}
14
15impl NonZero {
16    // Sequential version
17    fn nonzero<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Vec<u32> {
18        let n = layout.dims().len();
19        let mut result = Vec::new();
20        let mut indices = vec![0u32; n];
21        for (i, v) in vs.iter().enumerate() {
22            if !v.is_zero() {
23                let mut idx = i;
24                for (dim_index, dim) in layout.dims().iter().enumerate().rev() {
25                    let d = idx % dim;
26                    indices[dim_index] = u32::try_from(d).unwrap();
27                    idx /= dim;
28                }
29                result.extend_from_slice(&indices);
30            }
31        }
32        result
33    }
34}
35
36impl CustomOp1 for NonZero {
37    fn name(&self) -> &'static str {
38        "nonzero"
39    }
40
41    fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
42        if !layout.is_contiguous() {
43            return Err(Error::RequiresContiguous { op: "nonzero" });
44        }
45        let result = match storage {
46            candle::CpuStorage::U8(vs) => self.nonzero(vs, layout),
47            candle::CpuStorage::U32(vs) => self.nonzero(vs, layout),
48            candle::CpuStorage::I64(vs) => self.nonzero(vs, layout),
49            candle::CpuStorage::BF16(vs) => self.nonzero(vs, layout),
50            candle::CpuStorage::F16(vs) => self.nonzero(vs, layout),
51            candle::CpuStorage::F32(vs) => self.nonzero(vs, layout),
52            candle::CpuStorage::F64(vs) => self.nonzero(vs, layout),
53        };
54        let index_len = layout.dims().len();
55        let result_len = result.len() / index_len;
56        let result = CpuStorage::U32(result);
57        let shape = Shape::from_dims(&[result_len, index_len]);
58        Ok((result, shape))
59    }
60}
61
62pub trait NonZeroOp {
63    fn nonzero(&self) -> Result<Tensor>;
64}
65
66impl NonZeroOp for Tensor {
67    fn nonzero(&self) -> Result<Tensor> {
68        if !self.is_contiguous() {
69            return Err(candle::Error::RequiresContiguous { op: "nonzero" });
70        }
71        let original_device = self.device();
72        self.to_device(&candle::Device::Cpu)?
73            .apply_op1_no_bwd(&NonZero {})?
74            .to_device(original_device)
75    }
76}
77
78pub struct TopKOutput {
79    pub values: Tensor,
80    pub indices: Tensor,
81}
82
83pub trait TopKLastDimOp {
84    /// Topk in the last dim. `values` retains a gradient but `indices` has none w.r.t self.
85    /// This expects a contiguous tensor.
86    /// Note: this implements torch.topk with sorted=True.
87    fn topk(&self, topk: usize) -> Result<TopKOutput>;
88
89    /// Topk in the last dim. `values` retains a gradient but `indices` has none w.r.t self.
90    /// This expects a contiguous tensor.
91    /// Note: this implements torch.topk with sorted=False.
92    fn topk_unsorted(&self, topk: usize) -> Result<TopKOutput>;
93}
94
95impl TopKLastDimOp for Tensor {
96    fn topk(&self, topk: usize) -> Result<TopKOutput> {
97        // Sorted descending
98        let sorted_indices = self.arg_sort_last_dim(false)?;
99        let topk_indices = sorted_indices.narrow(D::Minus1, 0, topk)?.contiguous()?;
100        Ok(TopKOutput {
101            values: self.gather(&topk_indices, D::Minus1)?,
102            indices: topk_indices,
103        })
104    }
105
106    fn topk_unsorted(&self, topk: usize) -> Result<TopKOutput> {
107        // Sorted descending
108        let sorted_indices_all = self.arg_sort_last_dim(false)?;
109        let topk_indices_sorted = sorted_indices_all
110            .narrow(D::Minus1, 0, topk)?
111            .contiguous()?;
112        let topk_values_sorted = self.gather(&topk_indices_sorted, D::Minus1)?;
113
114        // Reorder the indices ascending
115        let reorder_indices = topk_indices_sorted.arg_sort_last_dim(true)?;
116        let topk_indices_unsorted = topk_indices_sorted.gather(&reorder_indices, D::Minus1)?;
117        let topk_values_unsorted = topk_values_sorted.gather(&reorder_indices, D::Minus1)?;
118        Ok(TopKOutput {
119            values: topk_values_unsorted,
120            indices: topk_indices_unsorted,
121        })
122    }
123}
124
125pub trait SplitOp {
126    fn split<D: Dim>(&self, splits: &[usize], dim: D) -> Result<Vec<Tensor>>;
127}
128
129impl SplitOp for Tensor {
130    fn split<D: Dim>(&self, splits: &[usize], dim: D) -> Result<Vec<Tensor>> {
131        let dim = dim.to_index(self.shape(), "split")?;
132        let mut split_res = Vec::new();
133        let mut index = 0;
134        for split in splits {
135            split_res.push(self.narrow(dim, index, *split)?);
136            index += *split;
137        }
138        Ok(split_res)
139    }
140}
141
142pub trait BincountOp {
143    fn bincount(&self, minlength: u32) -> Result<Vec<u32>>;
144}
145
146fn bincount(values: &[u32], minlength: u32) -> Vec<u32> {
147    // Find the maximum value in `values` (or zero if empty)
148    let max_val = values.par_iter().max().copied().unwrap_or(0);
149
150    // The final size of the bin counts must be at least `minlength`
151    // and large enough to include the largest value in `values`.
152    let result_len = (max_val + 1).max(minlength);
153
154    // Each thread creates a local histogram (`fold`),
155    // and then they are merged together (`reduce`).
156    values
157        .par_iter()
158        .fold(
159            // Create a local histogram
160            || vec![0u32; result_len as usize],
161            // Update the local histogram
162            |mut local_counts, &val| {
163                local_counts[val as usize] += 1;
164                local_counts
165            },
166        )
167        // Merge histograms from all threads
168        .reduce(
169            // Identity (empty histogram)
170            || vec![0u32; result_len as usize],
171            // Combine two histograms
172            |mut global_counts, local_counts| {
173                for (g, l) in global_counts.iter_mut().zip(local_counts) {
174                    *g += l;
175                }
176                global_counts
177            },
178        )
179}
180
181impl BincountOp for Tensor {
182    fn bincount(&self, minlength: u32) -> Result<Vec<u32>> {
183        let values = self.to_vec1::<u32>()?;
184
185        Ok(bincount(&values, minlength))
186    }
187}
188
189fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
190    let shape = mask.shape();
191    let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
192    let m = mask.where_cond(&on_true, on_false)?;
193    Ok(m)
194}
195
196#[doc(hidden)]
197#[macro_export]
198macro_rules! serde_default_fn {
199    ($t:ty, $name:ident, $v:expr) => {
200        fn $name() -> $t {
201            $v
202        }
203    };
204}
205
206serde_default_fn!(f64, routed_scaling_factor, 1.0);
207serde_default_fn!(TopkMethod, topk_method, TopkMethod::Greedy);
208serde_default_fn!(usize, moe_layer_freq, 1);
209serde_default_fn!(usize, first_k_dense_replace, 0);
210serde_default_fn!(bool, norm_topk_prob, false);
211serde_default_fn!(ScoringFunc, scoring_func, ScoringFunc::Softmax);
212serde_default_fn!(Activation, hidden_act, Activation::Silu);
213serde_default_fn!(bool, tie_word_embeddings, false);
214
215#[derive(Deserialize, Clone, Debug)]
216enum TopkMethod {
217    #[serde(rename = "greedy")]
218    Greedy,
219    #[serde(rename = "group_limited_greedy")]
220    GroupLimitedGreedy,
221}
222
223#[derive(Deserialize, Clone, Debug)]
224enum ScoringFunc {
225    #[serde(rename = "softmax")]
226    Softmax,
227}
228
229#[derive(Deserialize, Clone, Debug)]
230pub struct DeepSeekV2Config {
231    pub(crate) vocab_size: usize,
232    pub(crate) hidden_size: usize,
233    pub(crate) intermediate_size: usize,
234    pub(crate) moe_intermediate_size: usize,
235    pub(crate) num_hidden_layers: usize,
236    pub(crate) num_attention_heads: usize,
237    pub(crate) n_shared_experts: Option<usize>,
238    pub(crate) n_routed_experts: Option<usize>,
239    #[serde(default = "routed_scaling_factor")]
240    pub(crate) routed_scaling_factor: f64,
241    #[serde(default = "topk_method")]
242    topk_method: TopkMethod,
243    pub(crate) num_experts_per_tok: Option<usize>,
244    #[serde(default = "moe_layer_freq")]
245    pub(crate) moe_layer_freq: usize,
246    #[serde(default = "first_k_dense_replace")]
247    pub(crate) first_k_dense_replace: usize,
248    // k dense layers
249    #[serde(default = "norm_topk_prob")]
250    pub(crate) norm_topk_prob: bool,
251    #[serde(default = "scoring_func")]
252    scoring_func: ScoringFunc,
253    #[serde(default = "hidden_act")]
254    pub(crate) hidden_act: Activation,
255    pub(crate) max_position_embeddings: usize,
256    pub(crate) rms_norm_eps: f64,
257    #[serde(default = "tie_word_embeddings")]
258    pub(crate) tie_word_embeddings: bool,
259    pub(crate) rope_theta: f32,
260    pub(crate) rope_scaling: Option<DeepSeekV2RopeScaling>,
261    pub(crate) attention_bias: bool,
262    pub(crate) q_lora_rank: Option<usize>,
263    pub(crate) qk_rope_head_dim: usize,
264    pub(crate) kv_lora_rank: usize,
265    pub(crate) v_head_dim: usize,
266    pub(crate) qk_nope_head_dim: usize,
267    pub(crate) n_group: usize,
268    pub(crate) topk_group: usize,
269}
270
271#[derive(Debug, Clone, Deserialize)]
272#[serde(rename_all = "lowercase")]
273pub enum ScaledRopeType {
274    #[serde(alias = "su")]
275    #[serde(alias = "longrope")]
276    Su,
277    #[serde(alias = "yarn")]
278    Yarn,
279    #[serde(alias = "dynamic")]
280    Dynamic,
281    #[serde(alias = "linear")]
282    Linear,
283}
284
285#[derive(Debug, Clone)]
286pub struct DeepSeekV2RotaryEmbedding {
287    sin: Tensor,
288    cos: Tensor,
289}
290
291#[derive(Debug, Clone, Deserialize)]
292#[serde(untagged)]
293pub enum DeepSeekV2RopeScaling {
294    Yarn {
295        original_max_position_embeddings: usize,
296        beta_fast: f32,
297        beta_slow: f32,
298        mscale: f32,
299        mscale_all_dim: f32,
300        factor: f32,
301        #[serde(rename = "type")]
302        scaling_type: ScaledRopeType,
303    },
304    LinearOrDynamic {
305        #[serde(rename = "type")]
306        scaling_type: ScaledRopeType,
307        factor: f64,
308    },
309}
310
311pub struct DeepSeekV2RopeConfig {
312    pub rope_scaling: Option<DeepSeekV2RopeScaling>,
313    pub max_position_embeddings: usize,
314    pub rope_theta: f32,
315    pub qk_rope_head_dim: usize,
316}
317
318impl DeepSeekV2RotaryEmbedding {
319    fn new_unscaled(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
320        let max_seq_len = cfg.max_position_embeddings;
321        let dim = cfg.qk_rope_head_dim;
322
323        let inv_freq: Vec<_> = (0..dim)
324            .step_by(2)
325            .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32))
326            .collect();
327        let inv_freq_len = inv_freq.len();
328        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
329        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
330            .to_dtype(DType::F32)?
331            .reshape((max_seq_len, 1))?;
332        let freqs = t.matmul(&inv_freq)?;
333
334        let sin = freqs.sin()?.to_dtype(dtype)?;
335        let cos = freqs.cos()?.to_dtype(dtype)?;
336
337        Ok(Self { sin, cos })
338    }
339
340    fn yarn_find_correction_dim(
341        num_rot: f32,
342        dim: usize,
343        base: f32,
344        max_position_embeddings: usize,
345    ) -> f32 {
346        (dim as f32 * (max_position_embeddings as f32 / (num_rot * 2. * PI)).ln())
347            / (2. * base.ln())
348    }
349
350    fn yarn_find_correction_range(
351        low_rot: f32,
352        high_rot: f32,
353        dim: usize,
354        base: f32,
355        max_position_embeddings: usize,
356    ) -> (f32, f32) {
357        let low =
358            Self::yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings).floor();
359        let high =
360            Self::yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings).ceil();
361        (low.max(0.), high.min(dim as f32 - 1.))
362    }
363
364    fn yarn_linear_ramp_mask(min: f32, mut max: f32, dim: usize, dev: &Device) -> Result<Tensor> {
365        if min == max {
366            // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/604d5664dddd88a0433dbae533b7fe9472482de0/modeling_deepseek.py#L255
367            max += 0.001;
368        }
369        let linear_func =
370            ((Tensor::arange(0f32, dim as f32, dev)? - min as f64)? / (max as f64 - min as f64))?;
371        linear_func.clamp(0., 1.)
372    }
373
374    pub(crate) fn yarn_get_mscale(scale: f32, mscale: f32) -> f32 {
375        if scale <= 1. {
376            return 1.;
377        }
378        0.1 * mscale * scale.ln() + 1.
379    }
380
381    #[allow(clippy::too_many_arguments)]
382    fn new_yarn(
383        cfg: &DeepSeekV2RopeConfig,
384        dtype: DType,
385        dev: &Device,
386        original_max_position_embeddings: usize,
387        beta_fast: f32,
388        beta_slow: f32,
389        factor: f32,
390        mscale: f32,
391        mscale_all_dim: f32,
392    ) -> Result<Self> {
393        let freq_extra: Vec<_> = (0..cfg.qk_rope_head_dim)
394            .step_by(2)
395            .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32))
396            .collect();
397        let freq_extra_len = freq_extra.len();
398        let freq_extra = Tensor::from_vec(freq_extra, freq_extra_len, dev)?;
399        let freq_inter: Vec<_> = (0..cfg.qk_rope_head_dim)
400            .step_by(2)
401            .map(|i| 1f32 / (factor * cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32)))
402            .collect();
403        let freq_inter_len = freq_inter.len();
404        let freq_inter = Tensor::from_vec(freq_inter, (1, freq_inter_len), dev)?;
405
406        let (low, high) = Self::yarn_find_correction_range(
407            beta_fast,
408            beta_slow,
409            cfg.qk_rope_head_dim,
410            cfg.rope_theta,
411            original_max_position_embeddings,
412        );
413        let inv_freq_mask =
414            (1. - Self::yarn_linear_ramp_mask(low, high, cfg.qk_rope_head_dim / 2, dev)?)?;
415        let inv_freq = freq_inter
416            .broadcast_mul(&(1. - &inv_freq_mask)?)?
417            .broadcast_add(&freq_extra.broadcast_mul(&inv_freq_mask)?)?;
418
419        let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
420            .to_dtype(DType::F32)?
421            .reshape((cfg.max_position_embeddings, 1))?;
422        let freqs = t.matmul(&inv_freq)?;
423
424        let mscale =
425            Self::yarn_get_mscale(factor, mscale) / Self::yarn_get_mscale(factor, mscale_all_dim);
426        let sin = (freqs.sin()? * mscale as f64)?.to_dtype(dtype)?;
427        let cos = (freqs.cos()? * mscale as f64)?.to_dtype(dtype)?;
428
429        Ok(Self { sin, cos })
430    }
431
432    pub fn new(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
433        match &cfg.rope_scaling {
434            Some(DeepSeekV2RopeScaling::LinearOrDynamic {
435                scaling_type: _,
436                factor: _,
437            }) => candle::bail!("linear and dynamic rope are not implemented yet!"),
438            Some(DeepSeekV2RopeScaling::Yarn {
439                original_max_position_embeddings,
440                beta_fast,
441                beta_slow,
442                factor,
443                mscale,
444                mscale_all_dim,
445                scaling_type: _,
446            }) => Self::new_yarn(
447                cfg,
448                dtype,
449                dev,
450                *original_max_position_embeddings,
451                *beta_fast,
452                *beta_slow,
453                *factor,
454                *mscale,
455                *mscale_all_dim,
456            ),
457            None => Self::new_unscaled(cfg, dtype, dev),
458        }
459    }
460
461    pub fn forward(
462        &self,
463        q: &Tensor,
464        k: &Tensor,
465        seqlen_offset: usize,
466    ) -> Result<(Tensor, Tensor)> {
467        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
468
469        let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
470        let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
471
472        let q_embed = candle_nn::rotary_emb::rope_i(&q.contiguous()?, &cos, &sin)?;
473        let k_embed = candle_nn::rotary_emb::rope_i(&k.contiguous()?, &cos, &sin)?;
474
475        Ok((q_embed, k_embed))
476    }
477}
478
479impl DeepSeekV2Config {
480    pub(crate) fn q_head_dim(&self) -> usize {
481        self.qk_rope_head_dim + self.qk_nope_head_dim
482    }
483
484    fn softmax_scale(&self) -> f32 {
485        let mut softmax_scale = 1.0 / (self.q_head_dim() as f32).sqrt();
486        if let Some(DeepSeekV2RopeScaling::Yarn {
487            mscale_all_dim,
488            factor,
489            ..
490        }) = self.rope_scaling
491        {
492            let mscale = DeepSeekV2RotaryEmbedding::yarn_get_mscale(factor, mscale_all_dim);
493            softmax_scale = softmax_scale * mscale * mscale;
494        }
495        softmax_scale
496    }
497}
498
499enum QProj {
500    Plain(Linear),
501    Lora { a: Linear, norm: RmsNorm, b: Linear },
502}
503
504impl QProj {
505    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
506        match self {
507            Self::Lora { a, norm, b } => b.forward(&norm.forward(&a.forward(xs)?)?),
508            Self::Plain(lin) => lin.forward(xs),
509        }
510    }
511}
512
513struct Attention {
514    q: QProj,
515    kv_a_proj_with_mqa: Linear,
516    kv_a_layernorm: RmsNorm,
517    kv_b_proj: Linear,
518    o_proj: Linear,
519    rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
520    cfg: DeepSeekV2Config,
521    q_head_dim: usize,
522    softmax_scale: f64,
523    kv_cache: Option<(Tensor, Tensor)>,
524}
525
526impl Attention {
527    fn new(
528        rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
529        cfg: &DeepSeekV2Config,
530        vb: VarBuilder,
531    ) -> Result<Self> {
532        let q_head_dim = cfg.q_head_dim();
533        let q = match cfg.q_lora_rank {
534            Some(lora_rank) => {
535                let a = candle_nn::linear_b(
536                    cfg.hidden_size,
537                    lora_rank,
538                    cfg.attention_bias,
539                    vb.pp("q_a_proj"),
540                )?;
541                let norm = rms_norm(lora_rank, cfg.rms_norm_eps, vb.pp("q_a_layernorm"))?;
542                let b = candle_nn::linear_no_bias(
543                    lora_rank,
544                    cfg.num_attention_heads * q_head_dim,
545                    vb.pp("q_b_proj"),
546                )?;
547                QProj::Lora { a, norm, b }
548            }
549            None => QProj::Plain(candle_nn::linear_no_bias(
550                cfg.hidden_size,
551                cfg.num_attention_heads * q_head_dim,
552                vb.pp("q_proj"),
553            )?),
554        };
555
556        let kv_a_proj_with_mqa = candle_nn::linear_b(
557            cfg.hidden_size,
558            cfg.kv_lora_rank + cfg.qk_rope_head_dim,
559            cfg.attention_bias,
560            vb.pp("kv_a_proj_with_mqa"),
561        )?;
562        let kv_a_layernorm = rms_norm(cfg.kv_lora_rank, cfg.rms_norm_eps, vb.pp("kv_a_layernorm"))?;
563        let kv_b_proj = candle_nn::linear_no_bias(
564            cfg.kv_lora_rank,
565            cfg.num_attention_heads * (q_head_dim - cfg.qk_rope_head_dim + cfg.v_head_dim),
566            vb.pp("kv_b_proj"),
567        )?;
568
569        let o_proj = candle_nn::linear_b(
570            cfg.num_attention_heads * cfg.v_head_dim,
571            cfg.hidden_size,
572            cfg.attention_bias,
573            vb.pp("o_proj"),
574        )?;
575
576        Ok(Self {
577            q,
578            kv_a_proj_with_mqa,
579            kv_a_layernorm,
580            kv_b_proj,
581            o_proj,
582            rotary_emb,
583            cfg: cfg.clone(),
584            q_head_dim,
585            softmax_scale: cfg.softmax_scale() as f64,
586            kv_cache: None,
587        })
588    }
589
590    fn forward(
591        &mut self,
592        xs: &Tensor,
593        attention_mask: Option<&Tensor>,
594        seqlen_offset: usize,
595    ) -> Result<Tensor> {
596        let (bs, seq_len, _) = xs.dims3()?;
597
598        let q = {
599            let q = self.q.forward(xs)?;
600            q.reshape((bs, seq_len, self.cfg.num_attention_heads, self.q_head_dim))?
601                .transpose(1, 2)?
602        };
603        let q_split = q.split(
604            &[self.cfg.qk_nope_head_dim, self.cfg.qk_rope_head_dim],
605            D::Minus1,
606        )?;
607        let q_nope = q_split[0].clone();
608        let q_pe = q_split[1].clone();
609
610        let compressed_kv = self.kv_a_proj_with_mqa.forward(xs)?;
611        let ckv_split = compressed_kv.split(
612            &[self.cfg.kv_lora_rank, self.cfg.qk_rope_head_dim],
613            D::Minus1,
614        )?;
615        let compressed_kv = ckv_split[0].clone();
616        let k_pe = {
617            let k_pe = ckv_split[1].clone();
618            k_pe.reshape((bs, seq_len, 1, self.cfg.qk_rope_head_dim))?
619                .transpose(1, 2)?
620        };
621        let kv = {
622            let kv = self
623                .kv_b_proj
624                .forward(&self.kv_a_layernorm.forward(&compressed_kv)?)?;
625            kv.reshape((
626                bs,
627                seq_len,
628                self.cfg.num_attention_heads,
629                self.cfg.qk_nope_head_dim + self.cfg.v_head_dim,
630            ))?
631            .transpose(1, 2)?
632        };
633
634        let kv_split = kv.split(&[self.cfg.qk_nope_head_dim, self.cfg.v_head_dim], D::Minus1)?;
635        let k_nope = kv_split[0].clone();
636        let v = kv_split[1].clone();
637
638        let (q_pe, k_pe) = self.rotary_emb.forward(&q_pe, &k_pe, seqlen_offset)?;
639
640        let q = Tensor::cat(&[q_nope, q_pe], D::Minus1)?;
641        let k = Tensor::cat(&[k_nope, k_pe.repeat((1, q.dim(1)?, 1, 1))?], D::Minus1)?;
642
643        let (k, v) = match &self.kv_cache {
644            None => (k, v),
645            Some((prev_k, prev_v)) => {
646                let key_states = Tensor::cat(&[prev_k, &k], 2)?;
647                let value_states = Tensor::cat(&[prev_v, &v], 2)?;
648                (key_states, value_states)
649            }
650        };
651        self.kv_cache = Some((k.clone(), v.clone()));
652
653        let attn_out = {
654            let att = (q.contiguous()?.matmul(&k.t()?.contiguous()?)? * self.softmax_scale)?;
655            let att = match attention_mask {
656                Some(mask) => att.broadcast_add(mask)?,
657                None => att,
658            };
659
660            let att = candle_nn::ops::softmax_last_dim(&att)?;
661            // Convert to contiguous as matmul doesn't support strided vs for now.
662            att.matmul(&v.contiguous()?)?
663        };
664
665        let attn_out = if attention_mask.is_some() {
666            attn_out.transpose(1, 2)?.reshape((bs, seq_len, ()))?
667        } else {
668            attn_out.reshape((bs, seq_len, ()))?
669        };
670
671        self.o_proj.forward(&attn_out)
672    }
673
674    fn clear_kv_cache(&mut self) {
675        self.kv_cache = None
676    }
677}
678
679struct Mlp {
680    gate: Linear,
681    up: Linear,
682    down: Linear,
683    act: Activation,
684}
685
686impl Mlp {
687    fn new(
688        cfg: &DeepSeekV2Config,
689        vb: VarBuilder,
690        hidden_size: Option<usize>,
691        intermediate_size: Option<usize>,
692    ) -> Result<Self> {
693        let hidden_size = hidden_size.unwrap_or(cfg.hidden_size);
694        let intermediate_size = intermediate_size.unwrap_or(cfg.intermediate_size);
695
696        Ok(Self {
697            gate: candle_nn::linear_no_bias(hidden_size, intermediate_size, vb.pp("gate_proj"))?,
698            up: candle_nn::linear_no_bias(hidden_size, intermediate_size, vb.pp("up_proj"))?,
699            down: candle_nn::linear_no_bias(intermediate_size, hidden_size, vb.pp("down_proj"))?,
700            act: cfg.hidden_act,
701        })
702    }
703
704    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
705        let lhs = self.gate.forward(xs)?.apply(&self.act)?;
706        let rhs = self.up.forward(xs)?;
707        self.down.forward(&(&lhs * &rhs)?)
708    }
709}
710
711struct MoeGate {
712    weight: Tensor,
713    cfg: DeepSeekV2Config,
714    top_k: usize,
715    n_routed_experts: usize,
716}
717
718impl MoeGate {
719    fn new(cfg: &DeepSeekV2Config, vb: VarBuilder, n_routed_experts: usize) -> Result<Self> {
720        let weight = vb.get((n_routed_experts, cfg.hidden_size), "weight")?;
721        Ok(Self {
722            weight,
723            cfg: cfg.clone(),
724            top_k: cfg.num_experts_per_tok.unwrap(),
725            n_routed_experts,
726        })
727    }
728
729    /// (topk_idx, topk_weight)
730    fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor)> {
731        let (bs, seq_len, h) = xs.dims3()?;
732        // Compute gating score
733        let xs = xs.reshape(((), h))?;
734        let logits = xs
735            .to_dtype(DType::F32)?
736            .broadcast_matmul(&self.weight.t()?.to_dtype(DType::F32)?)?;
737        let scores = match self.cfg.scoring_func {
738            ScoringFunc::Softmax => candle_nn::ops::softmax_last_dim(&logits)?,
739        };
740
741        // Select top-k experts
742        let (mut topk_weight, topk_idx) = match self.cfg.topk_method {
743            TopkMethod::Greedy => {
744                let TopKOutput { values, indices } = scores.topk_unsorted(self.top_k)?;
745                (values, indices)
746            }
747            TopkMethod::GroupLimitedGreedy => {
748                // (n, n_group)
749                let group_scores = scores
750                    .reshape((bs * seq_len, self.cfg.n_group, ()))?
751                    .max(D::Minus1)?;
752                // (n, topk_group)
753                let group_idx = scores.topk_unsorted(self.cfg.topk_group)?.indices;
754                // (n, n_group)
755                let group_mask = group_scores.zeros_like()?.scatter_add(
756                    &group_idx,
757                    &group_idx.ones_like()?.to_dtype(group_scores.dtype())?,
758                    1,
759                )?;
760                // (n, e)
761                let score_mask = group_mask
762                    .unsqueeze(D::Minus1)?
763                    .expand((
764                        bs * seq_len,
765                        self.cfg.n_group,
766                        self.n_routed_experts / self.cfg.n_group,
767                    ))?
768                    .reshape((bs, seq_len, ()))?;
769                // (n, e)
770                // Invert the mask
771                let tmp_scores = masked_fill(&score_mask, &(1. - &score_mask.ne(0.)?)?, 0.)?;
772                let TopKOutput { values, indices } = tmp_scores.topk_unsorted(self.top_k)?;
773                (values, indices)
774            }
775        };
776
777        if self.top_k > 1 && self.cfg.norm_topk_prob {
778            let denominator = (topk_weight.sum_keepdim(D::Minus1)? + 1e-20)?;
779            topk_weight = (topk_weight / denominator)?;
780        } else {
781            topk_weight = (topk_weight * self.cfg.routed_scaling_factor)?;
782        }
783        Ok((topk_idx, topk_weight))
784    }
785}
786
787struct Moe {
788    experts: Vec<Mlp>,
789    shared_experts: Option<Mlp>,
790    gate: MoeGate,
791}
792
793impl Moe {
794    fn new(
795        cfg: &DeepSeekV2Config,
796        vb: VarBuilder,
797
798        n_shared_experts: Option<usize>,
799        n_routed_experts: usize,
800    ) -> Result<Self> {
801        let mut experts = Vec::with_capacity(n_routed_experts);
802        for i in 0..n_routed_experts {
803            let vb_e = vb.pp("experts").pp(i);
804            experts.push(Mlp::new(cfg, vb_e, None, Some(cfg.moe_intermediate_size))?);
805        }
806        let shared_experts = if let Some(n_shared_experts) = n_shared_experts {
807            let intermediate_size = cfg.moe_intermediate_size * n_shared_experts;
808            Some(Mlp::new(
809                cfg,
810                vb.pp("shared_experts"),
811                None,
812                Some(intermediate_size),
813            )?)
814        } else {
815            None
816        };
817        let gate = MoeGate::new(cfg, vb.pp("gate"), n_routed_experts)?;
818        Ok(Self {
819            experts,
820            shared_experts,
821            gate,
822        })
823    }
824
825    fn moe_infer(&self, xs: &Tensor, topk_ids: &Tensor, topk_weight: &Tensor) -> Result<Tensor> {
826        let mut y = xs.zeros_like()?;
827        let counts = topk_ids
828            .flatten_all()?
829            .bincount(self.experts.len() as u32)?;
830        for (i, expert) in self.experts.iter().enumerate() {
831            if counts[i] == 0 {
832                continue;
833            }
834            let idx_top = topk_ids.eq(i as f64)?.nonzero()?.t()?;
835            let idx = &idx_top.i(0)?.contiguous()?;
836            let top = &idx_top.i(1)?.contiguous()?;
837
838            y = y.index_add(
839                idx,
840                &expert.forward(&xs.index_select(idx, 0)?)?.broadcast_mul(
841                    &topk_weight
842                        .index_select(idx, 0)?
843                        .gather(&top.unsqueeze(1)?, 1)?
844                        .squeeze(1)?
845                        .unsqueeze(D::Minus1)?
846                        .to_dtype(xs.dtype())?,
847                )?,
848                0,
849            )?;
850        }
851
852        Ok(y)
853    }
854
855    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
856        let identity = xs.clone();
857        let orig_shape = xs.shape();
858        let (topk_idx, topk_weight) = self.gate.forward(xs)?;
859        let xs = xs.reshape(((), xs.dim(D::Minus1)?))?;
860
861        let mut y = self
862            .moe_infer(&xs, &topk_idx, &topk_weight)?
863            .reshape(orig_shape)?;
864        if let Some(ref shared_experts) = self.shared_experts {
865            y = (y + shared_experts.forward(&identity)?)?;
866        }
867        Ok(y)
868    }
869}
870
871enum MoeOrMlp {
872    Moe(Moe),
873    Mlp(Mlp),
874}
875
876impl MoeOrMlp {
877    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
878        match self {
879            Self::Mlp(mlp) => mlp.forward(xs),
880            Self::Moe(moe) => moe.forward(xs),
881        }
882    }
883}
884
885struct DecoderLayer {
886    input_layernorm: RmsNorm,
887    post_attention_layernorm: RmsNorm,
888    attn: Attention,
889    moe_or_mlp: MoeOrMlp,
890}
891
892impl DecoderLayer {
893    fn new(
894        rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
895        cfg: &DeepSeekV2Config,
896        vb: VarBuilder,
897        layer_idx: usize,
898    ) -> Result<Self> {
899        let attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
900        let input_layernorm =
901            rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
902        let post_attention_layernorm = rms_norm(
903            cfg.hidden_size,
904            cfg.rms_norm_eps,
905            vb.pp("post_attention_layernorm"),
906        )?;
907        let moe_or_mlp = if cfg.n_routed_experts.is_some()
908            && layer_idx >= cfg.first_k_dense_replace
909            && layer_idx % cfg.moe_layer_freq == 0
910        {
911            MoeOrMlp::Moe(Moe::new(
912                cfg,
913                vb.pp("mlp"),
914                cfg.n_shared_experts,
915                cfg.n_routed_experts.unwrap(),
916            )?)
917        } else {
918            MoeOrMlp::Mlp(Mlp::new(cfg, vb.pp("mlp"), None, None)?)
919        };
920
921        Ok(Self {
922            input_layernorm,
923            post_attention_layernorm,
924            attn,
925            moe_or_mlp,
926        })
927    }
928
929    fn forward(
930        &mut self,
931        xs: &Tensor,
932        attention_mask: Option<&Tensor>,
933        seqlen_offset: usize,
934    ) -> Result<Tensor> {
935        let residual = xs;
936        let xs = self.input_layernorm.forward(xs)?;
937        let xs = self.attn.forward(&xs, attention_mask, seqlen_offset)?;
938        let xs = (xs + residual)?;
939        let residual = &xs;
940        let xs = self
941            .moe_or_mlp
942            .forward(&xs.apply(&self.post_attention_layernorm)?)?;
943        residual + xs
944    }
945
946    fn clear_kv_cache(&mut self) {
947        self.attn.clear_kv_cache();
948    }
949}
950
951pub struct DeepSeekV2 {
952    lm_head: Linear,
953    embed_tokens: Embedding,
954    norm: RmsNorm,
955    layers: Vec<DecoderLayer>,
956    dtype: DType,
957    device: Device,
958}
959
960impl DeepSeekV2 {
961    pub fn new(cfg: &DeepSeekV2Config, vb: VarBuilder) -> Result<Self> {
962        let vb_m = vb.pp("model");
963
964        let embed_tokens = embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
965        let lm_head = if !cfg.tie_word_embeddings {
966            candle_nn::linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
967        } else {
968            candle_nn::Linear::new(embed_tokens.embeddings().clone(), None)
969        };
970        let norm = rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
971
972        let rope_cfg = DeepSeekV2RopeConfig {
973            rope_scaling: cfg.rope_scaling.clone(),
974            max_position_embeddings: cfg.max_position_embeddings,
975            rope_theta: cfg.rope_theta,
976            qk_rope_head_dim: cfg.qk_rope_head_dim,
977        };
978        let rotary_emb = Arc::new(DeepSeekV2RotaryEmbedding::new(
979            &rope_cfg,
980            vb.dtype(),
981            vb.device(),
982        )?);
983
984        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
985        let vb_l = vb_m.pp("layers");
986        for layer_idx in 0..cfg.num_hidden_layers {
987            let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx), layer_idx)?;
988            layers.push(layer)
989        }
990
991        Ok(Self {
992            lm_head,
993            embed_tokens,
994            norm,
995            layers,
996            dtype: vb.dtype(),
997            device: vb.device().clone(),
998        })
999    }
1000
1001    fn prepare_decoder_attention_mask(
1002        &self,
1003        b_size: usize,
1004        tgt_len: usize,
1005        seqlen_offset: usize,
1006    ) -> Result<Tensor> {
1007        let mask: Vec<_> = (0..tgt_len)
1008            .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
1009            .collect();
1010        let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
1011        let mask = if seqlen_offset > 0 {
1012            let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
1013            Tensor::cat(&[&mask0, &mask], D::Minus1)?
1014        } else {
1015            mask
1016        };
1017        mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
1018            .to_dtype(self.dtype)
1019    }
1020
1021    pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
1022        let (bs, seq_len) = input_ids.dims2()?;
1023        let mut xs = self.embed_tokens.forward(input_ids)?;
1024        let attention_mask = if seq_len == 1 {
1025            None
1026        } else {
1027            let mask = self.prepare_decoder_attention_mask(bs, seq_len, seqlen_offset)?;
1028            Some(mask)
1029        };
1030        for layer in &mut self.layers {
1031            xs = layer.forward(
1032                &xs,
1033                attention_mask
1034                    .as_ref()
1035                    .map(|m| m.to_device(xs.device()).unwrap())
1036                    .as_ref(),
1037                seqlen_offset,
1038            )?;
1039        }
1040        let xs = xs.apply(&self.norm)?;
1041        let xs = xs.i((.., seq_len - 1, ..))?.contiguous()?;
1042        let logits = self.lm_head.forward(&xs)?;
1043        logits.to_dtype(DType::F32)
1044    }
1045
1046    pub fn clear_kv_cache(&mut self) {
1047        for layer in self.layers.iter_mut() {
1048            layer.clear_kv_cache();
1049        }
1050    }
1051}