1use crate::models::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
23use candle::{DType, Device, Module, Result, Tensor, D};
24use candle_nn::VarBuilder;
25use std::sync::Arc;
26
27#[derive(Debug, Clone, serde::Deserialize)]
29pub struct Config {
30 pub vocab_size: usize,
31 pub hidden_act: candle_nn::Activation,
32 pub hidden_size: usize,
33 pub intermediate_size: usize,
34 pub num_hidden_layers: usize,
35 pub num_attention_heads: usize,
36 pub num_key_value_heads: usize,
37 pub rms_norm_eps: f64,
38 pub rope_theta: f64,
39 pub bos_token_id: Option<u32>,
40 pub eos_token_id: Option<u32>,
41 pub rope_scaling: Option<String>,
42 pub max_position_embeddings: usize,
43}
44
45impl Config {
46 pub fn head_dim(&self) -> usize {
47 self.hidden_size / self.num_attention_heads
48 }
49}
50
51#[derive(Debug, Clone)]
52pub struct RotaryEmbedding {
53 sin: Tensor,
54 cos: Tensor,
55}
56
57impl RotaryEmbedding {
58 pub fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
59 let dim = cfg.head_dim();
60 let max_seq_len = cfg.max_position_embeddings;
61 let inv_freq: Vec<_> = (0..dim)
62 .step_by(2)
63 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
64 .collect();
65 let inv_freq_len = inv_freq.len();
66 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
67 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
68 .to_dtype(dtype)?
69 .reshape((max_seq_len, 1))?;
70 let freqs = t.matmul(&inv_freq)?;
71 Ok(Self {
72 sin: freqs.sin()?,
73 cos: freqs.cos()?,
74 })
75 }
76
77 pub fn apply_rotary_emb_qkv(
78 &self,
79 q: &Tensor,
80 k: &Tensor,
81 seqlen_offset: usize,
82 ) -> Result<(Tensor, Tensor)> {
83 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
84 let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
85 let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
86 let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
87 let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
88 Ok((q_embed, k_embed))
89 }
90}
91
92#[derive(Debug, Clone)]
93struct Attention {
94 qkv_proj: Linear,
95 o_proj: Linear,
96 num_heads: usize,
97 num_kv_heads: usize,
98 num_kv_groups: usize,
99 head_dim: usize,
100 rotary_emb: Arc<RotaryEmbedding>,
101 kv_cache: Option<(Tensor, Tensor)>,
102}
103
104impl Attention {
105 fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
106 let num_heads = cfg.num_attention_heads;
107 let num_kv_heads = cfg.num_key_value_heads;
108 let head_dim = cfg.head_dim();
109 let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim;
110 let qkv_proj = linear(cfg.hidden_size, op_size, vb.pp("qkv_proj"))?;
111 let o_proj = linear(num_heads * head_dim, cfg.hidden_size, vb.pp("o_proj"))?;
112 Ok(Self {
113 qkv_proj,
114 o_proj,
115 rotary_emb,
116 kv_cache: None,
117 num_heads,
118 num_kv_heads,
119 num_kv_groups: num_heads / num_kv_heads,
120 head_dim,
121 })
122 }
123
124 fn forward(
125 &mut self,
126 xs: &Tensor,
127 attention_mask: Option<&Tensor>,
128 seqlen_offset: usize,
129 ) -> Result<Tensor> {
130 let (b_sz, q_len, _) = xs.dims3()?;
131
132 let qkv = self.qkv_proj.forward(xs)?;
133 let query_pos = self.num_heads * self.head_dim;
134 let query_states = qkv.narrow(D::Minus1, 0, query_pos)?;
135 let key_states = qkv.narrow(D::Minus1, query_pos, self.num_kv_heads * self.head_dim)?;
136 let value_states = qkv.narrow(
137 D::Minus1,
138 query_pos + self.num_kv_heads * self.head_dim,
139 self.num_kv_heads * self.head_dim,
140 )?;
141
142 let query_states = query_states
143 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
144 .transpose(1, 2)?;
145 let key_states = key_states
146 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
147 .transpose(1, 2)?;
148 let value_states = value_states
149 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
150 .transpose(1, 2)?;
151
152 let (query_states, key_states) =
153 self.rotary_emb
154 .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
155
156 let (key_states, value_states) = match &self.kv_cache {
157 None => (key_states, value_states),
158 Some((prev_k, prev_v)) => {
159 let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
160 let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
161 (key_states, value_states)
162 }
163 };
164 self.kv_cache = Some((key_states.clone(), value_states.clone()));
165
166 let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
167 let value_states =
168 crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
169
170 let attn_output = {
171 let scale = 1f64 / f64::sqrt(self.head_dim as f64);
172 let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
173
174 let attn_weights = match attention_mask {
175 None => attn_weights,
176 Some(mask) => attn_weights.broadcast_add(mask)?,
177 };
178 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
179 attn_weights.matmul(&value_states)?
180 };
181 attn_output
182 .transpose(1, 2)?
183 .reshape((b_sz, q_len, ()))?
184 .apply(&self.o_proj)
185 }
186
187 fn clear_kv_cache(&mut self) {
188 self.kv_cache = None
189 }
190}
191
192#[derive(Debug, Clone)]
193struct Mlp {
194 gate_up_proj: Linear,
195 down_proj: Linear,
196 act_fn: candle_nn::Activation,
197 i_size: usize,
198}
199
200impl Mlp {
201 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
202 let hidden_size = cfg.hidden_size;
203 let i_size = cfg.intermediate_size;
204 let gate_up_proj = linear(hidden_size, 2 * i_size, vb.pp("gate_up_proj"))?;
205 let down_proj = linear(i_size, hidden_size, vb.pp("down_proj"))?;
206 Ok(Self {
207 gate_up_proj,
208 down_proj,
209 act_fn: cfg.hidden_act,
210 i_size,
211 })
212 }
213}
214
215impl Module for Mlp {
216 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
217 let up_states = xs.apply(&self.gate_up_proj)?;
218 let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
219 let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
220 let up_states = (up_states * gate.apply(&self.act_fn))?;
221 up_states.apply(&self.down_proj)
222 }
223}
224
225#[derive(Debug, Clone)]
226struct DecoderLayer {
227 self_attn: Attention,
228 mlp: Mlp,
229 input_layernorm: RmsNorm,
230 post_attention_layernorm: RmsNorm,
231}
232
233impl DecoderLayer {
234 fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
235 let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
236 let mlp = Mlp::new(cfg, vb.pp("mlp"))?;
237 let input_layernorm =
238 RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
239 let post_attention_layernorm = RmsNorm::new(
240 cfg.hidden_size,
241 cfg.rms_norm_eps,
242 vb.pp("post_attention_layernorm"),
243 )?;
244 Ok(Self {
245 self_attn,
246 mlp,
247 input_layernorm,
248 post_attention_layernorm,
249 })
250 }
251
252 fn forward(
253 &mut self,
254 xs: &Tensor,
255 attention_mask: Option<&Tensor>,
256 seqlen_offset: usize,
257 ) -> Result<Tensor> {
258 let residual = xs;
259 let xs = self.input_layernorm.forward(xs)?;
260 let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
261 let xs = (xs + residual)?;
262 let residual = &xs;
263 let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
264 residual + xs
265 }
266
267 fn clear_kv_cache(&mut self) {
268 self.self_attn.clear_kv_cache()
269 }
270}
271
272#[derive(Debug, Clone)]
273pub struct Model {
274 embed_tokens: candle_nn::Embedding,
275 layers: Vec<DecoderLayer>,
276 norm: RmsNorm,
277 lm_head: Linear,
278 device: Device,
279 dtype: DType,
280}
281
282impl Model {
283 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
284 let vb_m = vb.pp("model");
285 let embed_tokens =
286 candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
287 let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
288 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
289 let vb_l = vb_m.pp("layers");
290 for layer_idx in 0..cfg.num_hidden_layers {
291 let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
292 layers.push(layer)
293 }
294 let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
295 let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
296 Ok(Self {
297 embed_tokens,
298 layers,
299 norm,
300 lm_head,
301 device: vb.device().clone(),
302 dtype: vb.dtype(),
303 })
304 }
305
306 fn prepare_decoder_attention_mask(
307 &self,
308 b_size: usize,
309 tgt_len: usize,
310 seqlen_offset: usize,
311 ) -> Result<Tensor> {
312 let mask: Vec<_> = (0..tgt_len)
313 .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
314 .collect();
315 let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
316 let mask = if seqlen_offset > 0 {
317 let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
318 Tensor::cat(&[&mask0, &mask], D::Minus1)?
319 } else {
320 mask
321 };
322 mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
323 .to_dtype(self.dtype)
324 }
325
326 pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
327 let (b_size, seq_len) = input_ids.dims2()?;
328 let attention_mask = if seq_len <= 1 {
329 None
330 } else {
331 let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
332 Some(mask)
333 };
334 let mut xs = self.embed_tokens.forward(input_ids)?;
335 for layer in self.layers.iter_mut() {
336 xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
337 }
338 xs.narrow(1, seq_len - 1, 1)?
339 .apply(&self.norm)?
340 .apply(&self.lm_head)
341 }
342
343 pub fn clear_kv_cache(&mut self) {
344 for layer in self.layers.iter_mut() {
345 layer.clear_kv_cache()
346 }
347 }
348}