candle_transformers/models/
distilbert.rs

1//! Implementation of DistilBert, a distilled version of BERT.
2//!
3//! See:
4//! - ["DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter"](https://arxiv.org/abs/1910.01108)
5//!
6use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
7use candle::{DType, Device, Result, Tensor};
8use candle_nn::{Embedding, Module, VarBuilder};
9use serde::Deserialize;
10
11pub const DTYPE: DType = DType::F32;
12
13fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
14    let shape = mask.shape();
15    let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
16    let m = mask.where_cond(&on_true, on_false)?;
17    Ok(m)
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
21#[serde(rename_all = "lowercase")]
22enum HiddenAct {
23    Gelu,
24    Relu,
25}
26
27struct HiddenActLayer {
28    act: HiddenAct,
29    span: tracing::Span,
30}
31
32impl HiddenActLayer {
33    fn new(act: HiddenAct) -> Self {
34        let span = tracing::span!(tracing::Level::TRACE, "hidden-act");
35        Self { act, span }
36    }
37}
38
39impl Module for HiddenActLayer {
40    fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
41        let _enter = self.span.enter();
42        match self.act {
43            // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
44            HiddenAct::Gelu => xs.gelu(),
45            HiddenAct::Relu => xs.relu(),
46        }
47    }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
51#[serde(rename_all = "lowercase")]
52enum PositionEmbeddingType {
53    #[default]
54    Absolute,
55}
56
57#[derive(Debug, Clone, PartialEq, Deserialize)]
58pub struct Config {
59    vocab_size: usize,
60    dim: usize,
61    n_layers: usize,
62    n_heads: usize,
63    hidden_dim: usize,
64    activation: HiddenAct,
65    max_position_embeddings: usize,
66    initializer_range: f64,
67    pad_token_id: usize,
68    #[serde(default)]
69    position_embedding_type: PositionEmbeddingType,
70    #[serde(default)]
71    use_cache: bool,
72    model_type: Option<String>,
73}
74
75impl Default for Config {
76    fn default() -> Self {
77        Self {
78            vocab_size: 30522,
79            dim: 768,
80            n_layers: 12,
81            n_heads: 12,
82            hidden_dim: 3072,
83            activation: HiddenAct::Gelu,
84            max_position_embeddings: 512,
85            initializer_range: 0.02,
86            pad_token_id: 0,
87            position_embedding_type: PositionEmbeddingType::Absolute,
88            use_cache: true,
89            model_type: Some("distilbert".to_string()),
90        }
91    }
92}
93
94struct Embeddings {
95    word_embeddings: Embedding,
96    position_embeddings: Embedding,
97    layer_norm: LayerNorm,
98    span: tracing::Span,
99}
100
101impl Embeddings {
102    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
103        let word_embeddings =
104            candle_nn::embedding(config.vocab_size, config.dim, vb.pp("word_embeddings"))?;
105        let position_embeddings = candle_nn::embedding(
106            config.max_position_embeddings,
107            config.dim,
108            vb.pp("position_embeddings"),
109        )?;
110        let layer_norm = layer_norm(config.dim, 1e-12, vb.pp("LayerNorm"))?;
111        Ok(Self {
112            word_embeddings,
113            position_embeddings,
114            layer_norm,
115            span: tracing::span!(tracing::Level::TRACE, "embeddings"),
116        })
117    }
118
119    fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
120        let _enter = self.span.enter();
121        let (_bsize, seq_len) = input_ids.dims2()?;
122        let input_embeddings = self.word_embeddings.forward(input_ids)?;
123        let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
124        let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
125        let embeddings =
126            input_embeddings.broadcast_add(&self.position_embeddings.forward(&position_ids)?)?;
127
128        let embeddings = self.layer_norm.forward(&embeddings)?;
129        Ok(embeddings)
130    }
131}
132
133struct MultiHeadSelfAttention {
134    q_lin: Linear,
135    k_lin: Linear,
136    v_lin: Linear,
137    out_lin: Linear,
138    n_heads: usize,
139    attention_head_size: usize,
140    span: tracing::Span,
141}
142
143impl MultiHeadSelfAttention {
144    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
145        let attention_head_size = config.dim / config.n_heads;
146        let all_head_size = config.n_heads * attention_head_size;
147        let dim = config.dim;
148        let q_lin = linear(dim, all_head_size, vb.pp("q_lin"))?;
149        let v_lin = linear(dim, all_head_size, vb.pp("v_lin"))?;
150        let k_lin = linear(dim, all_head_size, vb.pp("k_lin"))?;
151        let out_lin = linear(all_head_size, dim, vb.pp("out_lin"))?;
152        Ok(Self {
153            q_lin,
154            k_lin,
155            v_lin,
156            out_lin,
157            n_heads: config.n_heads,
158            attention_head_size,
159            span: tracing::span!(tracing::Level::TRACE, "attention"),
160        })
161    }
162}
163
164impl MultiHeadSelfAttention {
165    fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
166        let _enter = self.span.enter();
167        let (bs, q_length, _dim) = hidden_states.dims3()?;
168
169        let dim_per_head = self.attention_head_size;
170        let q = self.q_lin.forward(hidden_states)?;
171        let k = self.k_lin.forward(hidden_states)?;
172        let v = self.v_lin.forward(hidden_states)?;
173
174        let q = q
175            .reshape((bs, q_length, self.n_heads, dim_per_head))?
176            .transpose(1, 2)?;
177        let k = k
178            .reshape((bs, q_length, self.n_heads, dim_per_head))?
179            .transpose(1, 2)?;
180        let v = v
181            .reshape((bs, q_length, self.n_heads, dim_per_head))?
182            .transpose(1, 2)?;
183
184        let q: Tensor = (q / (dim_per_head as f64).sqrt())?;
185        let scores = q.matmul(&k.transpose(2, 3)?.contiguous()?)?;
186        let mask = attention_mask.broadcast_as(scores.shape())?;
187
188        let scores = masked_fill(&scores.to_dtype(DType::F32)?, &mask, f32::NEG_INFINITY)?;
189        let weights = candle_nn::ops::softmax(&scores, candle::D::Minus1)?;
190
191        let context = weights.matmul(&v.contiguous()?)?;
192        let context = context
193            .transpose(1, 2)?
194            .reshape((bs, q_length, self.n_heads * dim_per_head))?
195            .contiguous()?;
196        let context = self.out_lin.forward(&context)?;
197
198        Ok(context)
199    }
200}
201
202#[allow(clippy::upper_case_acronyms)]
203struct FFN {
204    lin1: Linear,
205    lin2: Linear,
206    activation: HiddenActLayer,
207    span: tracing::Span,
208}
209
210impl FFN {
211    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
212        let lin1 = linear(config.dim, config.hidden_dim, vb.pp("lin1"))?;
213        let lin2 = linear(config.hidden_dim, config.dim, vb.pp("lin2"))?;
214        Ok(Self {
215            lin1,
216            lin2,
217            activation: HiddenActLayer::new(config.activation),
218            span: tracing::span!(tracing::Level::TRACE, "ffn"),
219        })
220    }
221}
222
223impl Module for FFN {
224    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
225        let _enter = self.span.enter();
226        hidden_states
227            .apply(&self.lin1)?
228            .apply(&self.activation)?
229            .apply(&self.lin2)
230    }
231}
232
233struct TransformerBlock {
234    attention: MultiHeadSelfAttention,
235    sa_layer_norm: LayerNorm,
236    ffn: FFN,
237    output_layer_norm: LayerNorm,
238    span: tracing::Span,
239}
240
241impl TransformerBlock {
242    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
243        let attention = MultiHeadSelfAttention::load(vb.pp("attention"), config)?;
244        let sa_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("sa_layer_norm"))?;
245        let ffn = FFN::load(vb.pp("ffn"), config)?;
246        let output_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("output_layer_norm"))?;
247        Ok(Self {
248            attention,
249            sa_layer_norm,
250            ffn,
251            output_layer_norm,
252            span: tracing::span!(tracing::Level::TRACE, "layer"),
253        })
254    }
255}
256
257impl TransformerBlock {
258    fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
259        let _enter = self.span.enter();
260        let sa_output = self.attention.forward(hidden_states, attention_mask)?;
261        // TODO: Support cross-attention?
262        // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
263        // TODO: Support something similar to `apply_chunking_to_forward`?
264        let sa_output = sa_output.broadcast_add(hidden_states)?;
265        let sa_output = self.sa_layer_norm.forward(&sa_output)?;
266
267        let ffn_output = self.ffn.forward(&sa_output)?;
268        let ffn_output = (&ffn_output + sa_output)?;
269        let output = self.output_layer_norm.forward(&ffn_output)?;
270        Ok(output)
271    }
272}
273
274// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
275struct Transformer {
276    layers: Vec<TransformerBlock>,
277    span: tracing::Span,
278}
279
280impl Transformer {
281    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
282        let layers = (0..config.n_layers)
283            .map(|index| TransformerBlock::load(vb.pp(format!("layer.{index}")), config))
284            .collect::<Result<Vec<_>>>()?;
285        let span = tracing::span!(tracing::Level::TRACE, "encoder");
286        Ok(Transformer { layers, span })
287    }
288}
289
290impl Transformer {
291    fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
292        let _enter = self.span.enter();
293        let mut hidden_states = hidden_states.clone();
294        // Use a loop rather than a fold as it's easier to modify when adding debug/...
295        for layer in self.layers.iter() {
296            hidden_states = layer.forward(&hidden_states, attention_mask)?;
297        }
298        Ok(hidden_states)
299    }
300}
301
302pub struct DistilBertModel {
303    embeddings: Embeddings,
304    transformer: Transformer,
305    pub device: Device,
306    span: tracing::Span,
307}
308
309impl DistilBertModel {
310    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
311        let (embeddings, transformer) = match (
312            Embeddings::load(vb.pp("embeddings"), config),
313            Transformer::load(vb.pp("transformer"), config),
314        ) {
315            (Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
316            (Err(err), _) | (_, Err(err)) => {
317                if let Some(model_type) = &config.model_type {
318                    if let (Ok(embeddings), Ok(encoder)) = (
319                        Embeddings::load(vb.pp(format!("{model_type}.embeddings")), config),
320                        Transformer::load(vb.pp(format!("{model_type}.transformer")), config),
321                    ) {
322                        (embeddings, encoder)
323                    } else {
324                        return Err(err);
325                    }
326                } else {
327                    return Err(err);
328                }
329            }
330        };
331        Ok(Self {
332            embeddings,
333            transformer,
334            device: vb.device().clone(),
335            span: tracing::span!(tracing::Level::TRACE, "model"),
336        })
337    }
338
339    pub fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
340        let _enter = self.span.enter();
341        let embedding_output = self.embeddings.forward(input_ids)?;
342        let sequence_output = self
343            .transformer
344            .forward(&embedding_output, attention_mask)?;
345        Ok(sequence_output)
346    }
347}