1use std::collections::HashMap;
2
3use candle::{bail, Context, DType, Device, Module, Result, Tensor, D};
4use candle_nn::{
5 conv1d, embedding, layer_norm, Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder,
6};
7use serde::{Deserialize, Deserializer};
8
9pub const DTYPE: DType = DType::F32;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
13#[serde(rename_all = "lowercase")]
14pub enum HiddenAct {
15 Gelu,
16 GeluApproximate,
17 Relu,
18}
19
20pub struct HiddenActLayer {
21 act: HiddenAct,
22 span: tracing::Span,
23}
24
25impl HiddenActLayer {
26 fn new(act: HiddenAct) -> Self {
27 let span = tracing::span!(tracing::Level::TRACE, "hidden-act");
28 Self { act, span }
29 }
30
31 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
32 let _enter = self.span.enter();
33 match self.act {
34 HiddenAct::Gelu => xs.gelu_erf(),
36 HiddenAct::GeluApproximate => xs.gelu(),
37 HiddenAct::Relu => xs.relu(),
38 }
39 }
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
43#[serde(rename_all = "lowercase")]
44enum PositionEmbeddingType {
45 #[default]
46 Absolute,
47}
48
49pub type Id2Label = HashMap<u32, String>;
50pub type Label2Id = HashMap<String, u32>;
51
52#[derive(Debug, Clone, PartialEq, Deserialize)]
53pub struct Config {
54 pub vocab_size: usize,
55 pub hidden_size: usize,
56 pub num_hidden_layers: usize,
57 pub num_attention_heads: usize,
58 pub intermediate_size: usize,
59 pub hidden_act: HiddenAct,
60 pub hidden_dropout_prob: f64,
61 pub attention_probs_dropout_prob: f64,
62 pub max_position_embeddings: usize,
63 pub type_vocab_size: usize,
64 pub initializer_range: f64,
65 pub layer_norm_eps: f64,
66 pub relative_attention: bool,
67 pub max_relative_positions: isize,
68 pub pad_token_id: Option<usize>,
69 pub position_biased_input: bool,
70 #[serde(deserialize_with = "deserialize_pos_att_type")]
71 pub pos_att_type: Vec<String>,
72 pub position_buckets: Option<isize>,
73 pub share_att_key: Option<bool>,
74 pub attention_head_size: Option<usize>,
75 pub embedding_size: Option<usize>,
76 pub norm_rel_ebd: Option<String>,
77 pub conv_kernel_size: Option<usize>,
78 pub conv_groups: Option<usize>,
79 pub conv_act: Option<String>,
80 pub id2label: Option<Id2Label>,
81 pub label2id: Option<Label2Id>,
82 pub pooler_dropout: Option<f64>,
83 pub pooler_hidden_act: Option<HiddenAct>,
84 pub pooler_hidden_size: Option<usize>,
85 pub cls_dropout: Option<f64>,
86}
87
88fn deserialize_pos_att_type<'de, D>(deserializer: D) -> std::result::Result<Vec<String>, D::Error>
89where
90 D: Deserializer<'de>,
91{
92 #[derive(Deserialize, Debug)]
93 #[serde(untagged)]
94 enum StringOrVec {
95 String(String),
96 Vec(Vec<String>),
97 }
98
99 match StringOrVec::deserialize(deserializer)? {
100 StringOrVec::String(s) => Ok(s.split('|').map(String::from).collect()),
101 StringOrVec::Vec(v) => Ok(v),
102 }
103}
104
105pub struct StableDropout {
108 _drop_prob: f64,
109 _count: usize,
110}
111
112impl StableDropout {
113 pub fn new(drop_prob: f64) -> Self {
114 Self {
115 _drop_prob: drop_prob,
116 _count: 0,
117 }
118 }
119
120 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
121 Ok(x.clone())
122 }
123}
124
125pub struct DebertaV2Embeddings {
127 device: Device,
128 word_embeddings: Embedding,
129 position_embeddings: Option<Embedding>,
130 token_type_embeddings: Option<Embedding>,
131 layer_norm: LayerNorm,
132 dropout: StableDropout,
133 position_ids: Tensor,
134 config: Config,
135 embedding_size: usize,
136 embed_proj: Option<candle_nn::Linear>,
137}
138
139impl DebertaV2Embeddings {
140 pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
141 let device = vb.device().clone();
142 let config = config.clone();
143
144 let embedding_size = config.embedding_size.unwrap_or(config.hidden_size);
145
146 let word_embeddings =
147 embedding(config.vocab_size, embedding_size, vb.pp("word_embeddings"))?;
148
149 let position_embeddings = if config.position_biased_input {
150 Some(embedding(
151 config.max_position_embeddings,
152 embedding_size,
153 vb.pp("position_embeddings"),
154 )?)
155 } else {
156 None
157 };
158
159 let token_type_embeddings: Option<Embedding> = if config.type_vocab_size > 0 {
160 Some(candle_nn::embedding(
161 config.type_vocab_size,
162 config.hidden_size,
163 vb.pp("token_type_embeddings"),
164 )?)
165 } else {
166 None
167 };
168
169 let embed_proj: Option<candle_nn::Linear> = if embedding_size != config.hidden_size {
170 Some(candle_nn::linear_no_bias(
171 embedding_size,
172 config.hidden_size,
173 vb.pp("embed_proj"),
174 )?)
175 } else {
176 None
177 };
178
179 let layer_norm = layer_norm(
180 config.hidden_size,
181 config.layer_norm_eps,
182 vb.pp("LayerNorm"),
183 )?;
184
185 let dropout = StableDropout::new(config.hidden_dropout_prob);
186
187 let position_ids =
188 Tensor::arange(0, config.max_position_embeddings as u32, &device)?.unsqueeze(0)?;
189
190 Ok(Self {
191 word_embeddings,
192 position_embeddings,
193 token_type_embeddings,
194 layer_norm,
195 dropout,
196 position_ids,
197 device,
198 config,
199 embedding_size,
200 embed_proj,
201 })
202 }
203
204 pub fn forward(
205 &self,
206 input_ids: Option<&Tensor>,
207 token_type_ids: Option<&Tensor>,
208 position_ids: Option<&Tensor>,
209 mask: Option<&Tensor>,
210 inputs_embeds: Option<&Tensor>,
211 ) -> Result<Tensor> {
212 let (input_shape, input_embeds) = match (input_ids, inputs_embeds) {
213 (Some(ids), None) => {
214 let embs = self.word_embeddings.forward(ids)?;
215 (ids.dims(), embs)
216 }
217 (None, Some(e)) => (e.dims(), e.clone()),
218 (None, None) => {
219 bail!("Must specify either input_ids or inputs_embeds")
220 }
221 (Some(_), Some(_)) => {
222 bail!("Can't specify both input_ids and inputs_embeds")
223 }
224 };
225
226 let seq_length = match input_shape.last() {
227 Some(v) => *v,
228 None => bail!("DebertaV2Embeddings invalid input shape"),
229 };
230
231 let position_ids = match position_ids {
232 Some(v) => v.clone(),
233 None => self.position_ids.narrow(1, 0, seq_length)?,
234 };
235
236 let token_type_ids = match token_type_ids {
237 Some(ids) => ids.clone(),
238 None => Tensor::zeros(input_shape, DType::U32, &self.device)?,
239 };
240
241 let position_embeddings = match &self.position_embeddings {
242 Some(emb) => emb.forward(&position_ids)?,
243 None => Tensor::zeros_like(&input_embeds)?,
244 };
245
246 let mut embeddings = input_embeds;
247
248 if self.config.position_biased_input {
249 embeddings = embeddings.add(&position_embeddings)?;
250 }
251
252 if self.config.type_vocab_size > 0 {
253 embeddings = self.token_type_embeddings.as_ref().map_or_else(
254 || bail!("token_type_embeddings must be set when type_vocab_size > 0"),
255 |token_type_embeddings| {
256 embeddings.add(&token_type_embeddings.forward(&token_type_ids)?)
257 },
258 )?;
259 }
260
261 if self.embedding_size != self.config.hidden_size {
262 embeddings = if let Some(embed_proj) = &self.embed_proj {
263 embed_proj.forward(&embeddings)?
264 } else {
265 bail!("embed_proj must exist if embedding_size != config.hidden_size");
266 }
267 }
268
269 embeddings = self.layer_norm.forward(&embeddings)?;
270
271 if let Some(mask) = mask {
272 let mut mask = mask.clone();
273 if mask.dims() != embeddings.dims() {
274 if mask.dims().len() == 4 {
275 mask = mask.squeeze(1)?.squeeze(1)?;
276 }
277 mask = mask.unsqueeze(2)?;
278 }
279
280 mask = mask.to_dtype(embeddings.dtype())?;
281 embeddings = embeddings.broadcast_mul(&mask)?;
282 }
283
284 self.dropout.forward(&embeddings)
285 }
286}
287
288struct XSoftmax {}
290
291impl XSoftmax {
292 pub fn apply(input: &Tensor, mask: &Tensor, dim: D, device: &Device) -> Result<Tensor> {
293 let mut rmask = mask.broadcast_as(input.shape())?.to_dtype(DType::F32)?;
295
296 rmask = rmask
297 .broadcast_lt(&Tensor::new(&[1.0_f32], device)?)?
298 .to_dtype(DType::U8)?;
299
300 let min_value_tensor = Tensor::new(&[f32::MIN], device)?.broadcast_as(input.shape())?;
301 let mut output = rmask.where_cond(&min_value_tensor, input)?;
302
303 output = candle_nn::ops::softmax(&output, dim)?;
304
305 let t_zeroes = Tensor::new(&[0f32], device)?.broadcast_as(input.shape())?;
306 output = rmask.where_cond(&t_zeroes, &output)?;
307
308 Ok(output)
309 }
310}
311
312pub struct DebertaV2DisentangledSelfAttention {
314 config: Config,
315 num_attention_heads: usize,
316 query_proj: candle_nn::Linear,
317 key_proj: candle_nn::Linear,
318 value_proj: candle_nn::Linear,
319 dropout: StableDropout,
320 device: Device,
321 relative_attention: bool,
322 pos_dropout: Option<StableDropout>,
323 position_buckets: isize,
324 max_relative_positions: isize,
325 pos_ebd_size: isize,
326 share_att_key: bool,
327 pos_key_proj: Option<candle_nn::Linear>,
328 pos_query_proj: Option<candle_nn::Linear>,
329}
330
331impl DebertaV2DisentangledSelfAttention {
332 pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
333 let config = config.clone();
334 let vb = vb.clone();
335
336 if config.hidden_size % config.num_attention_heads != 0 {
337 return Err(candle::Error::Msg(format!(
338 "The hidden size {} is not a multiple of the number of attention heads {}",
339 config.hidden_size, config.num_attention_heads
340 )));
341 }
342
343 let num_attention_heads = config.num_attention_heads;
344
345 let attention_head_size = config
346 .attention_head_size
347 .unwrap_or(config.hidden_size / config.num_attention_heads);
348
349 let all_head_size = num_attention_heads * attention_head_size;
350
351 let query_proj = candle_nn::linear(config.hidden_size, all_head_size, vb.pp("query_proj"))?;
352 let key_proj = candle_nn::linear(config.hidden_size, all_head_size, vb.pp("key_proj"))?;
353 let value_proj = candle_nn::linear(config.hidden_size, all_head_size, vb.pp("value_proj"))?;
354
355 let share_att_key = config.share_att_key.unwrap_or(false);
356 let relative_attention = config.relative_attention;
357 let mut max_relative_positions = config.max_relative_positions;
358
359 let mut pos_ebd_size: isize = 0;
360 let position_buckets = config.position_buckets.unwrap_or(-1);
361 let mut pos_dropout: Option<StableDropout> = None;
362 let mut pos_key_proj: Option<candle_nn::Linear> = None;
363 let mut pos_query_proj: Option<candle_nn::Linear> = None;
364
365 if relative_attention {
366 if max_relative_positions < 1 {
367 max_relative_positions = config.max_position_embeddings as isize;
368 }
369 pos_ebd_size = max_relative_positions;
370 if position_buckets > 0 {
371 pos_ebd_size = position_buckets
372 }
373
374 pos_dropout = Some(StableDropout::new(config.hidden_dropout_prob));
375
376 if !share_att_key {
377 if config.pos_att_type.iter().any(|s| s == "c2p") {
378 pos_key_proj = Some(candle_nn::linear(
379 config.hidden_size,
380 all_head_size,
381 vb.pp("pos_key_proj"),
382 )?);
383 }
384 if config.pos_att_type.iter().any(|s| s == "p2c") {
385 pos_query_proj = Some(candle_nn::linear(
386 config.hidden_size,
387 all_head_size,
388 vb.pp("pos_query_proj"),
389 )?);
390 }
391 }
392 }
393
394 let dropout = StableDropout::new(config.attention_probs_dropout_prob);
395 let device = vb.device().clone();
396
397 Ok(Self {
398 config,
399 num_attention_heads,
400 query_proj,
401 key_proj,
402 value_proj,
403 dropout,
404 device,
405 relative_attention,
406 pos_dropout,
407 position_buckets,
408 max_relative_positions,
409 pos_ebd_size,
410 share_att_key,
411 pos_key_proj,
412 pos_query_proj,
413 })
414 }
415
416 pub fn forward(
417 &self,
418 hidden_states: &Tensor,
419 attention_mask: &Tensor,
420 query_states: Option<&Tensor>,
421 relative_pos: Option<&Tensor>,
422 rel_embeddings: Option<&Tensor>,
423 ) -> Result<Tensor> {
424 let query_states = match query_states {
425 Some(qs) => qs,
426 None => hidden_states,
427 };
428
429 let query_layer = self.transpose_for_scores(&self.query_proj.forward(query_states)?)?;
430 let key_layer = self.transpose_for_scores(&self.key_proj.forward(query_states)?)?;
431 let value_layer = self.transpose_for_scores(&self.value_proj.forward(query_states)?)?;
432
433 let mut rel_att: Option<Tensor> = None;
434
435 let mut scale_factor: usize = 1;
436
437 if self.config.pos_att_type.iter().any(|s| s == "c2p") {
438 scale_factor += 1;
439 }
440
441 if self.config.pos_att_type.iter().any(|s| s == "p2c") {
442 scale_factor += 1;
443 }
444
445 let scale = {
446 let q_size = query_layer.dim(D::Minus1)?;
447 Tensor::new(&[(q_size * scale_factor) as f32], &self.device)?.sqrt()?
448 };
449
450 let mut attention_scores: Tensor = {
451 let key_layer_transposed = key_layer.t()?;
452 let div = key_layer_transposed
453 .broadcast_div(scale.to_dtype(query_layer.dtype())?.as_ref())?;
454 query_layer.matmul(&div)?
455 };
456
457 if self.relative_attention {
458 if let Some(rel_embeddings) = rel_embeddings {
459 let rel_embeddings = self
460 .pos_dropout
461 .as_ref()
462 .context("relative_attention requires pos_dropout")?
463 .forward(rel_embeddings)?;
464 rel_att = Some(self.disentangled_attention_bias(
465 query_layer,
466 key_layer,
467 relative_pos,
468 rel_embeddings,
469 scale_factor,
470 )?);
471 }
472 }
473
474 if let Some(rel_att) = rel_att {
475 attention_scores = attention_scores.broadcast_add(&rel_att)?;
476 }
477
478 attention_scores = attention_scores.reshape((
479 (),
480 self.num_attention_heads,
481 attention_scores.dim(D::Minus2)?,
482 attention_scores.dim(D::Minus1)?,
483 ))?;
484
485 let mut attention_probs =
486 XSoftmax::apply(&attention_scores, attention_mask, D::Minus1, &self.device)?;
487
488 attention_probs = self.dropout.forward(&attention_probs)?;
489
490 let mut context_layer = attention_probs
491 .reshape((
492 (),
493 attention_probs.dim(D::Minus2)?,
494 attention_probs.dim(D::Minus1)?,
495 ))?
496 .matmul(&value_layer)?;
497
498 context_layer = context_layer
499 .reshape((
500 (),
501 self.num_attention_heads,
502 context_layer.dim(D::Minus2)?,
503 context_layer.dim(D::Minus1)?,
504 ))?
505 .permute((0, 2, 1, 3))?
506 .contiguous()?;
507
508 let dims = context_layer.dims();
509
510 context_layer = match dims.len() {
511 2 => context_layer.reshape(())?,
512 3 => context_layer.reshape((dims[0], ()))?,
513 4 => context_layer.reshape((dims[0], dims[1], ()))?,
514 5 => context_layer.reshape((dims[0], dims[1], dims[2], ()))?,
515 _ => {
516 bail!(
517 "Invalid shape for DisentabgledSelfAttention context layer: {:?}",
518 dims
519 )
520 }
521 };
522
523 Ok(context_layer)
524 }
525
526 fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
527 let dims = xs.dims().to_vec();
528 match dims.len() {
529 3 => {
530 let reshaped = xs.reshape((dims[0], dims[1], self.num_attention_heads, ()))?;
531
532 reshaped.transpose(1, 2)?.contiguous()?.reshape((
533 (),
534 reshaped.dim(1)?,
535 reshaped.dim(D::Minus1)?,
536 ))
537 }
538 shape => {
539 bail!("Invalid shape for transpose_for_scores. Expected 3 dimensions, got {shape}")
540 }
541 }
542 }
543
544 fn disentangled_attention_bias(
545 &self,
546 query_layer: Tensor,
547 key_layer: Tensor,
548 relative_pos: Option<&Tensor>,
549 rel_embeddings: Tensor,
550 scale_factor: usize,
551 ) -> Result<Tensor> {
552 let mut relative_pos = relative_pos.map_or(
553 build_relative_position(
554 query_layer.dim(D::Minus2)?,
555 key_layer.dim(D::Minus2)?,
556 &self.device,
557 Some(self.position_buckets),
558 Some(self.max_relative_positions),
559 )?,
560 |pos| pos.clone(),
561 );
562
563 relative_pos = match relative_pos.dims().len() {
564 2 => relative_pos.unsqueeze(0)?.unsqueeze(0)?,
565 3 => relative_pos.unsqueeze(1)?,
566 other => {
567 bail!("Relative position ids must be of dim 2 or 3 or 4. Got dim of size {other}")
568 }
569 };
570
571 let att_span = self.pos_ebd_size;
572
573 let rel_embeddings = rel_embeddings
574 .narrow(0, 0, (att_span * 2) as usize)?
575 .unsqueeze(0)?;
576
577 let mut pos_query_layer: Option<Tensor> = None;
578 let mut pos_key_layer: Option<Tensor> = None;
579
580 let repeat_with = query_layer.dim(0)? / self.num_attention_heads;
581 if self.share_att_key {
582 pos_query_layer = Some(
583 self.transpose_for_scores(&self.query_proj.forward(&rel_embeddings)?)?
584 .repeat(repeat_with)?,
585 );
586
587 pos_key_layer = Some(
588 self.transpose_for_scores(&self.key_proj.forward(&rel_embeddings)?)?
589 .repeat(repeat_with)?,
590 )
591 } else {
592 if self.config.pos_att_type.iter().any(|s| s == "c2p") {
593 pos_key_layer = Some(
594 self.transpose_for_scores(
595 &self
596 .pos_key_proj
597 .as_ref()
598 .context(
599 "Need pos_key_proj when share_att_key is false or not specified",
600 )?
601 .forward(&rel_embeddings)?,
602 )?
603 .repeat(repeat_with)?,
604 )
605 }
606 if self.config.pos_att_type.iter().any(|s| s == "p2c") {
607 pos_query_layer = Some(self.transpose_for_scores(&self
608 .pos_query_proj
609 .as_ref()
610 .context("Need a pos_query_proj when share_att_key is false or not specified")?
611 .forward(&rel_embeddings)?)?.repeat(repeat_with)?)
612 }
613 }
614
615 let mut score = Tensor::new(&[0 as f32], &self.device)?;
616
617 if self.config.pos_att_type.iter().any(|s| s == "c2p") {
618 let pos_key_layer = pos_key_layer.context("c2p without pos_key_layer")?;
619
620 let scale = Tensor::new(
621 &[(pos_key_layer.dim(D::Minus1)? * scale_factor) as f32],
622 &self.device,
623 )?
624 .sqrt()?;
625
626 let mut c2p_att = query_layer.matmul(&pos_key_layer.t()?)?;
627
628 let c2p_pos = relative_pos
629 .broadcast_add(&Tensor::new(&[att_span as i64], &self.device)?)?
630 .clamp(0 as f32, (att_span * 2 - 1) as f32)?;
631
632 c2p_att = c2p_att.gather(
633 &c2p_pos
634 .squeeze(0)?
635 .expand(&[
636 query_layer.dim(0)?,
637 query_layer.dim(1)?,
638 relative_pos.dim(D::Minus1)?,
639 ])?
640 .contiguous()?,
641 D::Minus1,
642 )?;
643
644 score = score.broadcast_add(
645 &c2p_att.broadcast_div(scale.to_dtype(c2p_att.dtype())?.as_ref())?,
646 )?;
647 }
648
649 if self.config.pos_att_type.iter().any(|s| s == "p2c") {
650 let pos_query_layer = pos_query_layer.context("p2c without pos_key_layer")?;
651
652 let scale = Tensor::new(
653 &[(pos_query_layer.dim(D::Minus1)? * scale_factor) as f32],
654 &self.device,
655 )?
656 .sqrt()?;
657
658 let r_pos = {
659 if key_layer.dim(D::Minus2)? != query_layer.dim(D::Minus2)? {
660 build_relative_position(
661 key_layer.dim(D::Minus2)?,
662 key_layer.dim(D::Minus2)?,
663 &self.device,
664 Some(self.position_buckets),
665 Some(self.max_relative_positions),
666 )?
667 .unsqueeze(0)?
668 } else {
669 relative_pos
670 }
671 };
672
673 let p2c_pos = r_pos
674 .to_dtype(DType::F32)?
675 .neg()?
676 .broadcast_add(&Tensor::new(&[att_span as f32], &self.device)?)?
677 .clamp(0f32, (att_span * 2 - 1) as f32)?;
678
679 let p2c_att = key_layer
680 .matmul(&pos_query_layer.t()?)?
681 .gather(
682 &p2c_pos
683 .squeeze(0)?
684 .expand(&[
685 query_layer.dim(0)?,
686 key_layer.dim(D::Minus2)?,
687 key_layer.dim(D::Minus2)?,
688 ])?
689 .contiguous()?
690 .to_dtype(DType::U32)?,
691 D::Minus1,
692 )?
693 .t()?;
694
695 score =
696 score.broadcast_add(&p2c_att.broadcast_div(&scale.to_dtype(p2c_att.dtype())?)?)?;
697 }
698
699 Ok(score)
700 }
701}
702
703pub struct DebertaV2Attention {
705 dsa: DebertaV2DisentangledSelfAttention,
706 output: DebertaV2SelfOutput,
707}
708
709impl DebertaV2Attention {
710 pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
711 let dsa = DebertaV2DisentangledSelfAttention::load(vb.pp("attention.self"), config)?;
712 let output = DebertaV2SelfOutput::load(vb.pp("attention.output"), config)?;
713 Ok(Self { dsa, output })
714 }
715
716 fn forward(
717 &self,
718 hidden_states: &Tensor,
719 attention_mask: &Tensor,
720 query_states: Option<&Tensor>,
721 relative_pos: Option<&Tensor>,
722 rel_embeddings: Option<&Tensor>,
723 ) -> Result<Tensor> {
724 let self_output = self.dsa.forward(
725 hidden_states,
726 attention_mask,
727 query_states,
728 relative_pos,
729 rel_embeddings,
730 )?;
731
732 self.output
733 .forward(&self_output, query_states.unwrap_or(hidden_states))
734 }
735}
736
737pub struct DebertaV2SelfOutput {
739 dense: candle_nn::Linear,
740 layer_norm: LayerNorm,
741 dropout: StableDropout,
742}
743
744impl DebertaV2SelfOutput {
745 pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
746 let dense = candle_nn::linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
747 let layer_norm = candle_nn::layer_norm(
748 config.hidden_size,
749 config.layer_norm_eps,
750 vb.pp("LayerNorm"),
751 )?;
752 let dropout = StableDropout::new(config.hidden_dropout_prob);
753 Ok(Self {
754 dense,
755 layer_norm,
756 dropout,
757 })
758 }
759
760 pub fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
761 let mut hidden_states = self.dense.forward(hidden_states)?;
762 hidden_states = self.dropout.forward(&hidden_states)?;
763 self.layer_norm
764 .forward(&hidden_states.broadcast_add(input_tensor)?)
765 }
766}
767
768pub struct DebertaV2Intermediate {
770 dense: candle_nn::Linear,
771 intermediate_act: HiddenActLayer,
772}
773
774impl DebertaV2Intermediate {
775 pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
776 let dense = candle_nn::linear(
777 config.hidden_size,
778 config.intermediate_size,
779 vb.pp("intermediate.dense"),
780 )?;
781 let intermediate_act = HiddenActLayer::new(config.hidden_act);
782 Ok(Self {
783 dense,
784 intermediate_act,
785 })
786 }
787
788 pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
789 self.intermediate_act
790 .forward(&self.dense.forward(hidden_states)?)
791 }
792}
793
794pub struct DebertaV2Output {
796 dense: candle_nn::Linear,
797 layer_norm: LayerNorm,
798 dropout: StableDropout,
799}
800
801impl DebertaV2Output {
802 pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
803 let dense = candle_nn::linear(
804 config.intermediate_size,
805 config.hidden_size,
806 vb.pp("output.dense"),
807 )?;
808 let layer_norm = candle_nn::layer_norm(
809 config.hidden_size,
810 config.layer_norm_eps,
811 vb.pp("output.LayerNorm"),
812 )?;
813 let dropout = StableDropout::new(config.hidden_dropout_prob);
814 Ok(Self {
815 dense,
816 layer_norm,
817 dropout,
818 })
819 }
820
821 pub fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
822 let mut hidden_states = self.dense.forward(hidden_states)?;
823 hidden_states = self.dropout.forward(&hidden_states)?;
824 hidden_states = {
825 let to_norm = hidden_states.broadcast_add(input_tensor)?;
826 self.layer_norm.forward(&to_norm)?
827 };
828 Ok(hidden_states)
829 }
830}
831
832pub struct DebertaV2Layer {
834 attention: DebertaV2Attention,
835 intermediate: DebertaV2Intermediate,
836 output: DebertaV2Output,
837}
838
839impl DebertaV2Layer {
840 pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
841 let attention = DebertaV2Attention::load(vb.clone(), config)?;
842 let intermediate = DebertaV2Intermediate::load(vb.clone(), config)?;
843 let output = DebertaV2Output::load(vb.clone(), config)?;
844 Ok(Self {
845 attention,
846 intermediate,
847 output,
848 })
849 }
850
851 fn forward(
852 &self,
853 hidden_states: &Tensor,
854 attention_mask: &Tensor,
855 query_states: Option<&Tensor>,
856 relative_pos: Option<&Tensor>,
857 rel_embeddings: Option<&Tensor>,
858 ) -> Result<Tensor> {
859 let attention_output = self.attention.forward(
860 hidden_states,
861 attention_mask,
862 query_states,
863 relative_pos,
864 rel_embeddings,
865 )?;
866
867 let intermediate_output = self.intermediate.forward(&attention_output)?;
868
869 let layer_output = self
870 .output
871 .forward(&intermediate_output, &attention_output)?;
872
873 Ok(layer_output)
874 }
875}
876
877pub struct ConvLayer {
880 _conv_act: String,
881 _conv: Conv1d,
882 _layer_norm: LayerNorm,
883 _dropout: StableDropout,
884 _config: Config,
885}
886
887impl ConvLayer {
888 pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
889 let config = config.clone();
890 let kernel_size = config.conv_kernel_size.unwrap_or(3);
891 let groups = config.conv_groups.unwrap_or(1);
892 let conv_act: String = config.conv_act.clone().unwrap_or("tanh".to_string());
893
894 let conv_conf = Conv1dConfig {
895 padding: (kernel_size - 1) / 2,
896 groups,
897 ..Default::default()
898 };
899
900 let conv = conv1d(
901 config.hidden_size,
902 config.hidden_size,
903 kernel_size,
904 conv_conf,
905 vb.pp("conv"),
906 )?;
907
908 let layer_norm = layer_norm(
909 config.hidden_size,
910 config.layer_norm_eps,
911 vb.pp("LayerNorm"),
912 )?;
913
914 let dropout = StableDropout::new(config.hidden_dropout_prob);
915
916 Ok(Self {
917 _conv_act: conv_act,
918 _conv: conv,
919 _layer_norm: layer_norm,
920 _dropout: dropout,
921 _config: config,
922 })
923 }
924
925 pub fn forward(
926 &self,
927 _hidden_states: &Tensor,
928 _residual_states: &Tensor,
929 _input_mask: &Tensor,
930 ) -> Result<Tensor> {
931 todo!("Need a model that contains a conv layer to test against.")
932 }
933}
934
935pub struct DebertaV2Encoder {
937 layer: Vec<DebertaV2Layer>,
938 relative_attention: bool,
939 max_relative_positions: isize,
940 position_buckets: isize,
941 rel_embeddings: Option<Embedding>,
942 norm_rel_ebd: String,
943 layer_norm: Option<LayerNorm>,
944 conv: Option<ConvLayer>,
945 device: Device,
946}
947
948impl DebertaV2Encoder {
949 pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
950 let layer = (0..config.num_hidden_layers)
951 .map(|index| DebertaV2Layer::load(vb.pp(format!("layer.{index}")), config))
952 .collect::<Result<Vec<_>>>()?;
953
954 let relative_attention = config.relative_attention;
955 let mut max_relative_positions = config.max_relative_positions;
956
957 let position_buckets = config.position_buckets.unwrap_or(-1);
958
959 let mut rel_embeddings: Option<Embedding> = None;
960
961 if relative_attention {
962 if max_relative_positions < 1 {
963 max_relative_positions = config.max_position_embeddings as isize;
964 }
965
966 let mut pos_ebd_size = max_relative_positions * 2;
967
968 if position_buckets > 0 {
969 pos_ebd_size = position_buckets * 2;
970 }
971
972 rel_embeddings = Some(embedding(
973 pos_ebd_size as usize,
974 config.hidden_size,
975 vb.pp("rel_embeddings"),
976 )?);
977 }
978
979 let norm_rel_ebd = match config.norm_rel_ebd.as_ref() {
982 Some(nre) => nre.trim().to_string(),
983 None => "none".to_string(),
984 };
985
986 let layer_norm: Option<LayerNorm> = if norm_rel_ebd == "layer_norm" {
987 Some(layer_norm(
988 config.hidden_size,
989 config.layer_norm_eps,
990 vb.pp("LayerNorm"),
991 )?)
992 } else {
993 None
994 };
995
996 let conv: Option<ConvLayer> = if config.conv_kernel_size.unwrap_or(0) > 0 {
997 Some(ConvLayer::load(vb.pp("conv"), config)?)
998 } else {
999 None
1000 };
1001
1002 Ok(Self {
1003 layer,
1004 relative_attention,
1005 max_relative_positions,
1006 position_buckets,
1007 rel_embeddings,
1008 norm_rel_ebd,
1009 layer_norm,
1010 conv,
1011 device: vb.device().clone(),
1012 })
1013 }
1014
1015 pub fn forward(
1016 &self,
1017 hidden_states: &Tensor,
1018 attention_mask: &Tensor,
1019 query_states: Option<&Tensor>,
1020 relative_pos: Option<&Tensor>,
1021 ) -> Result<Tensor> {
1022 let input_mask = if attention_mask.dims().len() <= 2 {
1023 attention_mask.clone()
1024 } else {
1025 attention_mask
1026 .sum_keepdim(attention_mask.rank() - 2)?
1027 .gt(0.)?
1028 };
1029
1030 let attention_mask = self.get_attention_mask(attention_mask.clone())?;
1031
1032 let relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)?;
1033
1034 let mut next_kv: Tensor = hidden_states.clone();
1035 let rel_embeddings = self.get_rel_embedding()?;
1036 let mut output_states = next_kv.to_owned();
1037 let mut query_states: Option<Tensor> = query_states.cloned();
1038
1039 for (i, layer_module) in self.layer.iter().enumerate() {
1040 output_states = layer_module.forward(
1045 next_kv.as_ref(),
1046 &attention_mask,
1047 query_states.as_ref(),
1048 relative_pos.as_ref(),
1049 rel_embeddings.as_ref(),
1050 )?;
1051
1052 if i == 0 {
1053 if let Some(conv) = &self.conv {
1054 output_states = conv.forward(hidden_states, &output_states, &input_mask)?;
1055 }
1056 }
1057
1058 if query_states.is_some() {
1059 query_states = Some(output_states.clone());
1060 } else {
1061 next_kv = output_states.clone();
1062 }
1063 }
1064
1065 Ok(output_states)
1066 }
1067
1068 fn get_attention_mask(&self, mut attention_mask: Tensor) -> Result<Tensor> {
1069 match attention_mask.dims().len() {
1070 0..=2 => {
1071 let extended_attention_mask = attention_mask.unsqueeze(1)?.unsqueeze(2)?;
1072 attention_mask = extended_attention_mask.broadcast_mul(
1073 &extended_attention_mask
1074 .squeeze(D::Minus2)?
1075 .unsqueeze(D::Minus1)?,
1076 )?;
1077 }
1078 3 => attention_mask = attention_mask.unsqueeze(1)?,
1079 len => bail!("Unsupported attentiom mask size length: {len}"),
1080 }
1081
1082 Ok(attention_mask)
1083 }
1084
1085 fn get_rel_pos(
1086 &self,
1087 hidden_states: &Tensor,
1088 query_states: Option<&Tensor>,
1089 relative_pos: Option<&Tensor>,
1090 ) -> Result<Option<Tensor>> {
1091 if self.relative_attention && relative_pos.is_none() {
1092 let q = if let Some(query_states) = query_states {
1093 query_states.dim(D::Minus2)?
1094 } else {
1095 hidden_states.dim(D::Minus2)?
1096 };
1097
1098 return Ok(Some(build_relative_position(
1099 q,
1100 hidden_states.dim(D::Minus2)?,
1101 &self.device,
1102 Some(self.position_buckets),
1103 Some(self.max_relative_positions),
1104 )?));
1105 }
1106
1107 if relative_pos.is_some() {
1108 Ok(relative_pos.cloned())
1109 } else {
1110 Ok(None)
1111 }
1112 }
1113 fn get_rel_embedding(&self) -> Result<Option<Tensor>> {
1114 if !self.relative_attention {
1115 return Ok(None);
1116 }
1117
1118 let rel_embeddings = self
1119 .rel_embeddings
1120 .as_ref()
1121 .context("self.rel_embeddings not present when using relative_attention")?
1122 .embeddings()
1123 .clone();
1124
1125 if !self.norm_rel_ebd.contains("layer_norm") {
1126 return Ok(Some(rel_embeddings));
1127 }
1128
1129 let layer_normed_embeddings = self
1130 .layer_norm
1131 .as_ref()
1132 .context("DebertaV2Encoder layer_norm is None when norm_rel_ebd contains layer_norm")?
1133 .forward(&rel_embeddings)?;
1134
1135 Ok(Some(layer_normed_embeddings))
1136 }
1137}
1138
1139pub struct DebertaV2Model {
1141 embeddings: DebertaV2Embeddings,
1142 encoder: DebertaV2Encoder,
1143 z_steps: usize,
1144 pub device: Device,
1145}
1146
1147impl DebertaV2Model {
1148 pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
1149 let vb = vb.clone();
1150 let embeddings = DebertaV2Embeddings::load(vb.pp("embeddings"), config)?;
1151 let encoder = DebertaV2Encoder::load(vb.pp("encoder"), config)?;
1152 let z_steps: usize = 0;
1153
1154 Ok(Self {
1155 embeddings,
1156 encoder,
1157 z_steps,
1158 device: vb.device().clone(),
1159 })
1160 }
1161
1162 pub fn forward(
1163 &self,
1164 input_ids: &Tensor,
1165 token_type_ids: Option<Tensor>,
1166 attention_mask: Option<Tensor>,
1167 ) -> Result<Tensor> {
1168 let input_ids_shape = input_ids.shape();
1169
1170 let attention_mask = match attention_mask {
1171 Some(mask) => mask,
1172 None => Tensor::ones(input_ids_shape, DType::I64, &self.device)?,
1173 };
1174
1175 let token_type_ids = match token_type_ids {
1176 Some(ids) => ids,
1177 None => Tensor::zeros(input_ids_shape, DType::U32, &self.device)?,
1178 };
1179
1180 let embedding_output = self.embeddings.forward(
1181 Some(input_ids),
1182 Some(&token_type_ids),
1183 None,
1184 Some(&attention_mask),
1185 None,
1186 )?;
1187
1188 let encoder_output =
1189 self.encoder
1190 .forward(&embedding_output, &attention_mask, None, None)?;
1191
1192 if self.z_steps > 1 {
1193 todo!("Complete DebertaV2Model forward() when z_steps > 1 -- Needs a model to test this situation.")
1194 }
1195
1196 Ok(encoder_output)
1197 }
1198}
1199
1200#[derive(Debug)]
1201pub struct NERItem {
1202 pub entity: String,
1203 pub word: String,
1204 pub score: f32,
1205 pub start: usize,
1206 pub end: usize,
1207 pub index: usize,
1208}
1209
1210#[derive(Debug)]
1211pub struct TextClassificationItem {
1212 pub label: String,
1213 pub score: f32,
1214}
1215
1216pub struct DebertaV2NERModel {
1217 pub device: Device,
1218 deberta: DebertaV2Model,
1219 dropout: candle_nn::Dropout,
1220 classifier: candle_nn::Linear,
1221}
1222
1223fn id2label_len(config: &Config, id2label: Option<HashMap<u32, String>>) -> Result<usize> {
1224 let id2label_len = match (&config.id2label, id2label) {
1225 (None, None) => bail!("Id2Label is either not present in the model configuration or not passed into DebertaV2NERModel::load as a parameter"),
1226 (None, Some(id2label_p)) => id2label_p.len(),
1227 (Some(id2label_c), None) => id2label_c.len(),
1228 (Some(id2label_c), Some(id2label_p)) => {
1229 if *id2label_c == id2label_p {
1230 id2label_c.len()
1231 } else {
1232 bail!("Id2Label is both present in the model configuration and provided as a parameter, and they are different.")
1233 }
1234 }
1235 };
1236 Ok(id2label_len)
1237}
1238
1239impl DebertaV2NERModel {
1240 pub fn load(vb: VarBuilder, config: &Config, id2label: Option<Id2Label>) -> Result<Self> {
1241 let id2label_len = id2label_len(config, id2label)?;
1242
1243 let deberta = DebertaV2Model::load(vb.clone(), config)?;
1244 let dropout = candle_nn::Dropout::new(config.hidden_dropout_prob as f32);
1245 let classifier: candle_nn::Linear = candle_nn::linear_no_bias(
1246 config.hidden_size,
1247 id2label_len,
1248 vb.root().pp("classifier"),
1249 )?;
1250
1251 Ok(Self {
1252 device: vb.device().clone(),
1253 deberta,
1254 dropout,
1255 classifier,
1256 })
1257 }
1258
1259 pub fn forward(
1260 &self,
1261 input_ids: &Tensor,
1262 token_type_ids: Option<Tensor>,
1263 attention_mask: Option<Tensor>,
1264 ) -> Result<Tensor> {
1265 let output = self
1266 .deberta
1267 .forward(input_ids, token_type_ids, attention_mask)?;
1268 let output = self.dropout.forward(&output, false)?;
1269 self.classifier.forward(&output)
1270 }
1271}
1272
1273pub struct DebertaV2SeqClassificationModel {
1274 pub device: Device,
1275 deberta: DebertaV2Model,
1276 dropout: StableDropout,
1277 pooler: DebertaV2ContextPooler,
1278 classifier: candle_nn::Linear,
1279}
1280
1281impl DebertaV2SeqClassificationModel {
1282 pub fn load(vb: VarBuilder, config: &Config, id2label: Option<Id2Label>) -> Result<Self> {
1283 let id2label_len = id2label_len(config, id2label)?;
1284 let deberta = DebertaV2Model::load(vb.clone(), config)?;
1285 let pooler = DebertaV2ContextPooler::load(vb.clone(), config)?;
1286 let output_dim = pooler.output_dim()?;
1287 let classifier = candle_nn::linear(output_dim, id2label_len, vb.root().pp("classifier"))?;
1288 let dropout = match config.cls_dropout {
1289 Some(cls_dropout) => StableDropout::new(cls_dropout),
1290 None => StableDropout::new(config.hidden_dropout_prob),
1291 };
1292
1293 Ok(Self {
1294 device: vb.device().clone(),
1295 deberta,
1296 dropout,
1297 pooler,
1298 classifier,
1299 })
1300 }
1301
1302 pub fn forward(
1303 &self,
1304 input_ids: &Tensor,
1305 token_type_ids: Option<Tensor>,
1306 attention_mask: Option<Tensor>,
1307 ) -> Result<Tensor> {
1308 let encoder_layer = self
1309 .deberta
1310 .forward(input_ids, token_type_ids, attention_mask)?;
1311 let pooled_output = self.pooler.forward(&encoder_layer)?;
1312 let pooled_output = self.dropout.forward(&pooled_output)?;
1313 self.classifier.forward(&pooled_output)
1314 }
1315}
1316
1317pub struct DebertaV2ContextPooler {
1318 dense: candle_nn::Linear,
1319 dropout: StableDropout,
1320 config: Config,
1321}
1322
1323impl DebertaV2ContextPooler {
1325 pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
1326 let pooler_hidden_size = config
1327 .pooler_hidden_size
1328 .context("config.pooler_hidden_size is required for DebertaV2ContextPooler")?;
1329
1330 let pooler_dropout = config
1331 .pooler_dropout
1332 .context("config.pooler_dropout is required for DebertaV2ContextPooler")?;
1333
1334 let dense = candle_nn::linear(
1335 pooler_hidden_size,
1336 pooler_hidden_size,
1337 vb.root().pp("pooler.dense"),
1338 )?;
1339
1340 let dropout = StableDropout::new(pooler_dropout);
1341
1342 Ok(Self {
1343 dense,
1344 dropout,
1345 config: config.clone(),
1346 })
1347 }
1348
1349 pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
1350 let context_token = hidden_states.narrow(1, 0, 1)?.squeeze(1)?;
1351 let context_token = self.dropout.forward(&context_token)?;
1352
1353 let pooled_output = self.dense.forward(&context_token.contiguous()?)?;
1354 let pooler_hidden_act = self
1355 .config
1356 .pooler_hidden_act
1357 .context("Could not obtain pooler hidden act from config")?;
1358
1359 HiddenActLayer::new(pooler_hidden_act).forward(&pooled_output)
1360 }
1361
1362 pub fn output_dim(&self) -> Result<usize> {
1363 self.config.pooler_hidden_size.context("DebertaV2ContextPooler cannot return output_dim (pooler_hidden_size) since it is not specified in the model config")
1364 }
1365}
1366
1367pub(crate) fn build_relative_position(
1369 query_size: usize,
1370 key_size: usize,
1371 device: &Device,
1372 bucket_size: Option<isize>,
1373 max_position: Option<isize>,
1374) -> Result<Tensor> {
1375 let q_ids = Tensor::arange(0, query_size as i64, device)?.unsqueeze(0)?;
1376 let k_ids: Tensor = Tensor::arange(0, key_size as i64, device)?.unsqueeze(D::Minus1)?;
1377 let mut rel_pos_ids = k_ids.broadcast_sub(&q_ids)?;
1378 let bucket_size = bucket_size.unwrap_or(-1);
1379 let max_position = max_position.unwrap_or(-1);
1380
1381 if bucket_size > 0 && max_position > 0 {
1382 rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position, device)?;
1383 }
1384
1385 rel_pos_ids = rel_pos_ids.to_dtype(DType::I64)?;
1386 rel_pos_ids = rel_pos_ids.narrow(0, 0, query_size)?;
1387 rel_pos_ids.unsqueeze(0)
1388}
1389
1390pub(crate) fn make_log_bucket_position(
1392 relative_pos: Tensor,
1393 bucket_size: isize,
1394 max_position: isize,
1395 device: &Device,
1396) -> Result<Tensor> {
1397 let sign = relative_pos.to_dtype(DType::F32)?.sign()?;
1398
1399 let mid = bucket_size / 2;
1400
1401 let lt_mid = relative_pos.lt(mid as i64)?;
1402 let gt_neg_mid = relative_pos.gt(-mid as i64)?;
1403
1404 let condition = lt_mid
1405 .to_dtype(candle::DType::F32)?
1406 .mul(>_neg_mid.to_dtype(candle::DType::F32)?)?
1407 .to_dtype(DType::U8)?;
1408
1409 let on_true = Tensor::new(&[(mid - 1) as u32], device)?
1410 .broadcast_as(relative_pos.shape())?
1411 .to_dtype(relative_pos.dtype())?;
1412
1413 let on_false = relative_pos
1414 .to_dtype(DType::F32)?
1415 .abs()?
1416 .to_dtype(DType::I64)?;
1417
1418 let abs_pos = condition.where_cond(&on_true, &on_false)?;
1419
1420 let mid_as_tensor = Tensor::from_slice(&[mid as f32], (1,), device)?;
1421
1422 let log_pos = {
1423 let first_log = abs_pos
1424 .to_dtype(DType::F32)?
1425 .broadcast_div(&mid_as_tensor)?
1426 .log()?;
1427
1428 let second_log =
1429 Tensor::from_slice(&[((max_position as f32 - 1.0) / mid as f32)], (1,), device)?
1430 .log()?;
1431
1432 let first_div_second = first_log.broadcast_div(&second_log)?;
1433
1434 let to_ceil = first_div_second
1435 .broadcast_mul(Tensor::from_slice(&[(mid - 1) as f32], (1,), device)?.as_ref())?;
1436
1437 let ceil = to_ceil.ceil()?;
1438
1439 ceil.broadcast_add(&mid_as_tensor)?
1440 };
1441
1442 Ok({
1443 let abs_pos_lte_mid = abs_pos.to_dtype(DType::F32)?.broadcast_le(&mid_as_tensor)?;
1444 let relative_pos = relative_pos.to_dtype(relative_pos.dtype())?;
1445 let log_pos_mul_sign = log_pos.broadcast_mul(&sign.to_dtype(DType::F32)?)?;
1446 abs_pos_lte_mid.where_cond(&relative_pos.to_dtype(DType::F32)?, &log_pos_mul_sign)?
1447 })
1448}