candle_transformers/models/
mpt.rs

1//! Module implementing the MPT (Multi-Purpose Transformer) model
2//!
3//! References:
4//! - [MPT Model used by replit-code-v1_5-3b](https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/modeling_mpt.py)
5//! - [Configuration](https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/configuration_mpt.py)
6//!
7//! The model uses grouped query attention and alibi positional embeddings.
8
9use crate::models::with_tracing::{linear_no_bias, Embedding, Linear};
10/// MPT model used by replit-code-v1_5-3b
11/// https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/modeling_mpt.py
12use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
13use candle_nn::{layer_norm, LayerNorm, VarBuilder};
14
15// https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/configuration_mpt.py
16#[derive(Debug, Clone, PartialEq)]
17pub struct Config {
18    pub(crate) d_model: usize,
19    pub(crate) n_heads: usize,
20    pub(crate) n_layers: usize,
21    pub(crate) expansion_ratio: usize,
22    pub(crate) max_seq_len: usize,
23    pub(crate) vocab_size: usize,
24    pub(crate) kv_n_heads: usize,
25    pub(crate) attn_prefix_lm: bool,
26    pub(crate) attn_alibi: bool,
27    pub(crate) attn_alibi_bias_max: usize,
28}
29
30impl Config {
31    pub fn replit_code_v1_5_3b() -> Self {
32        Self {
33            d_model: 3072,
34            n_heads: 24,
35            n_layers: 32,
36            expansion_ratio: 4,
37            max_seq_len: 4096,
38            vocab_size: 32768,
39            kv_n_heads: 8,
40            attn_prefix_lm: false,
41            attn_alibi: true,
42            attn_alibi_bias_max: 8,
43        }
44    }
45
46    pub fn is_causal(&self) -> bool {
47        !self.attn_prefix_lm
48    }
49}
50
51#[derive(Debug, Clone)]
52struct GroupedQueryAttention {
53    wqkv: Linear,
54    out_proj: Linear,
55    kv_cache: Option<(Tensor, Tensor)>,
56    softmax_scale: f64,
57    head_dim: usize,
58    d_model: usize,
59    n_heads: usize,
60    kv_n_heads: usize,
61    attn_bias: Tensor,
62    span: tracing::Span,
63}
64
65impl GroupedQueryAttention {
66    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
67        let head_dim = cfg.d_model / cfg.n_heads;
68        let wqkv_size = cfg.d_model + 2 * cfg.kv_n_heads * head_dim;
69        let wqkv = linear_no_bias(cfg.d_model, wqkv_size, vb.pp("Wqkv"))?;
70        let softmax_scale = 1f64 / (head_dim as f64).sqrt();
71        let out_proj = linear_no_bias(cfg.d_model, cfg.d_model, vb.pp("out_proj"))?;
72        let attn_bias = build_alibi_bias(cfg)?.to_device(vb.device())?;
73        Ok(Self {
74            wqkv,
75            out_proj,
76            kv_cache: None,
77            softmax_scale,
78            head_dim,
79            d_model: cfg.d_model,
80            n_heads: cfg.n_heads,
81            kv_n_heads: cfg.kv_n_heads,
82            attn_bias,
83            span: tracing::span!(tracing::Level::TRACE, "gqa"),
84        })
85    }
86
87    fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
88        let _enter = self.span.enter();
89        let (b_size, seq_len, _n_embd) = xs.dims3()?;
90        let qkv = self.wqkv.forward(xs)?;
91        let query = qkv.narrow(2, 0, self.d_model)?;
92        let kv_size = self.kv_n_heads * self.head_dim;
93        let key = qkv.narrow(2, self.d_model, kv_size)?;
94        let value = qkv.narrow(2, self.d_model + kv_size, kv_size)?;
95        // scaled_multihead_dot_product_attention
96        let query = query
97            .reshape((b_size, seq_len, self.n_heads, ()))?
98            .transpose(1, 2)?; // b,h,s,d
99        let key = key
100            .reshape((b_size, seq_len, self.kv_n_heads, ()))?
101            .permute((0, 2, 3, 1))?; // b,h,d,s
102        let value = value
103            .reshape((b_size, seq_len, self.kv_n_heads, ()))?
104            .transpose(1, 2)?; // b,h,s,d
105        let (key, value) = match &self.kv_cache {
106            None => (key, value),
107            Some((prev_k, prev_v)) => {
108                let k = Tensor::cat(&[prev_k, &key], 3)?;
109                let v = Tensor::cat(&[prev_v, &value], 2)?;
110                (k, v)
111            }
112        };
113        self.kv_cache = Some((key.clone(), value.clone()));
114        let query = query.contiguous()?;
115        let key = crate::utils::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
116        let value = crate::utils::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
117        let attn_weights = (query.matmul(&key)? * self.softmax_scale)?;
118        let attn_bias = {
119            let s_q = query.dim(D::Minus2)?;
120            let s_k = key.dim(D::Minus1)?;
121            let (_, _, a_q, a_k) = self.attn_bias.dims4()?;
122            let start_q = a_q.saturating_sub(s_q);
123            let start_k = a_k.saturating_sub(s_k);
124            self.attn_bias.i((.., .., start_q.., start_k..))?
125        };
126        let attn_weights = attn_weights.broadcast_add(&attn_bias)?;
127        let attn_weights = match mask {
128            None => attn_weights,
129            Some(mask) => masked_fill(
130                &attn_weights,
131                &mask.broadcast_as(attn_weights.shape())?,
132                f32::NEG_INFINITY,
133            )?,
134        };
135        let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
136        let attn_output = attn_weights
137            .matmul(&value)?
138            .transpose(1, 2)?
139            .flatten_from(D::Minus2)?;
140        let out = attn_output.apply(&self.out_proj)?;
141        Ok(out)
142    }
143}
144
145#[derive(Debug, Clone)]
146struct Ffn {
147    up_proj: Linear,
148    down_proj: Linear,
149}
150
151impl Ffn {
152    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
153        let hidden = cfg.d_model * cfg.expansion_ratio;
154        let up_proj = linear_no_bias(cfg.d_model, hidden, vb.pp("up_proj"))?;
155        let down_proj = linear_no_bias(hidden, cfg.d_model, vb.pp("down_proj"))?;
156        Ok(Self { up_proj, down_proj })
157    }
158}
159
160impl Module for Ffn {
161    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
162        xs.apply(&self.up_proj)?.gelu_erf()?.apply(&self.down_proj)
163    }
164}
165
166#[derive(Debug, Clone)]
167struct MPTBlock {
168    norm1: LayerNorm, // Do we need the low-precision variant?
169    attn: GroupedQueryAttention,
170    norm2: LayerNorm,
171    ffn: Ffn,
172}
173
174impl MPTBlock {
175    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
176        let ln_cfg = candle_nn::LayerNormConfig {
177            affine: false,
178            ..Default::default()
179        };
180        let norm1 = layer_norm(cfg.d_model, ln_cfg, vb.pp("norm_1"))?;
181        let norm2 = layer_norm(cfg.d_model, ln_cfg, vb.pp("norm_2"))?;
182        let attn = GroupedQueryAttention::new(cfg, vb.pp("attn"))?;
183        let ffn = Ffn::new(cfg, vb.pp("ffn"))?;
184        Ok(Self {
185            norm1,
186            attn,
187            norm2,
188            ffn,
189        })
190    }
191
192    fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
193        let residual = xs;
194        let xs = xs.apply(&self.norm1)?;
195        let xs = self.attn.forward(&xs, mask)?;
196        let xs = (xs + residual)?;
197        let residual = &xs;
198        let xs = xs.apply(&self.norm2)?.apply(&self.ffn)?;
199        xs + residual
200    }
201}
202
203pub(crate) fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
204    let full = !cfg.is_causal();
205    let seq_len = cfg.max_seq_len;
206    let alibi_bias = Tensor::arange(1 - seq_len as i64, 1, &Device::Cpu)?;
207    let alibi_bias = if full {
208        let a1 = alibi_bias.reshape((1, 1, 1, seq_len))?;
209        let a2 = alibi_bias.reshape((1, 1, seq_len, 1))?;
210        a1.broadcast_sub(&a2)?.abs()?.neg()?
211    } else {
212        alibi_bias.reshape((1, 1, 1, seq_len))?
213    };
214    let mut n_heads2 = 1;
215    while n_heads2 < cfg.n_heads {
216        n_heads2 *= 2
217    }
218    let slopes = (1..=n_heads2)
219        .map(|v| 1f32 / 2f32.powf((v * cfg.attn_alibi_bias_max) as f32 / n_heads2 as f32))
220        .collect::<Vec<_>>();
221    let slopes = if n_heads2 == cfg.n_heads {
222        slopes
223    } else {
224        slopes
225            .iter()
226            .skip(1)
227            .step_by(2)
228            .chain(slopes.iter().step_by(2))
229            .take(cfg.n_heads)
230            .cloned()
231            .collect::<Vec<f32>>()
232    };
233    let slopes = Tensor::new(slopes, &Device::Cpu)?.reshape((1, (), 1, 1))?;
234    alibi_bias.to_dtype(DType::F32)?.broadcast_mul(&slopes)
235}
236
237#[derive(Debug, Clone)]
238pub struct Model {
239    wte: Embedding,
240    blocks: Vec<MPTBlock>,
241    norm_f: LayerNorm,
242}
243
244impl Model {
245    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
246        let wte = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("wte"))?;
247        let vb_b = vb.pp("blocks");
248        let mut blocks = Vec::with_capacity(cfg.n_layers);
249        for i in 0..cfg.n_layers {
250            let block = MPTBlock::new(cfg, vb_b.pp(i))?;
251            blocks.push(block)
252        }
253        let ln_cfg = candle_nn::LayerNormConfig {
254            affine: false,
255            ..Default::default()
256        };
257        let norm_f = candle_nn::layer_norm(cfg.d_model, ln_cfg, vb.pp("norm_f"))?;
258        Ok(Self {
259            wte,
260            blocks,
261            norm_f,
262        })
263    }
264
265    pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
266        let (_b_size, seq_len) = xs.dims2()?;
267        let mut xs = xs.apply(&self.wte)?;
268        let mask = if seq_len <= 1 {
269            None
270        } else {
271            Some(get_mask(seq_len, xs.device())?)
272        };
273        for block in self.blocks.iter_mut() {
274            xs = block.forward(&xs, mask.as_ref())?;
275        }
276        let xs = xs.apply(&self.norm_f)?;
277        let logits = xs
278            .narrow(1, seq_len - 1, 1)?
279            .squeeze(1)?
280            .matmul(&self.wte.embeddings().t()?)?
281            .squeeze(1)?;
282        Ok(logits)
283    }
284}
285
286pub(crate) fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
287    let mask: Vec<_> = (0..size)
288        .flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
289        .collect();
290    Tensor::from_slice(&mask, (size, size), device)
291}
292
293pub(crate) fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
294    let shape = mask.shape();
295    let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
296    let m = mask.where_cond(&on_true, on_false)?;
297    Ok(m)
298}