1use crate::models::with_tracing::{linear_no_bias, Linear, RmsNorm};
19use candle::{DType, Device, Module, Result, Tensor, D};
20use candle_nn::{Activation, VarBuilder};
21use std::sync::Arc;
22
23#[derive(Debug, Clone, PartialEq)]
24pub struct Config {
25 pub(crate) vocab_size: usize,
26 pub(crate) hidden_size: usize,
27 pub(crate) intermediate_size: usize,
28 pub(crate) num_hidden_layers: usize,
29 pub(crate) num_attention_heads: usize,
30 pub(crate) num_key_value_heads: usize,
31 pub(crate) hidden_act: Activation,
32 pub(crate) max_position_embeddings: usize,
33 pub(crate) rms_norm_eps: f64,
34 pub(crate) rope_theta: f64,
35}
36
37impl Config {
38 pub fn config_6b() -> Self {
39 Self {
40 vocab_size: 64000,
41 hidden_size: 4096,
42 intermediate_size: 11008,
43 num_hidden_layers: 32,
44 num_attention_heads: 32,
45 num_key_value_heads: 4,
46 hidden_act: Activation::Silu,
47 max_position_embeddings: 4096,
48 rms_norm_eps: 1e-5,
49 rope_theta: 5_000_000.,
50 }
51 }
52
53 pub fn config_34b() -> Self {
54 Self {
55 vocab_size: 64000,
56 hidden_size: 7168,
57 intermediate_size: 20480,
58 num_hidden_layers: 60,
59 num_attention_heads: 56,
60 num_key_value_heads: 8,
61 hidden_act: Activation::Silu,
62 max_position_embeddings: 4096,
63 rms_norm_eps: 1e-5,
64 rope_theta: 5_000_000.,
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
70struct RotaryEmbedding {
71 sin: Tensor,
72 cos: Tensor,
73}
74
75fn rotate_half(xs: &Tensor) -> Result<Tensor> {
76 let last_dim = xs.dim(D::Minus1)?;
77 let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
78 let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
79 Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
80}
81
82impl RotaryEmbedding {
83 fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
84 let dim = cfg.hidden_size / cfg.num_attention_heads;
85 let max_seq_len = cfg.max_position_embeddings;
86 let inv_freq: Vec<_> = (0..dim)
87 .step_by(2)
88 .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
89 .collect();
90 let inv_freq_len = inv_freq.len();
91 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
92 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
93 .to_dtype(dtype)?
94 .reshape((max_seq_len, 1))?;
95 let freqs = t.matmul(&inv_freq)?;
96 let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
97 Ok(Self {
98 sin: freqs.sin()?,
99 cos: freqs.cos()?,
100 })
101 }
102
103 fn apply_rotary_emb_qkv(
104 &self,
105 q: &Tensor,
106 k: &Tensor,
107 seqlen_offset: usize,
108 ) -> Result<(Tensor, Tensor)> {
109 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
110 let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
111 let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
112 let cos = cos.unsqueeze(0)?.unsqueeze(0)?; let sin = sin.unsqueeze(0)?.unsqueeze(0)?; let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
115 let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
116 Ok((q_embed, k_embed))
117 }
118}
119
120#[derive(Debug, Clone)]
121#[allow(clippy::upper_case_acronyms)]
122struct MLP {
123 gate_proj: Linear,
124 up_proj: Linear,
125 down_proj: Linear,
126 act_fn: Activation,
127}
128
129impl MLP {
130 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
131 let hidden_sz = cfg.hidden_size;
132 let intermediate_sz = cfg.intermediate_size;
133 let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
134 let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
135 let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
136 Ok(Self {
137 gate_proj,
138 up_proj,
139 down_proj,
140 act_fn: cfg.hidden_act,
141 })
142 }
143}
144
145impl Module for MLP {
146 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
147 let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
148 let rhs = xs.apply(&self.up_proj)?;
149 (lhs * rhs)?.apply(&self.down_proj)
150 }
151}
152
153#[derive(Debug, Clone)]
154struct Attention {
155 q_proj: Linear,
156 k_proj: Linear,
157 v_proj: Linear,
158 o_proj: Linear,
159 num_heads: usize,
160 num_kv_heads: usize,
161 num_kv_groups: usize,
162 head_dim: usize,
163 hidden_size: usize,
164 rotary_emb: Arc<RotaryEmbedding>,
165 kv_cache: Option<(Tensor, Tensor)>,
166}
167
168impl Attention {
169 fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
170 let hidden_sz = cfg.hidden_size;
171 let num_heads = cfg.num_attention_heads;
172 let num_kv_heads = cfg.num_key_value_heads;
173 let num_kv_groups = num_heads / num_kv_heads;
174 let head_dim = hidden_sz / num_heads;
175 let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
176 let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
177 let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
178 let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
179 Ok(Self {
180 q_proj,
181 k_proj,
182 v_proj,
183 o_proj,
184 num_heads,
185 num_kv_heads,
186 num_kv_groups,
187 head_dim,
188 hidden_size: hidden_sz,
189 rotary_emb,
190 kv_cache: None,
191 })
192 }
193
194 fn forward(
195 &mut self,
196 xs: &Tensor,
197 attention_mask: Option<&Tensor>,
198 seqlen_offset: usize,
199 ) -> Result<Tensor> {
200 let (b_sz, q_len, _) = xs.dims3()?;
201
202 let query_states = self.q_proj.forward(xs)?;
203 let key_states = self.k_proj.forward(xs)?;
204 let value_states = self.v_proj.forward(xs)?;
205
206 let query_states = query_states
207 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
208 .transpose(1, 2)?;
209 let key_states = key_states
210 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
211 .transpose(1, 2)?;
212 let value_states = value_states
213 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
214 .transpose(1, 2)?;
215
216 let (query_states, key_states) =
217 self.rotary_emb
218 .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
219
220 let (key_states, value_states) = match &self.kv_cache {
221 None => (key_states, value_states),
222 Some((prev_k, prev_v)) => {
223 let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
224 let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
225 (key_states, value_states)
226 }
227 };
228 self.kv_cache = Some((key_states.clone(), value_states.clone()));
229
230 let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
231 let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
232
233 let attn_output = {
234 let scale = 1f64 / f64::sqrt(self.head_dim as f64);
235 let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
236
237 let attn_weights = match attention_mask {
238 None => attn_weights,
239 Some(mask) => attn_weights.broadcast_add(mask)?,
240 };
241 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
242 attn_weights.matmul(&value_states)?
243 };
244 attn_output
245 .transpose(1, 2)?
246 .reshape((b_sz, q_len, self.hidden_size))?
247 .apply(&self.o_proj)
248 }
249}
250
251#[derive(Debug, Clone)]
252struct DecoderLayer {
253 self_attn: Attention,
254 mlp: MLP,
255 ln1: RmsNorm,
256 ln2: RmsNorm,
257}
258
259impl DecoderLayer {
260 fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
261 let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
262 let mlp = MLP::new(cfg, vb.pp("mlp"))?;
263 let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
264 let ln2 = RmsNorm::new(
265 cfg.hidden_size,
266 cfg.rms_norm_eps,
267 vb.pp("post_attention_layernorm"),
268 )?;
269 Ok(Self {
270 self_attn,
271 mlp,
272 ln1,
273 ln2,
274 })
275 }
276
277 fn forward(
278 &mut self,
279 xs: &Tensor,
280 attention_mask: Option<&Tensor>,
281 seqlen_offset: usize,
282 ) -> Result<Tensor> {
283 let residual = xs;
284 let xs = self.ln1.forward(xs)?;
285 let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
286 let xs = (xs + residual)?;
287 let residual = &xs;
288 let xs = xs.apply(&self.ln2)?.apply(&self.mlp)?;
289 residual + xs
290 }
291}
292
293#[derive(Debug, Clone)]
294pub struct Model {
295 embed_tokens: candle_nn::Embedding,
296 layers: Vec<DecoderLayer>,
297 norm: RmsNorm,
298 lm_head: Linear,
299 device: Device,
300 dtype: DType,
301}
302
303impl Model {
304 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
305 let vb_m = vb.pp("model");
306 let embed_tokens =
307 candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
308 let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
309 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
310 let vb_l = vb_m.pp("layers");
311 for layer_idx in 0..cfg.num_hidden_layers {
312 let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
313 layers.push(layer)
314 }
315 let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
316 let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
317 Ok(Self {
318 embed_tokens,
319 layers,
320 norm,
321 lm_head,
322 device: vb.device().clone(),
323 dtype: vb.dtype(),
324 })
325 }
326
327 fn prepare_decoder_attention_mask(
328 &self,
329 b_size: usize,
330 tgt_len: usize,
331 seqlen_offset: usize,
332 ) -> Result<Tensor> {
333 let mask: Vec<_> = (0..tgt_len)
335 .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
336 .collect();
337 let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
338 let mask = if seqlen_offset > 0 {
339 let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
340 Tensor::cat(&[&mask0, &mask], D::Minus1)?
341 } else {
342 mask
343 };
344 mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
345 .to_dtype(self.dtype)
346 }
347
348 pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
349 let (b_size, seq_len) = input_ids.dims2()?;
350 let attention_mask = if seq_len <= 1 {
351 None
352 } else {
353 let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
354 Some(mask)
355 };
356 let mut xs = self.embed_tokens.forward(input_ids)?;
357 for layer in self.layers.iter_mut() {
358 xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
359 }
360 xs.narrow(1, seq_len - 1, 1)?
361 .apply(&self.norm)?
362 .apply(&self.lm_head)
363 }
364}