1use crate::models::with_tracing::{linear, Embedding as E, Linear};
9use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
13use candle_nn::{Activation, VarBuilder};
14use serde::Deserialize;
15
16const MAX_SEQ_LEN: usize = 4096;
17
18#[derive(Debug, Clone, PartialEq, Deserialize)]
20pub struct Config {
21 pub(crate) vocab_size: usize,
22 pub(crate) n_positions: usize,
23 pub(crate) n_embd: usize,
24 pub(crate) n_layer: usize,
25 pub(crate) n_inner: Option<usize>,
26 pub(crate) n_head: usize,
27 pub(crate) rotary_dim: usize,
28 pub(crate) activation_function: Activation,
29 pub(crate) layer_norm_epsilon: f64,
30 pub(crate) tie_word_embeddings: bool,
31 pub(crate) pad_vocab_size_multiple: usize,
32}
33
34impl Config {
35 pub fn v1() -> Self {
36 Self {
37 vocab_size: 50304,
38 n_positions: 2048,
39 n_embd: 1024,
40 n_layer: 20,
41 n_inner: None,
42 n_head: 16,
43 rotary_dim: usize::min(32, 1024 / 16),
44 activation_function: Activation::Gelu,
45 layer_norm_epsilon: 1e-5,
46 tie_word_embeddings: false,
47 pad_vocab_size_multiple: 64,
48 }
49 }
50
51 pub fn v1_5() -> Self {
52 Self {
53 vocab_size: 51200,
54 n_positions: 2048,
55 n_embd: 2048,
56 n_layer: 24,
57 n_inner: None,
58 n_head: 32,
59 rotary_dim: usize::min(32, 2048 / 32),
60 activation_function: Activation::Gelu,
61 layer_norm_epsilon: 1e-5,
62 tie_word_embeddings: false,
63 pad_vocab_size_multiple: 64,
64 }
65 }
66
67 pub fn v2() -> Self {
68 Self {
69 vocab_size: 51200,
70 n_positions: 2048,
71 n_embd: 2560,
72 n_layer: 32,
73 n_inner: None,
74 n_head: 32,
75 rotary_dim: usize::min(32, 2560 / 32),
76 activation_function: Activation::Gelu,
77 layer_norm_epsilon: 1e-5,
78 tie_word_embeddings: false,
79 pad_vocab_size_multiple: 64,
80 }
81 }
82
83 pub fn puffin_phi_v2() -> Self {
85 Self {
86 vocab_size: 50304,
87 n_positions: 2048,
88 n_embd: 2048,
89 n_layer: 24,
90 n_inner: None,
91 n_head: 32,
92 rotary_dim: usize::min(32, 2048 / 32),
93 activation_function: Activation::Gelu,
94 layer_norm_epsilon: 1e-5,
95 tie_word_embeddings: false,
96 pad_vocab_size_multiple: 64,
97 }
98 }
99
100 pub fn phi_hermes_1_3b() -> Self {
102 Self {
103 vocab_size: 50304,
104 n_positions: 2048,
105 n_embd: 2048,
106 n_layer: 24,
107 n_inner: None,
108 n_head: 32,
109 rotary_dim: usize::min(32, 2048 / 32),
110 activation_function: Activation::NewGelu,
111 layer_norm_epsilon: 1e-5,
112 tie_word_embeddings: false,
113 pad_vocab_size_multiple: 64,
114 }
115 }
116}
117
118#[derive(Debug, Clone)]
119struct Embedding {
120 wte: E,
121}
122
123impl Embedding {
124 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
125 let wte = E::new(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?;
126 Ok(Self { wte })
127 }
128}
129
130impl Module for Embedding {
131 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
132 self.wte.forward(xs)
133 }
134}
135
136fn get_mask(size: usize, dtype: DType, device: &Device) -> Result<Tensor> {
137 let mask: Vec<_> = (0..size)
138 .flat_map(|i| (0..size).map(move |j| if j > i { f32::NEG_INFINITY } else { 0. }))
139 .collect();
140 Tensor::from_slice(&mask, (size, size), device)?.to_dtype(dtype)
141}
142
143#[derive(Debug, Clone)]
144struct RotaryEmbedding {
145 sin: Tensor,
146 cos: Tensor,
147}
148
149impl RotaryEmbedding {
150 fn new(dim: usize, max_seq_len: usize, dtype: DType, dev: &Device) -> Result<Self> {
151 let inv_freq: Vec<_> = (0..dim)
152 .step_by(2)
153 .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
154 .collect();
155 let inv_freq_len = inv_freq.len();
156 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
157 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
158 .to_dtype(DType::F32)?
159 .reshape((max_seq_len, 1))?;
160 let freqs = t.matmul(&inv_freq)?;
161 Ok(Self {
162 sin: freqs.sin()?.to_dtype(dtype)?,
163 cos: freqs.cos()?.to_dtype(dtype)?,
164 })
165 }
166
167 fn apply_rotary_emb_qkv(
168 &self,
169 qkv: &Tensor,
170 seqlen_offset: usize,
171 ) -> Result<(Tensor, Tensor, Tensor)> {
172 let (_b_size, seqlen, three, _, _headdim) = qkv.dims5()?;
173 if three != 3 {
174 candle::bail!("unexpected shape for qkv {:?}", qkv.shape())
175 }
176 let (_rotary_seqlen, rotary_dim) = self.cos.dims2()?;
177 let rotary_dim = rotary_dim * 2;
178 let q_rot = qkv.i((.., .., 0, .., ..rotary_dim))?.contiguous()?;
179 let q_pass = qkv.i((.., .., 0, .., rotary_dim..))?;
180 let k_rot = qkv.i((.., .., 1, .., ..rotary_dim))?.contiguous()?;
181 let k_pass = qkv.i((.., .., 1, .., rotary_dim..))?;
182 let c = self.cos.narrow(0, seqlen_offset, seqlen)?;
183 let s = self.sin.narrow(0, seqlen_offset, seqlen)?;
184 let q_rot = candle_nn::rotary_emb::rope_thd(&q_rot, &c, &s)?;
185 let k_rot = candle_nn::rotary_emb::rope_thd(&k_rot, &c, &s)?;
186 let q = Tensor::cat(&[&q_rot, &q_pass], D::Minus1)?;
187 let k = Tensor::cat(&[&k_rot, &k_pass], D::Minus1)?;
188 let v = qkv.i((.., .., 2))?;
189 Ok((q, k, v))
190 }
191}
192
193#[derive(Debug, Clone)]
194#[allow(clippy::upper_case_acronyms)]
195struct MLP {
196 fc1: Linear,
197 fc2: Linear,
198 act: Activation,
199 span: tracing::Span,
200}
201
202impl MLP {
203 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
204 let n_inner = cfg.n_inner.unwrap_or(4 * cfg.n_embd);
205 let fc1 = linear(cfg.n_embd, n_inner, vb.pp("fc1"))?;
206 let fc2 = linear(n_inner, cfg.n_embd, vb.pp("fc2"))?;
207 Ok(Self {
208 fc1,
209 fc2,
210 act: cfg.activation_function,
211 span: tracing::span!(tracing::Level::TRACE, "mlp"),
212 })
213 }
214}
215
216impl Module for MLP {
217 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
218 let _enter = self.span.enter();
219 xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
220 }
221}
222
223#[derive(Debug, Clone)]
224struct CausalLMHead {
225 ln: candle_nn::LayerNorm,
226 linear: Linear,
227}
228
229impl CausalLMHead {
230 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
231 let ln = candle_nn::layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?;
232 let linear = linear(cfg.n_embd, cfg.vocab_size, vb.pp("linear"))?;
233 Ok(Self { ln, linear })
234 }
235}
236
237impl Module for CausalLMHead {
238 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
239 xs.apply(&self.ln)?
240 .apply(&self.linear)?
241 .to_dtype(DType::F32)
242 }
243}
244
245#[derive(Debug, Clone)]
246#[allow(clippy::upper_case_acronyms)]
247struct MHA {
248 wqkv: Linear,
249 out_proj: Linear,
250 rotary_emb: RotaryEmbedding,
251 kv_cache: Option<(Tensor, Tensor)>,
252 head_dim: usize,
253 softmax_scale: f64,
254 span: tracing::Span,
255 span_rope: tracing::Span,
256 span_mask: tracing::Span,
257 span_softmax: tracing::Span,
258}
259
260impl MHA {
261 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
262 let head_dim = cfg.n_embd / cfg.n_head;
263 let op_size = cfg.n_embd;
264 let wqkv = linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?;
265 let out_proj = linear(op_size, cfg.n_embd, vb.pp("out_proj"))?;
266 let rotary_emb =
267 RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.dtype(), vb.device())?;
268 let softmax_scale = 1f64 / (head_dim as f64).sqrt();
269 Ok(Self {
270 wqkv,
271 out_proj,
272 head_dim,
273 kv_cache: None,
274 rotary_emb,
275 softmax_scale,
276 span: tracing::span!(tracing::Level::TRACE, "mha"),
277 span_rope: tracing::span!(tracing::Level::TRACE, "rope"),
278 span_mask: tracing::span!(tracing::Level::TRACE, "mask"),
279 span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
280 })
281 }
282
283 fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
284 let _enter = self.span.enter();
285 let (b_size, seq_len, _n_embd) = xs.dims3()?;
286 let qkv = self
287 .wqkv
288 .forward(xs)?
289 .reshape((b_size, seq_len, 3, (), self.head_dim))?;
290 let seqlen_offset = match &self.kv_cache {
291 None => 0,
292 Some((prev_k, _)) => prev_k.dim(1)?,
293 };
294 let (q, k, v) = {
296 let _enter = self.span_rope.enter();
297 self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?
298 };
299 let (k, v) = match &self.kv_cache {
300 None => (k, v),
301 Some((prev_k, prev_v)) => {
302 let k = Tensor::cat(&[prev_k, &k], 1)?;
303 let v = Tensor::cat(&[prev_v, &v], 1)?;
304 (k, v)
305 }
306 };
307 self.kv_cache = Some((k.clone(), v.clone()));
308 let q = q.transpose(1, 2)?.flatten_to(1)?; let k = k.transpose(1, 2)?.flatten_to(1)?; let v = v.transpose(1, 2)?.flatten_to(1)?; let attn_weights = (q.matmul(&k.t()?)? * self.softmax_scale)?; let attn_weights = match mask {
317 None => attn_weights,
318 Some(mask) => {
319 let _enter = self.span_mask.enter();
320 attn_weights.broadcast_add(mask)?
321 }
322 };
323 let attn_weights = {
324 let _enter = self.span_softmax.enter();
325 candle_nn::ops::softmax_last_dim(&attn_weights)?
326 };
327
328 let attn_output = attn_weights.matmul(&v)?;
331 let attn_output = attn_output
333 .reshape((b_size, (), seq_len, self.head_dim))?
334 .transpose(1, 2)?
335 .flatten_from(D::Minus2)?;
336 attn_output.apply(&self.out_proj)
337 }
338
339 fn clear_kv_cache(&mut self) {
340 self.kv_cache = None
341 }
342}
343
344#[derive(Debug, Clone)]
345struct ParallelBlock {
346 ln: candle_nn::LayerNorm,
347 mixer: MHA,
348 mlp: MLP,
349 span: tracing::Span,
350}
351
352impl ParallelBlock {
353 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
354 let ln = candle_nn::layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?;
355 let mixer = MHA::new(cfg, vb.pp("mixer"))?;
356 let mlp = MLP::new(cfg, vb.pp("mlp"))?;
357 Ok(Self {
358 ln,
359 mixer,
360 mlp,
361 span: tracing::span!(tracing::Level::TRACE, "block"),
362 })
363 }
364
365 fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
366 let _enter = self.span.enter();
367 let residual = xs;
368 let xs = xs.apply(&self.ln)?;
369 let attn_outputs = self.mixer.forward(&xs, mask)?;
370 let feed_forward_hidden_states = self.mlp.forward(&xs)?;
371 attn_outputs + feed_forward_hidden_states + residual
372 }
373
374 fn clear_kv_cache(&mut self) {
375 self.mixer.clear_kv_cache()
376 }
377}
378
379#[derive(Debug, Clone)]
380pub struct MixFormerSequentialForCausalLM {
381 embedding: Embedding,
382 blocks: Vec<ParallelBlock>,
383 head: CausalLMHead,
384 span: tracing::Span,
385}
386
387impl MixFormerSequentialForCausalLM {
388 pub fn new_v2(cfg: &Config, vb: VarBuilder) -> Result<Self> {
389 let vb_head = vb.pp("lm_head");
390 let vb = vb.pp("transformer");
391 let embedding = Embedding::new(cfg, vb.pp("embd"))?;
392 let mut blocks = Vec::new();
393 for i in 0..cfg.n_layer {
394 let block = ParallelBlock::new(cfg, vb.pp("h").pp(i))?;
395 blocks.push(block)
396 }
397 let head = CausalLMHead::new(cfg, vb_head)?;
398 Ok(Self {
399 embedding,
400 blocks,
401 head,
402 span: tracing::span!(tracing::Level::TRACE, "mixformer"),
403 })
404 }
405
406 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
407 let vb = vb.pp("layers");
408 let embedding = Embedding::new(cfg, vb.pp(0))?;
409 let mut blocks = Vec::new();
410 for i in 0..cfg.n_layer {
411 let block = ParallelBlock::new(cfg, vb.pp(i + 1))?;
412 blocks.push(block)
413 }
414 let head = CausalLMHead::new(cfg, vb.pp(cfg.n_layer + 1))?;
415 Ok(Self {
416 embedding,
417 blocks,
418 head,
419 span: tracing::span!(tracing::Level::TRACE, "mixformer"),
420 })
421 }
422
423 pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
424 let _enter = self.span.enter();
425 let (_b_size, seq_len) = xs.dims2()?;
426 let mut xs = xs.apply(&self.embedding)?;
427 let mask = if seq_len <= 1 {
428 None
429 } else {
430 Some(get_mask(seq_len, xs.dtype(), xs.device())?)
431 };
432 for block in self.blocks.iter_mut() {
433 xs = block.forward(&xs, mask.as_ref())?
434 }
435 xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
436 }
437
438 pub fn forward_with_img(
439 &mut self,
440 bos_token: &Tensor,
441 xs: &Tensor,
442 img_embeds: &Tensor,
443 ) -> Result<Tensor> {
444 let _enter = self.span.enter();
445 let xs = xs.apply(&self.embedding)?;
446 let bos_token = bos_token.apply(&self.embedding)?;
447 let mut xs = Tensor::cat(&[bos_token, img_embeds.clone(), xs], 1)?;
450 let (_b_size, seq_len, _embds) = xs.dims3()?;
451 let mask = Some(get_mask(seq_len, xs.dtype(), xs.device())?);
452 for block in self.blocks.iter_mut() {
453 xs = block.forward(&xs, mask.as_ref())?
454 }
455 let xs = xs
456 .narrow(1, seq_len - 1, 1)?
457 .apply(&self.head)?
458 .squeeze(1)?;
459 Ok(xs)
460 }
461
462 pub fn clear_kv_cache(&mut self) {
463 self.blocks.iter_mut().for_each(|b| b.clear_kv_cache())
464 }
465}