1use super::with_tracing::{linear, linear_no_bias, Embedding, Linear};
8use candle::{DType, Device, IndexOp, Result, Tensor, D};
9use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder};
10use serde::Deserialize;
11
12pub const DTYPE: DType = DType::F32;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
15#[serde(rename_all = "lowercase")]
16pub enum PositionEmbeddingType {
17 Absolute,
18 Alibi,
19}
20
21#[derive(Debug, Clone, PartialEq, Deserialize)]
23pub struct Config {
24 pub vocab_size: usize,
25 pub hidden_size: usize,
26 pub num_hidden_layers: usize,
27 pub num_attention_heads: usize,
28 pub intermediate_size: usize,
29 pub hidden_act: candle_nn::Activation,
30 pub max_position_embeddings: usize,
31 pub type_vocab_size: usize,
32 pub initializer_range: f64,
33 pub layer_norm_eps: f64,
34 pub pad_token_id: usize,
35 pub position_embedding_type: PositionEmbeddingType,
36}
37
38impl Config {
39 pub fn v2_base() -> Self {
40 Self {
42 vocab_size: 30528,
43 hidden_size: 768,
44 num_hidden_layers: 12,
45 num_attention_heads: 12,
46 intermediate_size: 3072,
47 hidden_act: candle_nn::Activation::Gelu,
48 max_position_embeddings: 8192,
49 type_vocab_size: 2,
50 initializer_range: 0.02,
51 layer_norm_eps: 1e-12,
52 pad_token_id: 0,
53 position_embedding_type: PositionEmbeddingType::Alibi,
54 }
55 }
56
57 #[allow(clippy::too_many_arguments)]
58 pub fn new(
59 vocab_size: usize,
60 hidden_size: usize,
61 num_hidden_layers: usize,
62 num_attention_heads: usize,
63 intermediate_size: usize,
64 hidden_act: candle_nn::Activation,
65 max_position_embeddings: usize,
66 type_vocab_size: usize,
67 initializer_range: f64,
68 layer_norm_eps: f64,
69 pad_token_id: usize,
70 position_embedding_type: PositionEmbeddingType,
71 ) -> Self {
72 Config {
73 vocab_size,
74 hidden_size,
75 num_hidden_layers,
76 num_attention_heads,
77 intermediate_size,
78 hidden_act,
79 max_position_embeddings,
80 type_vocab_size,
81 initializer_range,
82 layer_norm_eps,
83 pad_token_id,
84 position_embedding_type,
85 }
86 }
87}
88
89#[derive(Clone, Debug)]
90struct BertEmbeddings {
91 word_embeddings: Embedding,
92 token_type_embeddings: Embedding,
94 layer_norm: LayerNorm,
95 span: tracing::Span,
96}
97
98impl BertEmbeddings {
99 fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
100 let word_embeddings =
101 Embedding::new(cfg.vocab_size, cfg.hidden_size, vb.pp("word_embeddings"))?;
102 let token_type_embeddings = Embedding::new(
103 cfg.type_vocab_size,
104 cfg.hidden_size,
105 vb.pp("token_type_embeddings"),
106 )?;
107 let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
108 Ok(Self {
109 word_embeddings,
110 token_type_embeddings,
111 layer_norm,
112 span: tracing::span!(tracing::Level::TRACE, "embeddings"),
113 })
114 }
115}
116
117impl Module for BertEmbeddings {
118 fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
119 let _enter = self.span.enter();
120 let (b_size, seq_len) = input_ids.dims2()?;
121 let input_embeddings = self.word_embeddings.forward(input_ids)?;
122 let token_type_embeddings = Tensor::zeros(seq_len, DType::U32, input_ids.device())?
123 .broadcast_left(b_size)?
124 .apply(&self.token_type_embeddings)?;
125 let embeddings = (&input_embeddings + token_type_embeddings)?;
126 let embeddings = self.layer_norm.forward(&embeddings)?;
127 Ok(embeddings)
128 }
129}
130
131#[derive(Clone, Debug)]
132struct BertSelfAttention {
133 query: Linear,
134 key: Linear,
135 value: Linear,
136 num_attention_heads: usize,
137 attention_head_size: usize,
138 span: tracing::Span,
139 span_softmax: tracing::Span,
140}
141
142impl BertSelfAttention {
143 fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
144 let attention_head_size = cfg.hidden_size / cfg.num_attention_heads;
145 let all_head_size = cfg.num_attention_heads * attention_head_size;
146 let hidden_size = cfg.hidden_size;
147 let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
148 let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
149 let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
150 Ok(Self {
151 query,
152 key,
153 value,
154 num_attention_heads: cfg.num_attention_heads,
155 attention_head_size,
156 span: tracing::span!(tracing::Level::TRACE, "self-attn"),
157 span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
158 })
159 }
160
161 fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
162 let mut x_shape = xs.dims().to_vec();
163 x_shape.pop();
164 x_shape.push(self.num_attention_heads);
165 x_shape.push(self.attention_head_size);
166 xs.reshape(x_shape)?.transpose(1, 2)?.contiguous()
167 }
168
169 fn forward(&self, xs: &Tensor, bias: &Tensor) -> Result<Tensor> {
170 let _enter = self.span.enter();
171 let query_layer = self.query.forward(xs)?;
172 let key_layer = self.key.forward(xs)?;
173 let value_layer = self.value.forward(xs)?;
174
175 let query_layer = self.transpose_for_scores(&query_layer)?;
176 let key_layer = self.transpose_for_scores(&key_layer)?;
177 let value_layer = self.transpose_for_scores(&value_layer)?;
178
179 let attention_scores = query_layer.matmul(&key_layer.t()?)?;
180 let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
181 let attention_scores = attention_scores.broadcast_add(bias)?;
182 let attention_probs = {
183 let _enter_sm = self.span_softmax.enter();
184 candle_nn::ops::softmax_last_dim(&attention_scores)?
185 };
186 let context_layer = attention_probs.matmul(&value_layer)?;
187 let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
188 let context_layer = context_layer.flatten_from(D::Minus2)?;
189 Ok(context_layer)
190 }
191}
192
193#[derive(Clone, Debug)]
194struct BertSelfOutput {
195 dense: Linear,
196 layer_norm: LayerNorm,
197 span: tracing::Span,
198}
199
200impl BertSelfOutput {
201 fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
202 let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
203 let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
204 Ok(Self {
205 dense,
206 layer_norm,
207 span: tracing::span!(tracing::Level::TRACE, "self-out"),
208 })
209 }
210
211 fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
212 let _enter = self.span.enter();
213 let xs = self.dense.forward(xs)?;
214 self.layer_norm.forward(&(xs + input_tensor)?)
215 }
216}
217
218#[derive(Clone, Debug)]
219struct BertAttention {
220 self_attention: BertSelfAttention,
221 self_output: BertSelfOutput,
222 span: tracing::Span,
223}
224
225impl BertAttention {
226 fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
227 let self_attention = BertSelfAttention::new(vb.pp("self"), cfg)?;
228 let self_output = BertSelfOutput::new(vb.pp("output"), cfg)?;
229 Ok(Self {
230 self_attention,
231 self_output,
232 span: tracing::span!(tracing::Level::TRACE, "attn"),
233 })
234 }
235
236 fn forward(&self, xs: &Tensor, bias: &Tensor) -> Result<Tensor> {
237 let _enter = self.span.enter();
238 let self_outputs = self.self_attention.forward(xs, bias)?;
239 let attention_output = self.self_output.forward(&self_outputs, xs)?;
240 Ok(attention_output)
241 }
242}
243
244#[derive(Clone, Debug)]
245struct BertGLUMLP {
246 gated_layers: Linear,
247 act: candle_nn::Activation,
248 wo: Linear,
249 layernorm: LayerNorm,
250 intermediate_size: usize,
251}
252
253impl BertGLUMLP {
254 fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
255 let gated_layers = linear_no_bias(
256 cfg.hidden_size,
257 cfg.intermediate_size * 2,
258 vb.pp("gated_layers"),
259 )?;
260 let act = candle_nn::Activation::Gelu; let wo = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("wo"))?;
262 let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("layernorm"))?;
263 Ok(Self {
264 gated_layers,
265 act,
266 wo,
267 layernorm,
268 intermediate_size: cfg.intermediate_size,
269 })
270 }
271}
272
273impl Module for BertGLUMLP {
274 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
275 let residual = xs;
276 let xs = xs.apply(&self.gated_layers)?;
277 let gated = xs.narrow(D::Minus1, 0, self.intermediate_size)?;
278 let non_gated = xs.narrow(D::Minus1, self.intermediate_size, self.intermediate_size)?;
279 let xs = (gated.apply(&self.act) * non_gated)?.apply(&self.wo);
280 (xs + residual)?.apply(&self.layernorm)
281 }
282}
283
284#[derive(Clone, Debug)]
285struct BertLayer {
286 attention: BertAttention,
287 mlp: BertGLUMLP,
288 span: tracing::Span,
289}
290
291impl BertLayer {
292 fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
293 let attention = BertAttention::new(vb.pp("attention"), cfg)?;
294 let mlp = BertGLUMLP::new(vb.pp("mlp"), cfg)?;
295 Ok(Self {
296 attention,
297 mlp,
298 span: tracing::span!(tracing::Level::TRACE, "layer"),
299 })
300 }
301
302 fn forward(&self, xs: &Tensor, bias: &Tensor) -> Result<Tensor> {
303 let _enter = self.span.enter();
304 self.attention.forward(xs, bias)?.apply(&self.mlp)
305 }
306}
307
308fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
309 let n_heads = cfg.num_attention_heads;
310 let seq_len = cfg.max_position_embeddings;
311 let alibi_bias = Tensor::arange(0, seq_len as i64, &Device::Cpu)?.to_dtype(DType::F32)?;
312 let alibi_bias = {
313 let a1 = alibi_bias.reshape((1, seq_len))?;
314 let a2 = alibi_bias.reshape((seq_len, 1))?;
315 a1.broadcast_sub(&a2)?.abs()?.broadcast_left(n_heads)?
316 };
317 let mut n_heads2 = 1;
318 while n_heads2 < n_heads {
319 n_heads2 *= 2
320 }
321 let slopes = (1..=n_heads2)
322 .map(|v| -1f32 / 2f32.powf((v * 8) as f32 / n_heads2 as f32))
323 .collect::<Vec<_>>();
324 let slopes = if n_heads2 == n_heads {
325 slopes
326 } else {
327 slopes
328 .iter()
329 .skip(1)
330 .step_by(2)
331 .chain(slopes.iter().step_by(2))
332 .take(n_heads)
333 .cloned()
334 .collect::<Vec<f32>>()
335 };
336 let slopes = Tensor::new(slopes, &Device::Cpu)?.reshape((1, (), 1, 1))?;
337 alibi_bias.to_dtype(DType::F32)?.broadcast_mul(&slopes)
338}
339
340#[derive(Clone, Debug)]
341struct BertEncoder {
342 alibi: Tensor,
343 layers: Vec<BertLayer>,
344 span: tracing::Span,
345}
346
347impl BertEncoder {
348 fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
349 if cfg.position_embedding_type != PositionEmbeddingType::Alibi {
350 candle::bail!("only alibi is supported as a position-embedding-type")
351 }
352 let layers = (0..cfg.num_hidden_layers)
353 .map(|index| BertLayer::new(vb.pp(format!("layer.{index}")), cfg))
354 .collect::<Result<Vec<_>>>()?;
355 let span = tracing::span!(tracing::Level::TRACE, "encoder");
356 let alibi = build_alibi_bias(cfg)?.to_device(vb.device())?;
357 Ok(Self {
358 alibi,
359 layers,
360 span,
361 })
362 }
363}
364
365impl Module for BertEncoder {
366 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
367 let _enter = self.span.enter();
368 let seq_len = xs.dim(1)?;
369 let alibi_bias = self.alibi.i((.., .., ..seq_len, ..seq_len))?;
370 let mut xs = xs.clone();
371 for layer in self.layers.iter() {
372 xs = layer.forward(&xs, &alibi_bias)?
373 }
374 Ok(xs)
375 }
376}
377
378#[derive(Clone, Debug)]
379pub struct BertModel {
380 embeddings: BertEmbeddings,
381 encoder: BertEncoder,
382 pub device: Device,
383 span: tracing::Span,
384}
385
386impl BertModel {
387 pub fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
388 let embeddings = BertEmbeddings::new(vb.pp("embeddings"), cfg)?;
389 let encoder = BertEncoder::new(vb.pp("encoder"), cfg)?;
390 Ok(Self {
391 embeddings,
392 encoder,
393 device: vb.device().clone(),
394 span: tracing::span!(tracing::Level::TRACE, "model"),
395 })
396 }
397}
398
399impl Module for BertModel {
400 fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
401 let _enter = self.span.enter();
402 let embedding_output = self.embeddings.forward(input_ids)?;
403 let sequence_output = self.encoder.forward(&embedding_output)?;
404 Ok(sequence_output)
405 }
406}