1use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm};
18use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
19use candle_nn::{Activation, VarBuilder};
20use std::sync::Arc;
21
22#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
23pub struct Config {
24 pub vocab_size: usize,
25 pub hidden_size: usize,
26 pub intermediate_size: usize,
27 pub num_hidden_layers: usize,
28 pub num_attention_heads: usize,
29 pub num_key_value_heads: usize,
30 pub max_position_embeddings: usize,
31 pub sliding_window: usize,
32 pub max_window_layers: usize,
33 pub tie_word_embeddings: bool,
34 pub rope_theta: f64,
35 pub rms_norm_eps: f64,
36 pub use_sliding_window: bool,
37 pub hidden_act: Activation,
38}
39
40#[derive(Debug, Clone)]
41struct RotaryEmbedding {
42 sin: Tensor,
43 cos: Tensor,
44}
45
46impl RotaryEmbedding {
47 fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
48 let dim = cfg.hidden_size / cfg.num_attention_heads;
49 let max_seq_len = cfg.max_position_embeddings;
50 let inv_freq: Vec<_> = (0..dim)
51 .step_by(2)
52 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
53 .collect();
54 let inv_freq_len = inv_freq.len();
55 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
56 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
57 .to_dtype(dtype)?
58 .reshape((max_seq_len, 1))?;
59 let freqs = t.matmul(&inv_freq)?;
60 Ok(Self {
61 sin: freqs.sin()?,
62 cos: freqs.cos()?,
63 })
64 }
65
66 fn apply_rotary_emb_qkv(
67 &self,
68 q: &Tensor,
69 k: &Tensor,
70 seqlen_offset: usize,
71 ) -> Result<(Tensor, Tensor)> {
72 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
73 let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
74 let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
75 let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
76 let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
77 Ok((q_embed, k_embed))
78 }
79}
80
81#[derive(Debug, Clone)]
82#[allow(clippy::upper_case_acronyms)]
83struct MLP {
84 gate_proj: Linear,
85 up_proj: Linear,
86 down_proj: Linear,
87 act_fn: Activation,
88}
89
90impl MLP {
91 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
92 let hidden_sz = cfg.hidden_size;
93 let intermediate_sz = cfg.intermediate_size;
94 let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
95 let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
96 let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
97 Ok(Self {
98 gate_proj,
99 up_proj,
100 down_proj,
101 act_fn: cfg.hidden_act,
102 })
103 }
104}
105
106impl Module for MLP {
107 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
108 let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
109 let rhs = xs.apply(&self.up_proj)?;
110 (lhs * rhs)?.apply(&self.down_proj)
111 }
112}
113
114#[derive(Debug, Clone)]
115struct Attention {
116 q_proj: Linear,
117 k_proj: Linear,
118 v_proj: Linear,
119 o_proj: Linear,
120 num_heads: usize,
121 num_kv_heads: usize,
122 num_kv_groups: usize,
123 head_dim: usize,
124 hidden_size: usize,
125 rotary_emb: Arc<RotaryEmbedding>,
126 kv_cache: Option<(Tensor, Tensor)>,
127}
128
129impl Attention {
130 fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
131 let hidden_sz = cfg.hidden_size;
132 let num_heads = cfg.num_attention_heads;
133 let num_kv_heads = cfg.num_key_value_heads;
134 let num_kv_groups = num_heads / num_kv_heads;
135 let head_dim = hidden_sz / num_heads;
136 let q_proj = linear(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
137 let k_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
138 let v_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
139 let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
140 Ok(Self {
141 q_proj,
142 k_proj,
143 v_proj,
144 o_proj,
145 num_heads,
146 num_kv_heads,
147 num_kv_groups,
148 head_dim,
149 hidden_size: hidden_sz,
150 rotary_emb,
151 kv_cache: None,
152 })
153 }
154
155 fn forward(
156 &mut self,
157 xs: &Tensor,
158 attention_mask: Option<&Tensor>,
159 seqlen_offset: usize,
160 ) -> Result<Tensor> {
161 let (b_sz, q_len, _) = xs.dims3()?;
162
163 let query_states = self.q_proj.forward(xs)?;
164 let key_states = self.k_proj.forward(xs)?;
165 let value_states = self.v_proj.forward(xs)?;
166
167 let query_states = query_states
168 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
169 .transpose(1, 2)?;
170 let key_states = key_states
171 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
172 .transpose(1, 2)?;
173 let value_states = value_states
174 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
175 .transpose(1, 2)?;
176
177 let (query_states, key_states) =
178 self.rotary_emb
179 .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
180
181 let (key_states, value_states) = match &self.kv_cache {
182 None => (key_states, value_states),
183 Some((prev_k, prev_v)) => {
184 let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
185 let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
186 (key_states, value_states)
187 }
188 };
189 self.kv_cache = Some((key_states.clone(), value_states.clone()));
190
191 let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
192 let value_states =
193 crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
194
195 let attn_output = {
196 let scale = 1f64 / f64::sqrt(self.head_dim as f64);
197 let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
198
199 let attn_weights = match attention_mask {
200 None => attn_weights,
201 Some(mask) => attn_weights.broadcast_add(mask)?,
202 };
203 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
204 attn_weights.matmul(&value_states)?
205 };
206 attn_output
207 .transpose(1, 2)?
208 .reshape((b_sz, q_len, self.hidden_size))?
209 .apply(&self.o_proj)
210 }
211
212 fn clear_kv_cache(&mut self) {
213 self.kv_cache = None
214 }
215}
216
217#[derive(Debug, Clone)]
218struct DecoderLayer {
219 self_attn: Attention,
220 mlp: MLP,
221 input_layernorm: RmsNorm,
222 post_attention_layernorm: RmsNorm,
223}
224
225impl DecoderLayer {
226 fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
227 let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
228 let mlp = MLP::new(cfg, vb.pp("mlp"))?;
229 let input_layernorm =
230 RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
231 let post_attention_layernorm = RmsNorm::new(
232 cfg.hidden_size,
233 cfg.rms_norm_eps,
234 vb.pp("post_attention_layernorm"),
235 )?;
236 Ok(Self {
237 self_attn,
238 mlp,
239 input_layernorm,
240 post_attention_layernorm,
241 })
242 }
243
244 fn forward(
245 &mut self,
246 xs: &Tensor,
247 attention_mask: Option<&Tensor>,
248 seqlen_offset: usize,
249 ) -> Result<Tensor> {
250 let residual = xs;
251 let xs = self.input_layernorm.forward(xs)?;
252 let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
253 let xs = (xs + residual)?;
254 let residual = &xs;
255 let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
256 residual + xs
257 }
258
259 fn clear_kv_cache(&mut self) {
260 self.self_attn.clear_kv_cache()
261 }
262}
263
264#[derive(Debug, Clone)]
265pub struct Model {
266 embed_tokens: candle_nn::Embedding,
267 layers: Vec<DecoderLayer>,
268 norm: RmsNorm,
269 sliding_window: usize,
270 device: Device,
271 dtype: DType,
272}
273
274impl Model {
275 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
276 let vb_m = vb.pp("model");
277 let embed_tokens =
278 candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
279 let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
280 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
281 let vb_l = vb_m.pp("layers");
282 for layer_idx in 0..cfg.num_hidden_layers {
283 let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
284 layers.push(layer)
285 }
286 let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
287 Ok(Self {
288 embed_tokens,
289 layers,
290 norm,
291 sliding_window: cfg.sliding_window,
292 device: vb.device().clone(),
293 dtype: vb.dtype(),
294 })
295 }
296
297 fn prepare_causal_attention_mask(
298 &self,
299 b_size: usize,
300 tgt_len: usize,
301 seqlen_offset: usize,
302 ) -> Result<Tensor> {
303 let mask: Vec<_> = (0..tgt_len)
305 .flat_map(|i| {
306 (0..tgt_len).map(move |j| {
307 if i < j || j + self.sliding_window < i {
308 f32::NEG_INFINITY
309 } else {
310 0.
311 }
312 })
313 })
314 .collect();
315 let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
316 let mask = if seqlen_offset > 0 {
317 let mask0 = Tensor::zeros((tgt_len, seqlen_offset), self.dtype, &self.device)?;
318 Tensor::cat(&[&mask0, &mask], D::Minus1)?
319 } else {
320 mask
321 };
322 mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
323 .to_dtype(self.dtype)
324 }
325
326 fn prepare_attention_mask(&self, attn_mask: &Tensor) -> Result<Tensor> {
327 let (b_sz, sql_len) = attn_mask.dims2()?;
328 let mut mask: Vec<Tensor> = vec![];
329 for b in 0..b_sz {
330 mask.push(attn_mask.i((b, ..))?.expand((1, 1, sql_len, sql_len))?);
331 }
332 let mask = Tensor::cat(&mask, 0)?;
333 let on_true = mask.zeros_like()?.to_dtype(self.dtype)?;
334 let on_false = Tensor::new(f32::NEG_INFINITY, &self.device)?
335 .broadcast_as(mask.shape())?
336 .to_dtype(self.dtype)?;
337 mask.where_cond(&on_true, &on_false)
338 }
339
340 pub fn forward(
341 &mut self,
342 input_ids: &Tensor,
343 seqlen_offset: usize,
344 attn_mask: Option<&Tensor>,
345 ) -> Result<Tensor> {
346 let (b_size, seq_len) = input_ids.dims2()?;
347 let attention_mask: Option<Tensor> = match attn_mask {
348 Some(mask) => Some(self.prepare_attention_mask(mask)?),
349 None => {
350 if seq_len <= 1 {
351 None
352 } else {
353 Some(self.prepare_causal_attention_mask(b_size, seq_len, seqlen_offset)?)
354 }
355 }
356 };
357 let mut xs = self.embed_tokens.forward(input_ids)?;
358 for layer in self.layers.iter_mut() {
359 xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
360 }
361 xs.apply(&self.norm)
362 }
363
364 pub fn clear_kv_cache(&mut self) {
365 for layer in self.layers.iter_mut() {
366 layer.clear_kv_cache()
367 }
368 }
369}
370
371#[derive(Debug, Clone)]
372pub struct ModelForCausalLM {
373 base_model: Model,
374 lm_head: Linear,
375}
376
377impl ModelForCausalLM {
378 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
379 let base_model = Model::new(cfg, vb.clone())?;
380 let lm_head = if vb.contains_tensor("lm_head.weight") {
381 linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
382 } else {
383 Linear::from_weights(base_model.embed_tokens.embeddings().clone(), None)
384 };
385 Ok(Self {
386 base_model,
387 lm_head,
388 })
389 }
390
391 pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
392 let (_b_size, seq_len) = input_ids.dims2()?;
393 self.base_model
394 .forward(input_ids, seqlen_offset, None)?
395 .narrow(1, seq_len - 1, 1)?
396 .apply(&self.lm_head)
397 }
398
399 pub fn clear_kv_cache(&mut self) {
400 self.base_model.clear_kv_cache()
401 }
402}