candle_transformers/models/
quantized_mixformer.rs

1//! Module containing quantized MixFormer model implementation.
2//!
3//! MixFormer is an efficient transformer variant for text generation that uses
4//! mixture-of-experts and parallel attention/feed-forward blocks.
5//! This implementation provides quantization for reduced memory usage.
6//!
7//! Key features:
8//! - Parallel attention and feed-forward computation
9//! - Rotary positional embeddings
10//! - Optional key-value caching
11//! - Support for 8-bit quantization
12//!
13
14use crate::quantized_nn::{layer_norm, linear, Linear};
15pub use crate::quantized_var_builder::VarBuilder;
16use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
17use candle_nn::Activation;
18
19pub use crate::models::mixformer::Config;
20
21const MAX_SEQ_LEN: usize = 4096;
22
23#[derive(Debug, Clone)]
24struct Embedding {
25    wte: crate::quantized_nn::Embedding,
26}
27
28impl Embedding {
29    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
30        let wte = crate::quantized_nn::Embedding::new(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?;
31        Ok(Self { wte })
32    }
33}
34
35impl Module for Embedding {
36    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
37        self.wte.forward(xs)
38    }
39}
40
41fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
42    let mask: Vec<_> = (0..size)
43        .flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
44        .collect();
45    Tensor::from_slice(&mask, (size, size), device)
46}
47
48fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
49    let shape = mask.shape();
50    let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
51    let m = mask.where_cond(&on_true, on_false)?;
52    Ok(m)
53}
54
55#[derive(Debug, Clone)]
56struct RotaryEmbedding {
57    sin: Tensor,
58    cos: Tensor,
59}
60
61impl RotaryEmbedding {
62    fn new(dim: usize, max_seq_len: usize, dev: &Device) -> Result<Self> {
63        let inv_freq: Vec<_> = (0..dim)
64            .step_by(2)
65            .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
66            .collect();
67        let inv_freq_len = inv_freq.len();
68        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
69        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
70            .to_dtype(DType::F32)?
71            .reshape((max_seq_len, 1))?;
72        let freqs = t.matmul(&inv_freq)?;
73        Ok(Self {
74            sin: freqs.sin()?,
75            cos: freqs.cos()?,
76        })
77    }
78
79    fn apply_rotary_emb_qkv(
80        &self,
81        qkv: &Tensor,
82        seqlen_offset: usize,
83    ) -> Result<(Tensor, Tensor, Tensor)> {
84        let (_b_size, seqlen, three, _, _headdim) = qkv.dims5()?;
85        if three != 3 {
86            candle::bail!("unexpected shape for qkv {:?}", qkv.shape())
87        }
88        let (_rotary_seqlen, rotary_dim) = self.cos.dims2()?;
89        let rotary_dim = rotary_dim * 2;
90        let q_rot = qkv.i((.., .., 0, .., ..rotary_dim))?;
91        let q_pass = qkv.i((.., .., 0, .., rotary_dim..))?;
92        let k_rot = qkv.i((.., .., 1, .., ..rotary_dim))?;
93        let k_pass = qkv.i((.., .., 1, .., rotary_dim..))?;
94        let q12 = q_rot.chunk(2, D::Minus1)?;
95        let k12 = k_rot.chunk(2, D::Minus1)?;
96        let (q1, q2) = (&q12[0], &q12[1]);
97        let (k1, k2) = (&k12[0], &k12[1]);
98        let c = self.cos.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?;
99        let s = self.sin.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?;
100        let q_rot = Tensor::cat(
101            &[
102                (q1.broadcast_mul(&c)? - q2.broadcast_mul(&s)?)?,
103                (q1.broadcast_mul(&s)? + q2.broadcast_mul(&c)?)?,
104            ],
105            D::Minus1,
106        )?;
107        let k_rot = Tensor::cat(
108            &[
109                (k1.broadcast_mul(&c)? - k2.broadcast_mul(&s)?)?,
110                (k1.broadcast_mul(&s)? + k2.broadcast_mul(&c)?)?,
111            ],
112            D::Minus1,
113        )?;
114        let q = Tensor::cat(&[&q_rot, &q_pass], D::Minus1)?;
115        let k = Tensor::cat(&[&k_rot, &k_pass], D::Minus1)?;
116        let v = qkv.i((.., .., 2))?;
117        Ok((q, k, v))
118    }
119}
120
121#[derive(Debug, Clone)]
122#[allow(clippy::upper_case_acronyms)]
123struct MLP {
124    fc1: Linear,
125    fc2: Linear,
126    act: Activation,
127}
128
129impl MLP {
130    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
131        let n_inner = cfg.n_inner.unwrap_or(4 * cfg.n_embd);
132        let fc1 = linear(cfg.n_embd, n_inner, vb.pp("fc1"))?;
133        let fc2 = linear(n_inner, cfg.n_embd, vb.pp("fc2"))?;
134        Ok(Self {
135            fc1,
136            fc2,
137            act: cfg.activation_function,
138        })
139    }
140}
141
142impl Module for MLP {
143    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
144        xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
145    }
146}
147
148#[derive(Debug, Clone)]
149struct CausalLMHead {
150    ln: candle_nn::LayerNorm,
151    linear: Linear,
152}
153
154impl CausalLMHead {
155    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
156        let ln = layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?;
157        let linear = linear(cfg.n_embd, cfg.vocab_size, vb.pp("linear"))?;
158        Ok(Self { ln, linear })
159    }
160}
161
162impl Module for CausalLMHead {
163    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
164        xs.apply(&self.ln)?
165            .apply(&self.linear)?
166            .to_dtype(DType::F32)
167    }
168}
169
170#[derive(Debug, Clone)]
171#[allow(clippy::upper_case_acronyms)]
172struct MHA {
173    wqkv: Linear,
174    out_proj: Linear,
175    rotary_emb: RotaryEmbedding,
176    kv_cache: Option<(Tensor, Tensor)>,
177    head_dim: usize,
178    n_head: usize,
179    softmax_scale: f64,
180    span: tracing::Span,
181}
182
183impl MHA {
184    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
185        let head_dim = cfg.n_embd / cfg.n_head;
186        let op_size = cfg.n_embd;
187        let wqkv = linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?;
188        let out_proj = linear(op_size, cfg.n_embd, vb.pp("out_proj"))?;
189        let rotary_emb = RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.device())?;
190        let softmax_scale = 1f64 / (head_dim as f64).sqrt();
191        Ok(Self {
192            wqkv,
193            out_proj,
194            head_dim,
195            n_head: cfg.n_head,
196            kv_cache: None,
197            rotary_emb,
198            softmax_scale,
199            span: tracing::span!(tracing::Level::TRACE, "mha"),
200        })
201    }
202
203    fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
204        let _enter = self.span.enter();
205        let (b_size, seq_len, _n_embd) = xs.dims3()?;
206        let qkv = self
207            .wqkv
208            .forward(xs)?
209            .reshape((b_size, seq_len, 3, (), self.head_dim))?;
210        let seqlen_offset = match &self.kv_cache {
211            None => 0,
212            Some((prev_k, _)) => prev_k.dim(1)?,
213        };
214        // In the python implementation, a single tensor is returned with the third axis of size 3.
215        let (q, k, v) = self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?;
216        let (k, v) = match &self.kv_cache {
217            None => (k, v),
218            Some((prev_k, prev_v)) => {
219                let k = Tensor::cat(&[prev_k, &k], 1)?;
220                let v = Tensor::cat(&[prev_v, &v], 1)?;
221                (k, v)
222            }
223        };
224        self.kv_cache = Some((k.clone(), v.clone()));
225        // scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
226        let q = q.transpose(1, 2)?.flatten_to(1)?; // b*h, t, d
227        let k = k.transpose(1, 2)?.flatten_to(1)?; // b*h, s, d
228        let v = v.transpose(1, 2)?.flatten_to(1)?; // b*h, s, d
229        let attn_weights = (q.matmul(&k.t()?)? * self.softmax_scale)?; // b*h, t, s
230
231        // causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, device=scores.device), 1)
232        // scores = scores + causal_mask.to(dtype=scores.dtype)
233        let attn_weights = match mask {
234            None => attn_weights,
235            Some(mask) => masked_fill(
236                &attn_weights,
237                &mask.broadcast_left(b_size * self.n_head)?,
238                f32::NEG_INFINITY,
239            )?,
240        };
241        let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
242
243        // output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
244        // attn_weights: b*h,t,s, v: b*h,s,d
245        let attn_output = attn_weights.matmul(&v)?;
246        // b*h,t,d
247        let attn_output = attn_output
248            .reshape((b_size, (), seq_len, self.head_dim))?
249            .transpose(1, 2)?
250            .flatten_from(D::Minus2)?;
251        attn_output.apply(&self.out_proj)
252    }
253
254    fn clear_kv_cache(&mut self) {
255        self.kv_cache = None
256    }
257}
258
259#[derive(Debug, Clone)]
260struct ParallelBlock {
261    ln: candle_nn::LayerNorm,
262    mixer: MHA,
263    mlp: MLP,
264    span: tracing::Span,
265}
266
267impl ParallelBlock {
268    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
269        let ln = layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?;
270        let mixer = MHA::new(cfg, vb.pp("mixer"))?;
271        let mlp = MLP::new(cfg, vb.pp("mlp"))?;
272        Ok(Self {
273            ln,
274            mixer,
275            mlp,
276            span: tracing::span!(tracing::Level::TRACE, "block"),
277        })
278    }
279
280    fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
281        let _enter = self.span.enter();
282        let residual = xs;
283        let xs = xs.apply(&self.ln)?;
284        let attn_outputs = self.mixer.forward(&xs, mask)?;
285        let feed_forward_hidden_states = self.mlp.forward(&xs)?;
286        attn_outputs + feed_forward_hidden_states + residual
287    }
288
289    fn clear_kv_cache(&mut self) {
290        self.mixer.clear_kv_cache()
291    }
292}
293
294#[derive(Debug, Clone)]
295pub struct MixFormerSequentialForCausalLM {
296    embedding: Embedding,
297    blocks: Vec<ParallelBlock>,
298    head: CausalLMHead,
299    span: tracing::Span,
300}
301
302impl MixFormerSequentialForCausalLM {
303    pub fn new_v2(cfg: &Config, vb: VarBuilder) -> Result<Self> {
304        let vb_head = vb.pp("lm_head");
305        let vb = vb.pp("transformer");
306        let embedding = Embedding::new(cfg, vb.pp("embd"))?;
307        let mut blocks = Vec::new();
308        for i in 0..cfg.n_layer {
309            let block = ParallelBlock::new(cfg, vb.pp("h").pp(i))?;
310            blocks.push(block)
311        }
312        let head = CausalLMHead::new(cfg, vb_head)?;
313        Ok(Self {
314            embedding,
315            blocks,
316            head,
317            span: tracing::span!(tracing::Level::TRACE, "mixformer"),
318        })
319    }
320
321    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
322        let vb = vb.pp("layers");
323        let embedding = Embedding::new(cfg, vb.pp(0))?;
324        let mut blocks = Vec::new();
325        for i in 0..cfg.n_layer {
326            let block = ParallelBlock::new(cfg, vb.pp(i + 1))?;
327            blocks.push(block);
328        }
329        let head = CausalLMHead::new(cfg, vb.pp(cfg.n_layer + 1))?;
330        Ok(Self {
331            embedding,
332            blocks,
333            head,
334            span: tracing::span!(tracing::Level::TRACE, "mixformer"),
335        })
336    }
337
338    pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
339        let _enter = self.span.enter();
340        let (_b_size, seq_len) = xs.dims2()?;
341        let mut xs = xs.apply(&self.embedding)?;
342        let mask = if seq_len <= 1 {
343            None
344        } else {
345            Some(get_mask(seq_len, xs.device())?)
346        };
347        for block in self.blocks.iter_mut() {
348            xs = block.forward(&xs, mask.as_ref())?;
349        }
350        xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
351    }
352
353    pub fn forward_with_img(
354        &mut self,
355        bos_token: &Tensor,
356        xs: &Tensor,
357        img_embeds: &Tensor,
358    ) -> Result<Tensor> {
359        let _enter = self.span.enter();
360        let xs = xs.apply(&self.embedding)?;
361        let bos_token = bos_token.apply(&self.embedding)?;
362        // Python implementation sequence order is <bos token embedding><img embedding><rest of text embedding>
363        // https://github.com/vikhyat/moondream/blob/a9d788a20d1543fb1479edc54106e88cff7759d3/moondream/moondream.py#L43-L56
364        let mut xs = Tensor::cat(&[bos_token, img_embeds.clone(), xs], 1)?;
365        let (_b_size, seq_len, _embds) = xs.dims3()?;
366        let mask = Some(get_mask(seq_len, xs.device())?);
367        for block in self.blocks.iter_mut() {
368            xs = block.forward(&xs, mask.as_ref())?
369        }
370        let xs = xs
371            .narrow(1, seq_len - 1, 1)?
372            .apply(&self.head)?
373            .squeeze(1)?;
374        Ok(xs)
375    }
376
377    pub fn clear_kv_cache(&mut self) {
378        self.blocks.iter_mut().for_each(|b| b.clear_kv_cache())
379    }
380}