1use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
8use candle::{DType, Device, IndexOp, Result, Tensor, D};
9use candle_nn::{embedding, Embedding, Module, VarBuilder};
10use std::{collections::HashMap, f32::consts::PI};
11
12pub const DEFAULT_MAX_SEQ_LEN: usize = 4096;
13
14#[derive(Debug, Clone, serde::Deserialize, Default)]
15pub enum Llama3RopeType {
16 #[serde(rename = "llama3")]
17 Llama3,
18 #[default]
19 #[serde(rename = "default")]
20 Default,
21}
22
23#[derive(Debug, Clone, serde::Deserialize, Default)]
24pub struct Llama3RopeConfig {
25 pub factor: f32,
26 pub low_freq_factor: f32,
27 pub high_freq_factor: f32,
28 pub original_max_position_embeddings: usize,
29 pub rope_type: Llama3RopeType,
30}
31#[derive(Debug, Clone, serde::Deserialize)]
32#[serde(untagged)]
33pub enum LlamaEosToks {
34 Single(u32),
35 Multiple(Vec<u32>),
36}
37
38#[derive(Debug, Clone, serde::Deserialize)]
39pub struct LlamaConfig {
40 pub hidden_size: usize,
41 pub intermediate_size: usize,
42 pub vocab_size: usize,
43 pub num_hidden_layers: usize,
44 pub num_attention_heads: usize,
45 pub num_key_value_heads: Option<usize>,
46 pub rms_norm_eps: f64,
47 #[serde(default = "default_rope")]
48 pub rope_theta: f32,
49 pub bos_token_id: Option<u32>,
50 pub eos_token_id: Option<LlamaEosToks>,
51 pub rope_scaling: Option<Llama3RopeConfig>,
52 pub max_position_embeddings: usize,
53 pub tie_word_embeddings: Option<bool>,
54}
55
56impl LlamaConfig {
57 pub fn num_key_value_heads(&self) -> usize {
58 self.num_key_value_heads.unwrap_or(self.num_attention_heads)
59 }
60}
61
62fn default_rope() -> f32 {
63 10_000.0
64}
65
66impl LlamaConfig {
67 pub fn into_config(self, use_flash_attn: bool) -> Config {
68 Config {
69 hidden_size: self.hidden_size,
70 intermediate_size: self.intermediate_size,
71 vocab_size: self.vocab_size,
72 num_hidden_layers: self.num_hidden_layers,
73 num_attention_heads: self.num_attention_heads,
74 num_key_value_heads: self.num_key_value_heads(),
75 rms_norm_eps: self.rms_norm_eps,
76 rope_theta: self.rope_theta,
77 use_flash_attn,
78 bos_token_id: self.bos_token_id,
79 eos_token_id: self.eos_token_id,
80 rope_scaling: self.rope_scaling,
81 max_position_embeddings: self.max_position_embeddings,
82 tie_word_embeddings: self.tie_word_embeddings.unwrap_or(false),
83 }
84 }
85}
86
87#[derive(Debug, Clone)]
88pub struct Config {
89 pub hidden_size: usize,
90 pub intermediate_size: usize,
91 pub vocab_size: usize,
92 pub num_hidden_layers: usize,
93 pub num_attention_heads: usize,
94 pub num_key_value_heads: usize,
95 pub use_flash_attn: bool,
96 pub rms_norm_eps: f64,
97 pub rope_theta: f32,
98 pub bos_token_id: Option<u32>,
99 pub eos_token_id: Option<LlamaEosToks>,
100 pub rope_scaling: Option<Llama3RopeConfig>,
101 pub max_position_embeddings: usize,
102 pub tie_word_embeddings: bool,
103}
104
105impl Config {
106 pub fn config_7b_v1(use_flash_attn: bool) -> Self {
107 Self {
108 hidden_size: 4096,
109 intermediate_size: 11008,
110 vocab_size: 32000,
111 num_hidden_layers: 32,
112 num_attention_heads: 32,
113 num_key_value_heads: 32,
114 use_flash_attn,
115 rms_norm_eps: 1e-6,
116 rope_theta: 10_000.0,
117 bos_token_id: None,
118 eos_token_id: None,
119 rope_scaling: None,
120 max_position_embeddings: DEFAULT_MAX_SEQ_LEN,
121 tie_word_embeddings: false,
122 }
123 }
124
125 pub fn config_7b_v2(use_flash_attn: bool) -> Self {
126 Self {
127 hidden_size: 4096,
128 intermediate_size: 11008,
129 vocab_size: 32000,
130 num_hidden_layers: 32,
131 num_attention_heads: 32,
132 num_key_value_heads: 32,
133 use_flash_attn,
134 rms_norm_eps: 1e-5,
135 rope_theta: 10_000.0,
136 bos_token_id: None,
137 eos_token_id: None,
138 rope_scaling: None,
139 max_position_embeddings: DEFAULT_MAX_SEQ_LEN,
140 tie_word_embeddings: false,
141 }
142 }
143}
144
145#[derive(Debug, Clone)]
146pub struct Cache {
147 masks: HashMap<usize, Tensor>,
148 pub use_kv_cache: bool,
149 kvs: Vec<Option<(Tensor, Tensor)>>,
150 cos: Tensor,
151 sin: Tensor,
152 device: Device,
153}
154
155fn calculate_default_inv_freq(cfg: &Config) -> Vec<f32> {
156 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
157 (0..head_dim)
158 .step_by(2)
159 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
160 .collect()
161}
162
163impl Cache {
164 pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> {
165 let theta = match &config.rope_scaling {
167 None
168 | Some(Llama3RopeConfig {
169 rope_type: Llama3RopeType::Default,
170 ..
171 }) => calculate_default_inv_freq(config),
172 Some(rope_scaling) => {
173 let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
174 / rope_scaling.low_freq_factor;
175 let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
176 / rope_scaling.high_freq_factor;
177
178 calculate_default_inv_freq(config)
179 .into_iter()
180 .map(|freq| {
181 let wavelen = 2. * PI / freq;
182 if wavelen < high_freq_wavelen {
183 freq
184 } else if wavelen > low_freq_wavelen {
185 freq / rope_scaling.factor
186 } else {
187 let smooth = (rope_scaling.original_max_position_embeddings as f32
188 / wavelen
189 - rope_scaling.low_freq_factor)
190 / (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor);
191 (1. - smooth) * freq / rope_scaling.factor + smooth * freq
192 }
193 })
194 .collect::<Vec<_>>()
195 }
196 };
197
198 let theta = Tensor::new(theta, device)?;
199
200 let idx_theta = Tensor::arange(0, config.max_position_embeddings as u32, device)?
201 .to_dtype(DType::F32)?
202 .reshape((config.max_position_embeddings, 1))?
203 .matmul(&theta.reshape((1, theta.elem_count()))?)?;
204 let cos = idx_theta.cos()?.to_dtype(dtype)?;
207 let sin = idx_theta.sin()?.to_dtype(dtype)?;
208 Ok(Self {
209 masks: HashMap::new(),
210 use_kv_cache,
211 kvs: vec![None; config.num_hidden_layers],
212 device: device.clone(),
213 cos,
214 sin,
215 })
216 }
217
218 fn mask(&mut self, t: usize) -> Result<Tensor> {
219 if let Some(mask) = self.masks.get(&t) {
220 Ok(mask.clone())
221 } else {
222 let mask: Vec<_> = (0..t)
223 .flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
224 .collect();
225 let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
226 self.masks.insert(t, mask.clone());
227 Ok(mask)
228 }
229 }
230}
231
232#[derive(Debug, Clone)]
233struct CausalSelfAttention {
234 q_proj: Linear,
235 k_proj: Linear,
236 v_proj: Linear,
237 o_proj: Linear,
238 num_attention_heads: usize,
239 num_key_value_heads: usize,
240 head_dim: usize,
241 use_flash_attn: bool,
242 span: tracing::Span,
243 span_rot: tracing::Span,
244 max_position_embeddings: usize,
245}
246
247#[cfg(feature = "flash-attn")]
248fn flash_attn(
249 q: &Tensor,
250 k: &Tensor,
251 v: &Tensor,
252 softmax_scale: f32,
253 causal: bool,
254) -> Result<Tensor> {
255 candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
256}
257
258#[cfg(not(feature = "flash-attn"))]
259fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
260 unimplemented!("compile with '--features flash-attn'")
261}
262
263impl CausalSelfAttention {
264 fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
265 let _enter = self.span_rot.enter();
266 let (_b_sz, _, seq_len, _hidden_size) = x.dims4()?;
267 let cos = cache.cos.narrow(0, index_pos, seq_len)?;
268 let sin = cache.sin.narrow(0, index_pos, seq_len)?;
269 candle_nn::rotary_emb::rope(x, &cos, &sin)
270 }
271
272 fn forward(
273 &self,
274 x: &Tensor,
275 index_pos: usize,
276 block_idx: usize,
277 cache: &mut Cache,
278 ) -> Result<Tensor> {
279 let _enter = self.span.enter();
280 let (b_sz, seq_len, hidden_size) = x.dims3()?;
281 let q = self.q_proj.forward(x)?;
282 let k = self.k_proj.forward(x)?;
283 let v = self.v_proj.forward(x)?;
284
285 let q = q
286 .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
287 .transpose(1, 2)?
288 .contiguous()?;
289 let k = k
290 .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
291 .transpose(1, 2)?
292 .contiguous()?;
293 let mut v = v
294 .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
295 .transpose(1, 2)?;
296
297 let q = self.apply_rotary_emb(&q, index_pos, cache)?;
298 let mut k = self.apply_rotary_emb(&k, index_pos, cache)?;
299
300 if cache.use_kv_cache {
301 if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] {
302 k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
303 v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
304 let k_seq_len = k.dims()[1];
305 if k_seq_len > self.max_position_embeddings {
306 k = k
307 .narrow(
308 D::Minus1,
309 k_seq_len - self.max_position_embeddings,
310 self.max_position_embeddings,
311 )?
312 .contiguous()?
313 }
314 let v_seq_len = v.dims()[1];
315 if v_seq_len > 2 * self.max_position_embeddings {
316 v = v
317 .narrow(
318 D::Minus1,
319 v_seq_len - self.max_position_embeddings,
320 self.max_position_embeddings,
321 )?
322 .contiguous()?
323 }
324 }
325 cache.kvs[block_idx] = Some((k.clone(), v.clone()))
326 }
327
328 let k = self.repeat_kv(k)?;
329 let v = self.repeat_kv(v)?;
330
331 let y = if self.use_flash_attn {
332 let q = q.transpose(1, 2)?;
334 let k = k.transpose(1, 2)?;
335 let v = v.transpose(1, 2)?;
336 let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
337 flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2)?
338 } else {
339 let in_dtype = q.dtype();
340 let q = q.to_dtype(DType::F32)?;
341 let k = k.to_dtype(DType::F32)?;
342 let v = v.to_dtype(DType::F32)?;
343 let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
344 let att = if seq_len == 1 {
345 att
346 } else {
347 let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
348 masked_fill(&att, &mask, f32::NEG_INFINITY)?
349 };
350
351 let att = candle_nn::ops::softmax_last_dim(&att)?;
352 att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
354 };
355 let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
356 let y = self.o_proj.forward(&y)?;
357 Ok(y)
358 }
359
360 fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
361 crate::utils::repeat_kv(x, self.num_attention_heads / self.num_key_value_heads)
362 }
363
364 fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
365 let span = tracing::span!(tracing::Level::TRACE, "attn");
366 let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
367 let size_in = cfg.hidden_size;
368 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
369 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
370 let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?;
371 let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?;
372 let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?;
373 let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?;
374 Ok(Self {
375 q_proj,
376 k_proj,
377 v_proj,
378 o_proj,
379 num_attention_heads: cfg.num_attention_heads,
380 num_key_value_heads: cfg.num_key_value_heads,
381 head_dim: cfg.hidden_size / cfg.num_attention_heads,
382 use_flash_attn: cfg.use_flash_attn,
383 span,
384 span_rot,
385 max_position_embeddings: cfg.max_position_embeddings,
386 })
387 }
388}
389
390fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
391 let shape = mask.shape();
392 let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
393 let m = mask.where_cond(&on_true, on_false)?;
394 Ok(m)
395}
396
397#[derive(Debug, Clone)]
398struct Mlp {
399 c_fc1: Linear,
400 c_fc2: Linear,
401 c_proj: Linear,
402 span: tracing::Span,
403}
404
405impl Mlp {
406 fn forward(&self, x: &Tensor) -> Result<Tensor> {
407 let _enter = self.span.enter();
408 let x = (candle_nn::ops::silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
409 self.c_proj.forward(&x)
410 }
411
412 fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
413 let span = tracing::span!(tracing::Level::TRACE, "mlp");
414 let h_size = cfg.hidden_size;
415 let i_size = cfg.intermediate_size;
416 let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
417 let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?;
418 let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?;
419 Ok(Self {
420 c_fc1,
421 c_fc2,
422 c_proj,
423 span,
424 })
425 }
426}
427
428#[derive(Debug, Clone)]
429struct Block {
430 rms_1: RmsNorm,
431 attn: CausalSelfAttention,
432 rms_2: RmsNorm,
433 mlp: Mlp,
434 span: tracing::Span,
435}
436
437impl Block {
438 fn forward(
439 &self,
440 x: &Tensor,
441 index_pos: usize,
442 block_idx: usize,
443 cache: &mut Cache,
444 ) -> Result<Tensor> {
445 let _enter = self.span.enter();
446 let residual = x;
447 let x = self.rms_1.forward(x)?;
448 let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?;
449 let residual = &x;
450 let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
451 Ok(x)
452 }
453
454 fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
455 let span = tracing::span!(tracing::Level::TRACE, "block");
456 let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?;
457 let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
458 let rms_1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
459 let rms_2 = RmsNorm::new(
460 cfg.hidden_size,
461 cfg.rms_norm_eps,
462 vb.pp("post_attention_layernorm"),
463 )?;
464 Ok(Self {
465 rms_1,
466 attn,
467 rms_2,
468 mlp,
469 span,
470 })
471 }
472}
473
474#[derive(Debug, Clone)]
475pub struct Llama {
476 wte: Embedding,
477 blocks: Vec<Block>,
478 ln_f: RmsNorm,
479 lm_head: Linear,
480}
481
482impl Llama {
483 pub fn embed(&self, x: &Tensor) -> Result<Tensor> {
485 self.wte.forward(x)
486 }
487 pub fn forward_input_embed(
489 &self,
490 input_embed: &Tensor,
491 index_pos: usize,
492 cache: &mut Cache,
493 ) -> Result<Tensor> {
494 let (_, seq_len, _) = input_embed.dims3()?;
495 let mut x = input_embed.clone();
496 for (block_idx, block) in self.blocks.iter().enumerate() {
497 x = block.forward(&x, index_pos, block_idx, cache)?;
498 }
499 let x = self.ln_f.forward(&x)?;
500 let x = x.i((.., seq_len - 1, ..))?.contiguous()?;
501 let logits = self.lm_head.forward(&x)?;
502 logits.to_dtype(DType::F32)
503 }
504
505 pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
506 let (_b_sz, seq_len) = x.dims2()?;
507 let mut x = self.wte.forward(x)?;
508 for (block_idx, block) in self.blocks.iter().enumerate() {
509 x = block.forward(&x, index_pos, block_idx, cache)?;
510 }
511 let x = self.ln_f.forward(&x)?;
512 let x = x.i((.., seq_len - 1, ..))?.contiguous()?;
513 let logits = self.lm_head.forward(&x)?;
514 logits.to_dtype(DType::F32)
515 }
516
517 pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
518 let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
519 let lm_head = if cfg.tie_word_embeddings {
520 Linear::from_weights(wte.embeddings().clone(), None)
521 } else {
522 linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
523 };
524 let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
525 let blocks: Vec<_> = (0..cfg.num_hidden_layers)
526 .map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cfg).unwrap())
527 .collect();
528
529 Ok(Self {
530 wte,
531 blocks,
532 ln_f,
533 lm_head,
534 })
535 }
536}