candle_transformers/models/
xlm_roberta.rs

1use crate::models::with_tracing::{linear, Linear};
2use candle::{DType, Module, Result, Tensor};
3use candle_nn::{
4    embedding, layer_norm, ops::softmax_last_dim, Activation, Embedding, LayerNorm, VarBuilder,
5};
6
7#[derive(Debug, Clone, serde::Deserialize)]
8pub struct Config {
9    pub hidden_size: usize,
10    pub layer_norm_eps: f64,
11    pub attention_probs_dropout_prob: f32,
12    pub hidden_dropout_prob: f32,
13    pub num_attention_heads: usize,
14    pub position_embedding_type: String,
15    pub intermediate_size: usize,
16    pub hidden_act: Activation,
17    pub num_hidden_layers: usize,
18    pub vocab_size: usize,
19    pub max_position_embeddings: usize,
20    pub type_vocab_size: usize,
21    pub pad_token_id: u32,
22}
23
24struct XLMRobertaEmbeddings {
25    word_embeddings: Embedding,
26    position_embeddings: Option<Embedding>,
27    token_type_embeddings: Embedding,
28    layer_norm: LayerNorm,
29    padding_idx: u32,
30    span: tracing::Span,
31}
32
33impl XLMRobertaEmbeddings {
34    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
35        let word_embeddings = embedding(
36            config.vocab_size,
37            config.hidden_size,
38            vb.pp("word_embeddings"),
39        )?;
40        let position_embeddings = embedding(
41            config.max_position_embeddings,
42            config.hidden_size,
43            vb.pp("position_embeddings"),
44        )?;
45        let token_type_embeddings = embedding(
46            config.type_vocab_size,
47            config.hidden_size,
48            vb.pp("token_type_embeddings"),
49        )?;
50        let layer_norm = layer_norm(
51            config.hidden_size,
52            config.layer_norm_eps,
53            vb.pp("LayerNorm"),
54        )?;
55        Ok(Self {
56            word_embeddings,
57            position_embeddings: Some(position_embeddings),
58            token_type_embeddings,
59            layer_norm,
60            padding_idx: config.pad_token_id,
61            span: tracing::span!(tracing::Level::TRACE, "embeddings"),
62        })
63    }
64
65    fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
66        let _enter = self.span.enter();
67        let (_bsize, _) = input_ids.dims2()?;
68        let input_embeddings = self.word_embeddings.forward(input_ids)?;
69        let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
70        let mut embeddings = (&input_embeddings + token_type_embeddings)?;
71        if let Some(position_embeddings) = &self.position_embeddings {
72            let mask = input_ids
73                .ne(self.padding_idx)?
74                .to_dtype(input_embeddings.dtype())?;
75            let cumsum = mask.cumsum(1)?;
76            let position_ids = (cumsum * mask)?
77                .broadcast_add(
78                    &Tensor::try_from(self.padding_idx)?
79                        .to_dtype(input_embeddings.dtype())?
80                        .to_device(input_embeddings.device())?,
81                )?
82                .to_dtype(candle::DType::U32)?;
83            embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?;
84        }
85        let embeddings = self.layer_norm.forward(&embeddings)?;
86        Ok(embeddings)
87    }
88}
89
90struct XLMRobertaSelfAttention {
91    num_attention_heads: usize,
92    attention_head_size: usize,
93    all_head_size: usize,
94    query: Linear,
95    key: Linear,
96    value: Linear,
97}
98
99impl XLMRobertaSelfAttention {
100    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
101        let attention_head_size = cfg.hidden_size / cfg.num_attention_heads;
102        let all_head_size = cfg.num_attention_heads * attention_head_size;
103        Ok(Self {
104            num_attention_heads: cfg.num_attention_heads,
105            attention_head_size,
106            all_head_size,
107            query: linear(cfg.hidden_size, all_head_size, vb.pp("query"))?,
108            key: linear(cfg.hidden_size, all_head_size, vb.pp("key"))?,
109            value: linear(cfg.hidden_size, all_head_size, vb.pp("value"))?,
110        })
111    }
112
113    fn transpose_for_scores(&self, x: &Tensor) -> Result<Tensor> {
114        let mut new_x_shape = x.dims().to_vec();
115        new_x_shape[2] = self.num_attention_heads;
116        new_x_shape.push(self.attention_head_size);
117        let x = x.reshape(new_x_shape)?;
118        x.permute((0, 2, 1, 3))?.contiguous()
119    }
120
121    fn forward(
122        &self,
123        hidden_states: &Tensor,
124        encoder_hidden_states: Option<&Tensor>,
125        attention_mask: &Tensor,
126        past_key_value: Option<(&Tensor, &Tensor)>,
127        encoder_attention_mask: Option<&Tensor>,
128    ) -> Result<Tensor> {
129        let mixed_query_layer = self.query.forward(hidden_states)?;
130        let is_cross_attention = encoder_hidden_states.is_some();
131        let (key_layer, value_layer, attention_mask) = if is_cross_attention
132            && past_key_value.is_some()
133        {
134            let key_layer = past_key_value.unwrap().0.clone();
135            let value_layer = past_key_value.unwrap().1.clone();
136            let attention_mask = encoder_attention_mask.unwrap().clone();
137            (key_layer, value_layer, Some(attention_mask))
138        } else if is_cross_attention {
139            let key_layer =
140                self.transpose_for_scores(&self.key.forward(encoder_hidden_states.unwrap())?)?;
141            let value_layer =
142                self.transpose_for_scores(&self.value.forward(encoder_hidden_states.unwrap())?)?;
143            let attention_mask = encoder_attention_mask.unwrap();
144            (key_layer, value_layer, Some(attention_mask.clone()))
145        } else if past_key_value.is_some() {
146            let mut key_layer = self.transpose_for_scores(&self.key.forward(hidden_states)?)?;
147            let mut value_layer = self.transpose_for_scores(&self.value.forward(hidden_states)?)?;
148            key_layer = Tensor::cat(
149                &[
150                    past_key_value.clone().as_ref().unwrap().0.clone(),
151                    key_layer,
152                ],
153                2,
154            )?;
155            value_layer = Tensor::cat(
156                &[past_key_value.as_ref().unwrap().1.clone(), value_layer],
157                2,
158            )?;
159            (key_layer, value_layer, Some(attention_mask.clone()))
160        } else {
161            let key_layer = self.transpose_for_scores(&self.key.forward(hidden_states)?)?;
162            let value_layer = self.transpose_for_scores(&self.value.forward(hidden_states)?)?;
163            (key_layer, value_layer, Some(attention_mask.clone()))
164        };
165
166        let query_layer = self.transpose_for_scores(&mixed_query_layer)?;
167        let mut attention_scores = query_layer.matmul(&key_layer.transpose(2, 3)?)?;
168        let scale = 1f64 / f64::sqrt(self.attention_head_size as f64);
169
170        attention_scores = (attention_scores * scale)?;
171        attention_scores = match attention_mask {
172            None => attention_scores,
173            Some(mask) => {
174                attention_scores.broadcast_add(&mask.to_dtype(attention_scores.dtype())?)?
175            }
176        };
177        let attention_probs = softmax_last_dim(&attention_scores)?;
178
179        let context_layer = attention_probs
180            .matmul(&value_layer)?
181            .permute((0, 2, 1, 3))?
182            .contiguous()?;
183        let mut new_context_layer_shape =
184            context_layer.dims()[..context_layer.dims().len() - 2].to_vec();
185        new_context_layer_shape.push(self.all_head_size);
186        let context_layer = context_layer.reshape(new_context_layer_shape)?;
187
188        Ok(context_layer)
189    }
190}
191
192struct XLMRobertaSelfOutput {
193    dense: Linear,
194    layernorm: LayerNorm,
195}
196
197impl XLMRobertaSelfOutput {
198    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
199        let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
200        let layernorm =
201            candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
202        Ok(Self { dense, layernorm })
203    }
204
205    fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
206        let hidden_states = self.dense.forward(hidden_states)?;
207        let hidden_states = self.layernorm.forward(&(hidden_states + input_tensor)?)?;
208        Ok(hidden_states)
209    }
210}
211
212struct XLMRobertaAttention {
213    output: XLMRobertaSelfOutput,
214    self_attention: XLMRobertaSelfAttention,
215}
216
217impl XLMRobertaAttention {
218    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
219        let output = XLMRobertaSelfOutput::new(cfg, vb.pp("output"))?;
220        let self_attention = XLMRobertaSelfAttention::new(cfg, vb.pp("self"))?;
221        Ok(Self {
222            output,
223            self_attention,
224        })
225    }
226
227    fn forward(
228        &self,
229        hidden_states: &Tensor,
230        attention_mask: &Tensor,
231        encoder_hidden_states: Option<&Tensor>,
232        encoder_attention_mask: Option<&Tensor>,
233        past_key_value: Option<(&Tensor, &Tensor)>,
234    ) -> Result<(Tensor, Tensor)> {
235        let self_outputs = self.self_attention.forward(
236            hidden_states,
237            encoder_hidden_states,
238            attention_mask,
239            past_key_value,
240            encoder_attention_mask,
241        )?;
242        let attention_output = self.output.forward(&self_outputs, hidden_states)?;
243        Ok((attention_output, self_outputs))
244    }
245}
246
247struct XLMRobertaOutput {
248    dense: Linear,
249    layernorm: LayerNorm,
250}
251
252impl XLMRobertaOutput {
253    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
254        let dense = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("dense"))?;
255        let layernorm =
256            candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
257        Ok(Self { dense, layernorm })
258    }
259
260    fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
261        let hidden_states = self.dense.forward(hidden_states)?;
262        let hidden_states = self.layernorm.forward(&(hidden_states + input_tensor)?)?;
263        Ok(hidden_states)
264    }
265}
266
267struct XLMRobertaIntermediate {
268    dense: Linear,
269    intermediate_act_fn: Activation,
270}
271
272impl XLMRobertaIntermediate {
273    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
274        let dense = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("dense"))?;
275        let intermediate_act_fn = cfg.hidden_act;
276        Ok(Self {
277            dense,
278            intermediate_act_fn,
279        })
280    }
281
282    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
283        let hidden_states = self.dense.forward(hidden_states)?;
284        let hidden_states = self.intermediate_act_fn.forward(&hidden_states)?;
285        Ok(hidden_states)
286    }
287}
288
289struct XLMRobertaLayer {
290    attention: XLMRobertaAttention,
291    intermediate: XLMRobertaIntermediate,
292    output: XLMRobertaOutput,
293}
294
295impl XLMRobertaLayer {
296    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
297        let attention = XLMRobertaAttention::new(cfg, vb.pp("attention"))?;
298        let intermediate = XLMRobertaIntermediate::new(cfg, vb.pp("intermediate"))?;
299        let output = XLMRobertaOutput::new(cfg, vb.pp("output"))?;
300        Ok(Self {
301            attention,
302            intermediate,
303            output,
304        })
305    }
306
307    fn forward(
308        &self,
309        hidden_states: &Tensor,
310        attention_mask: &Tensor,
311        encoder_hidden_states: Option<&Tensor>,
312        encoder_attention_mask: Option<&Tensor>,
313        past_key_value: Option<(&Tensor, &Tensor)>,
314    ) -> Result<(Tensor, Tensor)> {
315        let self_attention_outputs = self.attention.forward(
316            hidden_states,
317            attention_mask,
318            encoder_hidden_states,
319            encoder_attention_mask,
320            past_key_value,
321        )?;
322        let attention_output = self_attention_outputs.0;
323        let outputs = self_attention_outputs.1;
324        let intermediate_output = self.intermediate.forward(&attention_output)?;
325        let layer_output = self
326            .output
327            .forward(&intermediate_output, &attention_output)?;
328        Ok((layer_output, outputs))
329    }
330}
331
332struct XLMRobertaEncoder {
333    layers: Vec<XLMRobertaLayer>,
334}
335
336impl XLMRobertaEncoder {
337    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
338        let layers = (0..cfg.num_hidden_layers)
339            .map(|i| XLMRobertaLayer::new(cfg, vb.pp(format!("layer.{}", i))))
340            .collect::<Result<Vec<_>>>()?;
341        Ok(Self { layers })
342    }
343
344    fn forward(
345        &self,
346        hidden_states: &Tensor,
347        attention_mask: &Tensor,
348        encoder_hidden_states: Option<&Tensor>,
349        encoder_attention_mask: Option<&Tensor>,
350        past_key_value: Option<(&Tensor, &Tensor)>,
351    ) -> Result<Tensor> {
352        let mut hidden_states = hidden_states.clone();
353        for layer_module in self.layers.iter() {
354            let layer_outputs = layer_module.forward(
355                &hidden_states,
356                attention_mask,
357                encoder_hidden_states,
358                encoder_attention_mask,
359                past_key_value,
360            )?;
361            hidden_states = layer_outputs.0;
362        }
363        Ok(hidden_states)
364    }
365}
366
367pub struct XLMRobertaModel {
368    encoder: XLMRobertaEncoder,
369    embeddings: XLMRobertaEmbeddings,
370}
371
372impl XLMRobertaModel {
373    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
374        let encoder = XLMRobertaEncoder::new(cfg, vb.pp("encoder"))?;
375        let embeddings = XLMRobertaEmbeddings::load(vb.pp("embeddings"), cfg)?;
376        Ok(Self {
377            encoder,
378            embeddings,
379        })
380    }
381
382    pub fn forward(
383        &self,
384        input_ids: &Tensor,
385        attention_mask: &Tensor,
386        token_type_ids: &Tensor,
387        past_key_value: Option<(&Tensor, &Tensor)>,
388        encoder_hidden_states: Option<&Tensor>,
389        encoder_attention_mask: Option<&Tensor>,
390    ) -> Result<Tensor> {
391        let hidden_states = self.embeddings.forward(input_ids, token_type_ids)?;
392        let attention_mask = prepare_4d_attention_mask(attention_mask, DType::F32, None)?
393            .to_device(hidden_states.device())?;
394        let hidden_states = self.encoder.forward(
395            &hidden_states,
396            &attention_mask,
397            encoder_hidden_states,
398            encoder_attention_mask,
399            past_key_value,
400        )?;
401        Ok(hidden_states)
402    }
403}
404
405struct XLMRobertaLMHead {
406    dense: Linear,
407    layer_norm: LayerNorm,
408}
409
410impl XLMRobertaLMHead {
411    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
412        let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
413        let layer_norm =
414            candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("layer_norm"))?;
415        Ok(Self { dense, layer_norm })
416    }
417
418    fn forward(&self, hidden_states: &Tensor, shared_embeddings: &Tensor) -> Result<Tensor> {
419        let hidden_states = self.dense.forward(hidden_states)?;
420        let hidden_states = candle_nn::Activation::Gelu.forward(&hidden_states)?;
421        let hidden_states = self.layer_norm.forward(&hidden_states)?;
422        let hidden_states = hidden_states.broadcast_matmul(shared_embeddings)?;
423        Ok(hidden_states)
424    }
425}
426
427pub struct XLMRobertaForMaskedLM {
428    roberta: XLMRobertaModel,
429    lm_head: XLMRobertaLMHead,
430}
431
432impl XLMRobertaForMaskedLM {
433    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
434        let roberta = XLMRobertaModel::new(cfg, vb.pp("roberta"))?;
435        let lm_head = XLMRobertaLMHead::new(cfg, vb.pp("lm_head"))?;
436        Ok(Self { roberta, lm_head })
437    }
438
439    pub fn forward(
440        &self,
441        input_ids: &Tensor,
442        attention_mask: &Tensor,
443        token_type_ids: &Tensor,
444        past_key_value: Option<(&Tensor, &Tensor)>,
445        encoder_hidden_states: Option<&Tensor>,
446        encoder_attention_mask: Option<&Tensor>,
447    ) -> Result<Tensor> {
448        let hidden_states = self.roberta.forward(
449            input_ids,
450            attention_mask,
451            token_type_ids,
452            past_key_value,
453            encoder_hidden_states,
454            encoder_attention_mask,
455        )?;
456        let lm_logits = self.lm_head.forward(
457            &hidden_states,
458            &self
459                .roberta
460                .embeddings
461                .word_embeddings
462                .embeddings()
463                .t()?
464                .unsqueeze(0)?,
465        )?;
466        Ok(lm_logits)
467    }
468}
469
470struct XLMRobertaClassificationHead {
471    dense: Linear,
472    out_proj: Linear,
473}
474
475impl XLMRobertaClassificationHead {
476    fn new(num_labels: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
477        let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
478        let out_proj = linear(cfg.hidden_size, num_labels, vb.pp("out_proj"))?;
479        Ok(Self { dense, out_proj })
480    }
481
482    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
483        let cls_states = hidden_states.get_on_dim(1, 0)?.contiguous()?;
484        let hidden_states = self.dense.forward(&cls_states)?;
485        let hidden_states = candle_nn::Activation::GeluPytorchTanh.forward(&hidden_states)?;
486        let hidden_states = self.out_proj.forward(&hidden_states)?;
487        Ok(hidden_states)
488    }
489}
490
491pub struct XLMRobertaForSequenceClassification {
492    roberta: XLMRobertaModel,
493    classifier: XLMRobertaClassificationHead,
494}
495
496impl XLMRobertaForSequenceClassification {
497    pub fn new(num_labels: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
498        let roberta = XLMRobertaModel::new(cfg, vb.pp("roberta"))?;
499        let classifier = XLMRobertaClassificationHead::new(num_labels, cfg, vb.pp("classifier"))?;
500        Ok(Self {
501            roberta,
502            classifier,
503        })
504    }
505
506    pub fn forward(
507        &self,
508        input_ids: &Tensor,
509        attention_mask: &Tensor,
510        token_type_ids: &Tensor,
511    ) -> Result<Tensor> {
512        let hidden_states =
513            self.roberta
514                .forward(input_ids, attention_mask, token_type_ids, None, None, None)?;
515        self.classifier.forward(&hidden_states)
516    }
517}
518
519fn prepare_4d_attention_mask(
520    mask: &Tensor,
521    dtype: DType,
522    tgt_len: Option<usize>,
523) -> Result<Tensor> {
524    let bsz = mask.dim(0)?;
525    let src_len = mask.dim(1)?;
526    let tgt_len = tgt_len.unwrap_or(src_len);
527
528    let expanded_mask = mask
529        .unsqueeze(1)?
530        .unsqueeze(2)?
531        .expand((bsz, 1, tgt_len, src_len))?
532        .to_dtype(dtype)?;
533
534    let inverted_mask = (1.0 - expanded_mask)?;
535
536    (inverted_mask * get_dtype_min_val(dtype))?.to_dtype(dtype)
537}
538
539fn get_dtype_min_val(dtype: DType) -> f64 {
540    match dtype {
541        DType::F32 => f32::MIN as f64,
542        DType::F64 => f64::MIN,
543        _ => panic!("Unsupported data type"),
544    }
545}