1use crate::models::{
3 mistral::Config,
4 with_tracing::{linear_no_bias, Linear, RmsNorm},
5};
6use crate::utils::repeat_kv;
7use candle::{DType, Device, Module, Result, Tensor};
8use candle_nn::{Activation, VarBuilder};
9use std::sync::Arc;
10
11#[derive(Debug, Clone)]
12struct RotaryEmbedding {
13 sin: Tensor,
14 cos: Tensor,
15}
16
17impl RotaryEmbedding {
18 fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
19 let rope_theta = cfg.rope_theta as f32;
20 let dim = cfg.hidden_size / cfg.num_attention_heads;
21 let max_seq_len = cfg.max_position_embeddings;
22 let inv_freq: Vec<_> = (0..dim)
23 .step_by(2)
24 .map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
25 .collect();
26 let inv_freq_len = inv_freq.len();
27 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
28 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
29 .to_dtype(dtype)?
30 .reshape((max_seq_len, 1))?;
31 let freqs = t.matmul(&inv_freq)?;
32 Ok(Self {
33 sin: freqs.sin()?,
34 cos: freqs.cos()?,
35 })
36 }
37
38 fn apply_rotary_emb_qkv(
39 &self,
40 q: &Tensor,
41 k: &Tensor,
42 seqlen_offset: usize,
43 ) -> Result<(Tensor, Tensor)> {
44 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
45 let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
46 let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
47 let q_embed = candle_nn::rotary_emb::rope(q, &cos, &sin)?;
48 let k_embed = candle_nn::rotary_emb::rope(k, &cos, &sin)?;
49 Ok((q_embed, k_embed))
50 }
51}
52
53#[derive(Debug, Clone)]
54#[allow(clippy::upper_case_acronyms)]
55struct MLP {
56 gate_proj: Linear,
57 up_proj: Linear,
58 down_proj: Linear,
59 act_fn: Activation,
60}
61
62impl MLP {
63 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
64 let hidden_sz = cfg.hidden_size;
65 let intermediate_sz = cfg.intermediate_size;
66 let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
67 let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
68 let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
69 Ok(Self {
70 gate_proj,
71 up_proj,
72 down_proj,
73 act_fn: cfg.hidden_act,
74 })
75 }
76}
77
78impl Module for MLP {
79 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
80 let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
81 let rhs = xs.apply(&self.up_proj)?;
82 (lhs * rhs)?.apply(&self.down_proj)
83 }
84}
85
86#[derive(Debug, Clone)]
87struct Attention {
88 q_proj: Linear,
89 k_proj: Linear,
90 v_proj: Linear,
91 o_proj: Linear,
92 num_heads: usize,
93 num_kv_heads: usize,
94 num_kv_groups: usize,
95 head_dim: usize,
96 hidden_size: usize,
97 rotary_emb: Arc<RotaryEmbedding>,
98}
99
100impl Attention {
101 fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
102 let hidden_sz = cfg.hidden_size;
103 let num_heads = cfg.num_attention_heads;
104 let num_kv_heads = cfg.num_key_value_heads;
105 let num_kv_groups = num_heads / num_kv_heads;
106 let head_dim = hidden_sz / num_heads;
107 let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
108 let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
109 let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
110 let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
111 Ok(Self {
112 q_proj,
113 k_proj,
114 v_proj,
115 o_proj,
116 num_heads,
117 num_kv_heads,
118 num_kv_groups,
119 head_dim,
120 hidden_size: hidden_sz,
121 rotary_emb,
122 })
123 }
124
125 fn forward(
126 &mut self,
127 xs: &Tensor,
128 attention_mask: Option<&Tensor>,
129 seqlen_offset: usize,
130 ) -> Result<Tensor> {
131 let (b_sz, q_len, _) = xs.dims3()?;
132
133 let query_states = self.q_proj.forward(xs)?;
134 let key_states = self.k_proj.forward(xs)?;
135 let value_states = self.v_proj.forward(xs)?;
136
137 let query_states = query_states
138 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
139 .transpose(1, 2)?
140 .contiguous()?;
141
142 let key_states = key_states
143 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
144 .transpose(1, 2)?
145 .contiguous()?;
146 let value_states = value_states
147 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
148 .transpose(1, 2)?;
149
150 let (query_states, key_states) =
151 self.rotary_emb
152 .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
153
154 let key_states = repeat_kv(key_states, self.num_kv_groups)?;
155 let value_states = repeat_kv(value_states, self.num_kv_groups)?;
156
157 let scale = 1f64 / f64::sqrt(self.head_dim as f64);
158 let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
159
160 let attn_weights = match attention_mask {
161 None => attn_weights,
162 Some(mask) => attn_weights.broadcast_add(mask)?,
163 };
164 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
165 let attn_output = attn_weights.matmul(&value_states)?;
166
167 attn_output
168 .transpose(1, 2)?
169 .reshape((b_sz, q_len, self.hidden_size))?
170 .apply(&self.o_proj)
171 }
172}
173
174#[derive(Debug, Clone)]
175struct DecoderLayer {
176 self_attn: Attention,
177 mlp: MLP,
178 input_layernorm: RmsNorm,
179 post_attention_layernorm: RmsNorm,
180}
181
182impl DecoderLayer {
183 fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
184 let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
185 let mlp = MLP::new(cfg, vb.pp("mlp"))?;
186 let input_layernorm =
187 RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
188 let post_attention_layernorm = RmsNorm::new(
189 cfg.hidden_size,
190 cfg.rms_norm_eps,
191 vb.pp("post_attention_layernorm"),
192 )?;
193 Ok(Self {
194 self_attn,
195 mlp,
196 input_layernorm,
197 post_attention_layernorm,
198 })
199 }
200
201 fn forward(
202 &mut self,
203 xs: &Tensor,
204 attention_mask: Option<&Tensor>,
205 seqlen_offset: usize,
206 ) -> Result<Tensor> {
207 let residual = xs;
208 let xs = self.input_layernorm.forward(xs)?;
209
210 let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
211
212 let xs = (xs + residual)?;
213 let residual = &xs;
214 let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
215 residual + xs
216 }
217}
218
219#[derive(Debug, Clone)]
220pub struct Model {
221 embed_tokens: candle_nn::Embedding,
222 layers: Vec<DecoderLayer>,
223 norm: RmsNorm,
224 pub cfg: Config,
225}
226
227impl Model {
228 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
229 let embed_tokens =
230 candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?;
231 let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?);
232 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
233 let vb_l = vb.pp("layers");
234 for layer_idx in 0..cfg.num_hidden_layers {
235 let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
236 layers.push(layer)
237 }
238 let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("norm"))?;
239 Ok(Self {
240 embed_tokens,
241 layers,
242 norm,
243 cfg: cfg.clone(),
244 })
245 }
246
247 pub fn forward(
249 &mut self,
250 attn_mask: &Tensor,
251 input_ids: &Tensor,
252 dtype: DType,
253 ) -> Result<Tensor> {
254 let mut xs = self.embed_tokens.forward(input_ids)?;
255
256 let attn_mask = prepare_4d_attention_mask(attn_mask, dtype, None)?;
258
259 for layer in self.layers.iter_mut() {
260 xs = layer.forward(&xs, Some(&attn_mask), 0)?;
261 }
262
263 xs.apply(&self.norm)
265 }
266}
267
268fn prepare_4d_attention_mask(
269 mask: &Tensor,
270 dtype: DType,
271 tgt_len: Option<usize>,
272) -> Result<Tensor> {
273 let bsz = mask.dims()[0];
274 let src_len = mask.dims()[1];
275 let tgt_len = tgt_len.unwrap_or(src_len);
276
277 let expanded_mask = mask
278 .unsqueeze(1)?
279 .unsqueeze(2)?
280 .expand((bsz, 1, tgt_len, src_len))?
281 .to_dtype(dtype)?;
282
283 let inverted_mask = (1.0 - expanded_mask)?;
284
285 (inverted_mask * get_dtype_min_val(dtype))?.to_dtype(dtype)
286}
287
288fn get_dtype_min_val(dtype: DType) -> f64 {
289 match dtype {
290 DType::F32 => f32::MIN as f64,
291 DType::F64 => f64::MIN,
292 _ => panic!("Unsupported data type"),
293 }
294}