1use crate::models::with_tracing::{linear, Linear};
2use candle::{DType, Module, Result, Tensor};
3use candle_nn::{
4 embedding, layer_norm, ops::softmax_last_dim, Activation, Embedding, LayerNorm, VarBuilder,
5};
6
7#[derive(Debug, Clone, serde::Deserialize)]
8pub struct Config {
9 pub hidden_size: usize,
10 pub layer_norm_eps: f64,
11 pub attention_probs_dropout_prob: f32,
12 pub hidden_dropout_prob: f32,
13 pub num_attention_heads: usize,
14 pub position_embedding_type: String,
15 pub intermediate_size: usize,
16 pub hidden_act: Activation,
17 pub num_hidden_layers: usize,
18 pub vocab_size: usize,
19 pub max_position_embeddings: usize,
20 pub type_vocab_size: usize,
21 pub pad_token_id: u32,
22}
23
24struct XLMRobertaEmbeddings {
25 word_embeddings: Embedding,
26 position_embeddings: Option<Embedding>,
27 token_type_embeddings: Embedding,
28 layer_norm: LayerNorm,
29 padding_idx: u32,
30 span: tracing::Span,
31}
32
33impl XLMRobertaEmbeddings {
34 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
35 let word_embeddings = embedding(
36 config.vocab_size,
37 config.hidden_size,
38 vb.pp("word_embeddings"),
39 )?;
40 let position_embeddings = embedding(
41 config.max_position_embeddings,
42 config.hidden_size,
43 vb.pp("position_embeddings"),
44 )?;
45 let token_type_embeddings = embedding(
46 config.type_vocab_size,
47 config.hidden_size,
48 vb.pp("token_type_embeddings"),
49 )?;
50 let layer_norm = layer_norm(
51 config.hidden_size,
52 config.layer_norm_eps,
53 vb.pp("LayerNorm"),
54 )?;
55 Ok(Self {
56 word_embeddings,
57 position_embeddings: Some(position_embeddings),
58 token_type_embeddings,
59 layer_norm,
60 padding_idx: config.pad_token_id,
61 span: tracing::span!(tracing::Level::TRACE, "embeddings"),
62 })
63 }
64
65 fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
66 let _enter = self.span.enter();
67 let (_bsize, _) = input_ids.dims2()?;
68 let input_embeddings = self.word_embeddings.forward(input_ids)?;
69 let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
70 let mut embeddings = (&input_embeddings + token_type_embeddings)?;
71 if let Some(position_embeddings) = &self.position_embeddings {
72 let mask = input_ids
73 .ne(self.padding_idx)?
74 .to_dtype(input_embeddings.dtype())?;
75 let cumsum = mask.cumsum(1)?;
76 let position_ids = (cumsum * mask)?
77 .broadcast_add(
78 &Tensor::try_from(self.padding_idx)?
79 .to_dtype(input_embeddings.dtype())?
80 .to_device(input_embeddings.device())?,
81 )?
82 .to_dtype(candle::DType::U32)?;
83 embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?;
84 }
85 let embeddings = self.layer_norm.forward(&embeddings)?;
86 Ok(embeddings)
87 }
88}
89
90struct XLMRobertaSelfAttention {
91 num_attention_heads: usize,
92 attention_head_size: usize,
93 all_head_size: usize,
94 query: Linear,
95 key: Linear,
96 value: Linear,
97}
98
99impl XLMRobertaSelfAttention {
100 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
101 let attention_head_size = cfg.hidden_size / cfg.num_attention_heads;
102 let all_head_size = cfg.num_attention_heads * attention_head_size;
103 Ok(Self {
104 num_attention_heads: cfg.num_attention_heads,
105 attention_head_size,
106 all_head_size,
107 query: linear(cfg.hidden_size, all_head_size, vb.pp("query"))?,
108 key: linear(cfg.hidden_size, all_head_size, vb.pp("key"))?,
109 value: linear(cfg.hidden_size, all_head_size, vb.pp("value"))?,
110 })
111 }
112
113 fn transpose_for_scores(&self, x: &Tensor) -> Result<Tensor> {
114 let mut new_x_shape = x.dims().to_vec();
115 new_x_shape[2] = self.num_attention_heads;
116 new_x_shape.push(self.attention_head_size);
117 let x = x.reshape(new_x_shape)?;
118 x.permute((0, 2, 1, 3))?.contiguous()
119 }
120
121 fn forward(
122 &self,
123 hidden_states: &Tensor,
124 encoder_hidden_states: Option<&Tensor>,
125 attention_mask: &Tensor,
126 past_key_value: Option<(&Tensor, &Tensor)>,
127 encoder_attention_mask: Option<&Tensor>,
128 ) -> Result<Tensor> {
129 let mixed_query_layer = self.query.forward(hidden_states)?;
130 let is_cross_attention = encoder_hidden_states.is_some();
131 let (key_layer, value_layer, attention_mask) = if is_cross_attention
132 && past_key_value.is_some()
133 {
134 let key_layer = past_key_value.unwrap().0.clone();
135 let value_layer = past_key_value.unwrap().1.clone();
136 let attention_mask = encoder_attention_mask.unwrap().clone();
137 (key_layer, value_layer, Some(attention_mask))
138 } else if is_cross_attention {
139 let key_layer =
140 self.transpose_for_scores(&self.key.forward(encoder_hidden_states.unwrap())?)?;
141 let value_layer =
142 self.transpose_for_scores(&self.value.forward(encoder_hidden_states.unwrap())?)?;
143 let attention_mask = encoder_attention_mask.unwrap();
144 (key_layer, value_layer, Some(attention_mask.clone()))
145 } else if past_key_value.is_some() {
146 let mut key_layer = self.transpose_for_scores(&self.key.forward(hidden_states)?)?;
147 let mut value_layer = self.transpose_for_scores(&self.value.forward(hidden_states)?)?;
148 key_layer = Tensor::cat(
149 &[
150 past_key_value.clone().as_ref().unwrap().0.clone(),
151 key_layer,
152 ],
153 2,
154 )?;
155 value_layer = Tensor::cat(
156 &[past_key_value.as_ref().unwrap().1.clone(), value_layer],
157 2,
158 )?;
159 (key_layer, value_layer, Some(attention_mask.clone()))
160 } else {
161 let key_layer = self.transpose_for_scores(&self.key.forward(hidden_states)?)?;
162 let value_layer = self.transpose_for_scores(&self.value.forward(hidden_states)?)?;
163 (key_layer, value_layer, Some(attention_mask.clone()))
164 };
165
166 let query_layer = self.transpose_for_scores(&mixed_query_layer)?;
167 let mut attention_scores = query_layer.matmul(&key_layer.transpose(2, 3)?)?;
168 let scale = 1f64 / f64::sqrt(self.attention_head_size as f64);
169
170 attention_scores = (attention_scores * scale)?;
171 attention_scores = match attention_mask {
172 None => attention_scores,
173 Some(mask) => {
174 attention_scores.broadcast_add(&mask.to_dtype(attention_scores.dtype())?)?
175 }
176 };
177 let attention_probs = softmax_last_dim(&attention_scores)?;
178
179 let context_layer = attention_probs
180 .matmul(&value_layer)?
181 .permute((0, 2, 1, 3))?
182 .contiguous()?;
183 let mut new_context_layer_shape =
184 context_layer.dims()[..context_layer.dims().len() - 2].to_vec();
185 new_context_layer_shape.push(self.all_head_size);
186 let context_layer = context_layer.reshape(new_context_layer_shape)?;
187
188 Ok(context_layer)
189 }
190}
191
192struct XLMRobertaSelfOutput {
193 dense: Linear,
194 layernorm: LayerNorm,
195}
196
197impl XLMRobertaSelfOutput {
198 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
199 let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
200 let layernorm =
201 candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
202 Ok(Self { dense, layernorm })
203 }
204
205 fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
206 let hidden_states = self.dense.forward(hidden_states)?;
207 let hidden_states = self.layernorm.forward(&(hidden_states + input_tensor)?)?;
208 Ok(hidden_states)
209 }
210}
211
212struct XLMRobertaAttention {
213 output: XLMRobertaSelfOutput,
214 self_attention: XLMRobertaSelfAttention,
215}
216
217impl XLMRobertaAttention {
218 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
219 let output = XLMRobertaSelfOutput::new(cfg, vb.pp("output"))?;
220 let self_attention = XLMRobertaSelfAttention::new(cfg, vb.pp("self"))?;
221 Ok(Self {
222 output,
223 self_attention,
224 })
225 }
226
227 fn forward(
228 &self,
229 hidden_states: &Tensor,
230 attention_mask: &Tensor,
231 encoder_hidden_states: Option<&Tensor>,
232 encoder_attention_mask: Option<&Tensor>,
233 past_key_value: Option<(&Tensor, &Tensor)>,
234 ) -> Result<(Tensor, Tensor)> {
235 let self_outputs = self.self_attention.forward(
236 hidden_states,
237 encoder_hidden_states,
238 attention_mask,
239 past_key_value,
240 encoder_attention_mask,
241 )?;
242 let attention_output = self.output.forward(&self_outputs, hidden_states)?;
243 Ok((attention_output, self_outputs))
244 }
245}
246
247struct XLMRobertaOutput {
248 dense: Linear,
249 layernorm: LayerNorm,
250}
251
252impl XLMRobertaOutput {
253 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
254 let dense = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("dense"))?;
255 let layernorm =
256 candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
257 Ok(Self { dense, layernorm })
258 }
259
260 fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
261 let hidden_states = self.dense.forward(hidden_states)?;
262 let hidden_states = self.layernorm.forward(&(hidden_states + input_tensor)?)?;
263 Ok(hidden_states)
264 }
265}
266
267struct XLMRobertaIntermediate {
268 dense: Linear,
269 intermediate_act_fn: Activation,
270}
271
272impl XLMRobertaIntermediate {
273 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
274 let dense = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("dense"))?;
275 let intermediate_act_fn = cfg.hidden_act;
276 Ok(Self {
277 dense,
278 intermediate_act_fn,
279 })
280 }
281
282 fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
283 let hidden_states = self.dense.forward(hidden_states)?;
284 let hidden_states = self.intermediate_act_fn.forward(&hidden_states)?;
285 Ok(hidden_states)
286 }
287}
288
289struct XLMRobertaLayer {
290 attention: XLMRobertaAttention,
291 intermediate: XLMRobertaIntermediate,
292 output: XLMRobertaOutput,
293}
294
295impl XLMRobertaLayer {
296 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
297 let attention = XLMRobertaAttention::new(cfg, vb.pp("attention"))?;
298 let intermediate = XLMRobertaIntermediate::new(cfg, vb.pp("intermediate"))?;
299 let output = XLMRobertaOutput::new(cfg, vb.pp("output"))?;
300 Ok(Self {
301 attention,
302 intermediate,
303 output,
304 })
305 }
306
307 fn forward(
308 &self,
309 hidden_states: &Tensor,
310 attention_mask: &Tensor,
311 encoder_hidden_states: Option<&Tensor>,
312 encoder_attention_mask: Option<&Tensor>,
313 past_key_value: Option<(&Tensor, &Tensor)>,
314 ) -> Result<(Tensor, Tensor)> {
315 let self_attention_outputs = self.attention.forward(
316 hidden_states,
317 attention_mask,
318 encoder_hidden_states,
319 encoder_attention_mask,
320 past_key_value,
321 )?;
322 let attention_output = self_attention_outputs.0;
323 let outputs = self_attention_outputs.1;
324 let intermediate_output = self.intermediate.forward(&attention_output)?;
325 let layer_output = self
326 .output
327 .forward(&intermediate_output, &attention_output)?;
328 Ok((layer_output, outputs))
329 }
330}
331
332struct XLMRobertaEncoder {
333 layers: Vec<XLMRobertaLayer>,
334}
335
336impl XLMRobertaEncoder {
337 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
338 let layers = (0..cfg.num_hidden_layers)
339 .map(|i| XLMRobertaLayer::new(cfg, vb.pp(format!("layer.{}", i))))
340 .collect::<Result<Vec<_>>>()?;
341 Ok(Self { layers })
342 }
343
344 fn forward(
345 &self,
346 hidden_states: &Tensor,
347 attention_mask: &Tensor,
348 encoder_hidden_states: Option<&Tensor>,
349 encoder_attention_mask: Option<&Tensor>,
350 past_key_value: Option<(&Tensor, &Tensor)>,
351 ) -> Result<Tensor> {
352 let mut hidden_states = hidden_states.clone();
353 for layer_module in self.layers.iter() {
354 let layer_outputs = layer_module.forward(
355 &hidden_states,
356 attention_mask,
357 encoder_hidden_states,
358 encoder_attention_mask,
359 past_key_value,
360 )?;
361 hidden_states = layer_outputs.0;
362 }
363 Ok(hidden_states)
364 }
365}
366
367pub struct XLMRobertaModel {
368 encoder: XLMRobertaEncoder,
369 embeddings: XLMRobertaEmbeddings,
370}
371
372impl XLMRobertaModel {
373 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
374 let encoder = XLMRobertaEncoder::new(cfg, vb.pp("encoder"))?;
375 let embeddings = XLMRobertaEmbeddings::load(vb.pp("embeddings"), cfg)?;
376 Ok(Self {
377 encoder,
378 embeddings,
379 })
380 }
381
382 pub fn forward(
383 &self,
384 input_ids: &Tensor,
385 attention_mask: &Tensor,
386 token_type_ids: &Tensor,
387 past_key_value: Option<(&Tensor, &Tensor)>,
388 encoder_hidden_states: Option<&Tensor>,
389 encoder_attention_mask: Option<&Tensor>,
390 ) -> Result<Tensor> {
391 let hidden_states = self.embeddings.forward(input_ids, token_type_ids)?;
392 let attention_mask = prepare_4d_attention_mask(attention_mask, DType::F32, None)?
393 .to_device(hidden_states.device())?;
394 let hidden_states = self.encoder.forward(
395 &hidden_states,
396 &attention_mask,
397 encoder_hidden_states,
398 encoder_attention_mask,
399 past_key_value,
400 )?;
401 Ok(hidden_states)
402 }
403}
404
405struct XLMRobertaLMHead {
406 dense: Linear,
407 layer_norm: LayerNorm,
408}
409
410impl XLMRobertaLMHead {
411 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
412 let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
413 let layer_norm =
414 candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("layer_norm"))?;
415 Ok(Self { dense, layer_norm })
416 }
417
418 fn forward(&self, hidden_states: &Tensor, shared_embeddings: &Tensor) -> Result<Tensor> {
419 let hidden_states = self.dense.forward(hidden_states)?;
420 let hidden_states = candle_nn::Activation::Gelu.forward(&hidden_states)?;
421 let hidden_states = self.layer_norm.forward(&hidden_states)?;
422 let hidden_states = hidden_states.broadcast_matmul(shared_embeddings)?;
423 Ok(hidden_states)
424 }
425}
426
427pub struct XLMRobertaForMaskedLM {
428 roberta: XLMRobertaModel,
429 lm_head: XLMRobertaLMHead,
430}
431
432impl XLMRobertaForMaskedLM {
433 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
434 let roberta = XLMRobertaModel::new(cfg, vb.pp("roberta"))?;
435 let lm_head = XLMRobertaLMHead::new(cfg, vb.pp("lm_head"))?;
436 Ok(Self { roberta, lm_head })
437 }
438
439 pub fn forward(
440 &self,
441 input_ids: &Tensor,
442 attention_mask: &Tensor,
443 token_type_ids: &Tensor,
444 past_key_value: Option<(&Tensor, &Tensor)>,
445 encoder_hidden_states: Option<&Tensor>,
446 encoder_attention_mask: Option<&Tensor>,
447 ) -> Result<Tensor> {
448 let hidden_states = self.roberta.forward(
449 input_ids,
450 attention_mask,
451 token_type_ids,
452 past_key_value,
453 encoder_hidden_states,
454 encoder_attention_mask,
455 )?;
456 let lm_logits = self.lm_head.forward(
457 &hidden_states,
458 &self
459 .roberta
460 .embeddings
461 .word_embeddings
462 .embeddings()
463 .t()?
464 .unsqueeze(0)?,
465 )?;
466 Ok(lm_logits)
467 }
468}
469
470struct XLMRobertaClassificationHead {
471 dense: Linear,
472 out_proj: Linear,
473}
474
475impl XLMRobertaClassificationHead {
476 fn new(num_labels: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
477 let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
478 let out_proj = linear(cfg.hidden_size, num_labels, vb.pp("out_proj"))?;
479 Ok(Self { dense, out_proj })
480 }
481
482 fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
483 let cls_states = hidden_states.get_on_dim(1, 0)?.contiguous()?;
484 let hidden_states = self.dense.forward(&cls_states)?;
485 let hidden_states = candle_nn::Activation::GeluPytorchTanh.forward(&hidden_states)?;
486 let hidden_states = self.out_proj.forward(&hidden_states)?;
487 Ok(hidden_states)
488 }
489}
490
491pub struct XLMRobertaForSequenceClassification {
492 roberta: XLMRobertaModel,
493 classifier: XLMRobertaClassificationHead,
494}
495
496impl XLMRobertaForSequenceClassification {
497 pub fn new(num_labels: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
498 let roberta = XLMRobertaModel::new(cfg, vb.pp("roberta"))?;
499 let classifier = XLMRobertaClassificationHead::new(num_labels, cfg, vb.pp("classifier"))?;
500 Ok(Self {
501 roberta,
502 classifier,
503 })
504 }
505
506 pub fn forward(
507 &self,
508 input_ids: &Tensor,
509 attention_mask: &Tensor,
510 token_type_ids: &Tensor,
511 ) -> Result<Tensor> {
512 let hidden_states =
513 self.roberta
514 .forward(input_ids, attention_mask, token_type_ids, None, None, None)?;
515 self.classifier.forward(&hidden_states)
516 }
517}
518
519fn prepare_4d_attention_mask(
520 mask: &Tensor,
521 dtype: DType,
522 tgt_len: Option<usize>,
523) -> Result<Tensor> {
524 let bsz = mask.dim(0)?;
525 let src_len = mask.dim(1)?;
526 let tgt_len = tgt_len.unwrap_or(src_len);
527
528 let expanded_mask = mask
529 .unsqueeze(1)?
530 .unsqueeze(2)?
531 .expand((bsz, 1, tgt_len, src_len))?
532 .to_dtype(dtype)?;
533
534 let inverted_mask = (1.0 - expanded_mask)?;
535
536 (inverted_mask * get_dtype_min_val(dtype))?.to_dtype(dtype)
537}
538
539fn get_dtype_min_val(dtype: DType) -> f64 {
540 match dtype {
541 DType::F32 => f32::MIN as f64,
542 DType::F64 => f64::MIN,
543 _ => panic!("Unsupported data type"),
544 }
545}