candle_transformers/models/
mixformer.rs

1//! MixFormer (Microsoft's Phi Architecture)
2//!
3//! See "Textbooks Are All You Need II: phi-1.5 technical report", Lin et al. 2023
4//! - [Arxiv](https://arxiv.org/abs/2309.05463)
5//! - [Github](https://huggingface.co/microsoft/phi-1_5)
6//!
7
8use crate::models::with_tracing::{linear, Embedding as E, Linear};
9/// MixFormer model.
10/// https://huggingface.co/microsoft/phi-1_5
11/// https://arxiv.org/abs/2309.05463
12use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
13use candle_nn::{Activation, VarBuilder};
14use serde::Deserialize;
15
16const MAX_SEQ_LEN: usize = 4096;
17
18// https://huggingface.co/microsoft/phi-1_5/blob/d38e6f954ec29b96fe2cf033937dad64e279b5d9/configuration_mixformer_sequential.py
19#[derive(Debug, Clone, PartialEq, Deserialize)]
20pub struct Config {
21    pub(crate) vocab_size: usize,
22    pub(crate) n_positions: usize,
23    pub(crate) n_embd: usize,
24    pub(crate) n_layer: usize,
25    pub(crate) n_inner: Option<usize>,
26    pub(crate) n_head: usize,
27    pub(crate) rotary_dim: usize,
28    pub(crate) activation_function: Activation,
29    pub(crate) layer_norm_epsilon: f64,
30    pub(crate) tie_word_embeddings: bool,
31    pub(crate) pad_vocab_size_multiple: usize,
32}
33
34impl Config {
35    pub fn v1() -> Self {
36        Self {
37            vocab_size: 50304,
38            n_positions: 2048,
39            n_embd: 1024,
40            n_layer: 20,
41            n_inner: None,
42            n_head: 16,
43            rotary_dim: usize::min(32, 1024 / 16),
44            activation_function: Activation::Gelu,
45            layer_norm_epsilon: 1e-5,
46            tie_word_embeddings: false,
47            pad_vocab_size_multiple: 64,
48        }
49    }
50
51    pub fn v1_5() -> Self {
52        Self {
53            vocab_size: 51200,
54            n_positions: 2048,
55            n_embd: 2048,
56            n_layer: 24,
57            n_inner: None,
58            n_head: 32,
59            rotary_dim: usize::min(32, 2048 / 32),
60            activation_function: Activation::Gelu,
61            layer_norm_epsilon: 1e-5,
62            tie_word_embeddings: false,
63            pad_vocab_size_multiple: 64,
64        }
65    }
66
67    pub fn v2() -> Self {
68        Self {
69            vocab_size: 51200,
70            n_positions: 2048,
71            n_embd: 2560,
72            n_layer: 32,
73            n_inner: None,
74            n_head: 32,
75            rotary_dim: usize::min(32, 2560 / 32),
76            activation_function: Activation::Gelu,
77            layer_norm_epsilon: 1e-5,
78            tie_word_embeddings: false,
79            pad_vocab_size_multiple: 64,
80        }
81    }
82
83    // https://huggingface.co/teknium/Puffin-Phi-v2/blob/main/config.json
84    pub fn puffin_phi_v2() -> Self {
85        Self {
86            vocab_size: 50304,
87            n_positions: 2048,
88            n_embd: 2048,
89            n_layer: 24,
90            n_inner: None,
91            n_head: 32,
92            rotary_dim: usize::min(32, 2048 / 32),
93            activation_function: Activation::Gelu,
94            layer_norm_epsilon: 1e-5,
95            tie_word_embeddings: false,
96            pad_vocab_size_multiple: 64,
97        }
98    }
99
100    // https://huggingface.co/teknium/Phi-Hermes-1.3B/blob/main/config.json
101    pub fn phi_hermes_1_3b() -> Self {
102        Self {
103            vocab_size: 50304,
104            n_positions: 2048,
105            n_embd: 2048,
106            n_layer: 24,
107            n_inner: None,
108            n_head: 32,
109            rotary_dim: usize::min(32, 2048 / 32),
110            activation_function: Activation::NewGelu,
111            layer_norm_epsilon: 1e-5,
112            tie_word_embeddings: false,
113            pad_vocab_size_multiple: 64,
114        }
115    }
116}
117
118#[derive(Debug, Clone)]
119struct Embedding {
120    wte: E,
121}
122
123impl Embedding {
124    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
125        let wte = E::new(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?;
126        Ok(Self { wte })
127    }
128}
129
130impl Module for Embedding {
131    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
132        self.wte.forward(xs)
133    }
134}
135
136fn get_mask(size: usize, dtype: DType, device: &Device) -> Result<Tensor> {
137    let mask: Vec<_> = (0..size)
138        .flat_map(|i| (0..size).map(move |j| if j > i { f32::NEG_INFINITY } else { 0. }))
139        .collect();
140    Tensor::from_slice(&mask, (size, size), device)?.to_dtype(dtype)
141}
142
143#[derive(Debug, Clone)]
144struct RotaryEmbedding {
145    sin: Tensor,
146    cos: Tensor,
147}
148
149impl RotaryEmbedding {
150    fn new(dim: usize, max_seq_len: usize, dtype: DType, dev: &Device) -> Result<Self> {
151        let inv_freq: Vec<_> = (0..dim)
152            .step_by(2)
153            .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
154            .collect();
155        let inv_freq_len = inv_freq.len();
156        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
157        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
158            .to_dtype(DType::F32)?
159            .reshape((max_seq_len, 1))?;
160        let freqs = t.matmul(&inv_freq)?;
161        Ok(Self {
162            sin: freqs.sin()?.to_dtype(dtype)?,
163            cos: freqs.cos()?.to_dtype(dtype)?,
164        })
165    }
166
167    fn apply_rotary_emb_qkv(
168        &self,
169        qkv: &Tensor,
170        seqlen_offset: usize,
171    ) -> Result<(Tensor, Tensor, Tensor)> {
172        let (_b_size, seqlen, three, _, _headdim) = qkv.dims5()?;
173        if three != 3 {
174            candle::bail!("unexpected shape for qkv {:?}", qkv.shape())
175        }
176        let (_rotary_seqlen, rotary_dim) = self.cos.dims2()?;
177        let rotary_dim = rotary_dim * 2;
178        let q_rot = qkv.i((.., .., 0, .., ..rotary_dim))?.contiguous()?;
179        let q_pass = qkv.i((.., .., 0, .., rotary_dim..))?;
180        let k_rot = qkv.i((.., .., 1, .., ..rotary_dim))?.contiguous()?;
181        let k_pass = qkv.i((.., .., 1, .., rotary_dim..))?;
182        let c = self.cos.narrow(0, seqlen_offset, seqlen)?;
183        let s = self.sin.narrow(0, seqlen_offset, seqlen)?;
184        let q_rot = candle_nn::rotary_emb::rope_thd(&q_rot, &c, &s)?;
185        let k_rot = candle_nn::rotary_emb::rope_thd(&k_rot, &c, &s)?;
186        let q = Tensor::cat(&[&q_rot, &q_pass], D::Minus1)?;
187        let k = Tensor::cat(&[&k_rot, &k_pass], D::Minus1)?;
188        let v = qkv.i((.., .., 2))?;
189        Ok((q, k, v))
190    }
191}
192
193#[derive(Debug, Clone)]
194#[allow(clippy::upper_case_acronyms)]
195struct MLP {
196    fc1: Linear,
197    fc2: Linear,
198    act: Activation,
199    span: tracing::Span,
200}
201
202impl MLP {
203    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
204        let n_inner = cfg.n_inner.unwrap_or(4 * cfg.n_embd);
205        let fc1 = linear(cfg.n_embd, n_inner, vb.pp("fc1"))?;
206        let fc2 = linear(n_inner, cfg.n_embd, vb.pp("fc2"))?;
207        Ok(Self {
208            fc1,
209            fc2,
210            act: cfg.activation_function,
211            span: tracing::span!(tracing::Level::TRACE, "mlp"),
212        })
213    }
214}
215
216impl Module for MLP {
217    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
218        let _enter = self.span.enter();
219        xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
220    }
221}
222
223#[derive(Debug, Clone)]
224struct CausalLMHead {
225    ln: candle_nn::LayerNorm,
226    linear: Linear,
227}
228
229impl CausalLMHead {
230    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
231        let ln = candle_nn::layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?;
232        let linear = linear(cfg.n_embd, cfg.vocab_size, vb.pp("linear"))?;
233        Ok(Self { ln, linear })
234    }
235}
236
237impl Module for CausalLMHead {
238    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
239        xs.apply(&self.ln)?
240            .apply(&self.linear)?
241            .to_dtype(DType::F32)
242    }
243}
244
245#[derive(Debug, Clone)]
246#[allow(clippy::upper_case_acronyms)]
247struct MHA {
248    wqkv: Linear,
249    out_proj: Linear,
250    rotary_emb: RotaryEmbedding,
251    kv_cache: Option<(Tensor, Tensor)>,
252    head_dim: usize,
253    softmax_scale: f64,
254    span: tracing::Span,
255    span_rope: tracing::Span,
256    span_mask: tracing::Span,
257    span_softmax: tracing::Span,
258}
259
260impl MHA {
261    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
262        let head_dim = cfg.n_embd / cfg.n_head;
263        let op_size = cfg.n_embd;
264        let wqkv = linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?;
265        let out_proj = linear(op_size, cfg.n_embd, vb.pp("out_proj"))?;
266        let rotary_emb =
267            RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.dtype(), vb.device())?;
268        let softmax_scale = 1f64 / (head_dim as f64).sqrt();
269        Ok(Self {
270            wqkv,
271            out_proj,
272            head_dim,
273            kv_cache: None,
274            rotary_emb,
275            softmax_scale,
276            span: tracing::span!(tracing::Level::TRACE, "mha"),
277            span_rope: tracing::span!(tracing::Level::TRACE, "rope"),
278            span_mask: tracing::span!(tracing::Level::TRACE, "mask"),
279            span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
280        })
281    }
282
283    fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
284        let _enter = self.span.enter();
285        let (b_size, seq_len, _n_embd) = xs.dims3()?;
286        let qkv = self
287            .wqkv
288            .forward(xs)?
289            .reshape((b_size, seq_len, 3, (), self.head_dim))?;
290        let seqlen_offset = match &self.kv_cache {
291            None => 0,
292            Some((prev_k, _)) => prev_k.dim(1)?,
293        };
294        // In the python implementation, a single tensor is returned with the third axis of size 3.
295        let (q, k, v) = {
296            let _enter = self.span_rope.enter();
297            self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?
298        };
299        let (k, v) = match &self.kv_cache {
300            None => (k, v),
301            Some((prev_k, prev_v)) => {
302                let k = Tensor::cat(&[prev_k, &k], 1)?;
303                let v = Tensor::cat(&[prev_v, &v], 1)?;
304                (k, v)
305            }
306        };
307        self.kv_cache = Some((k.clone(), v.clone()));
308        // scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
309        let q = q.transpose(1, 2)?.flatten_to(1)?; // b*h, t, d
310        let k = k.transpose(1, 2)?.flatten_to(1)?; // b*h, s, d
311        let v = v.transpose(1, 2)?.flatten_to(1)?; // b*h, s, d
312        let attn_weights = (q.matmul(&k.t()?)? * self.softmax_scale)?; // b*h, t, s
313
314        // causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, device=scores.device), 1)
315        // scores = scores + causal_mask.to(dtype=scores.dtype)
316        let attn_weights = match mask {
317            None => attn_weights,
318            Some(mask) => {
319                let _enter = self.span_mask.enter();
320                attn_weights.broadcast_add(mask)?
321            }
322        };
323        let attn_weights = {
324            let _enter = self.span_softmax.enter();
325            candle_nn::ops::softmax_last_dim(&attn_weights)?
326        };
327
328        // output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
329        // attn_weights: b*h,t,s, v: b*h,s,d
330        let attn_output = attn_weights.matmul(&v)?;
331        // b*h,t,d
332        let attn_output = attn_output
333            .reshape((b_size, (), seq_len, self.head_dim))?
334            .transpose(1, 2)?
335            .flatten_from(D::Minus2)?;
336        attn_output.apply(&self.out_proj)
337    }
338
339    fn clear_kv_cache(&mut self) {
340        self.kv_cache = None
341    }
342}
343
344#[derive(Debug, Clone)]
345struct ParallelBlock {
346    ln: candle_nn::LayerNorm,
347    mixer: MHA,
348    mlp: MLP,
349    span: tracing::Span,
350}
351
352impl ParallelBlock {
353    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
354        let ln = candle_nn::layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?;
355        let mixer = MHA::new(cfg, vb.pp("mixer"))?;
356        let mlp = MLP::new(cfg, vb.pp("mlp"))?;
357        Ok(Self {
358            ln,
359            mixer,
360            mlp,
361            span: tracing::span!(tracing::Level::TRACE, "block"),
362        })
363    }
364
365    fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
366        let _enter = self.span.enter();
367        let residual = xs;
368        let xs = xs.apply(&self.ln)?;
369        let attn_outputs = self.mixer.forward(&xs, mask)?;
370        let feed_forward_hidden_states = self.mlp.forward(&xs)?;
371        attn_outputs + feed_forward_hidden_states + residual
372    }
373
374    fn clear_kv_cache(&mut self) {
375        self.mixer.clear_kv_cache()
376    }
377}
378
379#[derive(Debug, Clone)]
380pub struct MixFormerSequentialForCausalLM {
381    embedding: Embedding,
382    blocks: Vec<ParallelBlock>,
383    head: CausalLMHead,
384    span: tracing::Span,
385}
386
387impl MixFormerSequentialForCausalLM {
388    pub fn new_v2(cfg: &Config, vb: VarBuilder) -> Result<Self> {
389        let vb_head = vb.pp("lm_head");
390        let vb = vb.pp("transformer");
391        let embedding = Embedding::new(cfg, vb.pp("embd"))?;
392        let mut blocks = Vec::new();
393        for i in 0..cfg.n_layer {
394            let block = ParallelBlock::new(cfg, vb.pp("h").pp(i))?;
395            blocks.push(block)
396        }
397        let head = CausalLMHead::new(cfg, vb_head)?;
398        Ok(Self {
399            embedding,
400            blocks,
401            head,
402            span: tracing::span!(tracing::Level::TRACE, "mixformer"),
403        })
404    }
405
406    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
407        let vb = vb.pp("layers");
408        let embedding = Embedding::new(cfg, vb.pp(0))?;
409        let mut blocks = Vec::new();
410        for i in 0..cfg.n_layer {
411            let block = ParallelBlock::new(cfg, vb.pp(i + 1))?;
412            blocks.push(block)
413        }
414        let head = CausalLMHead::new(cfg, vb.pp(cfg.n_layer + 1))?;
415        Ok(Self {
416            embedding,
417            blocks,
418            head,
419            span: tracing::span!(tracing::Level::TRACE, "mixformer"),
420        })
421    }
422
423    pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
424        let _enter = self.span.enter();
425        let (_b_size, seq_len) = xs.dims2()?;
426        let mut xs = xs.apply(&self.embedding)?;
427        let mask = if seq_len <= 1 {
428            None
429        } else {
430            Some(get_mask(seq_len, xs.dtype(), xs.device())?)
431        };
432        for block in self.blocks.iter_mut() {
433            xs = block.forward(&xs, mask.as_ref())?
434        }
435        xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
436    }
437
438    pub fn forward_with_img(
439        &mut self,
440        bos_token: &Tensor,
441        xs: &Tensor,
442        img_embeds: &Tensor,
443    ) -> Result<Tensor> {
444        let _enter = self.span.enter();
445        let xs = xs.apply(&self.embedding)?;
446        let bos_token = bos_token.apply(&self.embedding)?;
447        // Python implementation sequence order is <bos token embedding><img embedding><rest of text embedding>
448        // https://github.com/vikhyat/moondream/blob/a9d788a20d1543fb1479edc54106e88cff7759d3/moondream/moondream.py#L43-L56
449        let mut xs = Tensor::cat(&[bos_token, img_embeds.clone(), xs], 1)?;
450        let (_b_size, seq_len, _embds) = xs.dims3()?;
451        let mask = Some(get_mask(seq_len, xs.dtype(), xs.device())?);
452        for block in self.blocks.iter_mut() {
453            xs = block.forward(&xs, mask.as_ref())?
454        }
455        let xs = xs
456            .narrow(1, seq_len - 1, 1)?
457            .apply(&self.head)?
458            .squeeze(1)?;
459        Ok(xs)
460    }
461
462    pub fn clear_kv_cache(&mut self) {
463        self.blocks.iter_mut().for_each(|b| b.clear_kv_cache())
464    }
465}