1use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
11use candle::{DType, Device, Result, Tensor};
12use candle_nn::{embedding, Embedding, Module, VarBuilder};
13use serde::Deserialize;
14
15pub const DTYPE: DType = DType::F32;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
18#[serde(rename_all = "lowercase")]
19pub enum HiddenAct {
20 Gelu,
21 GeluApproximate,
22 Relu,
23}
24
25#[derive(Clone)]
26struct HiddenActLayer {
27 act: HiddenAct,
28 span: tracing::Span,
29}
30
31impl HiddenActLayer {
32 fn new(act: HiddenAct) -> Self {
33 let span = tracing::span!(tracing::Level::TRACE, "hidden-act");
34 Self { act, span }
35 }
36
37 fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
38 let _enter = self.span.enter();
39 match self.act {
40 HiddenAct::Gelu => xs.gelu_erf(),
42 HiddenAct::GeluApproximate => xs.gelu(),
43 HiddenAct::Relu => xs.relu(),
44 }
45 }
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
49#[serde(rename_all = "lowercase")]
50pub enum PositionEmbeddingType {
51 #[default]
52 Absolute,
53}
54
55#[derive(Debug, Clone, PartialEq, Deserialize)]
57pub struct Config {
58 pub vocab_size: usize,
59 pub hidden_size: usize,
60 pub num_hidden_layers: usize,
61 pub num_attention_heads: usize,
62 pub intermediate_size: usize,
63 pub hidden_act: HiddenAct,
64 pub hidden_dropout_prob: f64,
65 pub max_position_embeddings: usize,
66 pub type_vocab_size: usize,
67 pub initializer_range: f64,
68 pub layer_norm_eps: f64,
69 pub pad_token_id: usize,
70 #[serde(default)]
71 pub position_embedding_type: PositionEmbeddingType,
72 #[serde(default)]
73 pub use_cache: bool,
74 pub classifier_dropout: Option<f64>,
75 pub model_type: Option<String>,
76}
77
78impl Default for Config {
79 fn default() -> Self {
80 Self {
81 vocab_size: 30522,
82 hidden_size: 768,
83 num_hidden_layers: 12,
84 num_attention_heads: 12,
85 intermediate_size: 3072,
86 hidden_act: HiddenAct::Gelu,
87 hidden_dropout_prob: 0.1,
88 max_position_embeddings: 512,
89 type_vocab_size: 2,
90 initializer_range: 0.02,
91 layer_norm_eps: 1e-12,
92 pad_token_id: 0,
93 position_embedding_type: PositionEmbeddingType::Absolute,
94 use_cache: true,
95 classifier_dropout: None,
96 model_type: Some("bert".to_string()),
97 }
98 }
99}
100
101impl Config {
102 fn _all_mini_lm_l6_v2() -> Self {
103 Self {
105 vocab_size: 30522,
106 hidden_size: 384,
107 num_hidden_layers: 6,
108 num_attention_heads: 12,
109 intermediate_size: 1536,
110 hidden_act: HiddenAct::Gelu,
111 hidden_dropout_prob: 0.1,
112 max_position_embeddings: 512,
113 type_vocab_size: 2,
114 initializer_range: 0.02,
115 layer_norm_eps: 1e-12,
116 pad_token_id: 0,
117 position_embedding_type: PositionEmbeddingType::Absolute,
118 use_cache: true,
119 classifier_dropout: None,
120 model_type: Some("bert".to_string()),
121 }
122 }
123}
124
125#[derive(Clone)]
126struct Dropout {
127 #[allow(dead_code)]
128 pr: f64,
129}
130
131impl Dropout {
132 fn new(pr: f64) -> Self {
133 Self { pr }
134 }
135}
136
137impl Module for Dropout {
138 fn forward(&self, x: &Tensor) -> Result<Tensor> {
139 Ok(x.clone())
141 }
142}
143
144struct BertEmbeddings {
146 word_embeddings: Embedding,
147 position_embeddings: Option<Embedding>,
148 token_type_embeddings: Embedding,
149 layer_norm: LayerNorm,
150 dropout: Dropout,
151 span: tracing::Span,
152}
153
154impl BertEmbeddings {
155 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
156 let word_embeddings = embedding(
157 config.vocab_size,
158 config.hidden_size,
159 vb.pp("word_embeddings"),
160 )?;
161 let position_embeddings = embedding(
162 config.max_position_embeddings,
163 config.hidden_size,
164 vb.pp("position_embeddings"),
165 )?;
166 let token_type_embeddings = embedding(
167 config.type_vocab_size,
168 config.hidden_size,
169 vb.pp("token_type_embeddings"),
170 )?;
171 let layer_norm = layer_norm(
172 config.hidden_size,
173 config.layer_norm_eps,
174 vb.pp("LayerNorm"),
175 )?;
176 Ok(Self {
177 word_embeddings,
178 position_embeddings: Some(position_embeddings),
179 token_type_embeddings,
180 layer_norm,
181 dropout: Dropout::new(config.hidden_dropout_prob),
182 span: tracing::span!(tracing::Level::TRACE, "embeddings"),
183 })
184 }
185
186 fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
187 let _enter = self.span.enter();
188 let (_bsize, seq_len) = input_ids.dims2()?;
189 let input_embeddings = self.word_embeddings.forward(input_ids)?;
190 let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
191 let mut embeddings = (&input_embeddings + token_type_embeddings)?;
192 if let Some(position_embeddings) = &self.position_embeddings {
193 let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
195 let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
196 embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?
197 }
198 let embeddings = self.layer_norm.forward(&embeddings)?;
199 let embeddings = self.dropout.forward(&embeddings)?;
200 Ok(embeddings)
201 }
202}
203
204#[derive(Clone)]
205struct BertSelfAttention {
206 query: Linear,
207 key: Linear,
208 value: Linear,
209 dropout: Dropout,
210 num_attention_heads: usize,
211 attention_head_size: usize,
212 span: tracing::Span,
213 span_softmax: tracing::Span,
214}
215
216impl BertSelfAttention {
217 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
218 let attention_head_size = config.hidden_size / config.num_attention_heads;
219 let all_head_size = config.num_attention_heads * attention_head_size;
220 let dropout = Dropout::new(config.hidden_dropout_prob);
221 let hidden_size = config.hidden_size;
222 let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
223 let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
224 let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
225 Ok(Self {
226 query,
227 key,
228 value,
229 dropout,
230 num_attention_heads: config.num_attention_heads,
231 attention_head_size,
232 span: tracing::span!(tracing::Level::TRACE, "self-attn"),
233 span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
234 })
235 }
236
237 fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
238 let mut new_x_shape = xs.dims().to_vec();
239 new_x_shape.pop();
240 new_x_shape.push(self.num_attention_heads);
241 new_x_shape.push(self.attention_head_size);
242 let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
243 xs.contiguous()
244 }
245
246 fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
247 let _enter = self.span.enter();
248 let query_layer = self.query.forward(hidden_states)?;
249 let key_layer = self.key.forward(hidden_states)?;
250 let value_layer = self.value.forward(hidden_states)?;
251
252 let query_layer = self.transpose_for_scores(&query_layer)?;
253 let key_layer = self.transpose_for_scores(&key_layer)?;
254 let value_layer = self.transpose_for_scores(&value_layer)?;
255
256 let attention_scores = query_layer.matmul(&key_layer.t()?)?;
257 let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
258 let attention_scores = attention_scores.broadcast_add(attention_mask)?;
259 let attention_probs = {
260 let _enter_sm = self.span_softmax.enter();
261 candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)?
262 };
263 let attention_probs = self.dropout.forward(&attention_probs)?;
264
265 let context_layer = attention_probs.matmul(&value_layer)?;
266 let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
267 let context_layer = context_layer.flatten_from(candle::D::Minus2)?;
268 Ok(context_layer)
269 }
270}
271
272#[derive(Clone)]
273struct BertSelfOutput {
274 dense: Linear,
275 layer_norm: LayerNorm,
276 dropout: Dropout,
277 span: tracing::Span,
278}
279
280impl BertSelfOutput {
281 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
282 let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
283 let layer_norm = layer_norm(
284 config.hidden_size,
285 config.layer_norm_eps,
286 vb.pp("LayerNorm"),
287 )?;
288 let dropout = Dropout::new(config.hidden_dropout_prob);
289 Ok(Self {
290 dense,
291 layer_norm,
292 dropout,
293 span: tracing::span!(tracing::Level::TRACE, "self-out"),
294 })
295 }
296
297 fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
298 let _enter = self.span.enter();
299 let hidden_states = self.dense.forward(hidden_states)?;
300 let hidden_states = self.dropout.forward(&hidden_states)?;
301 self.layer_norm.forward(&(hidden_states + input_tensor)?)
302 }
303}
304
305#[derive(Clone)]
307struct BertAttention {
308 self_attention: BertSelfAttention,
309 self_output: BertSelfOutput,
310 span: tracing::Span,
311}
312
313impl BertAttention {
314 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
315 let self_attention = BertSelfAttention::load(vb.pp("self"), config)?;
316 let self_output = BertSelfOutput::load(vb.pp("output"), config)?;
317 Ok(Self {
318 self_attention,
319 self_output,
320 span: tracing::span!(tracing::Level::TRACE, "attn"),
321 })
322 }
323
324 fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
325 let _enter = self.span.enter();
326 let self_outputs = self.self_attention.forward(hidden_states, attention_mask)?;
327 let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
328 Ok(attention_output)
329 }
330}
331
332#[derive(Clone)]
334struct BertIntermediate {
335 dense: Linear,
336 intermediate_act: HiddenActLayer,
337 span: tracing::Span,
338}
339
340impl BertIntermediate {
341 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
342 let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?;
343 Ok(Self {
344 dense,
345 intermediate_act: HiddenActLayer::new(config.hidden_act),
346 span: tracing::span!(tracing::Level::TRACE, "inter"),
347 })
348 }
349}
350
351impl Module for BertIntermediate {
352 fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
353 let _enter = self.span.enter();
354 let hidden_states = self.dense.forward(hidden_states)?;
355 let ys = self.intermediate_act.forward(&hidden_states)?;
356 Ok(ys)
357 }
358}
359
360#[derive(Clone)]
362struct BertOutput {
363 dense: Linear,
364 layer_norm: LayerNorm,
365 dropout: Dropout,
366 span: tracing::Span,
367}
368
369impl BertOutput {
370 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
371 let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?;
372 let layer_norm = layer_norm(
373 config.hidden_size,
374 config.layer_norm_eps,
375 vb.pp("LayerNorm"),
376 )?;
377 let dropout = Dropout::new(config.hidden_dropout_prob);
378 Ok(Self {
379 dense,
380 layer_norm,
381 dropout,
382 span: tracing::span!(tracing::Level::TRACE, "out"),
383 })
384 }
385
386 fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
387 let _enter = self.span.enter();
388 let hidden_states = self.dense.forward(hidden_states)?;
389 let hidden_states = self.dropout.forward(&hidden_states)?;
390 self.layer_norm.forward(&(hidden_states + input_tensor)?)
391 }
392}
393
394#[derive(Clone)]
396pub struct BertLayer {
397 attention: BertAttention,
398 intermediate: BertIntermediate,
399 output: BertOutput,
400 span: tracing::Span,
401}
402
403impl BertLayer {
404 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
405 let attention = BertAttention::load(vb.pp("attention"), config)?;
406 let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?;
407 let output = BertOutput::load(vb.pp("output"), config)?;
408 Ok(Self {
409 attention,
410 intermediate,
411 output,
412 span: tracing::span!(tracing::Level::TRACE, "layer"),
413 })
414 }
415
416 fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
417 let _enter = self.span.enter();
418 let attention_output = self.attention.forward(hidden_states, attention_mask)?;
419 let intermediate_output = self.intermediate.forward(&attention_output)?;
423 let layer_output = self
424 .output
425 .forward(&intermediate_output, &attention_output)?;
426 Ok(layer_output)
427 }
428}
429
430#[derive(Clone)]
432pub struct BertEncoder {
433 pub layers: Vec<BertLayer>,
434 span: tracing::Span,
435}
436
437impl BertEncoder {
438 pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
439 let layers = (0..config.num_hidden_layers)
440 .map(|index| BertLayer::load(vb.pp(format!("layer.{index}")), config))
441 .collect::<Result<Vec<_>>>()?;
442 let span = tracing::span!(tracing::Level::TRACE, "encoder");
443 Ok(BertEncoder { layers, span })
444 }
445
446 pub fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
447 let _enter = self.span.enter();
448 let mut hidden_states = hidden_states.clone();
449 for layer in self.layers.iter() {
451 hidden_states = layer.forward(&hidden_states, attention_mask)?
452 }
453 Ok(hidden_states)
454 }
455}
456
457pub struct BertModel {
459 embeddings: BertEmbeddings,
460 encoder: BertEncoder,
461 pub device: Device,
462 span: tracing::Span,
463}
464
465impl BertModel {
466 pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
467 let (embeddings, encoder) = match (
468 BertEmbeddings::load(vb.pp("embeddings"), config),
469 BertEncoder::load(vb.pp("encoder"), config),
470 ) {
471 (Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
472 (Err(err), _) | (_, Err(err)) => {
473 if let Some(model_type) = &config.model_type {
474 if let (Ok(embeddings), Ok(encoder)) = (
475 BertEmbeddings::load(vb.pp(format!("{model_type}.embeddings")), config),
476 BertEncoder::load(vb.pp(format!("{model_type}.encoder")), config),
477 ) {
478 (embeddings, encoder)
479 } else {
480 return Err(err);
481 }
482 } else {
483 return Err(err);
484 }
485 }
486 };
487 Ok(Self {
488 embeddings,
489 encoder,
490 device: vb.device().clone(),
491 span: tracing::span!(tracing::Level::TRACE, "model"),
492 })
493 }
494
495 pub fn forward(
496 &self,
497 input_ids: &Tensor,
498 token_type_ids: &Tensor,
499 attention_mask: Option<&Tensor>,
500 ) -> Result<Tensor> {
501 let _enter = self.span.enter();
502 let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
503 let attention_mask = match attention_mask {
504 Some(attention_mask) => attention_mask.clone(),
505 None => input_ids.ones_like()?,
506 };
507 let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?;
509 let sequence_output = self.encoder.forward(&embedding_output, &attention_mask)?;
510 Ok(sequence_output)
511 }
512}
513
514fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<Tensor> {
515 let attention_mask = match attention_mask.rank() {
516 3 => attention_mask.unsqueeze(1)?,
517 2 => attention_mask.unsqueeze(1)?.unsqueeze(1)?,
518 _ => candle::bail!("Wrong shape for input_ids or attention_mask"),
519 };
520 let attention_mask = attention_mask.to_dtype(dtype)?;
521 (attention_mask.ones_like()? - &attention_mask)?
523 .broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
524}
525
526struct BertPredictionHeadTransform {
528 dense: Linear,
529 activation: HiddenActLayer,
530 layer_norm: LayerNorm,
531}
532
533impl BertPredictionHeadTransform {
534 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
535 let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
536 let activation = HiddenActLayer::new(config.hidden_act);
537 let layer_norm = layer_norm(
538 config.hidden_size,
539 config.layer_norm_eps,
540 vb.pp("LayerNorm"),
541 )?;
542 Ok(Self {
543 dense,
544 activation,
545 layer_norm,
546 })
547 }
548}
549
550impl Module for BertPredictionHeadTransform {
551 fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
552 let hidden_states = self
553 .activation
554 .forward(&self.dense.forward(hidden_states)?)?;
555 self.layer_norm.forward(&hidden_states)
556 }
557}
558
559pub struct BertLMPredictionHead {
561 transform: BertPredictionHeadTransform,
562 decoder: Linear,
563}
564
565impl BertLMPredictionHead {
566 pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
567 let transform = BertPredictionHeadTransform::load(vb.pp("transform"), config)?;
568 let decoder = linear(config.hidden_size, config.vocab_size, vb.pp("decoder"))?;
569 Ok(Self { transform, decoder })
570 }
571}
572
573impl Module for BertLMPredictionHead {
574 fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
575 self.decoder
576 .forward(&self.transform.forward(hidden_states)?)
577 }
578}
579
580pub struct BertOnlyMLMHead {
582 predictions: BertLMPredictionHead,
583}
584
585impl BertOnlyMLMHead {
586 pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
587 let predictions = BertLMPredictionHead::load(vb.pp("predictions"), config)?;
588 Ok(Self { predictions })
589 }
590}
591
592impl Module for BertOnlyMLMHead {
593 fn forward(&self, sequence_output: &Tensor) -> Result<Tensor> {
594 self.predictions.forward(sequence_output)
595 }
596}
597
598pub struct BertForMaskedLM {
599 bert: BertModel,
600 cls: BertOnlyMLMHead,
601}
602
603impl BertForMaskedLM {
604 pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
605 let bert = BertModel::load(vb.pp("bert"), config)?;
606 let cls = BertOnlyMLMHead::load(vb.pp("cls"), config)?;
607 Ok(Self { bert, cls })
608 }
609
610 pub fn forward(
611 &self,
612 input_ids: &Tensor,
613 token_type_ids: &Tensor,
614 attention_mask: Option<&Tensor>,
615 ) -> Result<Tensor> {
616 let sequence_output = self
617 .bert
618 .forward(input_ids, token_type_ids, attention_mask)?;
619 self.cls.forward(&sequence_output)
620 }
621}