1use crate::quantized_nn::{linear_no_bias, Embedding, Linear, RmsNorm};
19pub use crate::quantized_var_builder::VarBuilder;
20use candle::{DType, Device, Module, Result, Tensor, D};
21use candle_nn::Activation;
22use std::sync::Arc;
23
24pub use crate::models::mistral::Config;
25
26#[derive(Debug, Clone)]
27struct RotaryEmbedding {
28 sin: Tensor,
29 cos: Tensor,
30}
31
32impl RotaryEmbedding {
33 fn new(cfg: &Config, dev: &Device) -> Result<Self> {
34 let rope_theta = cfg.rope_theta as f32;
35 let dim = cfg.hidden_size / cfg.num_attention_heads;
36 let max_seq_len = cfg.max_position_embeddings;
37 let inv_freq: Vec<_> = (0..dim)
38 .step_by(2)
39 .map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
40 .collect();
41 let inv_freq_len = inv_freq.len();
42 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
43 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
44 .to_dtype(DType::F32)?
45 .reshape((max_seq_len, 1))?;
46 let freqs = t.matmul(&inv_freq)?;
47 Ok(Self {
48 sin: freqs.sin()?,
49 cos: freqs.cos()?,
50 })
51 }
52
53 fn apply_rotary_emb_qkv(
54 &self,
55 q: &Tensor,
56 k: &Tensor,
57 seqlen_offset: usize,
58 ) -> Result<(Tensor, Tensor)> {
59 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
60 let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
61 let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
62 let q_embed = candle_nn::rotary_emb::rope(q, &cos, &sin)?;
63 let k_embed = candle_nn::rotary_emb::rope(k, &cos, &sin)?;
64 Ok((q_embed, k_embed))
65 }
66}
67
68#[derive(Debug, Clone)]
69#[allow(clippy::upper_case_acronyms)]
70struct MLP {
71 gate_proj: Linear,
72 up_proj: Linear,
73 down_proj: Linear,
74 act_fn: Activation,
75}
76
77impl MLP {
78 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
79 let hidden_sz = cfg.hidden_size;
80 let intermediate_sz = cfg.intermediate_size;
81 let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
82 let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
83 let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
84 Ok(Self {
85 gate_proj,
86 up_proj,
87 down_proj,
88 act_fn: cfg.hidden_act,
89 })
90 }
91}
92
93impl Module for MLP {
94 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
95 let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
96 let rhs = xs.apply(&self.up_proj)?;
97 (lhs * rhs)?.apply(&self.down_proj)
98 }
99}
100
101#[derive(Debug, Clone)]
102struct Attention {
103 q_proj: Linear,
104 k_proj: Linear,
105 v_proj: Linear,
106 o_proj: Linear,
107 num_heads: usize,
108 num_kv_heads: usize,
109 num_kv_groups: usize,
110 head_dim: usize,
111 hidden_size: usize,
112 rotary_emb: Arc<RotaryEmbedding>,
113 kv_cache: Option<(Tensor, Tensor)>,
114}
115
116impl Attention {
117 fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
118 let hidden_sz = cfg.hidden_size;
119 let num_heads = cfg.num_attention_heads;
120 let num_kv_heads = cfg.num_key_value_heads;
121 let num_kv_groups = num_heads / num_kv_heads;
122 let head_dim = hidden_sz / num_heads;
123 let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
124 let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
125 let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
126 let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
127 Ok(Self {
128 q_proj,
129 k_proj,
130 v_proj,
131 o_proj,
132 num_heads,
133 num_kv_heads,
134 num_kv_groups,
135 head_dim,
136 hidden_size: hidden_sz,
137 rotary_emb,
138 kv_cache: None,
139 })
140 }
141
142 fn forward(
143 &mut self,
144 xs: &Tensor,
145 attention_mask: Option<&Tensor>,
146 seqlen_offset: usize,
147 ) -> Result<Tensor> {
148 let (b_sz, q_len, _) = xs.dims3()?;
149
150 let query_states = self.q_proj.forward(xs)?;
151 let key_states = self.k_proj.forward(xs)?;
152 let value_states = self.v_proj.forward(xs)?;
153
154 let query_states = query_states
155 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
156 .transpose(1, 2)?
157 .contiguous()?;
158 let key_states = key_states
159 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
160 .transpose(1, 2)?
161 .contiguous()?;
162 let value_states = value_states
163 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
164 .transpose(1, 2)?;
165
166 let (query_states, key_states) =
167 self.rotary_emb
168 .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
169
170 let (key_states, value_states) = match &self.kv_cache {
171 None => (key_states, value_states),
172 Some((prev_k, prev_v)) => {
173 let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
174 let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
175 (key_states, value_states)
176 }
177 };
178 self.kv_cache = Some((key_states.clone(), value_states.clone()));
179
180 let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
181 let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
182
183 let attn_output = {
184 let scale = 1f64 / f64::sqrt(self.head_dim as f64);
185 let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
186
187 let attn_weights = match attention_mask {
188 None => attn_weights,
189 Some(mask) => attn_weights.broadcast_add(mask)?,
190 };
191 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
192 attn_weights.matmul(&value_states)?
193 };
194 attn_output
195 .transpose(1, 2)?
196 .reshape((b_sz, q_len, self.hidden_size))?
197 .apply(&self.o_proj)
198 }
199
200 fn clear_kv_cache(&mut self) {
201 self.kv_cache = None
202 }
203}
204
205#[derive(Debug, Clone)]
206struct DecoderLayer {
207 self_attn: Attention,
208 mlp: MLP,
209 input_layernorm: RmsNorm,
210 post_attention_layernorm: RmsNorm,
211}
212
213impl DecoderLayer {
214 fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
215 let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
216 let mlp = MLP::new(cfg, vb.pp("mlp"))?;
217 let input_layernorm =
218 RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
219 let post_attention_layernorm = RmsNorm::new(
220 cfg.hidden_size,
221 cfg.rms_norm_eps,
222 vb.pp("post_attention_layernorm"),
223 )?;
224 Ok(Self {
225 self_attn,
226 mlp,
227 input_layernorm,
228 post_attention_layernorm,
229 })
230 }
231
232 fn forward(
233 &mut self,
234 xs: &Tensor,
235 attention_mask: Option<&Tensor>,
236 seqlen_offset: usize,
237 ) -> Result<Tensor> {
238 let residual = xs;
239 let xs = self.input_layernorm.forward(xs)?;
240 let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
241 let xs = (xs + residual)?;
242 let residual = &xs;
243 let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
244 residual + xs
245 }
246
247 fn clear_kv_cache(&mut self) {
248 self.self_attn.clear_kv_cache()
249 }
250}
251
252#[derive(Debug, Clone)]
253pub struct Model {
254 embed_tokens: Embedding,
255 layers: Vec<DecoderLayer>,
256 norm: RmsNorm,
257 lm_head: Linear,
258 sliding_window: Option<usize>,
259 device: Device,
260}
261
262impl Model {
263 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
264 let vb_m = vb.pp("model");
265 let embed_tokens =
266 Embedding::new(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
267 let rotary_emb = Arc::new(RotaryEmbedding::new(cfg, vb_m.device())?);
268 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
269 let vb_l = vb_m.pp("layers");
270 for layer_idx in 0..cfg.num_hidden_layers {
271 let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
272 layers.push(layer)
273 }
274 let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
275 let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
276 Ok(Self {
277 embed_tokens,
278 layers,
279 norm,
280 lm_head,
281 sliding_window: cfg.sliding_window,
282 device: vb.device().clone(),
283 })
284 }
285
286 fn prepare_decoder_attention_mask(
287 &self,
288 tgt_len: usize,
289 seqlen_offset: usize,
290 ) -> Result<Tensor> {
291 let sliding_window = self.sliding_window.unwrap_or(tgt_len + 1);
292 let mask: Vec<_> = (0..tgt_len)
293 .flat_map(|i| {
294 (0..tgt_len).map(move |j| {
295 if i < j || j + sliding_window < i {
296 f32::NEG_INFINITY
297 } else {
298 0.
299 }
300 })
301 })
302 .collect();
303 let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
304 let mask = if seqlen_offset > 0 {
305 let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
306 Tensor::cat(&[&mask0, &mask], D::Minus1)?
307 } else {
308 mask
309 };
310 mask.expand((1, 1, tgt_len, tgt_len + seqlen_offset))?
311 .to_dtype(DType::F32)
312 }
313
314 pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
315 let (_b_size, seq_len) = input_ids.dims2()?;
316 let attention_mask = if seq_len <= 1 {
317 None
318 } else {
319 let mask = self.prepare_decoder_attention_mask(seq_len, seqlen_offset)?;
320 Some(mask)
321 };
322 let mut xs = self.embed_tokens.forward(input_ids)?;
323 for layer in self.layers.iter_mut() {
324 xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
325 }
326 xs.narrow(1, seq_len - 1, 1)?
327 .contiguous()?
328 .apply(&self.norm)?
329 .apply(&self.lm_head)
330 }
331
332 pub fn clear_kv_cache(&mut self) {
333 for layer in self.layers.iter_mut() {
334 layer.clear_kv_cache()
335 }
336 }
337}