candle_transformers/models/
quantized_mpt.rs1use crate::quantized_nn::{layer_norm_no_bias, linear_no_bias, Embedding, Linear};
20pub use crate::quantized_var_builder::VarBuilder;
21use candle::{IndexOp, Module, Result, Tensor, D};
24use candle_nn::LayerNorm;
25
26pub use super::mpt::Config;
27
28#[derive(Debug, Clone)]
29struct GroupedQueryAttention {
30 wqkv: Linear,
31 out_proj: Linear,
32 kv_cache: Option<(Tensor, Tensor)>,
33 softmax_scale: f64,
34 head_dim: usize,
35 d_model: usize,
36 n_heads: usize,
37 kv_n_heads: usize,
38 attn_bias: Tensor,
39 span: tracing::Span,
40}
41
42impl GroupedQueryAttention {
43 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
44 let head_dim = cfg.d_model / cfg.n_heads;
45 let wqkv_size = cfg.d_model + 2 * cfg.kv_n_heads * head_dim;
46 let wqkv = linear_no_bias(cfg.d_model, wqkv_size, vb.pp("Wqkv"))?;
47 let softmax_scale = 1f64 / (head_dim as f64).sqrt();
48 let out_proj = linear_no_bias(cfg.d_model, cfg.d_model, vb.pp("out_proj"))?;
49 let attn_bias = super::mpt::build_alibi_bias(cfg)?.to_device(vb.device())?;
50 Ok(Self {
51 wqkv,
52 out_proj,
53 kv_cache: None,
54 softmax_scale,
55 head_dim,
56 d_model: cfg.d_model,
57 n_heads: cfg.n_heads,
58 kv_n_heads: cfg.kv_n_heads,
59 attn_bias,
60 span: tracing::span!(tracing::Level::TRACE, "gqa"),
61 })
62 }
63
64 fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
65 let _enter = self.span.enter();
66 let (b_size, seq_len, _n_embd) = xs.dims3()?;
67 let qkv = self.wqkv.forward(xs)?;
68 let query = qkv.narrow(2, 0, self.d_model)?;
69 let kv_size = self.kv_n_heads * self.head_dim;
70 let key = qkv.narrow(2, self.d_model, kv_size)?;
71 let value = qkv.narrow(2, self.d_model + kv_size, kv_size)?;
72 let query = query
74 .reshape((b_size, seq_len, self.n_heads, ()))?
75 .transpose(1, 2)?; let key = key
77 .reshape((b_size, seq_len, self.kv_n_heads, ()))?
78 .permute((0, 2, 3, 1))?; let value = value
80 .reshape((b_size, seq_len, self.kv_n_heads, ()))?
81 .transpose(1, 2)?; let (key, value) = match &self.kv_cache {
83 None => (key, value),
84 Some((prev_k, prev_v)) => {
85 let k = Tensor::cat(&[prev_k, &key], 3)?;
86 let v = Tensor::cat(&[prev_v, &value], 2)?;
87 (k, v)
88 }
89 };
90 self.kv_cache = Some((key.clone(), value.clone()));
91 let query = query.contiguous()?;
92 let key = crate::utils::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
93 let value = crate::utils::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
94 let attn_weights = (query.matmul(&key)? * self.softmax_scale)?;
95 let attn_bias = {
96 let s_q = query.dim(D::Minus2)?;
97 let s_k = key.dim(D::Minus1)?;
98 let (_, _, a_q, a_k) = self.attn_bias.dims4()?;
99 let start_q = a_q.saturating_sub(s_q);
100 let start_k = a_k.saturating_sub(s_k);
101 self.attn_bias.i((.., .., start_q.., start_k..))?
102 };
103 let attn_weights = attn_weights.broadcast_add(&attn_bias)?;
104 let attn_weights = match mask {
105 None => attn_weights,
106 Some(mask) => super::mpt::masked_fill(
107 &attn_weights,
108 &mask.broadcast_as(attn_weights.shape())?,
109 f32::NEG_INFINITY,
110 )?,
111 };
112 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
113 let attn_output = attn_weights
114 .matmul(&value)?
115 .transpose(1, 2)?
116 .flatten_from(D::Minus2)?;
117 let out = attn_output.apply(&self.out_proj)?;
118 Ok(out)
119 }
120}
121
122#[derive(Debug, Clone)]
123struct Ffn {
124 up_proj: Linear,
125 down_proj: Linear,
126}
127
128impl Ffn {
129 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
130 let hidden = cfg.d_model * cfg.expansion_ratio;
131 let up_proj = linear_no_bias(cfg.d_model, hidden, vb.pp("up_proj"))?;
132 let down_proj = linear_no_bias(hidden, cfg.d_model, vb.pp("down_proj"))?;
133 Ok(Self { up_proj, down_proj })
134 }
135}
136
137impl Module for Ffn {
138 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
139 xs.apply(&self.up_proj)?.gelu_erf()?.apply(&self.down_proj)
140 }
141}
142
143#[derive(Debug, Clone)]
144struct MPTBlock {
145 norm1: LayerNorm, attn: GroupedQueryAttention,
147 norm2: LayerNorm,
148 ffn: Ffn,
149}
150
151impl MPTBlock {
152 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
153 let norm1 = layer_norm_no_bias(cfg.d_model, 1e-5, vb.pp("norm_1"))?;
154 let norm2 = layer_norm_no_bias(cfg.d_model, 1e-5, vb.pp("norm_2"))?;
155 let attn = GroupedQueryAttention::new(cfg, vb.pp("attn"))?;
156 let ffn = Ffn::new(cfg, vb.pp("ffn"))?;
157 Ok(Self {
158 norm1,
159 attn,
160 norm2,
161 ffn,
162 })
163 }
164
165 fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
166 let residual = xs;
167 let xs = xs.apply(&self.norm1)?;
168 let xs = self.attn.forward(&xs, mask)?;
169 let xs = (xs + residual)?;
170 let residual = &xs;
171 let xs = xs.apply(&self.norm2)?.apply(&self.ffn)?;
172 xs + residual
173 }
174}
175
176#[derive(Debug, Clone)]
177pub struct Model {
178 wte: Embedding,
179 blocks: Vec<MPTBlock>,
180 norm_f: LayerNorm,
181}
182
183impl Model {
184 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
185 let wte = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("wte"))?;
186 let vb_b = vb.pp("blocks");
187 let mut blocks = Vec::with_capacity(cfg.n_layers);
188 for i in 0..cfg.n_layers {
189 let block = MPTBlock::new(cfg, vb_b.pp(i))?;
190 blocks.push(block)
191 }
192 let norm_f = layer_norm_no_bias(cfg.d_model, 1e-5, vb.pp("norm_f"))?;
193 Ok(Self {
194 wte,
195 blocks,
196 norm_f,
197 })
198 }
199
200 pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
201 let (_b_size, seq_len) = xs.dims2()?;
202 let mut xs = xs.apply(&self.wte)?;
203 let mask = if seq_len <= 1 {
204 None
205 } else {
206 Some(super::mpt::get_mask(seq_len, xs.device())?)
207 };
208 for block in self.blocks.iter_mut() {
209 xs = block.forward(&xs, mask.as_ref())?;
210 }
211 let xs = xs.apply(&self.norm_f)?;
212 let logits = xs
213 .narrow(1, seq_len - 1, 1)?
214 .squeeze(1)?
215 .matmul(&self.wte.embeddings().t()?)?
216 .squeeze(1)?;
217 Ok(logits)
218 }
219}