candle_transformers/models/
mixtral.rs

1//! Mixtral Model, a sparse mixture of expert model based on the Mistral architecture
2//!
3//! See Mixtral model details at:
4//! - [Hugging Face](https://huggingface.co/docs/transformers/model_doc/mixtral)
5//! - [Mixtral-8x7B Blog Post](https://mistral.ai/news/mixtral-of-experts/)
6//!
7//! The model uses a mixture of experts architecture with:
8//! - 8 experts per layer
9//! - Top 2 expert routing
10//! - Sliding window attention
11//! - RoPE embeddings
12//!
13//! References:
14//! - [Hugging Face Implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py)
15//! - [Mixtral Blog Post](https://mistral.ai/news/mixtral-of-experts/)
16//!
17
18use crate::models::with_tracing::{linear_no_bias, Linear, RmsNorm};
19/// Mixtral Model
20/// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
21/// https://mistral.ai/news/mixtral-of-experts/
22use candle::{DType, Device, Module, Result, Tensor, D};
23use candle_nn::{Activation, VarBuilder};
24use serde::Deserialize;
25use std::sync::Arc;
26
27/// https://github.com/huggingface/transformers/blob/1a585c1222a56bcaecc070966d558d4a9d862e83/src/transformers/models/mixtral/configuration_mixtral.py#L113
28#[derive(Debug, Clone, PartialEq, Deserialize)]
29pub struct Config {
30    pub(crate) vocab_size: usize,
31    pub(crate) hidden_size: usize,
32    pub(crate) intermediate_size: usize,
33    pub(crate) num_hidden_layers: usize,
34    pub(crate) num_attention_heads: usize,
35    pub(crate) num_key_value_heads: usize,
36    pub(crate) hidden_act: Activation,
37    pub(crate) max_position_embeddings: usize,
38    pub(crate) rms_norm_eps: f64,
39    pub(crate) rope_theta: f64,
40    pub(crate) sliding_window: usize,
41    pub(crate) num_experts_per_tok: usize,
42    pub(crate) num_local_experts: usize,
43    pub(crate) use_flash_attn: bool,
44}
45
46impl Config {
47    /// https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/blob/main/config.json
48    pub fn v0_1_8x7b(use_flash_attn: bool) -> Self {
49        Self {
50            vocab_size: 32000,
51            hidden_size: 4096,
52            intermediate_size: 14336,
53            num_hidden_layers: 32,
54            num_attention_heads: 32,
55            num_key_value_heads: 8,
56            hidden_act: Activation::Silu,
57            max_position_embeddings: 32768,
58            rms_norm_eps: 1e-5,
59            rope_theta: 1e6,
60            sliding_window: 4096,
61            num_experts_per_tok: 2,
62            num_local_experts: 8,
63            use_flash_attn,
64        }
65    }
66}
67
68#[derive(Debug, Clone)]
69struct RotaryEmbedding {
70    sin: Tensor,
71    cos: Tensor,
72}
73
74fn rotate_half(xs: &Tensor) -> Result<Tensor> {
75    let last_dim = xs.dim(D::Minus1)?;
76    let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
77    let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
78    Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
79}
80
81impl RotaryEmbedding {
82    fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
83        let dim = cfg.hidden_size / cfg.num_attention_heads;
84        let max_seq_len = cfg.max_position_embeddings;
85        let inv_freq: Vec<_> = (0..dim)
86            .step_by(2)
87            .map(|i| 1f32 / (cfg.rope_theta as f32).powf(i as f32 / dim as f32))
88            .collect();
89        let inv_freq_len = inv_freq.len();
90        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
91        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
92            .to_dtype(dtype)?
93            .reshape((max_seq_len, 1))?;
94        let freqs = t.matmul(&inv_freq)?;
95        let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
96        Ok(Self {
97            sin: freqs.sin()?,
98            cos: freqs.cos()?,
99        })
100    }
101
102    fn apply_rotary_emb_qkv(
103        &self,
104        q: &Tensor,
105        k: &Tensor,
106        seqlen_offset: usize,
107    ) -> Result<(Tensor, Tensor)> {
108        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
109        let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
110        let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
111        let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
112        let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
113        let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
114        let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
115        Ok((q_embed, k_embed))
116    }
117}
118
119#[cfg(feature = "flash-attn")]
120fn flash_attn(
121    q: &Tensor,
122    k: &Tensor,
123    v: &Tensor,
124    softmax_scale: f32,
125    causal: bool,
126) -> Result<Tensor> {
127    candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
128}
129
130#[cfg(not(feature = "flash-attn"))]
131fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
132    unimplemented!("compile with '--features flash-attn'")
133}
134
135#[derive(Debug, Clone)]
136struct Attention {
137    q_proj: Linear,
138    k_proj: Linear,
139    v_proj: Linear,
140    o_proj: Linear,
141    num_heads: usize,
142    num_kv_heads: usize,
143    num_kv_groups: usize,
144    head_dim: usize,
145    hidden_size: usize,
146    rotary_emb: Arc<RotaryEmbedding>,
147    kv_cache: Option<(Tensor, Tensor)>,
148    use_flash_attn: bool,
149}
150
151impl Attention {
152    fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
153        let hidden_sz = cfg.hidden_size;
154        let num_heads = cfg.num_attention_heads;
155        let num_kv_heads = cfg.num_key_value_heads;
156        let num_kv_groups = num_heads / num_kv_heads;
157        let head_dim = hidden_sz / num_heads;
158        let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
159        let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
160        let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
161        let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
162        Ok(Self {
163            q_proj,
164            k_proj,
165            v_proj,
166            o_proj,
167            num_heads,
168            num_kv_heads,
169            num_kv_groups,
170            head_dim,
171            hidden_size: hidden_sz,
172            rotary_emb,
173            kv_cache: None,
174            use_flash_attn: cfg.use_flash_attn,
175        })
176    }
177
178    fn forward(
179        &mut self,
180        xs: &Tensor,
181        attention_mask: Option<&Tensor>,
182        seqlen_offset: usize,
183    ) -> Result<Tensor> {
184        let (b_sz, q_len, _) = xs.dims3()?;
185
186        let query_states = self.q_proj.forward(xs)?;
187        let key_states = self.k_proj.forward(xs)?;
188        let value_states = self.v_proj.forward(xs)?;
189
190        let query_states = query_states
191            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
192            .transpose(1, 2)?;
193        let key_states = key_states
194            .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
195            .transpose(1, 2)?;
196        let value_states = value_states
197            .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
198            .transpose(1, 2)?;
199
200        let (query_states, key_states) =
201            self.rotary_emb
202                .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
203
204        let (key_states, value_states) = match &self.kv_cache {
205            None => (key_states, value_states),
206            Some((prev_k, prev_v)) => {
207                let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
208                let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
209                (key_states, value_states)
210            }
211        };
212        self.kv_cache = Some((key_states.clone(), value_states.clone()));
213
214        let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
215        let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
216
217        let attn_output = if self.use_flash_attn {
218            // flash-attn expects (b_sz, seq_len, nheads, head_dim)
219            let q = query_states.transpose(1, 2)?;
220            let k = key_states.transpose(1, 2)?;
221            let v = value_states.transpose(1, 2)?;
222            let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
223            flash_attn(&q, &k, &v, softmax_scale, q_len > 1)?.transpose(1, 2)?
224        } else {
225            let scale = 1f64 / f64::sqrt(self.head_dim as f64);
226            let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
227
228            let attn_weights = match attention_mask {
229                None => attn_weights,
230                Some(mask) => attn_weights.broadcast_add(mask)?,
231            };
232            let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
233            attn_weights.matmul(&value_states)?
234        };
235        attn_output
236            .transpose(1, 2)?
237            .reshape((b_sz, q_len, self.hidden_size))?
238            .apply(&self.o_proj)
239    }
240}
241
242#[derive(Debug, Clone)]
243struct BlockSparseTop2MLP {
244    w1: Linear,
245    w2: Linear,
246    w3: Linear,
247    act_fn: Activation,
248}
249
250impl BlockSparseTop2MLP {
251    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
252        let hidden_sz = cfg.hidden_size;
253        let intermediate_sz = cfg.intermediate_size;
254        let w1 = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("w1"))?;
255        let w2 = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("w2"))?;
256        let w3 = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("w3"))?;
257        Ok(Self {
258            w1,
259            w2,
260            w3,
261            act_fn: cfg.hidden_act,
262        })
263    }
264}
265
266impl Module for BlockSparseTop2MLP {
267    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
268        let lhs = xs.apply(&self.w1)?.apply(&self.act_fn)?;
269        let rhs = xs.apply(&self.w3)?;
270        (lhs * rhs)?.apply(&self.w2)
271    }
272}
273
274#[derive(Debug, Clone)]
275struct SparseMoeBlock {
276    gate: Linear,
277    experts: Vec<BlockSparseTop2MLP>,
278    num_experts_per_tok: usize,
279}
280
281impl SparseMoeBlock {
282    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
283        let gate = linear_no_bias(cfg.hidden_size, cfg.num_local_experts, vb.pp("gate"))?;
284        let mut experts = Vec::with_capacity(cfg.num_local_experts);
285        let vb = vb.pp("experts");
286        for idx in 0..cfg.num_local_experts {
287            let expert = BlockSparseTop2MLP::new(cfg, vb.pp(idx))?;
288            experts.push(expert)
289        }
290        Ok(SparseMoeBlock {
291            gate,
292            experts,
293            num_experts_per_tok: cfg.num_experts_per_tok,
294        })
295    }
296}
297
298impl Module for SparseMoeBlock {
299    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
300        let (b_size, seq_len, hidden_dim) = xs.dims3()?;
301        let xs = xs.reshape(((), hidden_dim))?;
302        let router_logits = xs.apply(&self.gate)?;
303        let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
304
305        // In order to extract topk, we extract the data from the tensor and manipulate it
306        // directly. Maybe we will want to use some custom ops instead at some point.
307        let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
308
309        // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
310        // top_x contains the row indexes to evaluate for each expert.
311        let mut top_x = vec![vec![]; self.experts.len()];
312        let mut selected_rws = vec![vec![]; self.experts.len()];
313        for (row_idx, rw) in routing_weights.iter().enumerate() {
314            let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
315            dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
316            let mut sum_routing_weights = 0f32;
317            for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
318                let expert_idx = expert_idx as usize;
319                let routing_weight = rw[expert_idx];
320                sum_routing_weights += routing_weight;
321                top_x[expert_idx].push(row_idx as u32);
322            }
323            for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
324                let expert_idx = expert_idx as usize;
325                let routing_weight = rw[expert_idx];
326                selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
327            }
328        }
329
330        // routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
331        // expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
332
333        let mut ys = xs.zeros_like()?;
334        for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
335            let top_x = &top_x[expert_idx];
336            if top_x.is_empty() {
337                continue;
338            }
339            let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
340            let selected_rws =
341                Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?.reshape(((), 1))?;
342            // Index the correct hidden states and compute the expert hidden state for
343            // the current expert. We need to make sure to multiply the output hidden
344            // states by `routing_weights` on the corresponding tokens (top-1 and top-2)
345            let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
346            // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
347            let current_hidden_states = expert_layer.forward(&current_state)?;
348            let current_hidden_states = current_hidden_states.broadcast_mul(&selected_rws)?;
349            ys = ys.index_add(&top_x, &current_hidden_states, 0)?;
350        }
351
352        let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
353        Ok(ys)
354    }
355}
356
357#[derive(Debug, Clone)]
358struct DecoderLayer {
359    self_attn: Attention,
360    block_sparse_moe: SparseMoeBlock,
361    input_layernorm: RmsNorm,
362    post_attention_layernorm: RmsNorm,
363}
364
365impl DecoderLayer {
366    fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
367        let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
368        let block_sparse_moe = SparseMoeBlock::new(cfg, vb.pp("block_sparse_moe"))?;
369        let input_layernorm =
370            RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
371        let post_attention_layernorm = RmsNorm::new(
372            cfg.hidden_size,
373            cfg.rms_norm_eps,
374            vb.pp("post_attention_layernorm"),
375        )?;
376        Ok(Self {
377            self_attn,
378            block_sparse_moe,
379            input_layernorm,
380            post_attention_layernorm,
381        })
382    }
383
384    fn forward(
385        &mut self,
386        xs: &Tensor,
387        attention_mask: Option<&Tensor>,
388        seqlen_offset: usize,
389    ) -> Result<Tensor> {
390        let residual = xs;
391        let xs = self.input_layernorm.forward(xs)?;
392        let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
393        let xs = (xs + residual)?;
394        let residual = &xs;
395        let xs = xs
396            .apply(&self.post_attention_layernorm)?
397            .apply(&self.block_sparse_moe)?;
398        residual + xs
399    }
400}
401
402#[derive(Debug, Clone)]
403pub struct Model {
404    embed_tokens: candle_nn::Embedding,
405    layers: Vec<DecoderLayer>,
406    norm: RmsNorm,
407    lm_head: Linear,
408    sliding_window: usize,
409    device: Device,
410    dtype: DType,
411}
412
413impl Model {
414    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
415        let vb_m = vb.pp("model");
416        let embed_tokens =
417            candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
418        let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
419        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
420        let vb_l = vb_m.pp("layers");
421        for layer_idx in 0..cfg.num_hidden_layers {
422            let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
423            layers.push(layer)
424        }
425        let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
426        let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
427        Ok(Self {
428            embed_tokens,
429            layers,
430            norm,
431            lm_head,
432            sliding_window: cfg.sliding_window,
433            device: vb.device().clone(),
434            dtype: vb.dtype(),
435        })
436    }
437
438    fn prepare_decoder_attention_mask(
439        &self,
440        b_size: usize,
441        tgt_len: usize,
442        seqlen_offset: usize,
443    ) -> Result<Tensor> {
444        // Sliding window mask?
445        let mask: Vec<_> = (0..tgt_len)
446            .flat_map(|i| {
447                (0..tgt_len).map(move |j| {
448                    if i < j || j + self.sliding_window < i {
449                        f32::NEG_INFINITY
450                    } else {
451                        0.
452                    }
453                })
454            })
455            .collect();
456        let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
457        let mask = if seqlen_offset > 0 {
458            let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
459            Tensor::cat(&[&mask0, &mask], D::Minus1)?
460        } else {
461            mask
462        };
463        mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
464            .to_dtype(self.dtype)
465    }
466
467    pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
468        let (b_size, seq_len) = input_ids.dims2()?;
469        let attention_mask = if seq_len <= 1 {
470            None
471        } else {
472            let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
473            Some(mask)
474        };
475        let mut xs = self.embed_tokens.forward(input_ids)?;
476        for layer in self.layers.iter_mut() {
477            xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
478        }
479        xs.narrow(1, seq_len - 1, 1)?
480            .apply(&self.norm)?
481            .apply(&self.lm_head)
482    }
483}