1use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
9use candle::{DType, Device, IndexOp, Result, Tensor, D};
10use candle_nn::{embedding, Embedding, Module, VarBuilder};
11use std::{collections::HashMap, f32::consts::PI};
12
13pub const DEFAULT_MAX_SEQ_LEN: usize = 4096;
14
15#[derive(Debug, Clone, serde::Deserialize, Default)]
16pub enum GraniteRopeType {
17 #[serde(rename = "granite")]
18 Granite,
19 #[default]
20 #[serde(rename = "default")]
21 Default,
22}
23
24#[derive(Debug, Clone, serde::Deserialize, Default)]
25pub struct GraniteRopeConfig {
26 pub factor: f32,
27 pub low_freq_factor: f32,
28 pub high_freq_factor: f32,
29 pub original_max_position_embeddings: usize,
30 pub rope_type: GraniteRopeType,
31}
32#[derive(Debug, Clone, serde::Deserialize)]
33#[serde(untagged)]
34pub enum GraniteEosToks {
35 Single(u32),
36 Multiple(Vec<u32>),
37}
38
39#[derive(Debug, Clone, serde::Deserialize)]
40pub struct GraniteConfig {
41 pub hidden_size: usize,
42 pub intermediate_size: usize,
43 pub vocab_size: usize,
44 pub num_hidden_layers: usize,
45 pub num_attention_heads: usize,
46 pub num_key_value_heads: Option<usize>,
47 pub rms_norm_eps: f64,
48 #[serde(default = "default_rope")]
49 pub rope_theta: f32,
50 pub bos_token_id: Option<u32>,
51 pub eos_token_id: Option<GraniteEosToks>,
52 pub rope_scaling: Option<GraniteRopeConfig>,
53 pub max_position_embeddings: usize,
54}
55
56impl GraniteConfig {
57 pub fn num_key_value_heads(&self) -> usize {
58 self.num_key_value_heads.unwrap_or(self.num_attention_heads)
59 }
60}
61
62fn default_rope() -> f32 {
63 10_000.0
64}
65
66impl GraniteConfig {
67 pub fn into_config(self, use_flash_attn: bool) -> Config {
68 Config {
69 hidden_size: self.hidden_size,
70 intermediate_size: self.intermediate_size,
71 vocab_size: self.vocab_size,
72 num_hidden_layers: self.num_hidden_layers,
73 num_attention_heads: self.num_attention_heads,
74 num_key_value_heads: self.num_key_value_heads(),
75 rms_norm_eps: self.rms_norm_eps,
76 rope_theta: self.rope_theta,
77 use_flash_attn,
78 bos_token_id: self.bos_token_id,
79 eos_token_id: self.eos_token_id,
80 rope_scaling: self.rope_scaling,
81 max_position_embeddings: self.max_position_embeddings,
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
87pub struct Config {
88 pub hidden_size: usize,
89 pub intermediate_size: usize,
90 pub vocab_size: usize,
91 pub num_hidden_layers: usize,
92 pub num_attention_heads: usize,
93 pub num_key_value_heads: usize,
94 pub use_flash_attn: bool,
95 pub rms_norm_eps: f64,
96 pub rope_theta: f32,
97 pub bos_token_id: Option<u32>,
98 pub eos_token_id: Option<GraniteEosToks>,
99 pub rope_scaling: Option<GraniteRopeConfig>,
100 pub max_position_embeddings: usize,
101}
102
103#[derive(Debug, Clone)]
104pub struct Cache {
105 masks: HashMap<usize, Tensor>,
106 pub use_kv_cache: bool,
107 kvs: Vec<Option<(Tensor, Tensor)>>,
108 cos: Tensor,
109 sin: Tensor,
110 device: Device,
111}
112
113fn calculate_default_inv_freq(cfg: &Config) -> Vec<f32> {
114 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
115 (0..head_dim)
116 .step_by(2)
117 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
118 .collect()
119}
120
121impl Cache {
122 pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> {
123 let theta = match &config.rope_scaling {
125 None
126 | Some(GraniteRopeConfig {
127 rope_type: GraniteRopeType::Default,
128 ..
129 }) => calculate_default_inv_freq(config),
130 Some(rope_scaling) => {
131 let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
132 / rope_scaling.low_freq_factor;
133 let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
134 / rope_scaling.high_freq_factor;
135
136 calculate_default_inv_freq(config)
137 .into_iter()
138 .map(|freq| {
139 let wavelen = 2. * PI / freq;
140 if wavelen < high_freq_wavelen {
141 freq
142 } else if wavelen > low_freq_wavelen {
143 freq / rope_scaling.factor
144 } else {
145 let smooth = (rope_scaling.original_max_position_embeddings as f32
146 / wavelen
147 - rope_scaling.low_freq_factor)
148 / (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor);
149 (1. - smooth) * freq / rope_scaling.factor + smooth * freq
150 }
151 })
152 .collect::<Vec<_>>()
153 }
154 };
155
156 let theta = Tensor::new(theta, device)?;
157
158 let idx_theta = Tensor::arange(0, config.max_position_embeddings as u32, device)?
159 .to_dtype(DType::F32)?
160 .reshape((config.max_position_embeddings, 1))?
161 .matmul(&theta.reshape((1, theta.elem_count()))?)?;
162 let cos = idx_theta.cos()?.to_dtype(dtype)?;
163 let sin = idx_theta.sin()?.to_dtype(dtype)?;
164 Ok(Self {
165 masks: HashMap::new(),
166 use_kv_cache,
167 kvs: vec![None; config.num_hidden_layers],
168 device: device.clone(),
169 cos,
170 sin,
171 })
172 }
173
174 fn mask(&mut self, t: usize) -> Result<Tensor> {
175 if let Some(mask) = self.masks.get(&t) {
176 Ok(mask.clone())
177 } else {
178 let mask: Vec<_> = (0..t)
179 .flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
180 .collect();
181 let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
182 self.masks.insert(t, mask.clone());
183 Ok(mask)
184 }
185 }
186}
187
188#[derive(Debug, Clone)]
189struct CausalSelfAttention {
190 q_proj: Linear,
191 k_proj: Linear,
192 v_proj: Linear,
193 o_proj: Linear,
194 num_attention_heads: usize,
195 num_key_value_heads: usize,
196 head_dim: usize,
197 use_flash_attn: bool,
198 span: tracing::Span,
199 span_rot: tracing::Span,
200 max_position_embeddings: usize,
201}
202
203#[cfg(feature = "flash-attn")]
204fn flash_attn(
205 q: &Tensor,
206 k: &Tensor,
207 v: &Tensor,
208 softmax_scale: f32,
209 causal: bool,
210) -> Result<Tensor> {
211 candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
212}
213
214#[cfg(not(feature = "flash-attn"))]
215fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
216 unimplemented!("compile with '--features flash-attn'")
217}
218
219impl CausalSelfAttention {
220 fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
221 let _enter = self.span_rot.enter();
222 let (_b_sz, _, seq_len, _hidden_size) = x.dims4()?;
223 let cos = cache.cos.narrow(0, index_pos, seq_len)?;
224 let sin = cache.sin.narrow(0, index_pos, seq_len)?;
225 candle_nn::rotary_emb::rope(x, &cos, &sin)
226 }
227
228 fn forward(
229 &self,
230 x: &Tensor,
231 index_pos: usize,
232 block_idx: usize,
233 cache: &mut Cache,
234 ) -> Result<Tensor> {
235 let _enter = self.span.enter();
236 let (b_sz, seq_len, hidden_size) = x.dims3()?;
237 let q = self.q_proj.forward(x)?;
238 let k = self.k_proj.forward(x)?;
239 let v = self.v_proj.forward(x)?;
240
241 let q = q
242 .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
243 .transpose(1, 2)?
244 .contiguous()?;
245 let k = k
246 .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
247 .transpose(1, 2)?
248 .contiguous()?;
249 let mut v = v
250 .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
251 .transpose(1, 2)?;
252
253 let q = self.apply_rotary_emb(&q, index_pos, cache)?;
254 let mut k = self.apply_rotary_emb(&k, index_pos, cache)?;
255
256 if cache.use_kv_cache {
257 if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] {
258 k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
259 v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
260 let k_seq_len = k.dims()[1];
261 if k_seq_len > self.max_position_embeddings {
262 k = k
263 .narrow(
264 D::Minus1,
265 k_seq_len - self.max_position_embeddings,
266 self.max_position_embeddings,
267 )?
268 .contiguous()?
269 }
270 let v_seq_len = v.dims()[1];
271 if v_seq_len > 2 * self.max_position_embeddings {
272 v = v
273 .narrow(
274 D::Minus1,
275 v_seq_len - self.max_position_embeddings,
276 self.max_position_embeddings,
277 )?
278 .contiguous()?
279 }
280 }
281 cache.kvs[block_idx] = Some((k.clone(), v.clone()))
282 }
283
284 let k = self.repeat_kv(k)?;
285 let v = self.repeat_kv(v)?;
286
287 let y = if self.use_flash_attn {
288 let q = q.transpose(1, 2)?;
290 let k = k.transpose(1, 2)?;
291 let v = v.transpose(1, 2)?;
292 let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
293 flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2)?
294 } else {
295 let in_dtype = q.dtype();
296 let q = q.to_dtype(DType::F32)?;
297 let k = k.to_dtype(DType::F32)?;
298 let v = v.to_dtype(DType::F32)?;
299 let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
300 let att = if seq_len == 1 {
301 att
302 } else {
303 let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
304 masked_fill(&att, &mask, f32::NEG_INFINITY)?
305 };
306 let att = candle_nn::ops::softmax(&att, D::Minus1)?;
307 att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
309 };
310 let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
311 let y = self.o_proj.forward(&y)?;
312 Ok(y)
313 }
314
315 fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
316 crate::utils::repeat_kv(x, self.num_attention_heads / self.num_key_value_heads)
317 }
318
319 fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
320 let span = tracing::span!(tracing::Level::TRACE, "attn");
321 let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
322 let size_in = cfg.hidden_size;
323 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
324 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
325 let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?;
326 let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?;
327 let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?;
328 let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?;
329 Ok(Self {
330 q_proj,
331 k_proj,
332 v_proj,
333 o_proj,
334 num_attention_heads: cfg.num_attention_heads,
335 num_key_value_heads: cfg.num_key_value_heads,
336 head_dim: cfg.hidden_size / cfg.num_attention_heads,
337 use_flash_attn: cfg.use_flash_attn,
338 span,
339 span_rot,
340 max_position_embeddings: cfg.max_position_embeddings,
341 })
342 }
343}
344
345fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
346 let shape = mask.shape();
347 let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
348 let m = mask.where_cond(&on_true, on_false)?;
349 Ok(m)
350}
351
352#[derive(Debug, Clone)]
353struct Mlp {
354 c_fc1: Linear,
355 c_fc2: Linear,
356 c_proj: Linear,
357 span: tracing::Span,
358}
359
360impl Mlp {
361 fn forward(&self, x: &Tensor) -> Result<Tensor> {
362 let _enter = self.span.enter();
363 let x = (candle_nn::ops::silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
364 self.c_proj.forward(&x)
365 }
366
367 fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
368 let span = tracing::span!(tracing::Level::TRACE, "mlp");
369 let h_size = cfg.hidden_size;
370 let i_size = cfg.intermediate_size;
371 let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
372 let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?;
373 let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?;
374 Ok(Self {
375 c_fc1,
376 c_fc2,
377 c_proj,
378 span,
379 })
380 }
381}
382
383#[derive(Debug, Clone)]
384struct Block {
385 rms_1: RmsNorm,
386 attn: CausalSelfAttention,
387 rms_2: RmsNorm,
388 mlp: Mlp,
389 span: tracing::Span,
390}
391
392impl Block {
393 fn forward(
394 &self,
395 x: &Tensor,
396 index_pos: usize,
397 block_idx: usize,
398 cache: &mut Cache,
399 ) -> Result<Tensor> {
400 let _enter = self.span.enter();
401 let residual = x;
402 let x = self.rms_1.forward(x)?;
403 let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?;
404 let residual = &x;
405 let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
406 Ok(x)
407 }
408
409 fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
410 let span = tracing::span!(tracing::Level::TRACE, "block");
411 let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?;
412 let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
413 let rms_1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
414 let rms_2 = RmsNorm::new(
415 cfg.hidden_size,
416 cfg.rms_norm_eps,
417 vb.pp("post_attention_layernorm"),
418 )?;
419 Ok(Self {
420 rms_1,
421 attn,
422 rms_2,
423 mlp,
424 span,
425 })
426 }
427}
428
429#[derive(Debug, Clone)]
430pub struct Granite {
431 wte: Embedding,
432 blocks: Vec<Block>,
433 ln_f: RmsNorm,
434 lm_head: Linear,
435}
436
437impl Granite {
438 pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
439 let (_b_sz, seq_len) = x.dims2()?;
440 let mut x = self.wte.forward(x)?;
441 for (block_idx, block) in self.blocks.iter().enumerate() {
442 x = block.forward(&x, index_pos, block_idx, cache)?;
443 }
444 let x = self.ln_f.forward(&x)?;
445 let x = x.i((.., seq_len - 1, ..))?.contiguous()?;
446 let logits = self.lm_head.forward(&x)?;
447 logits.to_dtype(DType::F32)
448 }
449
450 pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
451 let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
452 let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
453 let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
454 let blocks: Vec<_> = (0..cfg.num_hidden_layers)
455 .map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cfg).unwrap())
456 .collect();
457
458 Ok(Self {
459 wte,
460 blocks,
461 ln_f,
462 lm_head,
463 })
464 }
465}