candle_transformers/models/
bert.rs

1//! BERT (Bidirectional Encoder Representations from Transformers)
2//!
3//! Bert is a general large language model that can be used for various language tasks:
4//! - Compute sentence embeddings for a prompt.
5//! - Compute similarities between a set of sentences.
6//! - [Arxiv](https://arxiv.org/abs/1810.04805) "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding"
7//! - Upstream [Github repo](https://github.com/google-research/bert).
8//! - See bert in [candle-examples](https://github.com/huggingface/candle/tree/main/candle-examples/) for runnable code
9//!
10use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
11use candle::{DType, Device, Result, Tensor};
12use candle_nn::{embedding, Embedding, Module, VarBuilder};
13use serde::Deserialize;
14
15pub const DTYPE: DType = DType::F32;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
18#[serde(rename_all = "lowercase")]
19pub enum HiddenAct {
20    Gelu,
21    GeluApproximate,
22    Relu,
23}
24
25#[derive(Clone)]
26struct HiddenActLayer {
27    act: HiddenAct,
28    span: tracing::Span,
29}
30
31impl HiddenActLayer {
32    fn new(act: HiddenAct) -> Self {
33        let span = tracing::span!(tracing::Level::TRACE, "hidden-act");
34        Self { act, span }
35    }
36
37    fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
38        let _enter = self.span.enter();
39        match self.act {
40            // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
41            HiddenAct::Gelu => xs.gelu_erf(),
42            HiddenAct::GeluApproximate => xs.gelu(),
43            HiddenAct::Relu => xs.relu(),
44        }
45    }
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
49#[serde(rename_all = "lowercase")]
50pub enum PositionEmbeddingType {
51    #[default]
52    Absolute,
53}
54
55// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1
56#[derive(Debug, Clone, PartialEq, Deserialize)]
57pub struct Config {
58    pub vocab_size: usize,
59    pub hidden_size: usize,
60    pub num_hidden_layers: usize,
61    pub num_attention_heads: usize,
62    pub intermediate_size: usize,
63    pub hidden_act: HiddenAct,
64    pub hidden_dropout_prob: f64,
65    pub max_position_embeddings: usize,
66    pub type_vocab_size: usize,
67    pub initializer_range: f64,
68    pub layer_norm_eps: f64,
69    pub pad_token_id: usize,
70    #[serde(default)]
71    pub position_embedding_type: PositionEmbeddingType,
72    #[serde(default)]
73    pub use_cache: bool,
74    pub classifier_dropout: Option<f64>,
75    pub model_type: Option<String>,
76}
77
78impl Default for Config {
79    fn default() -> Self {
80        Self {
81            vocab_size: 30522,
82            hidden_size: 768,
83            num_hidden_layers: 12,
84            num_attention_heads: 12,
85            intermediate_size: 3072,
86            hidden_act: HiddenAct::Gelu,
87            hidden_dropout_prob: 0.1,
88            max_position_embeddings: 512,
89            type_vocab_size: 2,
90            initializer_range: 0.02,
91            layer_norm_eps: 1e-12,
92            pad_token_id: 0,
93            position_embedding_type: PositionEmbeddingType::Absolute,
94            use_cache: true,
95            classifier_dropout: None,
96            model_type: Some("bert".to_string()),
97        }
98    }
99}
100
101impl Config {
102    fn _all_mini_lm_l6_v2() -> Self {
103        // https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/blob/main/config.json
104        Self {
105            vocab_size: 30522,
106            hidden_size: 384,
107            num_hidden_layers: 6,
108            num_attention_heads: 12,
109            intermediate_size: 1536,
110            hidden_act: HiddenAct::Gelu,
111            hidden_dropout_prob: 0.1,
112            max_position_embeddings: 512,
113            type_vocab_size: 2,
114            initializer_range: 0.02,
115            layer_norm_eps: 1e-12,
116            pad_token_id: 0,
117            position_embedding_type: PositionEmbeddingType::Absolute,
118            use_cache: true,
119            classifier_dropout: None,
120            model_type: Some("bert".to_string()),
121        }
122    }
123}
124
125#[derive(Clone)]
126struct Dropout {
127    #[allow(dead_code)]
128    pr: f64,
129}
130
131impl Dropout {
132    fn new(pr: f64) -> Self {
133        Self { pr }
134    }
135}
136
137impl Module for Dropout {
138    fn forward(&self, x: &Tensor) -> Result<Tensor> {
139        // TODO
140        Ok(x.clone())
141    }
142}
143
144// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
145struct BertEmbeddings {
146    word_embeddings: Embedding,
147    position_embeddings: Option<Embedding>,
148    token_type_embeddings: Embedding,
149    layer_norm: LayerNorm,
150    dropout: Dropout,
151    span: tracing::Span,
152}
153
154impl BertEmbeddings {
155    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
156        let word_embeddings = embedding(
157            config.vocab_size,
158            config.hidden_size,
159            vb.pp("word_embeddings"),
160        )?;
161        let position_embeddings = embedding(
162            config.max_position_embeddings,
163            config.hidden_size,
164            vb.pp("position_embeddings"),
165        )?;
166        let token_type_embeddings = embedding(
167            config.type_vocab_size,
168            config.hidden_size,
169            vb.pp("token_type_embeddings"),
170        )?;
171        let layer_norm = layer_norm(
172            config.hidden_size,
173            config.layer_norm_eps,
174            vb.pp("LayerNorm"),
175        )?;
176        Ok(Self {
177            word_embeddings,
178            position_embeddings: Some(position_embeddings),
179            token_type_embeddings,
180            layer_norm,
181            dropout: Dropout::new(config.hidden_dropout_prob),
182            span: tracing::span!(tracing::Level::TRACE, "embeddings"),
183        })
184    }
185
186    fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
187        let _enter = self.span.enter();
188        let (_bsize, seq_len) = input_ids.dims2()?;
189        let input_embeddings = self.word_embeddings.forward(input_ids)?;
190        let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
191        let mut embeddings = (&input_embeddings + token_type_embeddings)?;
192        if let Some(position_embeddings) = &self.position_embeddings {
193            // TODO: Proper absolute positions?
194            let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
195            let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
196            embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?
197        }
198        let embeddings = self.layer_norm.forward(&embeddings)?;
199        let embeddings = self.dropout.forward(&embeddings)?;
200        Ok(embeddings)
201    }
202}
203
204#[derive(Clone)]
205struct BertSelfAttention {
206    query: Linear,
207    key: Linear,
208    value: Linear,
209    dropout: Dropout,
210    num_attention_heads: usize,
211    attention_head_size: usize,
212    span: tracing::Span,
213    span_softmax: tracing::Span,
214}
215
216impl BertSelfAttention {
217    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
218        let attention_head_size = config.hidden_size / config.num_attention_heads;
219        let all_head_size = config.num_attention_heads * attention_head_size;
220        let dropout = Dropout::new(config.hidden_dropout_prob);
221        let hidden_size = config.hidden_size;
222        let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
223        let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
224        let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
225        Ok(Self {
226            query,
227            key,
228            value,
229            dropout,
230            num_attention_heads: config.num_attention_heads,
231            attention_head_size,
232            span: tracing::span!(tracing::Level::TRACE, "self-attn"),
233            span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
234        })
235    }
236
237    fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
238        let mut new_x_shape = xs.dims().to_vec();
239        new_x_shape.pop();
240        new_x_shape.push(self.num_attention_heads);
241        new_x_shape.push(self.attention_head_size);
242        let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
243        xs.contiguous()
244    }
245
246    fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
247        let _enter = self.span.enter();
248        let query_layer = self.query.forward(hidden_states)?;
249        let key_layer = self.key.forward(hidden_states)?;
250        let value_layer = self.value.forward(hidden_states)?;
251
252        let query_layer = self.transpose_for_scores(&query_layer)?;
253        let key_layer = self.transpose_for_scores(&key_layer)?;
254        let value_layer = self.transpose_for_scores(&value_layer)?;
255
256        let attention_scores = query_layer.matmul(&key_layer.t()?)?;
257        let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
258        let attention_scores = attention_scores.broadcast_add(attention_mask)?;
259        let attention_probs = {
260            let _enter_sm = self.span_softmax.enter();
261            candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)?
262        };
263        let attention_probs = self.dropout.forward(&attention_probs)?;
264
265        let context_layer = attention_probs.matmul(&value_layer)?;
266        let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
267        let context_layer = context_layer.flatten_from(candle::D::Minus2)?;
268        Ok(context_layer)
269    }
270}
271
272#[derive(Clone)]
273struct BertSelfOutput {
274    dense: Linear,
275    layer_norm: LayerNorm,
276    dropout: Dropout,
277    span: tracing::Span,
278}
279
280impl BertSelfOutput {
281    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
282        let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
283        let layer_norm = layer_norm(
284            config.hidden_size,
285            config.layer_norm_eps,
286            vb.pp("LayerNorm"),
287        )?;
288        let dropout = Dropout::new(config.hidden_dropout_prob);
289        Ok(Self {
290            dense,
291            layer_norm,
292            dropout,
293            span: tracing::span!(tracing::Level::TRACE, "self-out"),
294        })
295    }
296
297    fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
298        let _enter = self.span.enter();
299        let hidden_states = self.dense.forward(hidden_states)?;
300        let hidden_states = self.dropout.forward(&hidden_states)?;
301        self.layer_norm.forward(&(hidden_states + input_tensor)?)
302    }
303}
304
305// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392
306#[derive(Clone)]
307struct BertAttention {
308    self_attention: BertSelfAttention,
309    self_output: BertSelfOutput,
310    span: tracing::Span,
311}
312
313impl BertAttention {
314    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
315        let self_attention = BertSelfAttention::load(vb.pp("self"), config)?;
316        let self_output = BertSelfOutput::load(vb.pp("output"), config)?;
317        Ok(Self {
318            self_attention,
319            self_output,
320            span: tracing::span!(tracing::Level::TRACE, "attn"),
321        })
322    }
323
324    fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
325        let _enter = self.span.enter();
326        let self_outputs = self.self_attention.forward(hidden_states, attention_mask)?;
327        let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
328        Ok(attention_output)
329    }
330}
331
332// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441
333#[derive(Clone)]
334struct BertIntermediate {
335    dense: Linear,
336    intermediate_act: HiddenActLayer,
337    span: tracing::Span,
338}
339
340impl BertIntermediate {
341    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
342        let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?;
343        Ok(Self {
344            dense,
345            intermediate_act: HiddenActLayer::new(config.hidden_act),
346            span: tracing::span!(tracing::Level::TRACE, "inter"),
347        })
348    }
349}
350
351impl Module for BertIntermediate {
352    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
353        let _enter = self.span.enter();
354        let hidden_states = self.dense.forward(hidden_states)?;
355        let ys = self.intermediate_act.forward(&hidden_states)?;
356        Ok(ys)
357    }
358}
359
360// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456
361#[derive(Clone)]
362struct BertOutput {
363    dense: Linear,
364    layer_norm: LayerNorm,
365    dropout: Dropout,
366    span: tracing::Span,
367}
368
369impl BertOutput {
370    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
371        let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?;
372        let layer_norm = layer_norm(
373            config.hidden_size,
374            config.layer_norm_eps,
375            vb.pp("LayerNorm"),
376        )?;
377        let dropout = Dropout::new(config.hidden_dropout_prob);
378        Ok(Self {
379            dense,
380            layer_norm,
381            dropout,
382            span: tracing::span!(tracing::Level::TRACE, "out"),
383        })
384    }
385
386    fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
387        let _enter = self.span.enter();
388        let hidden_states = self.dense.forward(hidden_states)?;
389        let hidden_states = self.dropout.forward(&hidden_states)?;
390        self.layer_norm.forward(&(hidden_states + input_tensor)?)
391    }
392}
393
394// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470
395#[derive(Clone)]
396pub struct BertLayer {
397    attention: BertAttention,
398    intermediate: BertIntermediate,
399    output: BertOutput,
400    span: tracing::Span,
401}
402
403impl BertLayer {
404    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
405        let attention = BertAttention::load(vb.pp("attention"), config)?;
406        let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?;
407        let output = BertOutput::load(vb.pp("output"), config)?;
408        Ok(Self {
409            attention,
410            intermediate,
411            output,
412            span: tracing::span!(tracing::Level::TRACE, "layer"),
413        })
414    }
415
416    fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
417        let _enter = self.span.enter();
418        let attention_output = self.attention.forward(hidden_states, attention_mask)?;
419        // TODO: Support cross-attention?
420        // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
421        // TODO: Support something similar to `apply_chunking_to_forward`?
422        let intermediate_output = self.intermediate.forward(&attention_output)?;
423        let layer_output = self
424            .output
425            .forward(&intermediate_output, &attention_output)?;
426        Ok(layer_output)
427    }
428}
429
430// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
431#[derive(Clone)]
432pub struct BertEncoder {
433    pub layers: Vec<BertLayer>,
434    span: tracing::Span,
435}
436
437impl BertEncoder {
438    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
439        let layers = (0..config.num_hidden_layers)
440            .map(|index| BertLayer::load(vb.pp(format!("layer.{index}")), config))
441            .collect::<Result<Vec<_>>>()?;
442        let span = tracing::span!(tracing::Level::TRACE, "encoder");
443        Ok(BertEncoder { layers, span })
444    }
445
446    pub fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
447        let _enter = self.span.enter();
448        let mut hidden_states = hidden_states.clone();
449        // Use a loop rather than a fold as it's easier to modify when adding debug/...
450        for layer in self.layers.iter() {
451            hidden_states = layer.forward(&hidden_states, attention_mask)?
452        }
453        Ok(hidden_states)
454    }
455}
456
457// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874
458pub struct BertModel {
459    embeddings: BertEmbeddings,
460    encoder: BertEncoder,
461    pub device: Device,
462    span: tracing::Span,
463}
464
465impl BertModel {
466    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
467        let (embeddings, encoder) = match (
468            BertEmbeddings::load(vb.pp("embeddings"), config),
469            BertEncoder::load(vb.pp("encoder"), config),
470        ) {
471            (Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
472            (Err(err), _) | (_, Err(err)) => {
473                if let Some(model_type) = &config.model_type {
474                    if let (Ok(embeddings), Ok(encoder)) = (
475                        BertEmbeddings::load(vb.pp(format!("{model_type}.embeddings")), config),
476                        BertEncoder::load(vb.pp(format!("{model_type}.encoder")), config),
477                    ) {
478                        (embeddings, encoder)
479                    } else {
480                        return Err(err);
481                    }
482                } else {
483                    return Err(err);
484                }
485            }
486        };
487        Ok(Self {
488            embeddings,
489            encoder,
490            device: vb.device().clone(),
491            span: tracing::span!(tracing::Level::TRACE, "model"),
492        })
493    }
494
495    pub fn forward(
496        &self,
497        input_ids: &Tensor,
498        token_type_ids: &Tensor,
499        attention_mask: Option<&Tensor>,
500    ) -> Result<Tensor> {
501        let _enter = self.span.enter();
502        let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
503        let attention_mask = match attention_mask {
504            Some(attention_mask) => attention_mask.clone(),
505            None => input_ids.ones_like()?,
506        };
507        // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
508        let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?;
509        let sequence_output = self.encoder.forward(&embedding_output, &attention_mask)?;
510        Ok(sequence_output)
511    }
512}
513
514fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<Tensor> {
515    let attention_mask = match attention_mask.rank() {
516        3 => attention_mask.unsqueeze(1)?,
517        2 => attention_mask.unsqueeze(1)?.unsqueeze(1)?,
518        _ => candle::bail!("Wrong shape for input_ids or attention_mask"),
519    };
520    let attention_mask = attention_mask.to_dtype(dtype)?;
521    // torch.finfo(dtype).min
522    (attention_mask.ones_like()? - &attention_mask)?
523        .broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
524}
525
526//https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L752-L766
527struct BertPredictionHeadTransform {
528    dense: Linear,
529    activation: HiddenActLayer,
530    layer_norm: LayerNorm,
531}
532
533impl BertPredictionHeadTransform {
534    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
535        let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
536        let activation = HiddenActLayer::new(config.hidden_act);
537        let layer_norm = layer_norm(
538            config.hidden_size,
539            config.layer_norm_eps,
540            vb.pp("LayerNorm"),
541        )?;
542        Ok(Self {
543            dense,
544            activation,
545            layer_norm,
546        })
547    }
548}
549
550impl Module for BertPredictionHeadTransform {
551    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
552        let hidden_states = self
553            .activation
554            .forward(&self.dense.forward(hidden_states)?)?;
555        self.layer_norm.forward(&hidden_states)
556    }
557}
558
559// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L769C1-L790C1
560pub struct BertLMPredictionHead {
561    transform: BertPredictionHeadTransform,
562    decoder: Linear,
563}
564
565impl BertLMPredictionHead {
566    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
567        let transform = BertPredictionHeadTransform::load(vb.pp("transform"), config)?;
568        let decoder = linear(config.hidden_size, config.vocab_size, vb.pp("decoder"))?;
569        Ok(Self { transform, decoder })
570    }
571}
572
573impl Module for BertLMPredictionHead {
574    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
575        self.decoder
576            .forward(&self.transform.forward(hidden_states)?)
577    }
578}
579
580// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L792
581pub struct BertOnlyMLMHead {
582    predictions: BertLMPredictionHead,
583}
584
585impl BertOnlyMLMHead {
586    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
587        let predictions = BertLMPredictionHead::load(vb.pp("predictions"), config)?;
588        Ok(Self { predictions })
589    }
590}
591
592impl Module for BertOnlyMLMHead {
593    fn forward(&self, sequence_output: &Tensor) -> Result<Tensor> {
594        self.predictions.forward(sequence_output)
595    }
596}
597
598pub struct BertForMaskedLM {
599    bert: BertModel,
600    cls: BertOnlyMLMHead,
601}
602
603impl BertForMaskedLM {
604    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
605        let bert = BertModel::load(vb.pp("bert"), config)?;
606        let cls = BertOnlyMLMHead::load(vb.pp("cls"), config)?;
607        Ok(Self { bert, cls })
608    }
609
610    pub fn forward(
611        &self,
612        input_ids: &Tensor,
613        token_type_ids: &Tensor,
614        attention_mask: Option<&Tensor>,
615    ) -> Result<Tensor> {
616        let sequence_output = self
617            .bert
618            .forward(input_ids, token_type_ids, attention_mask)?;
619        self.cls.forward(&sequence_output)
620    }
621}