candle_transformers/models/
modernbert.rs

1//! ModernBERT
2//!
3//! ModernBERT is a modernized bidirectional encoder-only Transformer model.
4//! - [Arxiv](https://arxiv.org/abs/2412.13663) "Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference"
5//! - Upstream [Github repo](https://github.com/AnswerDotAI/ModernBERT).
6//! - See modernbert in [candle-examples](https://github.com/huggingface/candle/tree/main/candle-examples/) for runnable code
7//!
8
9use candle::{DType, Device, IndexOp, Result, Tensor, D};
10use candle_nn::{
11    embedding, layer_norm_no_bias, linear, linear_no_bias, ops::softmax, Embedding, LayerNorm,
12    Linear, Module, VarBuilder,
13};
14use serde::Deserialize;
15
16use core::f32;
17use std::collections::HashMap;
18use std::sync::Arc;
19
20#[derive(Debug, Clone, PartialEq, Deserialize)]
21pub struct Config {
22    pub vocab_size: usize,
23    pub hidden_size: usize,
24    pub num_hidden_layers: usize,
25    pub num_attention_heads: usize,
26    pub intermediate_size: usize,
27    pub max_position_embeddings: usize,
28    pub layer_norm_eps: f64,
29    pub pad_token_id: u32,
30    pub global_attn_every_n_layers: usize,
31    pub global_rope_theta: f64,
32    pub local_attention: usize,
33    pub local_rope_theta: f64,
34    #[serde(default)]
35    #[serde(flatten)]
36    pub classifier_config: Option<ClassifierConfig>,
37}
38
39#[derive(Debug, Clone, Deserialize, PartialEq, Copy, Default)]
40#[serde(rename_all = "lowercase")]
41pub enum ClassifierPooling {
42    #[default]
43    CLS,
44    MEAN,
45}
46
47#[derive(Debug, Clone, PartialEq, Deserialize)]
48pub struct ClassifierConfig {
49    pub id2label: HashMap<String, String>,
50    pub label2id: HashMap<String, String>,
51    pub classifier_pooling: ClassifierPooling,
52}
53
54#[derive(Debug, Clone)]
55struct RotaryEmbedding {
56    sin: Tensor,
57    cos: Tensor,
58}
59
60impl RotaryEmbedding {
61    fn new(dtype: DType, config: &Config, rope_theta: f64, dev: &Device) -> Result<Self> {
62        let dim = config.hidden_size / config.num_attention_heads;
63        let inv_freq: Vec<_> = (0..dim)
64            .step_by(2)
65            .map(|i| 1f32 / rope_theta.powf(i as f64 / dim as f64) as f32)
66            .collect();
67        let inv_freq_len = inv_freq.len();
68        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
69        let max_seq_len = config.max_position_embeddings;
70        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
71            .to_dtype(dtype)?
72            .reshape((max_seq_len, 1))?;
73        let freqs = t.matmul(&inv_freq)?;
74        Ok(Self {
75            sin: freqs.sin()?,
76            cos: freqs.cos()?,
77        })
78    }
79
80    fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
81        let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &self.cos, &self.sin)?;
82        let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &self.cos, &self.sin)?;
83        Ok((q_embed, k_embed))
84    }
85}
86
87#[derive(Clone)]
88struct ModernBertAttention {
89    qkv: Linear,
90    proj: Linear,
91    num_attention_heads: usize,
92    attention_head_size: usize,
93    rotary_emb: Arc<RotaryEmbedding>,
94}
95
96impl ModernBertAttention {
97    fn load(vb: VarBuilder, config: &Config, rotary_emb: Arc<RotaryEmbedding>) -> Result<Self> {
98        let num_attention_heads = config.num_attention_heads;
99        let attention_head_size = config.hidden_size / config.num_attention_heads;
100
101        let qkv = linear_no_bias(config.hidden_size, config.hidden_size * 3, vb.pp("Wqkv"))?;
102        let proj = linear_no_bias(config.hidden_size, config.hidden_size, vb.pp("Wo"))?;
103
104        Ok(Self {
105            qkv,
106            proj,
107            num_attention_heads,
108            attention_head_size,
109            rotary_emb,
110        })
111    }
112
113    fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
114        let xs = hidden_states.clone();
115        let (b, seq_len, d) = xs.dims3()?;
116        let qkv = xs
117            .apply(&self.qkv)?
118            .reshape((
119                b,
120                seq_len,
121                3,
122                self.num_attention_heads,
123                self.attention_head_size,
124            ))?
125            .permute((2, 0, 3, 1, 4))?;
126
127        let q = qkv.get(0)?;
128        let k = qkv.get(1)?;
129        let v = qkv.get(2)?;
130
131        let (q, k) = self.rotary_emb.apply_rotary_emb_qkv(&q, &k)?;
132
133        let scale = (self.attention_head_size as f64).powf(-0.5);
134        let q = (q * scale)?;
135
136        let att = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?;
137
138        let att = att.broadcast_add(attention_mask)?;
139        let att = softmax(&att, D::Minus1)?;
140
141        let xs = att.matmul(&v)?;
142
143        let xs = xs.transpose(1, 2)?.reshape((b, seq_len, d))?;
144        let xs = xs.apply(&self.proj)?;
145        let xs = xs.reshape((b, seq_len, d))?;
146
147        Ok(xs)
148    }
149}
150
151#[derive(Clone)]
152pub struct ModernBertMLP {
153    wi: Linear,
154    wo: Linear,
155}
156
157impl ModernBertMLP {
158    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
159        let wi = linear_no_bias(
160            config.hidden_size,
161            config.intermediate_size * 2,
162            vb.pp("Wi"),
163        )?;
164        let wo = linear_no_bias(config.intermediate_size, config.hidden_size, vb.pp("Wo"))?;
165        Ok(Self { wi, wo })
166    }
167}
168
169impl Module for ModernBertMLP {
170    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
171        let xs = xs.apply(&self.wi)?;
172        let xs = xs.chunk(2, D::Minus1)?;
173        let xs = (&xs[0].gelu_erf()? * &xs[1])?.apply(&self.wo)?; // GeGLU
174        Ok(xs)
175    }
176}
177
178#[derive(Clone)]
179pub struct ModernBertLayer {
180    attn: ModernBertAttention,
181    mlp: ModernBertMLP,
182    attn_norm: Option<LayerNorm>,
183    mlp_norm: LayerNorm,
184    uses_local_attention: bool,
185}
186
187impl ModernBertLayer {
188    fn load(
189        vb: VarBuilder,
190        config: &Config,
191        rotary_emb: Arc<RotaryEmbedding>,
192        uses_local_attention: bool,
193    ) -> Result<Self> {
194        let attn = ModernBertAttention::load(vb.pp("attn"), config, rotary_emb)?;
195        let mlp = ModernBertMLP::load(vb.pp("mlp"), config)?;
196        let attn_norm = layer_norm_no_bias(
197            config.hidden_size,
198            config.layer_norm_eps,
199            vb.pp("attn_norm"),
200        )
201        .ok();
202        let mlp_norm =
203            layer_norm_no_bias(config.hidden_size, config.layer_norm_eps, vb.pp("mlp_norm"))?;
204        Ok(Self {
205            attn,
206            mlp,
207            attn_norm,
208            mlp_norm,
209            uses_local_attention,
210        })
211    }
212
213    fn forward(
214        &self,
215        xs: &Tensor,
216        global_attention_mask: &Tensor,
217        local_attention_mask: &Tensor,
218    ) -> Result<Tensor> {
219        let residual = xs.clone();
220        let mut xs = xs.clone();
221        if let Some(norm) = &self.attn_norm {
222            xs = xs.apply(norm)?;
223        }
224
225        let attention_mask = if self.uses_local_attention {
226            &global_attention_mask.broadcast_add(local_attention_mask)?
227        } else {
228            global_attention_mask
229        };
230        let xs = self.attn.forward(&xs, attention_mask)?;
231        let xs = (xs + residual)?;
232        let mlp_out = xs.apply(&self.mlp_norm)?.apply(&self.mlp)?;
233        let xs = (xs + mlp_out)?;
234        Ok(xs)
235    }
236}
237
238#[derive(Clone)]
239pub struct ModernBertHead {
240    dense: Linear,
241    norm: LayerNorm,
242}
243
244impl ModernBertHead {
245    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
246        let dense = linear_no_bias(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
247        let norm = layer_norm_no_bias(config.hidden_size, config.layer_norm_eps, vb.pp("norm"))?;
248        Ok(Self { dense, norm })
249    }
250}
251
252impl Module for ModernBertHead {
253    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
254        let xs = xs.apply(&self.dense)?.gelu_erf()?.apply(&self.norm)?;
255        Ok(xs)
256    }
257}
258
259#[derive(Clone)]
260pub struct ModernBertDecoder {
261    decoder: Linear,
262}
263
264impl ModernBertDecoder {
265    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
266        // The decoder weights are tied with the embeddings layer weights
267        let decoder_weights = vb.get(
268            (config.vocab_size, config.hidden_size),
269            "model.embeddings.tok_embeddings.weight",
270        )?;
271        let decoder_bias = vb.get(config.vocab_size, "decoder.bias")?;
272        let decoder = Linear::new(decoder_weights, Some(decoder_bias));
273        Ok(Self { decoder })
274    }
275}
276
277impl Module for ModernBertDecoder {
278    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
279        let xs = xs.apply(&self.decoder)?;
280        Ok(xs)
281    }
282}
283
284// Global attention mask calculated from padded token inputs
285fn prepare_4d_attention_mask(
286    mask: &Tensor,
287    dtype: DType,
288    tgt_len: Option<usize>,
289) -> Result<Tensor> {
290    let bsz = mask.dim(0)?;
291    let src_len = mask.dim(1)?;
292    let tgt_len = tgt_len.unwrap_or(src_len);
293
294    let expanded_mask = mask
295        .unsqueeze(1)?
296        .unsqueeze(2)?
297        .expand((bsz, 1, tgt_len, src_len))?
298        .to_dtype(dtype)?;
299
300    let inverted_mask = (1.0 - expanded_mask)?;
301
302    (inverted_mask * f32::MIN as f64)?.to_dtype(dtype)
303}
304
305// Attention mask caused by the sliding window
306fn get_local_attention_mask(
307    seq_len: usize,
308    max_distance: usize,
309    device: &Device,
310) -> Result<Tensor> {
311    let mask: Vec<_> = (0..seq_len)
312        .flat_map(|i| {
313            (0..seq_len).map(move |j| {
314                if (j as i32 - i as i32).abs() > max_distance as i32 {
315                    f32::NEG_INFINITY
316                } else {
317                    0.
318                }
319            })
320        })
321        .collect();
322    Tensor::from_slice(&mask, (seq_len, seq_len), device)
323}
324
325// ModernBERT backbone
326#[derive(Clone)]
327pub struct ModernBert {
328    word_embeddings: Embedding,
329    norm: LayerNorm,
330    layers: Vec<ModernBertLayer>,
331    final_norm: LayerNorm,
332    local_attention_size: usize,
333}
334
335impl ModernBert {
336    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
337        let word_embeddings = embedding(
338            config.vocab_size,
339            config.hidden_size,
340            vb.pp("model.embeddings.tok_embeddings"),
341        )?;
342        let norm = layer_norm_no_bias(
343            config.hidden_size,
344            config.layer_norm_eps,
345            vb.pp("model.embeddings.norm"),
346        )?;
347        let global_rotary_emb = Arc::new(RotaryEmbedding::new(
348            vb.dtype(),
349            config,
350            config.global_rope_theta,
351            vb.device(),
352        )?);
353        let local_rotary_emb = Arc::new(RotaryEmbedding::new(
354            vb.dtype(),
355            config,
356            config.local_rope_theta,
357            vb.device(),
358        )?);
359
360        let mut layers = Vec::with_capacity(config.num_hidden_layers);
361        for layer_id in 0..config.num_hidden_layers {
362            let layer_uses_local_attention = layer_id % config.global_attn_every_n_layers != 0;
363            layers.push(ModernBertLayer::load(
364                vb.pp(format!("model.layers.{layer_id}")),
365                config,
366                if layer_uses_local_attention {
367                    local_rotary_emb.clone()
368                } else {
369                    global_rotary_emb.clone()
370                },
371                layer_uses_local_attention,
372            )?);
373        }
374
375        let final_norm = layer_norm_no_bias(
376            config.hidden_size,
377            config.layer_norm_eps,
378            vb.pp("model.final_norm"),
379        )?;
380
381        Ok(Self {
382            word_embeddings,
383            norm,
384            layers,
385            final_norm,
386            local_attention_size: config.local_attention,
387        })
388    }
389
390    pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
391        let seq_len = xs.shape().dims()[1];
392        let global_attention_mask =
393            prepare_4d_attention_mask(mask, DType::F32, None)?.to_device(xs.device())?;
394        let local_attention_mask =
395            get_local_attention_mask(seq_len, self.local_attention_size / 2, xs.device())?;
396        let mut xs = xs.apply(&self.word_embeddings)?.apply(&self.norm)?;
397        for layer in self.layers.iter() {
398            xs = layer.forward(&xs, &global_attention_mask, &local_attention_mask)?;
399        }
400        let xs = xs.apply(&self.final_norm)?;
401        Ok(xs)
402    }
403}
404
405// ModernBERT for the fill-mask task
406#[derive(Clone)]
407pub struct ModernBertForMaskedLM {
408    model: ModernBert,
409    decoder: ModernBertDecoder,
410    head: ModernBertHead,
411}
412
413impl ModernBertForMaskedLM {
414    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
415        let model = ModernBert::load(vb.clone(), config)?;
416        let decoder = ModernBertDecoder::load(vb.clone(), config)?;
417        let head = ModernBertHead::load(vb.pp("head"), config)?;
418        Ok(Self {
419            model,
420            decoder,
421            head,
422        })
423    }
424
425    pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
426        let xs = self
427            .model
428            .forward(xs, mask)?
429            .apply(&self.head)?
430            .apply(&self.decoder)?;
431        Ok(xs)
432    }
433}
434
435#[derive(Clone)]
436pub struct ModernBertClassifier {
437    classifier: Linear,
438}
439
440impl ModernBertClassifier {
441    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
442        // The decoder weights are tied with the embeddings layer weights
443        let classifier = linear(
444            config.hidden_size,
445            config
446                .classifier_config
447                .as_ref()
448                .map(|cc| cc.id2label.len())
449                .unwrap_or_default(),
450            vb.pp("classifier"),
451        )?;
452        Ok(Self { classifier })
453    }
454}
455
456impl Module for ModernBertClassifier {
457    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
458        let xs = xs.apply(&self.classifier)?;
459        softmax(&xs, D::Minus1)
460    }
461}
462
463#[derive(Clone)]
464pub struct ModernBertForSequenceClassification {
465    model: ModernBert,
466    head: ModernBertHead,
467    classifier: ModernBertClassifier,
468    classifier_pooling: ClassifierPooling,
469}
470
471impl ModernBertForSequenceClassification {
472    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
473        let model = ModernBert::load(vb.clone(), config)?;
474        let classifier = ModernBertClassifier::load(vb.clone(), config)?;
475        let head = ModernBertHead::load(vb.pp("head"), config)?;
476        Ok(Self {
477            model,
478            head,
479            classifier,
480            classifier_pooling: config
481                .classifier_config
482                .as_ref()
483                .map(|cc| cc.classifier_pooling)
484                .unwrap_or_default(),
485        })
486    }
487
488    pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
489        let output = self.model.forward(xs, mask)?;
490        let last_hidden_state = match self.classifier_pooling {
491            ClassifierPooling::CLS => output.i((.., .., 0))?,
492            ClassifierPooling::MEAN => {
493                let unsqueezed_mask = &mask.unsqueeze(D::Minus1)?.to_dtype(DType::F32)?;
494                let sum_output = output.broadcast_mul(unsqueezed_mask)?.sum(1)?;
495                sum_output.broadcast_div(&mask.sum_keepdim(1)?.to_dtype(DType::F32)?)?
496            }
497        };
498        let xs = self
499            .head
500            .forward(&last_hidden_state)?
501            .apply(&self.classifier)?;
502        Ok(xs)
503    }
504}