1use std::sync::Arc;
8
9use candle::{DType, Device, Module, Result, Tensor, D};
10use candle_nn::{linear_b as linear, Activation, Linear, VarBuilder};
11
12#[derive(serde::Deserialize, Debug, Clone)]
13pub struct Config {
14 pub attention_bias: bool,
15 pub head_dim: usize,
16 pub hidden_activation: Activation,
17 pub hidden_size: usize,
18 pub intermediate_size: usize,
19 pub num_attention_heads: usize,
20 pub num_hidden_layers: usize,
21 pub num_key_value_heads: usize,
22 pub rms_norm_eps: f64,
23 pub rope_theta: f64,
24 pub vocab_size: usize,
25 pub final_logit_softcapping: Option<f64>,
26 pub attn_logit_softcapping: Option<f64>,
27 pub query_pre_attn_scalar: usize,
28 pub sliding_window: usize,
29 pub sliding_window_pattern: usize,
30 pub max_position_embeddings: usize,
31}
32
33#[derive(Debug, Clone)]
34struct RmsNorm {
35 weight: Tensor,
36 eps: f64,
37}
38
39impl RmsNorm {
40 fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
41 let weight = vb.get(dim, "weight")?;
42 Ok(Self { weight, eps })
43 }
44}
45
46impl Module for RmsNorm {
47 fn forward(&self, x: &Tensor) -> Result<Tensor> {
48 let x_dtype = x.dtype();
49 let internal_dtype = match x_dtype {
50 DType::F16 | DType::BF16 => DType::F32,
51 d => d,
52 };
53 let hidden_size = x.dim(D::Minus1)?;
54 let x = x.to_dtype(internal_dtype)?;
55 let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
56 let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
57 x_normed
58 .to_dtype(x_dtype)?
59 .broadcast_mul(&(&self.weight + 1.0)?)
60 }
61}
62
63#[derive(Debug, Clone)]
64struct RotaryEmbedding {
65 sin: Tensor,
66 cos: Tensor,
67}
68
69impl RotaryEmbedding {
70 fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
71 let dim = cfg.head_dim;
72 let max_seq_len = cfg.max_position_embeddings;
73 let inv_freq: Vec<_> = (0..dim)
74 .step_by(2)
75 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
76 .collect();
77 let inv_freq_len = inv_freq.len();
78 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
79 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
80 .to_dtype(dtype)?
81 .reshape((max_seq_len, 1))?;
82 let freqs = t.matmul(&inv_freq)?;
83 Ok(Self {
84 sin: freqs.sin()?,
85 cos: freqs.cos()?,
86 })
87 }
88
89 fn apply_rotary_emb_qkv(
90 &self,
91 q: &Tensor,
92 k: &Tensor,
93 seqlen_offset: usize,
94 ) -> Result<(Tensor, Tensor)> {
95 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
96 let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
97 let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
98 let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
99 let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
100 Ok((q_embed, k_embed))
101 }
102}
103
104#[derive(Debug, Clone)]
105#[allow(clippy::upper_case_acronyms)]
106struct MLP {
107 gate_proj: Linear,
108 up_proj: Linear,
109 down_proj: Linear,
110 act_fn: candle_nn::Activation,
111}
112
113impl MLP {
114 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
115 let hidden_sz = cfg.hidden_size;
116 let intermediate_sz = cfg.intermediate_size;
117 let gate_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("gate_proj"))?;
118 let up_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("up_proj"))?;
119 let down_proj = linear(intermediate_sz, hidden_sz, false, vb.pp("down_proj"))?;
120 Ok(Self {
121 gate_proj,
122 up_proj,
123 down_proj,
124 act_fn: cfg.hidden_activation,
125 })
126 }
127}
128
129impl Module for MLP {
130 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
131 let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
132 let rhs = xs.apply(&self.up_proj)?;
133 (lhs * rhs)?.apply(&self.down_proj)
134 }
135}
136
137#[derive(Debug, Clone)]
138enum KvCache {
139 Normal(candle_nn::kv_cache::KvCache),
140 Rotating(candle_nn::kv_cache::RotatingKvCache),
141}
142
143#[derive(Debug, Clone)]
144struct Attention {
145 q_proj: Linear,
146 k_proj: Linear,
147 v_proj: Linear,
148 o_proj: Linear,
149 q_norm: RmsNorm,
150 k_norm: RmsNorm,
151 num_heads: usize,
152 num_kv_heads: usize,
153 num_kv_groups: usize,
154 head_dim: usize,
155 attn_logit_softcapping: Option<f64>,
156 rotary_emb: Arc<RotaryEmbedding>,
157 kv_cache: KvCache,
158 use_flash_attn: bool,
159}
160
161impl Attention {
162 fn new(
163 rotary_emb: Arc<RotaryEmbedding>,
164 use_flash_attn: bool,
165 is_sliding: bool,
166 cfg: &Config,
167 vb: VarBuilder,
168 ) -> Result<Self> {
169 let hidden_sz = cfg.hidden_size;
170 let num_heads = cfg.num_attention_heads;
171 let num_kv_heads = cfg.num_key_value_heads;
172 let num_kv_groups = num_heads / num_kv_heads;
173 let head_dim = cfg.head_dim;
174 let bias = cfg.attention_bias;
175 let q_proj = linear(hidden_sz, num_heads * head_dim, bias, vb.pp("q_proj"))?;
176 let k_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("k_proj"))?;
177 let v_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("v_proj"))?;
178 let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?;
179 let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?;
180 let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?;
181 let kv_cache = if is_sliding {
182 KvCache::Rotating(candle_nn::kv_cache::RotatingKvCache::new(
183 2,
184 cfg.sliding_window,
185 ))
186 } else {
187 KvCache::Normal(candle_nn::kv_cache::KvCache::new(2, cfg.sliding_window))
188 };
189 Ok(Self {
190 q_proj,
191 k_proj,
192 v_proj,
193 o_proj,
194 q_norm,
195 k_norm,
196 num_heads,
197 num_kv_heads,
198 num_kv_groups,
199 head_dim,
200 attn_logit_softcapping: cfg.attn_logit_softcapping,
201 rotary_emb,
202 kv_cache,
203 use_flash_attn,
204 })
205 }
206
207 fn forward(
208 &mut self,
209 xs: &Tensor,
210 attention_mask: Option<&Tensor>,
211 seqlen_offset: usize,
212 ) -> Result<Tensor> {
213 let (b_sz, q_len, _) = xs.dims3()?;
214
215 let query_states = self.q_proj.forward(xs)?;
216 let key_states = self.k_proj.forward(xs)?;
217 let value_states = self.v_proj.forward(xs)?;
218
219 let query_states = query_states
220 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
221 .transpose(1, 2)?;
222 let key_states = key_states
223 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
224 .transpose(1, 2)?;
225 let value_states = value_states
226 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
227 .transpose(1, 2)?;
228 let query_states = self.q_norm.forward(&query_states)?;
229 let key_states = self.k_norm.forward(&key_states)?;
230
231 let (query_states, key_states) =
232 self.rotary_emb
233 .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
234
235 let (key_states, value_states) = match &mut self.kv_cache {
236 KvCache::Normal(cache) => cache.append(&key_states, &value_states)?,
237 KvCache::Rotating(cache) => cache.append(&key_states, &value_states)?,
238 };
239
240 let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
241 let value_states =
242 crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
243
244 let attn_output = if self.use_flash_attn {
245 let q = query_states.transpose(1, 2)?;
247 let k = key_states.transpose(1, 2)?;
248 let v = value_states.transpose(1, 2)?;
249 let scale = 1f32 / (self.head_dim as f32).sqrt();
250 flash_attn(&q, &k, &v, scale, attention_mask.is_some())?.transpose(1, 2)?
251 } else {
252 let scale = 1f64 / f64::sqrt(self.head_dim as f64);
253 let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
254
255 let attn_weights = match self.attn_logit_softcapping {
256 None => attn_weights,
257 Some(sc) => ((attn_weights / sc)?.tanh()? * sc)?,
258 };
259
260 let attn_weights = match attention_mask {
261 None => attn_weights,
262 Some(mask) => attn_weights.broadcast_add(mask)?,
263 };
264 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
265 attn_weights.matmul(&value_states)?
266 };
267 attn_output
268 .transpose(1, 2)?
269 .reshape((b_sz, q_len, ()))?
270 .apply(&self.o_proj)
271 }
272
273 fn clear_kv_cache(&mut self) {
274 match &mut self.kv_cache {
275 KvCache::Normal(c) => c.reset(),
276 KvCache::Rotating(c) => c.reset(),
277 }
278 }
279}
280
281#[cfg(feature = "flash-attn")]
282fn flash_attn(
283 q: &Tensor,
284 k: &Tensor,
285 v: &Tensor,
286 softmax_scale: f32,
287 causal: bool,
288) -> Result<Tensor> {
289 candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
290}
291
292#[cfg(not(feature = "flash-attn"))]
293fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
294 unimplemented!("compile with '--features flash-attn'")
295}
296
297#[derive(Debug, Clone)]
298struct DecoderLayer {
299 self_attn: Attention,
300 mlp: MLP,
301 input_layernorm: RmsNorm,
302 pre_feedforward_layernorm: RmsNorm,
303 post_feedforward_layernorm: RmsNorm,
304 post_attention_layernorm: RmsNorm,
305}
306
307impl DecoderLayer {
308 fn new(
309 rotary_emb: Arc<RotaryEmbedding>,
310 use_flash_attn: bool,
311 is_sliding: bool,
312 cfg: &Config,
313 vb: VarBuilder,
314 ) -> Result<Self> {
315 let self_attn = Attention::new(
316 rotary_emb,
317 use_flash_attn,
318 is_sliding,
319 cfg,
320 vb.pp("self_attn"),
321 )?;
322 let mlp = MLP::new(cfg, vb.pp("mlp"))?;
323 let input_layernorm =
324 RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
325 let pre_feedforward_layernorm = RmsNorm::new(
326 cfg.hidden_size,
327 cfg.rms_norm_eps,
328 vb.pp("pre_feedforward_layernorm"),
329 )?;
330 let post_feedforward_layernorm = RmsNorm::new(
331 cfg.hidden_size,
332 cfg.rms_norm_eps,
333 vb.pp("post_feedforward_layernorm"),
334 )?;
335 let post_attention_layernorm = RmsNorm::new(
336 cfg.hidden_size,
337 cfg.rms_norm_eps,
338 vb.pp("post_attention_layernorm"),
339 )?;
340 Ok(Self {
341 self_attn,
342 mlp,
343 input_layernorm,
344 pre_feedforward_layernorm,
345 post_feedforward_layernorm,
346 post_attention_layernorm,
347 })
348 }
349
350 fn forward(
351 &mut self,
352 xs: &Tensor,
353 attention_mask: Option<&Tensor>,
354 seqlen_offset: usize,
355 ) -> Result<Tensor> {
356 let residual = xs;
357 let xs = self.input_layernorm.forward(xs)?;
358 let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
359 let xs = xs.apply(&self.post_attention_layernorm)?;
360 let xs = (xs + residual)?;
361 let residual = &xs;
362 let xs = xs.apply(&self.pre_feedforward_layernorm)?;
363 let xs = xs.apply(&self.mlp)?;
364 let xs = xs.apply(&self.post_feedforward_layernorm)?;
365 residual + xs
366 }
367
368 fn clear_kv_cache(&mut self) {
369 self.self_attn.clear_kv_cache()
370 }
371}
372
373#[derive(Debug, Clone)]
374pub struct Model {
375 embed_tokens: candle_nn::Embedding,
376 layers: Vec<DecoderLayer>,
377 norm: RmsNorm,
378 lm_head: Linear,
379 final_logit_softcapping: Option<f64>,
380 device: Device,
381 dtype: DType,
382 hidden_size: usize,
383 sliding_window: usize,
384}
385
386impl Model {
387 pub fn new(use_flash_attn: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> {
388 let vb_m = vb.pp("model");
389 let embed_tokens =
390 candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
391 let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
392 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
393 let vb_l = vb_m.pp("layers");
394 for layer_idx in 0..cfg.num_hidden_layers {
395 let is_sliding = (layer_idx + 1) % cfg.sliding_window_pattern > 0;
396 let layer = DecoderLayer::new(
397 rotary_emb.clone(),
398 use_flash_attn,
399 is_sliding,
400 cfg,
401 vb_l.pp(layer_idx),
402 )?;
403 layers.push(layer)
404 }
405 let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
406 let lm_head = Linear::new(embed_tokens.embeddings().clone(), None);
407 Ok(Self {
408 embed_tokens,
409 layers,
410 norm,
411 lm_head,
412 final_logit_softcapping: cfg.final_logit_softcapping,
413 device: vb.device().clone(),
414 dtype: vb.dtype(),
415 hidden_size: cfg.hidden_size,
416 sliding_window: cfg.sliding_window,
417 })
418 }
419
420 fn prepare_decoder_attention_mask(
421 &self,
422 b_size: usize,
423 tgt_len: usize,
424 seqlen_offset: usize,
425 ) -> Result<Tensor> {
426 let mask: Vec<_> = match Some(self.sliding_window) {
427 None => (0..tgt_len)
428 .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
429 .collect(),
430 Some(sliding_window) => (0..tgt_len)
431 .flat_map(|i| {
432 (0..tgt_len).map(move |j| {
433 if i < j || j + sliding_window < i {
434 f32::NEG_INFINITY
435 } else {
436 0.
437 }
438 })
439 })
440 .collect(),
441 };
442 let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
443 let mask = if seqlen_offset > 0 {
444 let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
445 Tensor::cat(&[&mask0, &mask], D::Minus1)?
446 } else {
447 mask
448 };
449 mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
450 .to_dtype(self.dtype)
451 }
452
453 pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
454 let (b_size, seq_len) = input_ids.dims2()?;
455 let attention_mask = if seq_len <= 1 {
456 None
457 } else {
458 let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
459 Some(mask)
460 };
461 let xs = self.embed_tokens.forward(input_ids)?;
462 let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
463 for layer in self.layers.iter_mut() {
464 xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
465 }
466 let logits = xs
467 .narrow(1, seq_len - 1, 1)?
468 .apply(&self.norm)?
469 .apply(&self.lm_head)?;
470 let logits = match self.final_logit_softcapping {
471 None => logits,
472 Some(sc) => ((logits / sc)?.tanh()? * sc)?,
473 };
474
475 Ok(logits)
476 }
477
478 pub fn clear_kv_cache(&mut self) {
479 for layer in self.layers.iter_mut() {
480 layer.clear_kv_cache()
481 }
482 }
483}