1use candle::{DType, Device, Module, Result, Tensor, D};
18use candle_nn::{linear_b, linear_no_bias, Activation, LayerNorm, Linear, VarBuilder};
19use std::sync::Arc;
20
21#[derive(Debug, Clone, serde::Deserialize)]
22pub struct Config {
23 pub vocab_size: usize,
24 pub hidden_size: usize,
25 pub intermediate_size: usize,
26 pub attention_bias: bool,
27 pub num_hidden_layers: usize,
28 pub num_attention_heads: usize,
29 pub num_key_value_heads: usize,
30 pub hidden_act: candle_nn::Activation,
31 pub max_position_embeddings: usize,
32 pub rope_theta: f64,
33 pub tie_word_embeddings: bool,
34 pub clip_qkv: Option<f64>,
35}
36
37#[derive(Debug, Clone)]
38struct RotaryEmbedding {
39 sin: Tensor,
40 cos: Tensor,
41}
42
43impl RotaryEmbedding {
44 fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
45 let dim = cfg.hidden_size / cfg.num_attention_heads;
46 let max_seq_len = cfg.max_position_embeddings;
47 let inv_freq: Vec<_> = (0..dim)
48 .step_by(2)
49 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
50 .collect();
51 let inv_freq_len = inv_freq.len();
52 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
53 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
54 .to_dtype(dtype)?
55 .reshape((max_seq_len, 1))?;
56 let freqs = t.matmul(&inv_freq)?;
57 Ok(Self {
58 sin: freqs.sin()?,
59 cos: freqs.cos()?,
60 })
61 }
62
63 fn apply_rotary_emb_qkv(
64 &self,
65 q: &Tensor,
66 k: &Tensor,
67 seqlen_offset: usize,
68 ) -> Result<(Tensor, Tensor)> {
69 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
70 let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
71 let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
72 let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
73 let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
74 Ok((q_embed, k_embed))
75 }
76}
77
78#[derive(Debug, Clone)]
79#[allow(clippy::upper_case_acronyms)]
80struct MLP {
81 gate_proj: Linear,
82 up_proj: Linear,
83 down_proj: Linear,
84 act_fn: Activation,
85}
86
87impl MLP {
88 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
89 let hidden_sz = cfg.hidden_size;
90 let intermediate_sz = cfg.intermediate_size;
91 let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
92 let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
93 let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
94 Ok(Self {
95 gate_proj,
96 up_proj,
97 down_proj,
98 act_fn: cfg.hidden_act,
99 })
100 }
101}
102
103impl Module for MLP {
104 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
105 let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
106 let rhs = xs.apply(&self.up_proj)?;
107 (lhs * rhs)?.apply(&self.down_proj)
108 }
109}
110
111#[derive(Debug, Clone)]
112struct Attention {
113 q_proj: Linear,
114 k_proj: Linear,
115 v_proj: Linear,
116 o_proj: Linear,
117 num_heads: usize,
118 num_kv_heads: usize,
119 num_kv_groups: usize,
120 head_dim: usize,
121 hidden_size: usize,
122 rotary_emb: Arc<RotaryEmbedding>,
123 qkv_clip: Option<f64>,
124 kv_cache: Option<(Tensor, Tensor)>,
125}
126
127impl Attention {
128 fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
129 let hidden_sz = cfg.hidden_size;
130 let num_heads = cfg.num_attention_heads;
131 let num_kv_heads = cfg.num_key_value_heads;
132 let num_kv_groups = num_heads / num_kv_heads;
133 let head_dim = hidden_sz / num_heads;
134 let b = cfg.attention_bias;
135 let qkv_clip = cfg.clip_qkv;
136 let q_proj = linear_b(hidden_sz, num_heads * head_dim, b, vb.pp("q_proj"))?;
137 let k_proj = linear_b(hidden_sz, num_kv_heads * head_dim, b, vb.pp("k_proj"))?;
138 let v_proj = linear_b(hidden_sz, num_kv_heads * head_dim, b, vb.pp("v_proj"))?;
139 let o_proj = linear_b(num_heads * head_dim, hidden_sz, b, vb.pp("o_proj"))?;
140 Ok(Self {
141 q_proj,
142 k_proj,
143 v_proj,
144 o_proj,
145 num_heads,
146 num_kv_heads,
147 num_kv_groups,
148 head_dim,
149 hidden_size: hidden_sz,
150 rotary_emb,
151 qkv_clip,
152 kv_cache: None,
153 })
154 }
155
156 fn forward(
157 &mut self,
158 xs: &Tensor,
159 attention_mask: Option<&Tensor>,
160 seqlen_offset: usize,
161 ) -> Result<Tensor> {
162 let (b_sz, q_len, _) = xs.dims3()?;
163
164 let query_states = self.q_proj.forward(xs)?;
165 let key_states = self.k_proj.forward(xs)?;
166 let value_states = self.v_proj.forward(xs)?;
167
168 let (query_states, key_states, value_states) = match &self.qkv_clip {
169 None => (query_states, key_states, value_states),
170 Some(qkv_clip) => {
171 let query_states = Tensor::clamp(&query_states, -qkv_clip, *qkv_clip)?;
172 let key_states = Tensor::clamp(&key_states, -qkv_clip, *qkv_clip)?;
173 let value_states = Tensor::clamp(&value_states, -qkv_clip, *qkv_clip)?;
174 (query_states, key_states, value_states)
175 }
176 };
177
178 let query_states = query_states
179 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
180 .transpose(1, 2)?;
181 let key_states = key_states
182 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
183 .transpose(1, 2)?;
184 let value_states = value_states
185 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
186 .transpose(1, 2)?;
187
188 let (query_states, key_states) =
189 self.rotary_emb
190 .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
191
192 let (key_states, value_states) = match &self.kv_cache {
193 None => (key_states, value_states),
194 Some((prev_k, prev_v)) => {
195 let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
196 let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
197 (key_states, value_states)
198 }
199 };
200 self.kv_cache = Some((key_states.clone(), value_states.clone()));
201
202 let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
203 let value_states =
204 crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
205
206 let attn_output = {
207 let scale = 1f64 / f64::sqrt(self.head_dim as f64);
208 let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
209
210 let attn_weights = match attention_mask {
211 None => attn_weights,
212 Some(mask) => attn_weights.broadcast_add(mask)?,
213 };
214 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
215 attn_weights.matmul(&value_states)?
216 };
217 attn_output
218 .transpose(1, 2)?
219 .reshape((b_sz, q_len, self.hidden_size))?
220 .apply(&self.o_proj)
221 }
222
223 fn clear_kv_cache(&mut self) {
224 self.kv_cache = None
225 }
226}
227
228#[derive(Debug, Clone)]
229struct DecoderLayer {
230 self_attn: Attention,
231 mlp: MLP,
232 input_layernorm: LayerNorm,
233 post_attention_layernorm: LayerNorm,
234}
235
236impl DecoderLayer {
237 fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
238 let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
239 let mlp = MLP::new(cfg, vb.pp("mlp"))?;
240 let ln_weight = Tensor::ones(cfg.hidden_size, vb.dtype(), vb.device())?;
241 let input_layernorm = LayerNorm::new_no_bias(ln_weight.clone(), 1e-5);
242 let post_attention_layernorm = LayerNorm::new_no_bias(ln_weight.clone(), 1e-5);
243 Ok(Self {
244 self_attn,
245 mlp,
246 input_layernorm,
247 post_attention_layernorm,
248 })
249 }
250
251 fn forward(
252 &mut self,
253 xs: &Tensor,
254 attention_mask: Option<&Tensor>,
255 seqlen_offset: usize,
256 ) -> Result<Tensor> {
257 let residual = xs;
258 let xs = self.input_layernorm.forward(xs)?;
259 let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
260 let xs = (xs + residual)?;
261 let residual = &xs;
262 let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
263 residual + xs
264 }
265
266 fn clear_kv_cache(&mut self) {
267 self.self_attn.clear_kv_cache()
268 }
269}
270
271#[derive(Debug, Clone)]
272pub struct Model {
273 embed_tokens: candle_nn::Embedding,
274 layers: Vec<DecoderLayer>,
275 norm: LayerNorm,
276 lm_head: Linear,
277 device: Device,
278 dtype: DType,
279}
280
281impl Model {
282 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
283 let vb_m = vb.pp("model");
284 let embed_tokens =
285 candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
286 let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
287 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
288 let vb_l = vb_m.pp("layers");
289 for layer_idx in 0..cfg.num_hidden_layers {
290 let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
291 layers.push(layer)
292 }
293 let ln_weight = Tensor::ones(cfg.hidden_size, vb.dtype(), vb.device())?;
294 let norm = LayerNorm::new_no_bias(ln_weight, 1e-5);
295 let lm_head = if cfg.tie_word_embeddings {
296 Linear::new(embed_tokens.embeddings().clone(), None)
297 } else {
298 linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
299 };
300 Ok(Self {
301 embed_tokens,
302 layers,
303 norm,
304 lm_head,
305 device: vb.device().clone(),
306 dtype: vb.dtype(),
307 })
308 }
309
310 fn prepare_decoder_attention_mask(
311 &self,
312 b_size: usize,
313 tgt_len: usize,
314 seqlen_offset: usize,
315 ) -> Result<Tensor> {
316 let mask: Vec<_> = (0..tgt_len)
318 .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
319 .collect();
320 let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
321 let mask = if seqlen_offset > 0 {
322 let mask0 = Tensor::zeros((tgt_len, seqlen_offset), self.dtype, &self.device)?;
323 Tensor::cat(&[&mask0, &mask], D::Minus1)?
324 } else {
325 mask
326 };
327 mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
328 .to_dtype(self.dtype)
329 }
330
331 pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
332 let (b_size, seq_len) = input_ids.dims2()?;
333 let attention_mask = if seq_len <= 1 {
334 None
335 } else {
336 let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
337 Some(mask)
338 };
339 let mut xs = self.embed_tokens.forward(input_ids)?;
340 for layer in self.layers.iter_mut() {
341 xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
342 }
343 xs.narrow(1, seq_len - 1, 1)?
344 .apply(&self.norm)?
345 .apply(&self.lm_head)
346 }
347
348 pub fn clear_kv_cache(&mut self) {
349 for layer in self.layers.iter_mut() {
350 layer.clear_kv_cache()
351 }
352 }
353}