candle_transformers/models/
mistral.rs

1//! Mixtral Model, based on the Mistral architecture
2//!
3//! See Mistral and Mixtral at:
4//! - [Hugging Face](https://huggingface.co/docs/transformers/model_doc/mixtral)
5//! - [Github](https://github.com/mistralai/mistral-src)
6//!
7
8use crate::models::with_tracing::{linear_no_bias, Linear, RmsNorm};
9/// Mistral LLM, https://github.com/mistralai/mistral-src
10use candle::{DType, Device, Module, Result, Tensor, D};
11use candle_nn::{Activation, VarBuilder};
12use std::sync::Arc;
13
14fn default_num_attention_heads() -> usize {
15    32
16}
17
18fn default_use_flash_attn() -> bool {
19    false
20}
21
22fn default_hidden_act() -> candle_nn::Activation {
23    candle_nn::Activation::Silu
24}
25
26#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
27pub struct Config {
28    pub vocab_size: usize,
29    pub hidden_size: usize,
30    pub intermediate_size: usize,
31    pub num_hidden_layers: usize,
32    #[serde(default = "default_num_attention_heads")]
33    pub num_attention_heads: usize,
34    pub head_dim: Option<usize>,
35    pub num_key_value_heads: usize,
36    #[serde(default = "default_hidden_act")]
37    pub hidden_act: Activation,
38    pub max_position_embeddings: usize,
39    pub rms_norm_eps: f64,
40    pub rope_theta: f64,
41    pub sliding_window: Option<usize>,
42    #[serde(default = "default_use_flash_attn")]
43    pub use_flash_attn: bool,
44}
45
46impl Config {
47    // https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
48    pub fn config_7b_v0_1(use_flash_attn: bool) -> Self {
49        Self {
50            vocab_size: 32000,
51            hidden_size: 4096,
52            intermediate_size: 14336,
53            num_hidden_layers: 32,
54            num_attention_heads: 32,
55            head_dim: None,
56            num_key_value_heads: 8,
57            hidden_act: Activation::Silu,
58            max_position_embeddings: 32768,
59            rms_norm_eps: 1e-5,
60            rope_theta: 10_000.,
61            sliding_window: Some(4096),
62            use_flash_attn,
63        }
64    }
65
66    // https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca/blob/main/config.json
67    // https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/config.json
68    pub fn config_chat_ml(use_flash_attn: bool) -> Self {
69        Self {
70            vocab_size: 32002,
71            hidden_size: 4096,
72            intermediate_size: 14336,
73            num_hidden_layers: 32,
74            num_attention_heads: 32,
75            head_dim: None,
76            num_key_value_heads: 8,
77            hidden_act: Activation::Silu,
78            max_position_embeddings: 32768,
79            rms_norm_eps: 1e-5,
80            rope_theta: 10_000.,
81            sliding_window: Some(4096),
82            use_flash_attn,
83        }
84    }
85
86    // https://huggingface.co/amazon/MistralLite/blob/main/config.json
87    pub fn config_amazon_mistral_lite(use_flash_attn: bool) -> Self {
88        Self {
89            vocab_size: 32003,
90            hidden_size: 4096,
91            intermediate_size: 14336,
92            num_hidden_layers: 32,
93            num_attention_heads: 32,
94            head_dim: None,
95            num_key_value_heads: 8,
96            hidden_act: Activation::Silu,
97            max_position_embeddings: 32768,
98            rms_norm_eps: 1e-5,
99            rope_theta: 10_000.,
100            sliding_window: Some(4096),
101            use_flash_attn,
102        }
103    }
104
105    fn head_dim(&self) -> usize {
106        self.head_dim
107            .unwrap_or(self.hidden_size / self.num_attention_heads)
108    }
109}
110
111#[derive(Debug, Clone)]
112struct RotaryEmbedding {
113    sin: Tensor,
114    cos: Tensor,
115}
116
117impl RotaryEmbedding {
118    fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
119        let rope_theta = cfg.rope_theta as f32;
120        let dim = cfg.head_dim();
121        let max_seq_len = cfg.max_position_embeddings;
122        let inv_freq: Vec<_> = (0..dim)
123            .step_by(2)
124            .map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
125            .collect();
126        let inv_freq_len = inv_freq.len();
127        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
128        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
129            .to_dtype(DType::F32)?
130            .reshape((max_seq_len, 1))?;
131        let freqs = t.matmul(&inv_freq)?;
132        Ok(Self {
133            sin: freqs.sin()?.to_dtype(dtype)?,
134            cos: freqs.cos()?.to_dtype(dtype)?,
135        })
136    }
137
138    fn apply_rotary_emb_qkv(
139        &self,
140        q: &Tensor,
141        k: &Tensor,
142        seqlen_offset: usize,
143    ) -> Result<(Tensor, Tensor)> {
144        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
145        let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
146        let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
147        let q_embed = candle_nn::rotary_emb::rope(q, &cos, &sin)?;
148        let k_embed = candle_nn::rotary_emb::rope(k, &cos, &sin)?;
149        Ok((q_embed, k_embed))
150    }
151}
152
153#[derive(Debug, Clone)]
154#[allow(clippy::upper_case_acronyms)]
155struct MLP {
156    gate_proj: Linear,
157    up_proj: Linear,
158    down_proj: Linear,
159    act_fn: Activation,
160}
161
162impl MLP {
163    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
164        let hidden_sz = cfg.hidden_size;
165        let intermediate_sz = cfg.intermediate_size;
166        let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
167        let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
168        let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
169        Ok(Self {
170            gate_proj,
171            up_proj,
172            down_proj,
173            act_fn: cfg.hidden_act,
174        })
175    }
176}
177
178impl Module for MLP {
179    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
180        let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
181        let rhs = xs.apply(&self.up_proj)?;
182        (lhs * rhs)?.apply(&self.down_proj)
183    }
184}
185
186#[cfg(feature = "flash-attn")]
187fn flash_attn(
188    q: &Tensor,
189    k: &Tensor,
190    v: &Tensor,
191    softmax_scale: f32,
192    causal: bool,
193) -> Result<Tensor> {
194    candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
195}
196
197#[cfg(not(feature = "flash-attn"))]
198fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
199    unimplemented!("compile with '--features flash-attn'")
200}
201
202#[derive(Debug, Clone)]
203struct Attention {
204    q_proj: Linear,
205    k_proj: Linear,
206    v_proj: Linear,
207    o_proj: Linear,
208    num_heads: usize,
209    num_kv_heads: usize,
210    num_kv_groups: usize,
211    head_dim: usize,
212    rotary_emb: Arc<RotaryEmbedding>,
213    kv_cache: Option<(Tensor, Tensor)>,
214    use_flash_attn: bool,
215}
216
217impl Attention {
218    fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
219        let hidden_sz = cfg.hidden_size;
220        let num_heads = cfg.num_attention_heads;
221        let num_kv_heads = cfg.num_key_value_heads;
222        let num_kv_groups = num_heads / num_kv_heads;
223        let head_dim = cfg.head_dim();
224        let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
225        let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
226        let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
227        let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
228        Ok(Self {
229            q_proj,
230            k_proj,
231            v_proj,
232            o_proj,
233            num_heads,
234            num_kv_heads,
235            num_kv_groups,
236            head_dim,
237            rotary_emb,
238            kv_cache: None,
239            use_flash_attn: cfg.use_flash_attn,
240        })
241    }
242
243    fn forward(
244        &mut self,
245        xs: &Tensor,
246        attention_mask: Option<&Tensor>,
247        seqlen_offset: usize,
248    ) -> Result<Tensor> {
249        let (b_sz, q_len, _) = xs.dims3()?;
250
251        let query_states = self.q_proj.forward(xs)?;
252        let key_states = self.k_proj.forward(xs)?;
253        let value_states = self.v_proj.forward(xs)?;
254
255        let query_states = query_states
256            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
257            .transpose(1, 2)?
258            .contiguous()?;
259        let key_states = key_states
260            .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
261            .transpose(1, 2)?
262            .contiguous()?;
263        let value_states = value_states
264            .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
265            .transpose(1, 2)?
266            .contiguous()?;
267
268        let (query_states, key_states) =
269            self.rotary_emb
270                .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
271
272        let (key_states, value_states) = match &self.kv_cache {
273            None => (key_states, value_states),
274            Some((prev_k, prev_v)) => {
275                let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
276                let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
277                (key_states, value_states)
278            }
279        };
280        self.kv_cache = Some((key_states.clone(), value_states.clone()));
281
282        let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
283        let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
284
285        let attn_output = if self.use_flash_attn {
286            // flash-attn expects (b_sz, seq_len, nheads, head_dim)
287            let q = query_states.transpose(1, 2)?;
288            let k = key_states.transpose(1, 2)?;
289            let v = value_states.transpose(1, 2)?;
290            let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
291            flash_attn(&q, &k, &v, softmax_scale, q_len > 1)?.transpose(1, 2)?
292        } else {
293            let scale = 1f64 / f64::sqrt(self.head_dim as f64);
294            let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
295
296            let attn_weights = match attention_mask {
297                None => attn_weights,
298                Some(mask) => attn_weights.broadcast_add(mask)?,
299            };
300            let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
301            attn_weights.matmul(&value_states)?
302        };
303        attn_output
304            .transpose(1, 2)?
305            .reshape((b_sz, q_len, self.num_heads * self.head_dim))?
306            .apply(&self.o_proj)
307    }
308
309    fn clear_kv_cache(&mut self) {
310        self.kv_cache = None
311    }
312}
313
314#[derive(Debug, Clone)]
315struct DecoderLayer {
316    self_attn: Attention,
317    mlp: MLP,
318    input_layernorm: RmsNorm,
319    post_attention_layernorm: RmsNorm,
320}
321
322impl DecoderLayer {
323    fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
324        let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
325        let mlp = MLP::new(cfg, vb.pp("mlp"))?;
326        let input_layernorm =
327            RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
328        let post_attention_layernorm = RmsNorm::new(
329            cfg.hidden_size,
330            cfg.rms_norm_eps,
331            vb.pp("post_attention_layernorm"),
332        )?;
333        Ok(Self {
334            self_attn,
335            mlp,
336            input_layernorm,
337            post_attention_layernorm,
338        })
339    }
340
341    fn forward(
342        &mut self,
343        xs: &Tensor,
344        attention_mask: Option<&Tensor>,
345        seqlen_offset: usize,
346    ) -> Result<Tensor> {
347        let residual = xs;
348        let xs = self.input_layernorm.forward(xs)?;
349        let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
350        let xs = (xs + residual)?;
351        let residual = &xs;
352        let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
353        residual + xs
354    }
355
356    fn clear_kv_cache(&mut self) {
357        self.self_attn.clear_kv_cache()
358    }
359}
360
361#[derive(Debug, Clone)]
362pub struct Model {
363    embed_tokens: candle_nn::Embedding,
364    layers: Vec<DecoderLayer>,
365    norm: RmsNorm,
366    lm_head: Linear,
367    sliding_window: Option<usize>,
368    device: Device,
369    dtype: DType,
370}
371
372impl Model {
373    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
374        let vb_m = vb.pp("model");
375        let embed_tokens =
376            candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
377        let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
378        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
379        let vb_l = vb_m.pp("layers");
380        for layer_idx in 0..cfg.num_hidden_layers {
381            let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
382            layers.push(layer)
383        }
384        let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
385        let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
386        Ok(Self {
387            embed_tokens,
388            layers,
389            norm,
390            lm_head,
391            sliding_window: cfg.sliding_window,
392            device: vb.device().clone(),
393            dtype: vb.dtype(),
394        })
395    }
396
397    fn prepare_decoder_attention_mask(
398        &self,
399        tgt_len: usize,
400        seqlen_offset: usize,
401    ) -> Result<Tensor> {
402        let sliding_window = self.sliding_window.unwrap_or(tgt_len + 1);
403        let mask: Vec<_> = (0..tgt_len)
404            .flat_map(|i| {
405                (0..tgt_len).map(move |j| {
406                    if i < j || j + sliding_window < i {
407                        f32::NEG_INFINITY
408                    } else {
409                        0.
410                    }
411                })
412            })
413            .collect();
414        let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
415        let mask = if seqlen_offset > 0 {
416            let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
417            Tensor::cat(&[&mask0, &mask], D::Minus1)?
418        } else {
419            mask
420        };
421        mask.expand((1, 1, tgt_len, tgt_len + seqlen_offset))?
422            .to_dtype(self.dtype)
423    }
424
425    pub fn embed_tokens(&self) -> &candle_nn::Embedding {
426        &self.embed_tokens
427    }
428
429    pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
430        let (_b_size, seq_len) = input_ids.dims2()?;
431        let attention_mask = if seq_len <= 1 {
432            None
433        } else {
434            let mask = self.prepare_decoder_attention_mask(seq_len, seqlen_offset)?;
435            Some(mask)
436        };
437        let mut xs = self.embed_tokens.forward(input_ids)?;
438        for layer in self.layers.iter_mut() {
439            xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
440        }
441        xs.narrow(1, seq_len - 1, 1)?
442            .apply(&self.norm)?
443            .apply(&self.lm_head)
444    }
445
446    pub fn forward_embeds(
447        &mut self,
448        xs: &Tensor,
449        attn_mask: Option<&Tensor>,
450        seqlen_offset: usize,
451    ) -> Result<Tensor> {
452        let (_b_size, seq_len, _) = xs.dims3()?;
453        let mut xs = xs.clone();
454        for layer in self.layers.iter_mut() {
455            xs = layer.forward(&xs, attn_mask, seqlen_offset)?
456        }
457        xs.narrow(1, seq_len - 1, 1)?
458            .apply(&self.norm)?
459            .apply(&self.lm_head)
460    }
461
462    pub fn clear_kv_cache(&mut self) {
463        for layer in self.layers.iter_mut() {
464            layer.clear_kv_cache()
465        }
466    }
467}