1use crate::models::with_tracing::{linear_no_bias, Embedding, Linear};
10use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
13use candle_nn::{layer_norm, LayerNorm, VarBuilder};
14
15#[derive(Debug, Clone, PartialEq)]
17pub struct Config {
18 pub(crate) d_model: usize,
19 pub(crate) n_heads: usize,
20 pub(crate) n_layers: usize,
21 pub(crate) expansion_ratio: usize,
22 pub(crate) max_seq_len: usize,
23 pub(crate) vocab_size: usize,
24 pub(crate) kv_n_heads: usize,
25 pub(crate) attn_prefix_lm: bool,
26 pub(crate) attn_alibi: bool,
27 pub(crate) attn_alibi_bias_max: usize,
28}
29
30impl Config {
31 pub fn replit_code_v1_5_3b() -> Self {
32 Self {
33 d_model: 3072,
34 n_heads: 24,
35 n_layers: 32,
36 expansion_ratio: 4,
37 max_seq_len: 4096,
38 vocab_size: 32768,
39 kv_n_heads: 8,
40 attn_prefix_lm: false,
41 attn_alibi: true,
42 attn_alibi_bias_max: 8,
43 }
44 }
45
46 pub fn is_causal(&self) -> bool {
47 !self.attn_prefix_lm
48 }
49}
50
51#[derive(Debug, Clone)]
52struct GroupedQueryAttention {
53 wqkv: Linear,
54 out_proj: Linear,
55 kv_cache: Option<(Tensor, Tensor)>,
56 softmax_scale: f64,
57 head_dim: usize,
58 d_model: usize,
59 n_heads: usize,
60 kv_n_heads: usize,
61 attn_bias: Tensor,
62 span: tracing::Span,
63}
64
65impl GroupedQueryAttention {
66 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
67 let head_dim = cfg.d_model / cfg.n_heads;
68 let wqkv_size = cfg.d_model + 2 * cfg.kv_n_heads * head_dim;
69 let wqkv = linear_no_bias(cfg.d_model, wqkv_size, vb.pp("Wqkv"))?;
70 let softmax_scale = 1f64 / (head_dim as f64).sqrt();
71 let out_proj = linear_no_bias(cfg.d_model, cfg.d_model, vb.pp("out_proj"))?;
72 let attn_bias = build_alibi_bias(cfg)?.to_device(vb.device())?;
73 Ok(Self {
74 wqkv,
75 out_proj,
76 kv_cache: None,
77 softmax_scale,
78 head_dim,
79 d_model: cfg.d_model,
80 n_heads: cfg.n_heads,
81 kv_n_heads: cfg.kv_n_heads,
82 attn_bias,
83 span: tracing::span!(tracing::Level::TRACE, "gqa"),
84 })
85 }
86
87 fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
88 let _enter = self.span.enter();
89 let (b_size, seq_len, _n_embd) = xs.dims3()?;
90 let qkv = self.wqkv.forward(xs)?;
91 let query = qkv.narrow(2, 0, self.d_model)?;
92 let kv_size = self.kv_n_heads * self.head_dim;
93 let key = qkv.narrow(2, self.d_model, kv_size)?;
94 let value = qkv.narrow(2, self.d_model + kv_size, kv_size)?;
95 let query = query
97 .reshape((b_size, seq_len, self.n_heads, ()))?
98 .transpose(1, 2)?; let key = key
100 .reshape((b_size, seq_len, self.kv_n_heads, ()))?
101 .permute((0, 2, 3, 1))?; let value = value
103 .reshape((b_size, seq_len, self.kv_n_heads, ()))?
104 .transpose(1, 2)?; let (key, value) = match &self.kv_cache {
106 None => (key, value),
107 Some((prev_k, prev_v)) => {
108 let k = Tensor::cat(&[prev_k, &key], 3)?;
109 let v = Tensor::cat(&[prev_v, &value], 2)?;
110 (k, v)
111 }
112 };
113 self.kv_cache = Some((key.clone(), value.clone()));
114 let query = query.contiguous()?;
115 let key = crate::utils::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
116 let value = crate::utils::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
117 let attn_weights = (query.matmul(&key)? * self.softmax_scale)?;
118 let attn_bias = {
119 let s_q = query.dim(D::Minus2)?;
120 let s_k = key.dim(D::Minus1)?;
121 let (_, _, a_q, a_k) = self.attn_bias.dims4()?;
122 let start_q = a_q.saturating_sub(s_q);
123 let start_k = a_k.saturating_sub(s_k);
124 self.attn_bias.i((.., .., start_q.., start_k..))?
125 };
126 let attn_weights = attn_weights.broadcast_add(&attn_bias)?;
127 let attn_weights = match mask {
128 None => attn_weights,
129 Some(mask) => masked_fill(
130 &attn_weights,
131 &mask.broadcast_as(attn_weights.shape())?,
132 f32::NEG_INFINITY,
133 )?,
134 };
135 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
136 let attn_output = attn_weights
137 .matmul(&value)?
138 .transpose(1, 2)?
139 .flatten_from(D::Minus2)?;
140 let out = attn_output.apply(&self.out_proj)?;
141 Ok(out)
142 }
143}
144
145#[derive(Debug, Clone)]
146struct Ffn {
147 up_proj: Linear,
148 down_proj: Linear,
149}
150
151impl Ffn {
152 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
153 let hidden = cfg.d_model * cfg.expansion_ratio;
154 let up_proj = linear_no_bias(cfg.d_model, hidden, vb.pp("up_proj"))?;
155 let down_proj = linear_no_bias(hidden, cfg.d_model, vb.pp("down_proj"))?;
156 Ok(Self { up_proj, down_proj })
157 }
158}
159
160impl Module for Ffn {
161 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
162 xs.apply(&self.up_proj)?.gelu_erf()?.apply(&self.down_proj)
163 }
164}
165
166#[derive(Debug, Clone)]
167struct MPTBlock {
168 norm1: LayerNorm, attn: GroupedQueryAttention,
170 norm2: LayerNorm,
171 ffn: Ffn,
172}
173
174impl MPTBlock {
175 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
176 let ln_cfg = candle_nn::LayerNormConfig {
177 affine: false,
178 ..Default::default()
179 };
180 let norm1 = layer_norm(cfg.d_model, ln_cfg, vb.pp("norm_1"))?;
181 let norm2 = layer_norm(cfg.d_model, ln_cfg, vb.pp("norm_2"))?;
182 let attn = GroupedQueryAttention::new(cfg, vb.pp("attn"))?;
183 let ffn = Ffn::new(cfg, vb.pp("ffn"))?;
184 Ok(Self {
185 norm1,
186 attn,
187 norm2,
188 ffn,
189 })
190 }
191
192 fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
193 let residual = xs;
194 let xs = xs.apply(&self.norm1)?;
195 let xs = self.attn.forward(&xs, mask)?;
196 let xs = (xs + residual)?;
197 let residual = &xs;
198 let xs = xs.apply(&self.norm2)?.apply(&self.ffn)?;
199 xs + residual
200 }
201}
202
203pub(crate) fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
204 let full = !cfg.is_causal();
205 let seq_len = cfg.max_seq_len;
206 let alibi_bias = Tensor::arange(1 - seq_len as i64, 1, &Device::Cpu)?;
207 let alibi_bias = if full {
208 let a1 = alibi_bias.reshape((1, 1, 1, seq_len))?;
209 let a2 = alibi_bias.reshape((1, 1, seq_len, 1))?;
210 a1.broadcast_sub(&a2)?.abs()?.neg()?
211 } else {
212 alibi_bias.reshape((1, 1, 1, seq_len))?
213 };
214 let mut n_heads2 = 1;
215 while n_heads2 < cfg.n_heads {
216 n_heads2 *= 2
217 }
218 let slopes = (1..=n_heads2)
219 .map(|v| 1f32 / 2f32.powf((v * cfg.attn_alibi_bias_max) as f32 / n_heads2 as f32))
220 .collect::<Vec<_>>();
221 let slopes = if n_heads2 == cfg.n_heads {
222 slopes
223 } else {
224 slopes
225 .iter()
226 .skip(1)
227 .step_by(2)
228 .chain(slopes.iter().step_by(2))
229 .take(cfg.n_heads)
230 .cloned()
231 .collect::<Vec<f32>>()
232 };
233 let slopes = Tensor::new(slopes, &Device::Cpu)?.reshape((1, (), 1, 1))?;
234 alibi_bias.to_dtype(DType::F32)?.broadcast_mul(&slopes)
235}
236
237#[derive(Debug, Clone)]
238pub struct Model {
239 wte: Embedding,
240 blocks: Vec<MPTBlock>,
241 norm_f: LayerNorm,
242}
243
244impl Model {
245 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
246 let wte = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("wte"))?;
247 let vb_b = vb.pp("blocks");
248 let mut blocks = Vec::with_capacity(cfg.n_layers);
249 for i in 0..cfg.n_layers {
250 let block = MPTBlock::new(cfg, vb_b.pp(i))?;
251 blocks.push(block)
252 }
253 let ln_cfg = candle_nn::LayerNormConfig {
254 affine: false,
255 ..Default::default()
256 };
257 let norm_f = candle_nn::layer_norm(cfg.d_model, ln_cfg, vb.pp("norm_f"))?;
258 Ok(Self {
259 wte,
260 blocks,
261 norm_f,
262 })
263 }
264
265 pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
266 let (_b_size, seq_len) = xs.dims2()?;
267 let mut xs = xs.apply(&self.wte)?;
268 let mask = if seq_len <= 1 {
269 None
270 } else {
271 Some(get_mask(seq_len, xs.device())?)
272 };
273 for block in self.blocks.iter_mut() {
274 xs = block.forward(&xs, mask.as_ref())?;
275 }
276 let xs = xs.apply(&self.norm_f)?;
277 let logits = xs
278 .narrow(1, seq_len - 1, 1)?
279 .squeeze(1)?
280 .matmul(&self.wte.embeddings().t()?)?
281 .squeeze(1)?;
282 Ok(logits)
283 }
284}
285
286pub(crate) fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
287 let mask: Vec<_> = (0..size)
288 .flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
289 .collect();
290 Tensor::from_slice(&mask, (size, size), device)
291}
292
293pub(crate) fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
294 let shape = mask.shape();
295 let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
296 let m = mask.where_cond(&on_true, on_false)?;
297 Ok(m)
298}