candle_transformers/models/
quantized_mpt.rs

1//! Quantized MPT model implementation.
2//!
3//! MPT (MPT-7B) is a causal transformer model series optimized for code generation.
4//! This implementation provides quantization for reduced memory and compute.
5//!
6//! Key characteristics:
7//! - Multi-Query Grouped Attention (MQA)
8//! - Support for KV-caching
9//! - Pre-computed ALiBi attention biases
10//! - Support for 8-bit quantization
11//!
12//! References:
13//! - [Replit Code Models](https://huggingface.co/replit/replit-code-v1_5-3b)
14//! - [MPT-7B Implementation](https://github.com/mosaicml/llm-foundry)
15//!
16/// MPT model used by replit-code-v1_5-3b
17/// https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/modeling_mpt.py
18///
19use crate::quantized_nn::{layer_norm_no_bias, linear_no_bias, Embedding, Linear};
20pub use crate::quantized_var_builder::VarBuilder;
21/// MPT model used by replit-code-v1_5-3b
22/// https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/modeling_mpt.py
23use candle::{IndexOp, Module, Result, Tensor, D};
24use candle_nn::LayerNorm;
25
26pub use super::mpt::Config;
27
28#[derive(Debug, Clone)]
29struct GroupedQueryAttention {
30    wqkv: Linear,
31    out_proj: Linear,
32    kv_cache: Option<(Tensor, Tensor)>,
33    softmax_scale: f64,
34    head_dim: usize,
35    d_model: usize,
36    n_heads: usize,
37    kv_n_heads: usize,
38    attn_bias: Tensor,
39    span: tracing::Span,
40}
41
42impl GroupedQueryAttention {
43    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
44        let head_dim = cfg.d_model / cfg.n_heads;
45        let wqkv_size = cfg.d_model + 2 * cfg.kv_n_heads * head_dim;
46        let wqkv = linear_no_bias(cfg.d_model, wqkv_size, vb.pp("Wqkv"))?;
47        let softmax_scale = 1f64 / (head_dim as f64).sqrt();
48        let out_proj = linear_no_bias(cfg.d_model, cfg.d_model, vb.pp("out_proj"))?;
49        let attn_bias = super::mpt::build_alibi_bias(cfg)?.to_device(vb.device())?;
50        Ok(Self {
51            wqkv,
52            out_proj,
53            kv_cache: None,
54            softmax_scale,
55            head_dim,
56            d_model: cfg.d_model,
57            n_heads: cfg.n_heads,
58            kv_n_heads: cfg.kv_n_heads,
59            attn_bias,
60            span: tracing::span!(tracing::Level::TRACE, "gqa"),
61        })
62    }
63
64    fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
65        let _enter = self.span.enter();
66        let (b_size, seq_len, _n_embd) = xs.dims3()?;
67        let qkv = self.wqkv.forward(xs)?;
68        let query = qkv.narrow(2, 0, self.d_model)?;
69        let kv_size = self.kv_n_heads * self.head_dim;
70        let key = qkv.narrow(2, self.d_model, kv_size)?;
71        let value = qkv.narrow(2, self.d_model + kv_size, kv_size)?;
72        // scaled_multihead_dot_product_attention
73        let query = query
74            .reshape((b_size, seq_len, self.n_heads, ()))?
75            .transpose(1, 2)?; // b,h,s,d
76        let key = key
77            .reshape((b_size, seq_len, self.kv_n_heads, ()))?
78            .permute((0, 2, 3, 1))?; // b,h,d,s
79        let value = value
80            .reshape((b_size, seq_len, self.kv_n_heads, ()))?
81            .transpose(1, 2)?; // b,h,s,d
82        let (key, value) = match &self.kv_cache {
83            None => (key, value),
84            Some((prev_k, prev_v)) => {
85                let k = Tensor::cat(&[prev_k, &key], 3)?;
86                let v = Tensor::cat(&[prev_v, &value], 2)?;
87                (k, v)
88            }
89        };
90        self.kv_cache = Some((key.clone(), value.clone()));
91        let query = query.contiguous()?;
92        let key = crate::utils::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
93        let value = crate::utils::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
94        let attn_weights = (query.matmul(&key)? * self.softmax_scale)?;
95        let attn_bias = {
96            let s_q = query.dim(D::Minus2)?;
97            let s_k = key.dim(D::Minus1)?;
98            let (_, _, a_q, a_k) = self.attn_bias.dims4()?;
99            let start_q = a_q.saturating_sub(s_q);
100            let start_k = a_k.saturating_sub(s_k);
101            self.attn_bias.i((.., .., start_q.., start_k..))?
102        };
103        let attn_weights = attn_weights.broadcast_add(&attn_bias)?;
104        let attn_weights = match mask {
105            None => attn_weights,
106            Some(mask) => super::mpt::masked_fill(
107                &attn_weights,
108                &mask.broadcast_as(attn_weights.shape())?,
109                f32::NEG_INFINITY,
110            )?,
111        };
112        let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
113        let attn_output = attn_weights
114            .matmul(&value)?
115            .transpose(1, 2)?
116            .flatten_from(D::Minus2)?;
117        let out = attn_output.apply(&self.out_proj)?;
118        Ok(out)
119    }
120}
121
122#[derive(Debug, Clone)]
123struct Ffn {
124    up_proj: Linear,
125    down_proj: Linear,
126}
127
128impl Ffn {
129    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
130        let hidden = cfg.d_model * cfg.expansion_ratio;
131        let up_proj = linear_no_bias(cfg.d_model, hidden, vb.pp("up_proj"))?;
132        let down_proj = linear_no_bias(hidden, cfg.d_model, vb.pp("down_proj"))?;
133        Ok(Self { up_proj, down_proj })
134    }
135}
136
137impl Module for Ffn {
138    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
139        xs.apply(&self.up_proj)?.gelu_erf()?.apply(&self.down_proj)
140    }
141}
142
143#[derive(Debug, Clone)]
144struct MPTBlock {
145    norm1: LayerNorm, // Do we need the low-precision variant?
146    attn: GroupedQueryAttention,
147    norm2: LayerNorm,
148    ffn: Ffn,
149}
150
151impl MPTBlock {
152    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
153        let norm1 = layer_norm_no_bias(cfg.d_model, 1e-5, vb.pp("norm_1"))?;
154        let norm2 = layer_norm_no_bias(cfg.d_model, 1e-5, vb.pp("norm_2"))?;
155        let attn = GroupedQueryAttention::new(cfg, vb.pp("attn"))?;
156        let ffn = Ffn::new(cfg, vb.pp("ffn"))?;
157        Ok(Self {
158            norm1,
159            attn,
160            norm2,
161            ffn,
162        })
163    }
164
165    fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
166        let residual = xs;
167        let xs = xs.apply(&self.norm1)?;
168        let xs = self.attn.forward(&xs, mask)?;
169        let xs = (xs + residual)?;
170        let residual = &xs;
171        let xs = xs.apply(&self.norm2)?.apply(&self.ffn)?;
172        xs + residual
173    }
174}
175
176#[derive(Debug, Clone)]
177pub struct Model {
178    wte: Embedding,
179    blocks: Vec<MPTBlock>,
180    norm_f: LayerNorm,
181}
182
183impl Model {
184    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
185        let wte = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("wte"))?;
186        let vb_b = vb.pp("blocks");
187        let mut blocks = Vec::with_capacity(cfg.n_layers);
188        for i in 0..cfg.n_layers {
189            let block = MPTBlock::new(cfg, vb_b.pp(i))?;
190            blocks.push(block)
191        }
192        let norm_f = layer_norm_no_bias(cfg.d_model, 1e-5, vb.pp("norm_f"))?;
193        Ok(Self {
194            wte,
195            blocks,
196            norm_f,
197        })
198    }
199
200    pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
201        let (_b_size, seq_len) = xs.dims2()?;
202        let mut xs = xs.apply(&self.wte)?;
203        let mask = if seq_len <= 1 {
204            None
205        } else {
206            Some(super::mpt::get_mask(seq_len, xs.device())?)
207        };
208        for block in self.blocks.iter_mut() {
209            xs = block.forward(&xs, mask.as_ref())?;
210        }
211        let xs = xs.apply(&self.norm_f)?;
212        let logits = xs
213            .narrow(1, seq_len - 1, 1)?
214            .squeeze(1)?
215            .matmul(&self.wte.embeddings().t()?)?
216            .squeeze(1)?;
217        Ok(logits)
218    }
219}