candle_transformers/models/nvembed_v2/
embedding.rs

1/// Mistral LLM, https://github.com/mistralai/mistral-src
2use crate::models::{
3    mistral::Config,
4    with_tracing::{linear_no_bias, Linear, RmsNorm},
5};
6use crate::utils::repeat_kv;
7use candle::{DType, Device, Module, Result, Tensor};
8use candle_nn::{Activation, VarBuilder};
9use std::sync::Arc;
10
11#[derive(Debug, Clone)]
12struct RotaryEmbedding {
13    sin: Tensor,
14    cos: Tensor,
15}
16
17impl RotaryEmbedding {
18    fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
19        let rope_theta = cfg.rope_theta as f32;
20        let dim = cfg.hidden_size / cfg.num_attention_heads;
21        let max_seq_len = cfg.max_position_embeddings;
22        let inv_freq: Vec<_> = (0..dim)
23            .step_by(2)
24            .map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
25            .collect();
26        let inv_freq_len = inv_freq.len();
27        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
28        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
29            .to_dtype(dtype)?
30            .reshape((max_seq_len, 1))?;
31        let freqs = t.matmul(&inv_freq)?;
32        Ok(Self {
33            sin: freqs.sin()?,
34            cos: freqs.cos()?,
35        })
36    }
37
38    fn apply_rotary_emb_qkv(
39        &self,
40        q: &Tensor,
41        k: &Tensor,
42        seqlen_offset: usize,
43    ) -> Result<(Tensor, Tensor)> {
44        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
45        let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
46        let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
47        let q_embed = candle_nn::rotary_emb::rope(q, &cos, &sin)?;
48        let k_embed = candle_nn::rotary_emb::rope(k, &cos, &sin)?;
49        Ok((q_embed, k_embed))
50    }
51}
52
53#[derive(Debug, Clone)]
54#[allow(clippy::upper_case_acronyms)]
55struct MLP {
56    gate_proj: Linear,
57    up_proj: Linear,
58    down_proj: Linear,
59    act_fn: Activation,
60}
61
62impl MLP {
63    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
64        let hidden_sz = cfg.hidden_size;
65        let intermediate_sz = cfg.intermediate_size;
66        let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
67        let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
68        let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
69        Ok(Self {
70            gate_proj,
71            up_proj,
72            down_proj,
73            act_fn: cfg.hidden_act,
74        })
75    }
76}
77
78impl Module for MLP {
79    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
80        let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
81        let rhs = xs.apply(&self.up_proj)?;
82        (lhs * rhs)?.apply(&self.down_proj)
83    }
84}
85
86#[derive(Debug, Clone)]
87struct Attention {
88    q_proj: Linear,
89    k_proj: Linear,
90    v_proj: Linear,
91    o_proj: Linear,
92    num_heads: usize,
93    num_kv_heads: usize,
94    num_kv_groups: usize,
95    head_dim: usize,
96    hidden_size: usize,
97    rotary_emb: Arc<RotaryEmbedding>,
98}
99
100impl Attention {
101    fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
102        let hidden_sz = cfg.hidden_size;
103        let num_heads = cfg.num_attention_heads;
104        let num_kv_heads = cfg.num_key_value_heads;
105        let num_kv_groups = num_heads / num_kv_heads;
106        let head_dim = hidden_sz / num_heads;
107        let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
108        let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
109        let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
110        let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
111        Ok(Self {
112            q_proj,
113            k_proj,
114            v_proj,
115            o_proj,
116            num_heads,
117            num_kv_heads,
118            num_kv_groups,
119            head_dim,
120            hidden_size: hidden_sz,
121            rotary_emb,
122        })
123    }
124
125    fn forward(
126        &mut self,
127        xs: &Tensor,
128        attention_mask: Option<&Tensor>,
129        seqlen_offset: usize,
130    ) -> Result<Tensor> {
131        let (b_sz, q_len, _) = xs.dims3()?;
132
133        let query_states = self.q_proj.forward(xs)?;
134        let key_states = self.k_proj.forward(xs)?;
135        let value_states = self.v_proj.forward(xs)?;
136
137        let query_states = query_states
138            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
139            .transpose(1, 2)?
140            .contiguous()?;
141
142        let key_states = key_states
143            .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
144            .transpose(1, 2)?
145            .contiguous()?;
146        let value_states = value_states
147            .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
148            .transpose(1, 2)?;
149
150        let (query_states, key_states) =
151            self.rotary_emb
152                .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
153
154        let key_states = repeat_kv(key_states, self.num_kv_groups)?;
155        let value_states = repeat_kv(value_states, self.num_kv_groups)?;
156
157        let scale = 1f64 / f64::sqrt(self.head_dim as f64);
158        let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
159
160        let attn_weights = match attention_mask {
161            None => attn_weights,
162            Some(mask) => attn_weights.broadcast_add(mask)?,
163        };
164        let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
165        let attn_output = attn_weights.matmul(&value_states)?;
166
167        attn_output
168            .transpose(1, 2)?
169            .reshape((b_sz, q_len, self.hidden_size))?
170            .apply(&self.o_proj)
171    }
172}
173
174#[derive(Debug, Clone)]
175struct DecoderLayer {
176    self_attn: Attention,
177    mlp: MLP,
178    input_layernorm: RmsNorm,
179    post_attention_layernorm: RmsNorm,
180}
181
182impl DecoderLayer {
183    fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
184        let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
185        let mlp = MLP::new(cfg, vb.pp("mlp"))?;
186        let input_layernorm =
187            RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
188        let post_attention_layernorm = RmsNorm::new(
189            cfg.hidden_size,
190            cfg.rms_norm_eps,
191            vb.pp("post_attention_layernorm"),
192        )?;
193        Ok(Self {
194            self_attn,
195            mlp,
196            input_layernorm,
197            post_attention_layernorm,
198        })
199    }
200
201    fn forward(
202        &mut self,
203        xs: &Tensor,
204        attention_mask: Option<&Tensor>,
205        seqlen_offset: usize,
206    ) -> Result<Tensor> {
207        let residual = xs;
208        let xs = self.input_layernorm.forward(xs)?;
209
210        let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
211
212        let xs = (xs + residual)?;
213        let residual = &xs;
214        let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
215        residual + xs
216    }
217}
218
219#[derive(Debug, Clone)]
220pub struct Model {
221    embed_tokens: candle_nn::Embedding,
222    layers: Vec<DecoderLayer>,
223    norm: RmsNorm,
224    pub cfg: Config,
225}
226
227impl Model {
228    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
229        let embed_tokens =
230            candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?;
231        let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?);
232        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
233        let vb_l = vb.pp("layers");
234        for layer_idx in 0..cfg.num_hidden_layers {
235            let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
236            layers.push(layer)
237        }
238        let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("norm"))?;
239        Ok(Self {
240            embed_tokens,
241            layers,
242            norm,
243            cfg: cfg.clone(),
244        })
245    }
246
247    // Attn mask used to mask out padding tokens
248    pub fn forward(
249        &mut self,
250        attn_mask: &Tensor,
251        input_ids: &Tensor,
252        dtype: DType,
253    ) -> Result<Tensor> {
254        let mut xs = self.embed_tokens.forward(input_ids)?;
255
256        // Expand to 4d mask for sdpa
257        let attn_mask = prepare_4d_attention_mask(attn_mask, dtype, None)?;
258
259        for layer in self.layers.iter_mut() {
260            xs = layer.forward(&xs, Some(&attn_mask), 0)?;
261        }
262
263        // Return hiddens instead of logits
264        xs.apply(&self.norm)
265    }
266}
267
268fn prepare_4d_attention_mask(
269    mask: &Tensor,
270    dtype: DType,
271    tgt_len: Option<usize>,
272) -> Result<Tensor> {
273    let bsz = mask.dims()[0];
274    let src_len = mask.dims()[1];
275    let tgt_len = tgt_len.unwrap_or(src_len);
276
277    let expanded_mask = mask
278        .unsqueeze(1)?
279        .unsqueeze(2)?
280        .expand((bsz, 1, tgt_len, src_len))?
281        .to_dtype(dtype)?;
282
283    let inverted_mask = (1.0 - expanded_mask)?;
284
285    (inverted_mask * get_dtype_min_val(dtype))?.to_dtype(dtype)
286}
287
288fn get_dtype_min_val(dtype: DType) -> f64 {
289    match dtype {
290        DType::F32 => f32::MIN as f64,
291        DType::F64 => f64::MIN,
292        _ => panic!("Unsupported data type"),
293    }
294}