1use candle::{DType, Device, IndexOp, Result, Tensor, D};
10use candle_nn::{
11 embedding, layer_norm_no_bias, linear, linear_no_bias, ops::softmax, Embedding, LayerNorm,
12 Linear, Module, VarBuilder,
13};
14use serde::Deserialize;
15
16use core::f32;
17use std::collections::HashMap;
18use std::sync::Arc;
19
20#[derive(Debug, Clone, PartialEq, Deserialize)]
21pub struct Config {
22 pub vocab_size: usize,
23 pub hidden_size: usize,
24 pub num_hidden_layers: usize,
25 pub num_attention_heads: usize,
26 pub intermediate_size: usize,
27 pub max_position_embeddings: usize,
28 pub layer_norm_eps: f64,
29 pub pad_token_id: u32,
30 pub global_attn_every_n_layers: usize,
31 pub global_rope_theta: f64,
32 pub local_attention: usize,
33 pub local_rope_theta: f64,
34 #[serde(default)]
35 #[serde(flatten)]
36 pub classifier_config: Option<ClassifierConfig>,
37}
38
39#[derive(Debug, Clone, Deserialize, PartialEq, Copy, Default)]
40#[serde(rename_all = "lowercase")]
41pub enum ClassifierPooling {
42 #[default]
43 CLS,
44 MEAN,
45}
46
47#[derive(Debug, Clone, PartialEq, Deserialize)]
48pub struct ClassifierConfig {
49 pub id2label: HashMap<String, String>,
50 pub label2id: HashMap<String, String>,
51 pub classifier_pooling: ClassifierPooling,
52}
53
54#[derive(Debug, Clone)]
55struct RotaryEmbedding {
56 sin: Tensor,
57 cos: Tensor,
58}
59
60impl RotaryEmbedding {
61 fn new(dtype: DType, config: &Config, rope_theta: f64, dev: &Device) -> Result<Self> {
62 let dim = config.hidden_size / config.num_attention_heads;
63 let inv_freq: Vec<_> = (0..dim)
64 .step_by(2)
65 .map(|i| 1f32 / rope_theta.powf(i as f64 / dim as f64) as f32)
66 .collect();
67 let inv_freq_len = inv_freq.len();
68 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
69 let max_seq_len = config.max_position_embeddings;
70 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
71 .to_dtype(dtype)?
72 .reshape((max_seq_len, 1))?;
73 let freqs = t.matmul(&inv_freq)?;
74 Ok(Self {
75 sin: freqs.sin()?,
76 cos: freqs.cos()?,
77 })
78 }
79
80 fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
81 let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &self.cos, &self.sin)?;
82 let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &self.cos, &self.sin)?;
83 Ok((q_embed, k_embed))
84 }
85}
86
87#[derive(Clone)]
88struct ModernBertAttention {
89 qkv: Linear,
90 proj: Linear,
91 num_attention_heads: usize,
92 attention_head_size: usize,
93 rotary_emb: Arc<RotaryEmbedding>,
94}
95
96impl ModernBertAttention {
97 fn load(vb: VarBuilder, config: &Config, rotary_emb: Arc<RotaryEmbedding>) -> Result<Self> {
98 let num_attention_heads = config.num_attention_heads;
99 let attention_head_size = config.hidden_size / config.num_attention_heads;
100
101 let qkv = linear_no_bias(config.hidden_size, config.hidden_size * 3, vb.pp("Wqkv"))?;
102 let proj = linear_no_bias(config.hidden_size, config.hidden_size, vb.pp("Wo"))?;
103
104 Ok(Self {
105 qkv,
106 proj,
107 num_attention_heads,
108 attention_head_size,
109 rotary_emb,
110 })
111 }
112
113 fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
114 let xs = hidden_states.clone();
115 let (b, seq_len, d) = xs.dims3()?;
116 let qkv = xs
117 .apply(&self.qkv)?
118 .reshape((
119 b,
120 seq_len,
121 3,
122 self.num_attention_heads,
123 self.attention_head_size,
124 ))?
125 .permute((2, 0, 3, 1, 4))?;
126
127 let q = qkv.get(0)?;
128 let k = qkv.get(1)?;
129 let v = qkv.get(2)?;
130
131 let (q, k) = self.rotary_emb.apply_rotary_emb_qkv(&q, &k)?;
132
133 let scale = (self.attention_head_size as f64).powf(-0.5);
134 let q = (q * scale)?;
135
136 let att = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?;
137
138 let att = att.broadcast_add(attention_mask)?;
139 let att = softmax(&att, D::Minus1)?;
140
141 let xs = att.matmul(&v)?;
142
143 let xs = xs.transpose(1, 2)?.reshape((b, seq_len, d))?;
144 let xs = xs.apply(&self.proj)?;
145 let xs = xs.reshape((b, seq_len, d))?;
146
147 Ok(xs)
148 }
149}
150
151#[derive(Clone)]
152pub struct ModernBertMLP {
153 wi: Linear,
154 wo: Linear,
155}
156
157impl ModernBertMLP {
158 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
159 let wi = linear_no_bias(
160 config.hidden_size,
161 config.intermediate_size * 2,
162 vb.pp("Wi"),
163 )?;
164 let wo = linear_no_bias(config.intermediate_size, config.hidden_size, vb.pp("Wo"))?;
165 Ok(Self { wi, wo })
166 }
167}
168
169impl Module for ModernBertMLP {
170 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
171 let xs = xs.apply(&self.wi)?;
172 let xs = xs.chunk(2, D::Minus1)?;
173 let xs = (&xs[0].gelu_erf()? * &xs[1])?.apply(&self.wo)?; Ok(xs)
175 }
176}
177
178#[derive(Clone)]
179pub struct ModernBertLayer {
180 attn: ModernBertAttention,
181 mlp: ModernBertMLP,
182 attn_norm: Option<LayerNorm>,
183 mlp_norm: LayerNorm,
184 uses_local_attention: bool,
185}
186
187impl ModernBertLayer {
188 fn load(
189 vb: VarBuilder,
190 config: &Config,
191 rotary_emb: Arc<RotaryEmbedding>,
192 uses_local_attention: bool,
193 ) -> Result<Self> {
194 let attn = ModernBertAttention::load(vb.pp("attn"), config, rotary_emb)?;
195 let mlp = ModernBertMLP::load(vb.pp("mlp"), config)?;
196 let attn_norm = layer_norm_no_bias(
197 config.hidden_size,
198 config.layer_norm_eps,
199 vb.pp("attn_norm"),
200 )
201 .ok();
202 let mlp_norm =
203 layer_norm_no_bias(config.hidden_size, config.layer_norm_eps, vb.pp("mlp_norm"))?;
204 Ok(Self {
205 attn,
206 mlp,
207 attn_norm,
208 mlp_norm,
209 uses_local_attention,
210 })
211 }
212
213 fn forward(
214 &self,
215 xs: &Tensor,
216 global_attention_mask: &Tensor,
217 local_attention_mask: &Tensor,
218 ) -> Result<Tensor> {
219 let residual = xs.clone();
220 let mut xs = xs.clone();
221 if let Some(norm) = &self.attn_norm {
222 xs = xs.apply(norm)?;
223 }
224
225 let attention_mask = if self.uses_local_attention {
226 &global_attention_mask.broadcast_add(local_attention_mask)?
227 } else {
228 global_attention_mask
229 };
230 let xs = self.attn.forward(&xs, attention_mask)?;
231 let xs = (xs + residual)?;
232 let mlp_out = xs.apply(&self.mlp_norm)?.apply(&self.mlp)?;
233 let xs = (xs + mlp_out)?;
234 Ok(xs)
235 }
236}
237
238#[derive(Clone)]
239pub struct ModernBertHead {
240 dense: Linear,
241 norm: LayerNorm,
242}
243
244impl ModernBertHead {
245 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
246 let dense = linear_no_bias(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
247 let norm = layer_norm_no_bias(config.hidden_size, config.layer_norm_eps, vb.pp("norm"))?;
248 Ok(Self { dense, norm })
249 }
250}
251
252impl Module for ModernBertHead {
253 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
254 let xs = xs.apply(&self.dense)?.gelu_erf()?.apply(&self.norm)?;
255 Ok(xs)
256 }
257}
258
259#[derive(Clone)]
260pub struct ModernBertDecoder {
261 decoder: Linear,
262}
263
264impl ModernBertDecoder {
265 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
266 let decoder_weights = vb.get(
268 (config.vocab_size, config.hidden_size),
269 "model.embeddings.tok_embeddings.weight",
270 )?;
271 let decoder_bias = vb.get(config.vocab_size, "decoder.bias")?;
272 let decoder = Linear::new(decoder_weights, Some(decoder_bias));
273 Ok(Self { decoder })
274 }
275}
276
277impl Module for ModernBertDecoder {
278 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
279 let xs = xs.apply(&self.decoder)?;
280 Ok(xs)
281 }
282}
283
284fn prepare_4d_attention_mask(
286 mask: &Tensor,
287 dtype: DType,
288 tgt_len: Option<usize>,
289) -> Result<Tensor> {
290 let bsz = mask.dim(0)?;
291 let src_len = mask.dim(1)?;
292 let tgt_len = tgt_len.unwrap_or(src_len);
293
294 let expanded_mask = mask
295 .unsqueeze(1)?
296 .unsqueeze(2)?
297 .expand((bsz, 1, tgt_len, src_len))?
298 .to_dtype(dtype)?;
299
300 let inverted_mask = (1.0 - expanded_mask)?;
301
302 (inverted_mask * f32::MIN as f64)?.to_dtype(dtype)
303}
304
305fn get_local_attention_mask(
307 seq_len: usize,
308 max_distance: usize,
309 device: &Device,
310) -> Result<Tensor> {
311 let mask: Vec<_> = (0..seq_len)
312 .flat_map(|i| {
313 (0..seq_len).map(move |j| {
314 if (j as i32 - i as i32).abs() > max_distance as i32 {
315 f32::NEG_INFINITY
316 } else {
317 0.
318 }
319 })
320 })
321 .collect();
322 Tensor::from_slice(&mask, (seq_len, seq_len), device)
323}
324
325#[derive(Clone)]
327pub struct ModernBert {
328 word_embeddings: Embedding,
329 norm: LayerNorm,
330 layers: Vec<ModernBertLayer>,
331 final_norm: LayerNorm,
332 local_attention_size: usize,
333}
334
335impl ModernBert {
336 pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
337 let word_embeddings = embedding(
338 config.vocab_size,
339 config.hidden_size,
340 vb.pp("model.embeddings.tok_embeddings"),
341 )?;
342 let norm = layer_norm_no_bias(
343 config.hidden_size,
344 config.layer_norm_eps,
345 vb.pp("model.embeddings.norm"),
346 )?;
347 let global_rotary_emb = Arc::new(RotaryEmbedding::new(
348 vb.dtype(),
349 config,
350 config.global_rope_theta,
351 vb.device(),
352 )?);
353 let local_rotary_emb = Arc::new(RotaryEmbedding::new(
354 vb.dtype(),
355 config,
356 config.local_rope_theta,
357 vb.device(),
358 )?);
359
360 let mut layers = Vec::with_capacity(config.num_hidden_layers);
361 for layer_id in 0..config.num_hidden_layers {
362 let layer_uses_local_attention = layer_id % config.global_attn_every_n_layers != 0;
363 layers.push(ModernBertLayer::load(
364 vb.pp(format!("model.layers.{layer_id}")),
365 config,
366 if layer_uses_local_attention {
367 local_rotary_emb.clone()
368 } else {
369 global_rotary_emb.clone()
370 },
371 layer_uses_local_attention,
372 )?);
373 }
374
375 let final_norm = layer_norm_no_bias(
376 config.hidden_size,
377 config.layer_norm_eps,
378 vb.pp("model.final_norm"),
379 )?;
380
381 Ok(Self {
382 word_embeddings,
383 norm,
384 layers,
385 final_norm,
386 local_attention_size: config.local_attention,
387 })
388 }
389
390 pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
391 let seq_len = xs.shape().dims()[1];
392 let global_attention_mask =
393 prepare_4d_attention_mask(mask, DType::F32, None)?.to_device(xs.device())?;
394 let local_attention_mask =
395 get_local_attention_mask(seq_len, self.local_attention_size / 2, xs.device())?;
396 let mut xs = xs.apply(&self.word_embeddings)?.apply(&self.norm)?;
397 for layer in self.layers.iter() {
398 xs = layer.forward(&xs, &global_attention_mask, &local_attention_mask)?;
399 }
400 let xs = xs.apply(&self.final_norm)?;
401 Ok(xs)
402 }
403}
404
405#[derive(Clone)]
407pub struct ModernBertForMaskedLM {
408 model: ModernBert,
409 decoder: ModernBertDecoder,
410 head: ModernBertHead,
411}
412
413impl ModernBertForMaskedLM {
414 pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
415 let model = ModernBert::load(vb.clone(), config)?;
416 let decoder = ModernBertDecoder::load(vb.clone(), config)?;
417 let head = ModernBertHead::load(vb.pp("head"), config)?;
418 Ok(Self {
419 model,
420 decoder,
421 head,
422 })
423 }
424
425 pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
426 let xs = self
427 .model
428 .forward(xs, mask)?
429 .apply(&self.head)?
430 .apply(&self.decoder)?;
431 Ok(xs)
432 }
433}
434
435#[derive(Clone)]
436pub struct ModernBertClassifier {
437 classifier: Linear,
438}
439
440impl ModernBertClassifier {
441 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
442 let classifier = linear(
444 config.hidden_size,
445 config
446 .classifier_config
447 .as_ref()
448 .map(|cc| cc.id2label.len())
449 .unwrap_or_default(),
450 vb.pp("classifier"),
451 )?;
452 Ok(Self { classifier })
453 }
454}
455
456impl Module for ModernBertClassifier {
457 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
458 let xs = xs.apply(&self.classifier)?;
459 softmax(&xs, D::Minus1)
460 }
461}
462
463#[derive(Clone)]
464pub struct ModernBertForSequenceClassification {
465 model: ModernBert,
466 head: ModernBertHead,
467 classifier: ModernBertClassifier,
468 classifier_pooling: ClassifierPooling,
469}
470
471impl ModernBertForSequenceClassification {
472 pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
473 let model = ModernBert::load(vb.clone(), config)?;
474 let classifier = ModernBertClassifier::load(vb.clone(), config)?;
475 let head = ModernBertHead::load(vb.pp("head"), config)?;
476 Ok(Self {
477 model,
478 head,
479 classifier,
480 classifier_pooling: config
481 .classifier_config
482 .as_ref()
483 .map(|cc| cc.classifier_pooling)
484 .unwrap_or_default(),
485 })
486 }
487
488 pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
489 let output = self.model.forward(xs, mask)?;
490 let last_hidden_state = match self.classifier_pooling {
491 ClassifierPooling::CLS => output.i((.., .., 0))?,
492 ClassifierPooling::MEAN => {
493 let unsqueezed_mask = &mask.unsqueeze(D::Minus1)?.to_dtype(DType::F32)?;
494 let sum_output = output.broadcast_mul(unsqueezed_mask)?.sum(1)?;
495 sum_output.broadcast_div(&mask.sum_keepdim(1)?.to_dtype(DType::F32)?)?
496 }
497 };
498 let xs = self
499 .head
500 .forward(&last_hidden_state)?
501 .apply(&self.classifier)?;
502 Ok(xs)
503 }
504}