candle_transformers/models/
debertav2.rs

1use std::collections::HashMap;
2
3use candle::{bail, Context, DType, Device, Module, Result, Tensor, D};
4use candle_nn::{
5    conv1d, embedding, layer_norm, Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder,
6};
7use serde::{Deserialize, Deserializer};
8
9pub const DTYPE: DType = DType::F32;
10
11// NOTE: HiddenAct and HiddenActLayer are both direct copies from bert.rs.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
13#[serde(rename_all = "lowercase")]
14pub enum HiddenAct {
15    Gelu,
16    GeluApproximate,
17    Relu,
18}
19
20pub struct HiddenActLayer {
21    act: HiddenAct,
22    span: tracing::Span,
23}
24
25impl HiddenActLayer {
26    fn new(act: HiddenAct) -> Self {
27        let span = tracing::span!(tracing::Level::TRACE, "hidden-act");
28        Self { act, span }
29    }
30
31    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
32        let _enter = self.span.enter();
33        match self.act {
34            // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
35            HiddenAct::Gelu => xs.gelu_erf(),
36            HiddenAct::GeluApproximate => xs.gelu(),
37            HiddenAct::Relu => xs.relu(),
38        }
39    }
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
43#[serde(rename_all = "lowercase")]
44enum PositionEmbeddingType {
45    #[default]
46    Absolute,
47}
48
49pub type Id2Label = HashMap<u32, String>;
50pub type Label2Id = HashMap<String, u32>;
51
52#[derive(Debug, Clone, PartialEq, Deserialize)]
53pub struct Config {
54    pub vocab_size: usize,
55    pub hidden_size: usize,
56    pub num_hidden_layers: usize,
57    pub num_attention_heads: usize,
58    pub intermediate_size: usize,
59    pub hidden_act: HiddenAct,
60    pub hidden_dropout_prob: f64,
61    pub attention_probs_dropout_prob: f64,
62    pub max_position_embeddings: usize,
63    pub type_vocab_size: usize,
64    pub initializer_range: f64,
65    pub layer_norm_eps: f64,
66    pub relative_attention: bool,
67    pub max_relative_positions: isize,
68    pub pad_token_id: Option<usize>,
69    pub position_biased_input: bool,
70    #[serde(deserialize_with = "deserialize_pos_att_type")]
71    pub pos_att_type: Vec<String>,
72    pub position_buckets: Option<isize>,
73    pub share_att_key: Option<bool>,
74    pub attention_head_size: Option<usize>,
75    pub embedding_size: Option<usize>,
76    pub norm_rel_ebd: Option<String>,
77    pub conv_kernel_size: Option<usize>,
78    pub conv_groups: Option<usize>,
79    pub conv_act: Option<String>,
80    pub id2label: Option<Id2Label>,
81    pub label2id: Option<Label2Id>,
82    pub pooler_dropout: Option<f64>,
83    pub pooler_hidden_act: Option<HiddenAct>,
84    pub pooler_hidden_size: Option<usize>,
85    pub cls_dropout: Option<f64>,
86}
87
88fn deserialize_pos_att_type<'de, D>(deserializer: D) -> std::result::Result<Vec<String>, D::Error>
89where
90    D: Deserializer<'de>,
91{
92    #[derive(Deserialize, Debug)]
93    #[serde(untagged)]
94    enum StringOrVec {
95        String(String),
96        Vec(Vec<String>),
97    }
98
99    match StringOrVec::deserialize(deserializer)? {
100        StringOrVec::String(s) => Ok(s.split('|').map(String::from).collect()),
101        StringOrVec::Vec(v) => Ok(v),
102    }
103}
104
105// NOTE: Dropout is probably not needed for now since this will primarily be used
106// in inferencing. However, for training/fine-tuning it will be necessary.
107pub struct StableDropout {
108    _drop_prob: f64,
109    _count: usize,
110}
111
112impl StableDropout {
113    pub fn new(drop_prob: f64) -> Self {
114        Self {
115            _drop_prob: drop_prob,
116            _count: 0,
117        }
118    }
119
120    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
121        Ok(x.clone())
122    }
123}
124
125// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L823
126pub struct DebertaV2Embeddings {
127    device: Device,
128    word_embeddings: Embedding,
129    position_embeddings: Option<Embedding>,
130    token_type_embeddings: Option<Embedding>,
131    layer_norm: LayerNorm,
132    dropout: StableDropout,
133    position_ids: Tensor,
134    config: Config,
135    embedding_size: usize,
136    embed_proj: Option<candle_nn::Linear>,
137}
138
139impl DebertaV2Embeddings {
140    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
141        let device = vb.device().clone();
142        let config = config.clone();
143
144        let embedding_size = config.embedding_size.unwrap_or(config.hidden_size);
145
146        let word_embeddings =
147            embedding(config.vocab_size, embedding_size, vb.pp("word_embeddings"))?;
148
149        let position_embeddings = if config.position_biased_input {
150            Some(embedding(
151                config.max_position_embeddings,
152                embedding_size,
153                vb.pp("position_embeddings"),
154            )?)
155        } else {
156            None
157        };
158
159        let token_type_embeddings: Option<Embedding> = if config.type_vocab_size > 0 {
160            Some(candle_nn::embedding(
161                config.type_vocab_size,
162                config.hidden_size,
163                vb.pp("token_type_embeddings"),
164            )?)
165        } else {
166            None
167        };
168
169        let embed_proj: Option<candle_nn::Linear> = if embedding_size != config.hidden_size {
170            Some(candle_nn::linear_no_bias(
171                embedding_size,
172                config.hidden_size,
173                vb.pp("embed_proj"),
174            )?)
175        } else {
176            None
177        };
178
179        let layer_norm = layer_norm(
180            config.hidden_size,
181            config.layer_norm_eps,
182            vb.pp("LayerNorm"),
183        )?;
184
185        let dropout = StableDropout::new(config.hidden_dropout_prob);
186
187        let position_ids =
188            Tensor::arange(0, config.max_position_embeddings as u32, &device)?.unsqueeze(0)?;
189
190        Ok(Self {
191            word_embeddings,
192            position_embeddings,
193            token_type_embeddings,
194            layer_norm,
195            dropout,
196            position_ids,
197            device,
198            config,
199            embedding_size,
200            embed_proj,
201        })
202    }
203
204    pub fn forward(
205        &self,
206        input_ids: Option<&Tensor>,
207        token_type_ids: Option<&Tensor>,
208        position_ids: Option<&Tensor>,
209        mask: Option<&Tensor>,
210        inputs_embeds: Option<&Tensor>,
211    ) -> Result<Tensor> {
212        let (input_shape, input_embeds) = match (input_ids, inputs_embeds) {
213            (Some(ids), None) => {
214                let embs = self.word_embeddings.forward(ids)?;
215                (ids.dims(), embs)
216            }
217            (None, Some(e)) => (e.dims(), e.clone()),
218            (None, None) => {
219                bail!("Must specify either input_ids or inputs_embeds")
220            }
221            (Some(_), Some(_)) => {
222                bail!("Can't specify both input_ids and inputs_embeds")
223            }
224        };
225
226        let seq_length = match input_shape.last() {
227            Some(v) => *v,
228            None => bail!("DebertaV2Embeddings invalid input shape"),
229        };
230
231        let position_ids = match position_ids {
232            Some(v) => v.clone(),
233            None => self.position_ids.narrow(1, 0, seq_length)?,
234        };
235
236        let token_type_ids = match token_type_ids {
237            Some(ids) => ids.clone(),
238            None => Tensor::zeros(input_shape, DType::U32, &self.device)?,
239        };
240
241        let position_embeddings = match &self.position_embeddings {
242            Some(emb) => emb.forward(&position_ids)?,
243            None => Tensor::zeros_like(&input_embeds)?,
244        };
245
246        let mut embeddings = input_embeds;
247
248        if self.config.position_biased_input {
249            embeddings = embeddings.add(&position_embeddings)?;
250        }
251
252        if self.config.type_vocab_size > 0 {
253            embeddings = self.token_type_embeddings.as_ref().map_or_else(
254                || bail!("token_type_embeddings must be set when type_vocab_size > 0"),
255                |token_type_embeddings| {
256                    embeddings.add(&token_type_embeddings.forward(&token_type_ids)?)
257                },
258            )?;
259        }
260
261        if self.embedding_size != self.config.hidden_size {
262            embeddings = if let Some(embed_proj) = &self.embed_proj {
263                embed_proj.forward(&embeddings)?
264            } else {
265                bail!("embed_proj must exist if embedding_size != config.hidden_size");
266            }
267        }
268
269        embeddings = self.layer_norm.forward(&embeddings)?;
270
271        if let Some(mask) = mask {
272            let mut mask = mask.clone();
273            if mask.dims() != embeddings.dims() {
274                if mask.dims().len() == 4 {
275                    mask = mask.squeeze(1)?.squeeze(1)?;
276                }
277                mask = mask.unsqueeze(2)?;
278            }
279
280            mask = mask.to_dtype(embeddings.dtype())?;
281            embeddings = embeddings.broadcast_mul(&mask)?;
282        }
283
284        self.dropout.forward(&embeddings)
285    }
286}
287
288// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L72
289struct XSoftmax {}
290
291impl XSoftmax {
292    pub fn apply(input: &Tensor, mask: &Tensor, dim: D, device: &Device) -> Result<Tensor> {
293        // NOTE: At the time of this writing, candle does not have a logical-not operator.
294        let mut rmask = mask.broadcast_as(input.shape())?.to_dtype(DType::F32)?;
295
296        rmask = rmask
297            .broadcast_lt(&Tensor::new(&[1.0_f32], device)?)?
298            .to_dtype(DType::U8)?;
299
300        let min_value_tensor = Tensor::new(&[f32::MIN], device)?.broadcast_as(input.shape())?;
301        let mut output = rmask.where_cond(&min_value_tensor, input)?;
302
303        output = candle_nn::ops::softmax(&output, dim)?;
304
305        let t_zeroes = Tensor::new(&[0f32], device)?.broadcast_as(input.shape())?;
306        output = rmask.where_cond(&t_zeroes, &output)?;
307
308        Ok(output)
309    }
310}
311
312// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L605
313pub struct DebertaV2DisentangledSelfAttention {
314    config: Config,
315    num_attention_heads: usize,
316    query_proj: candle_nn::Linear,
317    key_proj: candle_nn::Linear,
318    value_proj: candle_nn::Linear,
319    dropout: StableDropout,
320    device: Device,
321    relative_attention: bool,
322    pos_dropout: Option<StableDropout>,
323    position_buckets: isize,
324    max_relative_positions: isize,
325    pos_ebd_size: isize,
326    share_att_key: bool,
327    pos_key_proj: Option<candle_nn::Linear>,
328    pos_query_proj: Option<candle_nn::Linear>,
329}
330
331impl DebertaV2DisentangledSelfAttention {
332    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
333        let config = config.clone();
334        let vb = vb.clone();
335
336        if config.hidden_size % config.num_attention_heads != 0 {
337            return Err(candle::Error::Msg(format!(
338                "The hidden size {} is not a multiple of the number of attention heads {}",
339                config.hidden_size, config.num_attention_heads
340            )));
341        }
342
343        let num_attention_heads = config.num_attention_heads;
344
345        let attention_head_size = config
346            .attention_head_size
347            .unwrap_or(config.hidden_size / config.num_attention_heads);
348
349        let all_head_size = num_attention_heads * attention_head_size;
350
351        let query_proj = candle_nn::linear(config.hidden_size, all_head_size, vb.pp("query_proj"))?;
352        let key_proj = candle_nn::linear(config.hidden_size, all_head_size, vb.pp("key_proj"))?;
353        let value_proj = candle_nn::linear(config.hidden_size, all_head_size, vb.pp("value_proj"))?;
354
355        let share_att_key = config.share_att_key.unwrap_or(false);
356        let relative_attention = config.relative_attention;
357        let mut max_relative_positions = config.max_relative_positions;
358
359        let mut pos_ebd_size: isize = 0;
360        let position_buckets = config.position_buckets.unwrap_or(-1);
361        let mut pos_dropout: Option<StableDropout> = None;
362        let mut pos_key_proj: Option<candle_nn::Linear> = None;
363        let mut pos_query_proj: Option<candle_nn::Linear> = None;
364
365        if relative_attention {
366            if max_relative_positions < 1 {
367                max_relative_positions = config.max_position_embeddings as isize;
368            }
369            pos_ebd_size = max_relative_positions;
370            if position_buckets > 0 {
371                pos_ebd_size = position_buckets
372            }
373
374            pos_dropout = Some(StableDropout::new(config.hidden_dropout_prob));
375
376            if !share_att_key {
377                if config.pos_att_type.iter().any(|s| s == "c2p") {
378                    pos_key_proj = Some(candle_nn::linear(
379                        config.hidden_size,
380                        all_head_size,
381                        vb.pp("pos_key_proj"),
382                    )?);
383                }
384                if config.pos_att_type.iter().any(|s| s == "p2c") {
385                    pos_query_proj = Some(candle_nn::linear(
386                        config.hidden_size,
387                        all_head_size,
388                        vb.pp("pos_query_proj"),
389                    )?);
390                }
391            }
392        }
393
394        let dropout = StableDropout::new(config.attention_probs_dropout_prob);
395        let device = vb.device().clone();
396
397        Ok(Self {
398            config,
399            num_attention_heads,
400            query_proj,
401            key_proj,
402            value_proj,
403            dropout,
404            device,
405            relative_attention,
406            pos_dropout,
407            position_buckets,
408            max_relative_positions,
409            pos_ebd_size,
410            share_att_key,
411            pos_key_proj,
412            pos_query_proj,
413        })
414    }
415
416    pub fn forward(
417        &self,
418        hidden_states: &Tensor,
419        attention_mask: &Tensor,
420        query_states: Option<&Tensor>,
421        relative_pos: Option<&Tensor>,
422        rel_embeddings: Option<&Tensor>,
423    ) -> Result<Tensor> {
424        let query_states = match query_states {
425            Some(qs) => qs,
426            None => hidden_states,
427        };
428
429        let query_layer = self.transpose_for_scores(&self.query_proj.forward(query_states)?)?;
430        let key_layer = self.transpose_for_scores(&self.key_proj.forward(query_states)?)?;
431        let value_layer = self.transpose_for_scores(&self.value_proj.forward(query_states)?)?;
432
433        let mut rel_att: Option<Tensor> = None;
434
435        let mut scale_factor: usize = 1;
436
437        if self.config.pos_att_type.iter().any(|s| s == "c2p") {
438            scale_factor += 1;
439        }
440
441        if self.config.pos_att_type.iter().any(|s| s == "p2c") {
442            scale_factor += 1;
443        }
444
445        let scale = {
446            let q_size = query_layer.dim(D::Minus1)?;
447            Tensor::new(&[(q_size * scale_factor) as f32], &self.device)?.sqrt()?
448        };
449
450        let mut attention_scores: Tensor = {
451            let key_layer_transposed = key_layer.t()?;
452            let div = key_layer_transposed
453                .broadcast_div(scale.to_dtype(query_layer.dtype())?.as_ref())?;
454            query_layer.matmul(&div)?
455        };
456
457        if self.relative_attention {
458            if let Some(rel_embeddings) = rel_embeddings {
459                let rel_embeddings = self
460                    .pos_dropout
461                    .as_ref()
462                    .context("relative_attention requires pos_dropout")?
463                    .forward(rel_embeddings)?;
464                rel_att = Some(self.disentangled_attention_bias(
465                    query_layer,
466                    key_layer,
467                    relative_pos,
468                    rel_embeddings,
469                    scale_factor,
470                )?);
471            }
472        }
473
474        if let Some(rel_att) = rel_att {
475            attention_scores = attention_scores.broadcast_add(&rel_att)?;
476        }
477
478        attention_scores = attention_scores.reshape((
479            (),
480            self.num_attention_heads,
481            attention_scores.dim(D::Minus2)?,
482            attention_scores.dim(D::Minus1)?,
483        ))?;
484
485        let mut attention_probs =
486            XSoftmax::apply(&attention_scores, attention_mask, D::Minus1, &self.device)?;
487
488        attention_probs = self.dropout.forward(&attention_probs)?;
489
490        let mut context_layer = attention_probs
491            .reshape((
492                (),
493                attention_probs.dim(D::Minus2)?,
494                attention_probs.dim(D::Minus1)?,
495            ))?
496            .matmul(&value_layer)?;
497
498        context_layer = context_layer
499            .reshape((
500                (),
501                self.num_attention_heads,
502                context_layer.dim(D::Minus2)?,
503                context_layer.dim(D::Minus1)?,
504            ))?
505            .permute((0, 2, 1, 3))?
506            .contiguous()?;
507
508        let dims = context_layer.dims();
509
510        context_layer = match dims.len() {
511            2 => context_layer.reshape(())?,
512            3 => context_layer.reshape((dims[0], ()))?,
513            4 => context_layer.reshape((dims[0], dims[1], ()))?,
514            5 => context_layer.reshape((dims[0], dims[1], dims[2], ()))?,
515            _ => {
516                bail!(
517                    "Invalid shape for DisentabgledSelfAttention context layer: {:?}",
518                    dims
519                )
520            }
521        };
522
523        Ok(context_layer)
524    }
525
526    fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
527        let dims = xs.dims().to_vec();
528        match dims.len() {
529            3 => {
530                let reshaped = xs.reshape((dims[0], dims[1], self.num_attention_heads, ()))?;
531
532                reshaped.transpose(1, 2)?.contiguous()?.reshape((
533                    (),
534                    reshaped.dim(1)?,
535                    reshaped.dim(D::Minus1)?,
536                ))
537            }
538            shape => {
539                bail!("Invalid shape for transpose_for_scores. Expected 3 dimensions, got {shape}")
540            }
541        }
542    }
543
544    fn disentangled_attention_bias(
545        &self,
546        query_layer: Tensor,
547        key_layer: Tensor,
548        relative_pos: Option<&Tensor>,
549        rel_embeddings: Tensor,
550        scale_factor: usize,
551    ) -> Result<Tensor> {
552        let mut relative_pos = relative_pos.map_or(
553            build_relative_position(
554                query_layer.dim(D::Minus2)?,
555                key_layer.dim(D::Minus2)?,
556                &self.device,
557                Some(self.position_buckets),
558                Some(self.max_relative_positions),
559            )?,
560            |pos| pos.clone(),
561        );
562
563        relative_pos = match relative_pos.dims().len() {
564            2 => relative_pos.unsqueeze(0)?.unsqueeze(0)?,
565            3 => relative_pos.unsqueeze(1)?,
566            other => {
567                bail!("Relative position ids must be of dim 2 or 3 or 4. Got dim of size {other}")
568            }
569        };
570
571        let att_span = self.pos_ebd_size;
572
573        let rel_embeddings = rel_embeddings
574            .narrow(0, 0, (att_span * 2) as usize)?
575            .unsqueeze(0)?;
576
577        let mut pos_query_layer: Option<Tensor> = None;
578        let mut pos_key_layer: Option<Tensor> = None;
579
580        let repeat_with = query_layer.dim(0)? / self.num_attention_heads;
581        if self.share_att_key {
582            pos_query_layer = Some(
583                self.transpose_for_scores(&self.query_proj.forward(&rel_embeddings)?)?
584                    .repeat(repeat_with)?,
585            );
586
587            pos_key_layer = Some(
588                self.transpose_for_scores(&self.key_proj.forward(&rel_embeddings)?)?
589                    .repeat(repeat_with)?,
590            )
591        } else {
592            if self.config.pos_att_type.iter().any(|s| s == "c2p") {
593                pos_key_layer = Some(
594                    self.transpose_for_scores(
595                        &self
596                            .pos_key_proj
597                            .as_ref()
598                            .context(
599                                "Need pos_key_proj when share_att_key is false or not specified",
600                            )?
601                            .forward(&rel_embeddings)?,
602                    )?
603                    .repeat(repeat_with)?,
604                )
605            }
606            if self.config.pos_att_type.iter().any(|s| s == "p2c") {
607                pos_query_layer = Some(self.transpose_for_scores(&self
608                    .pos_query_proj
609                    .as_ref()
610                    .context("Need a pos_query_proj when share_att_key is false or not specified")?
611                    .forward(&rel_embeddings)?)?.repeat(repeat_with)?)
612            }
613        }
614
615        let mut score = Tensor::new(&[0 as f32], &self.device)?;
616
617        if self.config.pos_att_type.iter().any(|s| s == "c2p") {
618            let pos_key_layer = pos_key_layer.context("c2p without pos_key_layer")?;
619
620            let scale = Tensor::new(
621                &[(pos_key_layer.dim(D::Minus1)? * scale_factor) as f32],
622                &self.device,
623            )?
624            .sqrt()?;
625
626            let mut c2p_att = query_layer.matmul(&pos_key_layer.t()?)?;
627
628            let c2p_pos = relative_pos
629                .broadcast_add(&Tensor::new(&[att_span as i64], &self.device)?)?
630                .clamp(0 as f32, (att_span * 2 - 1) as f32)?;
631
632            c2p_att = c2p_att.gather(
633                &c2p_pos
634                    .squeeze(0)?
635                    .expand(&[
636                        query_layer.dim(0)?,
637                        query_layer.dim(1)?,
638                        relative_pos.dim(D::Minus1)?,
639                    ])?
640                    .contiguous()?,
641                D::Minus1,
642            )?;
643
644            score = score.broadcast_add(
645                &c2p_att.broadcast_div(scale.to_dtype(c2p_att.dtype())?.as_ref())?,
646            )?;
647        }
648
649        if self.config.pos_att_type.iter().any(|s| s == "p2c") {
650            let pos_query_layer = pos_query_layer.context("p2c without pos_key_layer")?;
651
652            let scale = Tensor::new(
653                &[(pos_query_layer.dim(D::Minus1)? * scale_factor) as f32],
654                &self.device,
655            )?
656            .sqrt()?;
657
658            let r_pos = {
659                if key_layer.dim(D::Minus2)? != query_layer.dim(D::Minus2)? {
660                    build_relative_position(
661                        key_layer.dim(D::Minus2)?,
662                        key_layer.dim(D::Minus2)?,
663                        &self.device,
664                        Some(self.position_buckets),
665                        Some(self.max_relative_positions),
666                    )?
667                    .unsqueeze(0)?
668                } else {
669                    relative_pos
670                }
671            };
672
673            let p2c_pos = r_pos
674                .to_dtype(DType::F32)?
675                .neg()?
676                .broadcast_add(&Tensor::new(&[att_span as f32], &self.device)?)?
677                .clamp(0f32, (att_span * 2 - 1) as f32)?;
678
679            let p2c_att = key_layer
680                .matmul(&pos_query_layer.t()?)?
681                .gather(
682                    &p2c_pos
683                        .squeeze(0)?
684                        .expand(&[
685                            query_layer.dim(0)?,
686                            key_layer.dim(D::Minus2)?,
687                            key_layer.dim(D::Minus2)?,
688                        ])?
689                        .contiguous()?
690                        .to_dtype(DType::U32)?,
691                    D::Minus1,
692                )?
693                .t()?;
694
695            score =
696                score.broadcast_add(&p2c_att.broadcast_div(&scale.to_dtype(p2c_att.dtype())?)?)?;
697        }
698
699        Ok(score)
700    }
701}
702
703// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L270
704pub struct DebertaV2Attention {
705    dsa: DebertaV2DisentangledSelfAttention,
706    output: DebertaV2SelfOutput,
707}
708
709impl DebertaV2Attention {
710    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
711        let dsa = DebertaV2DisentangledSelfAttention::load(vb.pp("attention.self"), config)?;
712        let output = DebertaV2SelfOutput::load(vb.pp("attention.output"), config)?;
713        Ok(Self { dsa, output })
714    }
715
716    fn forward(
717        &self,
718        hidden_states: &Tensor,
719        attention_mask: &Tensor,
720        query_states: Option<&Tensor>,
721        relative_pos: Option<&Tensor>,
722        rel_embeddings: Option<&Tensor>,
723    ) -> Result<Tensor> {
724        let self_output = self.dsa.forward(
725            hidden_states,
726            attention_mask,
727            query_states,
728            relative_pos,
729            rel_embeddings,
730        )?;
731
732        self.output
733            .forward(&self_output, query_states.unwrap_or(hidden_states))
734    }
735}
736
737// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L255
738pub struct DebertaV2SelfOutput {
739    dense: candle_nn::Linear,
740    layer_norm: LayerNorm,
741    dropout: StableDropout,
742}
743
744impl DebertaV2SelfOutput {
745    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
746        let dense = candle_nn::linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
747        let layer_norm = candle_nn::layer_norm(
748            config.hidden_size,
749            config.layer_norm_eps,
750            vb.pp("LayerNorm"),
751        )?;
752        let dropout = StableDropout::new(config.hidden_dropout_prob);
753        Ok(Self {
754            dense,
755            layer_norm,
756            dropout,
757        })
758    }
759
760    pub fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
761        let mut hidden_states = self.dense.forward(hidden_states)?;
762        hidden_states = self.dropout.forward(&hidden_states)?;
763        self.layer_norm
764            .forward(&hidden_states.broadcast_add(input_tensor)?)
765    }
766}
767
768// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L307
769pub struct DebertaV2Intermediate {
770    dense: candle_nn::Linear,
771    intermediate_act: HiddenActLayer,
772}
773
774impl DebertaV2Intermediate {
775    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
776        let dense = candle_nn::linear(
777            config.hidden_size,
778            config.intermediate_size,
779            vb.pp("intermediate.dense"),
780        )?;
781        let intermediate_act = HiddenActLayer::new(config.hidden_act);
782        Ok(Self {
783            dense,
784            intermediate_act,
785        })
786    }
787
788    pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
789        self.intermediate_act
790            .forward(&self.dense.forward(hidden_states)?)
791    }
792}
793
794// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L323
795pub struct DebertaV2Output {
796    dense: candle_nn::Linear,
797    layer_norm: LayerNorm,
798    dropout: StableDropout,
799}
800
801impl DebertaV2Output {
802    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
803        let dense = candle_nn::linear(
804            config.intermediate_size,
805            config.hidden_size,
806            vb.pp("output.dense"),
807        )?;
808        let layer_norm = candle_nn::layer_norm(
809            config.hidden_size,
810            config.layer_norm_eps,
811            vb.pp("output.LayerNorm"),
812        )?;
813        let dropout = StableDropout::new(config.hidden_dropout_prob);
814        Ok(Self {
815            dense,
816            layer_norm,
817            dropout,
818        })
819    }
820
821    pub fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
822        let mut hidden_states = self.dense.forward(hidden_states)?;
823        hidden_states = self.dropout.forward(&hidden_states)?;
824        hidden_states = {
825            let to_norm = hidden_states.broadcast_add(input_tensor)?;
826            self.layer_norm.forward(&to_norm)?
827        };
828        Ok(hidden_states)
829    }
830}
831
832// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L339
833pub struct DebertaV2Layer {
834    attention: DebertaV2Attention,
835    intermediate: DebertaV2Intermediate,
836    output: DebertaV2Output,
837}
838
839impl DebertaV2Layer {
840    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
841        let attention = DebertaV2Attention::load(vb.clone(), config)?;
842        let intermediate = DebertaV2Intermediate::load(vb.clone(), config)?;
843        let output = DebertaV2Output::load(vb.clone(), config)?;
844        Ok(Self {
845            attention,
846            intermediate,
847            output,
848        })
849    }
850
851    fn forward(
852        &self,
853        hidden_states: &Tensor,
854        attention_mask: &Tensor,
855        query_states: Option<&Tensor>,
856        relative_pos: Option<&Tensor>,
857        rel_embeddings: Option<&Tensor>,
858    ) -> Result<Tensor> {
859        let attention_output = self.attention.forward(
860            hidden_states,
861            attention_mask,
862            query_states,
863            relative_pos,
864            rel_embeddings,
865        )?;
866
867        let intermediate_output = self.intermediate.forward(&attention_output)?;
868
869        let layer_output = self
870            .output
871            .forward(&intermediate_output, &attention_output)?;
872
873        Ok(layer_output)
874    }
875}
876
877// TODO: In order to fully test ConvLayer a model needs to be found has a configuration where `conv_kernel_size` exists and is > 0
878// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L373
879pub struct ConvLayer {
880    _conv_act: String,
881    _conv: Conv1d,
882    _layer_norm: LayerNorm,
883    _dropout: StableDropout,
884    _config: Config,
885}
886
887impl ConvLayer {
888    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
889        let config = config.clone();
890        let kernel_size = config.conv_kernel_size.unwrap_or(3);
891        let groups = config.conv_groups.unwrap_or(1);
892        let conv_act: String = config.conv_act.clone().unwrap_or("tanh".to_string());
893
894        let conv_conf = Conv1dConfig {
895            padding: (kernel_size - 1) / 2,
896            groups,
897            ..Default::default()
898        };
899
900        let conv = conv1d(
901            config.hidden_size,
902            config.hidden_size,
903            kernel_size,
904            conv_conf,
905            vb.pp("conv"),
906        )?;
907
908        let layer_norm = layer_norm(
909            config.hidden_size,
910            config.layer_norm_eps,
911            vb.pp("LayerNorm"),
912        )?;
913
914        let dropout = StableDropout::new(config.hidden_dropout_prob);
915
916        Ok(Self {
917            _conv_act: conv_act,
918            _conv: conv,
919            _layer_norm: layer_norm,
920            _dropout: dropout,
921            _config: config,
922        })
923    }
924
925    pub fn forward(
926        &self,
927        _hidden_states: &Tensor,
928        _residual_states: &Tensor,
929        _input_mask: &Tensor,
930    ) -> Result<Tensor> {
931        todo!("Need a model that contains a conv layer to test against.")
932    }
933}
934
935// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L409
936pub struct DebertaV2Encoder {
937    layer: Vec<DebertaV2Layer>,
938    relative_attention: bool,
939    max_relative_positions: isize,
940    position_buckets: isize,
941    rel_embeddings: Option<Embedding>,
942    norm_rel_ebd: String,
943    layer_norm: Option<LayerNorm>,
944    conv: Option<ConvLayer>,
945    device: Device,
946}
947
948impl DebertaV2Encoder {
949    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
950        let layer = (0..config.num_hidden_layers)
951            .map(|index| DebertaV2Layer::load(vb.pp(format!("layer.{index}")), config))
952            .collect::<Result<Vec<_>>>()?;
953
954        let relative_attention = config.relative_attention;
955        let mut max_relative_positions = config.max_relative_positions;
956
957        let position_buckets = config.position_buckets.unwrap_or(-1);
958
959        let mut rel_embeddings: Option<Embedding> = None;
960
961        if relative_attention {
962            if max_relative_positions < 1 {
963                max_relative_positions = config.max_position_embeddings as isize;
964            }
965
966            let mut pos_ebd_size = max_relative_positions * 2;
967
968            if position_buckets > 0 {
969                pos_ebd_size = position_buckets * 2;
970            }
971
972            rel_embeddings = Some(embedding(
973                pos_ebd_size as usize,
974                config.hidden_size,
975                vb.pp("rel_embeddings"),
976            )?);
977        }
978
979        // NOTE: The Python code assumes that the config attribute "norm_rel_ebd" is an array of some kind, but most examples have it as a string.
980        // So it might need to be updated at some point.
981        let norm_rel_ebd = match config.norm_rel_ebd.as_ref() {
982            Some(nre) => nre.trim().to_string(),
983            None => "none".to_string(),
984        };
985
986        let layer_norm: Option<LayerNorm> = if norm_rel_ebd == "layer_norm" {
987            Some(layer_norm(
988                config.hidden_size,
989                config.layer_norm_eps,
990                vb.pp("LayerNorm"),
991            )?)
992        } else {
993            None
994        };
995
996        let conv: Option<ConvLayer> = if config.conv_kernel_size.unwrap_or(0) > 0 {
997            Some(ConvLayer::load(vb.pp("conv"), config)?)
998        } else {
999            None
1000        };
1001
1002        Ok(Self {
1003            layer,
1004            relative_attention,
1005            max_relative_positions,
1006            position_buckets,
1007            rel_embeddings,
1008            norm_rel_ebd,
1009            layer_norm,
1010            conv,
1011            device: vb.device().clone(),
1012        })
1013    }
1014
1015    pub fn forward(
1016        &self,
1017        hidden_states: &Tensor,
1018        attention_mask: &Tensor,
1019        query_states: Option<&Tensor>,
1020        relative_pos: Option<&Tensor>,
1021    ) -> Result<Tensor> {
1022        let input_mask = if attention_mask.dims().len() <= 2 {
1023            attention_mask.clone()
1024        } else {
1025            attention_mask
1026                .sum_keepdim(attention_mask.rank() - 2)?
1027                .gt(0.)?
1028        };
1029
1030        let attention_mask = self.get_attention_mask(attention_mask.clone())?;
1031
1032        let relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)?;
1033
1034        let mut next_kv: Tensor = hidden_states.clone();
1035        let rel_embeddings = self.get_rel_embedding()?;
1036        let mut output_states = next_kv.to_owned();
1037        let mut query_states: Option<Tensor> = query_states.cloned();
1038
1039        for (i, layer_module) in self.layer.iter().enumerate() {
1040            // NOTE: The original python code branches here if this model is being
1041            // used for training vs. inferencing. For now, we will only handle the
1042            // inferencing side of things
1043
1044            output_states = layer_module.forward(
1045                next_kv.as_ref(),
1046                &attention_mask,
1047                query_states.as_ref(),
1048                relative_pos.as_ref(),
1049                rel_embeddings.as_ref(),
1050            )?;
1051
1052            if i == 0 {
1053                if let Some(conv) = &self.conv {
1054                    output_states = conv.forward(hidden_states, &output_states, &input_mask)?;
1055                }
1056            }
1057
1058            if query_states.is_some() {
1059                query_states = Some(output_states.clone());
1060            } else {
1061                next_kv = output_states.clone();
1062            }
1063        }
1064
1065        Ok(output_states)
1066    }
1067
1068    fn get_attention_mask(&self, mut attention_mask: Tensor) -> Result<Tensor> {
1069        match attention_mask.dims().len() {
1070            0..=2 => {
1071                let extended_attention_mask = attention_mask.unsqueeze(1)?.unsqueeze(2)?;
1072                attention_mask = extended_attention_mask.broadcast_mul(
1073                    &extended_attention_mask
1074                        .squeeze(D::Minus2)?
1075                        .unsqueeze(D::Minus1)?,
1076                )?;
1077            }
1078            3 => attention_mask = attention_mask.unsqueeze(1)?,
1079            len => bail!("Unsupported attentiom mask size length: {len}"),
1080        }
1081
1082        Ok(attention_mask)
1083    }
1084
1085    fn get_rel_pos(
1086        &self,
1087        hidden_states: &Tensor,
1088        query_states: Option<&Tensor>,
1089        relative_pos: Option<&Tensor>,
1090    ) -> Result<Option<Tensor>> {
1091        if self.relative_attention && relative_pos.is_none() {
1092            let q = if let Some(query_states) = query_states {
1093                query_states.dim(D::Minus2)?
1094            } else {
1095                hidden_states.dim(D::Minus2)?
1096            };
1097
1098            return Ok(Some(build_relative_position(
1099                q,
1100                hidden_states.dim(D::Minus2)?,
1101                &self.device,
1102                Some(self.position_buckets),
1103                Some(self.max_relative_positions),
1104            )?));
1105        }
1106
1107        if relative_pos.is_some() {
1108            Ok(relative_pos.cloned())
1109        } else {
1110            Ok(None)
1111        }
1112    }
1113    fn get_rel_embedding(&self) -> Result<Option<Tensor>> {
1114        if !self.relative_attention {
1115            return Ok(None);
1116        }
1117
1118        let rel_embeddings = self
1119            .rel_embeddings
1120            .as_ref()
1121            .context("self.rel_embeddings not present when using relative_attention")?
1122            .embeddings()
1123            .clone();
1124
1125        if !self.norm_rel_ebd.contains("layer_norm") {
1126            return Ok(Some(rel_embeddings));
1127        }
1128
1129        let layer_normed_embeddings = self
1130            .layer_norm
1131            .as_ref()
1132            .context("DebertaV2Encoder layer_norm is None when norm_rel_ebd contains layer_norm")?
1133            .forward(&rel_embeddings)?;
1134
1135        Ok(Some(layer_normed_embeddings))
1136    }
1137}
1138
1139// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L991
1140pub struct DebertaV2Model {
1141    embeddings: DebertaV2Embeddings,
1142    encoder: DebertaV2Encoder,
1143    z_steps: usize,
1144    pub device: Device,
1145}
1146
1147impl DebertaV2Model {
1148    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
1149        let vb = vb.clone();
1150        let embeddings = DebertaV2Embeddings::load(vb.pp("embeddings"), config)?;
1151        let encoder = DebertaV2Encoder::load(vb.pp("encoder"), config)?;
1152        let z_steps: usize = 0;
1153
1154        Ok(Self {
1155            embeddings,
1156            encoder,
1157            z_steps,
1158            device: vb.device().clone(),
1159        })
1160    }
1161
1162    pub fn forward(
1163        &self,
1164        input_ids: &Tensor,
1165        token_type_ids: Option<Tensor>,
1166        attention_mask: Option<Tensor>,
1167    ) -> Result<Tensor> {
1168        let input_ids_shape = input_ids.shape();
1169
1170        let attention_mask = match attention_mask {
1171            Some(mask) => mask,
1172            None => Tensor::ones(input_ids_shape, DType::I64, &self.device)?,
1173        };
1174
1175        let token_type_ids = match token_type_ids {
1176            Some(ids) => ids,
1177            None => Tensor::zeros(input_ids_shape, DType::U32, &self.device)?,
1178        };
1179
1180        let embedding_output = self.embeddings.forward(
1181            Some(input_ids),
1182            Some(&token_type_ids),
1183            None,
1184            Some(&attention_mask),
1185            None,
1186        )?;
1187
1188        let encoder_output =
1189            self.encoder
1190                .forward(&embedding_output, &attention_mask, None, None)?;
1191
1192        if self.z_steps > 1 {
1193            todo!("Complete DebertaV2Model forward() when z_steps > 1 -- Needs a model to test this situation.")
1194        }
1195
1196        Ok(encoder_output)
1197    }
1198}
1199
1200#[derive(Debug)]
1201pub struct NERItem {
1202    pub entity: String,
1203    pub word: String,
1204    pub score: f32,
1205    pub start: usize,
1206    pub end: usize,
1207    pub index: usize,
1208}
1209
1210#[derive(Debug)]
1211pub struct TextClassificationItem {
1212    pub label: String,
1213    pub score: f32,
1214}
1215
1216pub struct DebertaV2NERModel {
1217    pub device: Device,
1218    deberta: DebertaV2Model,
1219    dropout: candle_nn::Dropout,
1220    classifier: candle_nn::Linear,
1221}
1222
1223fn id2label_len(config: &Config, id2label: Option<HashMap<u32, String>>) -> Result<usize> {
1224    let id2label_len = match (&config.id2label, id2label) {
1225        (None, None) => bail!("Id2Label is either not present in the model configuration or not passed into DebertaV2NERModel::load as a parameter"),
1226        (None, Some(id2label_p)) => id2label_p.len(),
1227        (Some(id2label_c), None) => id2label_c.len(),
1228        (Some(id2label_c), Some(id2label_p)) => {
1229          if *id2label_c == id2label_p {
1230            id2label_c.len()
1231          } else {
1232            bail!("Id2Label is both present in the model configuration and provided as a parameter, and they are different.")
1233          }
1234        }
1235    };
1236    Ok(id2label_len)
1237}
1238
1239impl DebertaV2NERModel {
1240    pub fn load(vb: VarBuilder, config: &Config, id2label: Option<Id2Label>) -> Result<Self> {
1241        let id2label_len = id2label_len(config, id2label)?;
1242
1243        let deberta = DebertaV2Model::load(vb.clone(), config)?;
1244        let dropout = candle_nn::Dropout::new(config.hidden_dropout_prob as f32);
1245        let classifier: candle_nn::Linear = candle_nn::linear_no_bias(
1246            config.hidden_size,
1247            id2label_len,
1248            vb.root().pp("classifier"),
1249        )?;
1250
1251        Ok(Self {
1252            device: vb.device().clone(),
1253            deberta,
1254            dropout,
1255            classifier,
1256        })
1257    }
1258
1259    pub fn forward(
1260        &self,
1261        input_ids: &Tensor,
1262        token_type_ids: Option<Tensor>,
1263        attention_mask: Option<Tensor>,
1264    ) -> Result<Tensor> {
1265        let output = self
1266            .deberta
1267            .forward(input_ids, token_type_ids, attention_mask)?;
1268        let output = self.dropout.forward(&output, false)?;
1269        self.classifier.forward(&output)
1270    }
1271}
1272
1273pub struct DebertaV2SeqClassificationModel {
1274    pub device: Device,
1275    deberta: DebertaV2Model,
1276    dropout: StableDropout,
1277    pooler: DebertaV2ContextPooler,
1278    classifier: candle_nn::Linear,
1279}
1280
1281impl DebertaV2SeqClassificationModel {
1282    pub fn load(vb: VarBuilder, config: &Config, id2label: Option<Id2Label>) -> Result<Self> {
1283        let id2label_len = id2label_len(config, id2label)?;
1284        let deberta = DebertaV2Model::load(vb.clone(), config)?;
1285        let pooler = DebertaV2ContextPooler::load(vb.clone(), config)?;
1286        let output_dim = pooler.output_dim()?;
1287        let classifier = candle_nn::linear(output_dim, id2label_len, vb.root().pp("classifier"))?;
1288        let dropout = match config.cls_dropout {
1289            Some(cls_dropout) => StableDropout::new(cls_dropout),
1290            None => StableDropout::new(config.hidden_dropout_prob),
1291        };
1292
1293        Ok(Self {
1294            device: vb.device().clone(),
1295            deberta,
1296            dropout,
1297            pooler,
1298            classifier,
1299        })
1300    }
1301
1302    pub fn forward(
1303        &self,
1304        input_ids: &Tensor,
1305        token_type_ids: Option<Tensor>,
1306        attention_mask: Option<Tensor>,
1307    ) -> Result<Tensor> {
1308        let encoder_layer = self
1309            .deberta
1310            .forward(input_ids, token_type_ids, attention_mask)?;
1311        let pooled_output = self.pooler.forward(&encoder_layer)?;
1312        let pooled_output = self.dropout.forward(&pooled_output)?;
1313        self.classifier.forward(&pooled_output)
1314    }
1315}
1316
1317pub struct DebertaV2ContextPooler {
1318    dense: candle_nn::Linear,
1319    dropout: StableDropout,
1320    config: Config,
1321}
1322
1323// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L49
1324impl DebertaV2ContextPooler {
1325    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
1326        let pooler_hidden_size = config
1327            .pooler_hidden_size
1328            .context("config.pooler_hidden_size is required for DebertaV2ContextPooler")?;
1329
1330        let pooler_dropout = config
1331            .pooler_dropout
1332            .context("config.pooler_dropout is required for DebertaV2ContextPooler")?;
1333
1334        let dense = candle_nn::linear(
1335            pooler_hidden_size,
1336            pooler_hidden_size,
1337            vb.root().pp("pooler.dense"),
1338        )?;
1339
1340        let dropout = StableDropout::new(pooler_dropout);
1341
1342        Ok(Self {
1343            dense,
1344            dropout,
1345            config: config.clone(),
1346        })
1347    }
1348
1349    pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
1350        let context_token = hidden_states.narrow(1, 0, 1)?.squeeze(1)?;
1351        let context_token = self.dropout.forward(&context_token)?;
1352
1353        let pooled_output = self.dense.forward(&context_token.contiguous()?)?;
1354        let pooler_hidden_act = self
1355            .config
1356            .pooler_hidden_act
1357            .context("Could not obtain pooler hidden act from config")?;
1358
1359        HiddenActLayer::new(pooler_hidden_act).forward(&pooled_output)
1360    }
1361
1362    pub fn output_dim(&self) -> Result<usize> {
1363        self.config.pooler_hidden_size.context("DebertaV2ContextPooler cannot return output_dim (pooler_hidden_size) since it is not specified in the model config")
1364    }
1365}
1366
1367// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L557
1368pub(crate) fn build_relative_position(
1369    query_size: usize,
1370    key_size: usize,
1371    device: &Device,
1372    bucket_size: Option<isize>,
1373    max_position: Option<isize>,
1374) -> Result<Tensor> {
1375    let q_ids = Tensor::arange(0, query_size as i64, device)?.unsqueeze(0)?;
1376    let k_ids: Tensor = Tensor::arange(0, key_size as i64, device)?.unsqueeze(D::Minus1)?;
1377    let mut rel_pos_ids = k_ids.broadcast_sub(&q_ids)?;
1378    let bucket_size = bucket_size.unwrap_or(-1);
1379    let max_position = max_position.unwrap_or(-1);
1380
1381    if bucket_size > 0 && max_position > 0 {
1382        rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position, device)?;
1383    }
1384
1385    rel_pos_ids = rel_pos_ids.to_dtype(DType::I64)?;
1386    rel_pos_ids = rel_pos_ids.narrow(0, 0, query_size)?;
1387    rel_pos_ids.unsqueeze(0)
1388}
1389
1390// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L542
1391pub(crate) fn make_log_bucket_position(
1392    relative_pos: Tensor,
1393    bucket_size: isize,
1394    max_position: isize,
1395    device: &Device,
1396) -> Result<Tensor> {
1397    let sign = relative_pos.to_dtype(DType::F32)?.sign()?;
1398
1399    let mid = bucket_size / 2;
1400
1401    let lt_mid = relative_pos.lt(mid as i64)?;
1402    let gt_neg_mid = relative_pos.gt(-mid as i64)?;
1403
1404    let condition = lt_mid
1405        .to_dtype(candle::DType::F32)?
1406        .mul(&gt_neg_mid.to_dtype(candle::DType::F32)?)?
1407        .to_dtype(DType::U8)?;
1408
1409    let on_true = Tensor::new(&[(mid - 1) as u32], device)?
1410        .broadcast_as(relative_pos.shape())?
1411        .to_dtype(relative_pos.dtype())?;
1412
1413    let on_false = relative_pos
1414        .to_dtype(DType::F32)?
1415        .abs()?
1416        .to_dtype(DType::I64)?;
1417
1418    let abs_pos = condition.where_cond(&on_true, &on_false)?;
1419
1420    let mid_as_tensor = Tensor::from_slice(&[mid as f32], (1,), device)?;
1421
1422    let log_pos = {
1423        let first_log = abs_pos
1424            .to_dtype(DType::F32)?
1425            .broadcast_div(&mid_as_tensor)?
1426            .log()?;
1427
1428        let second_log =
1429            Tensor::from_slice(&[((max_position as f32 - 1.0) / mid as f32)], (1,), device)?
1430                .log()?;
1431
1432        let first_div_second = first_log.broadcast_div(&second_log)?;
1433
1434        let to_ceil = first_div_second
1435            .broadcast_mul(Tensor::from_slice(&[(mid - 1) as f32], (1,), device)?.as_ref())?;
1436
1437        let ceil = to_ceil.ceil()?;
1438
1439        ceil.broadcast_add(&mid_as_tensor)?
1440    };
1441
1442    Ok({
1443        let abs_pos_lte_mid = abs_pos.to_dtype(DType::F32)?.broadcast_le(&mid_as_tensor)?;
1444        let relative_pos = relative_pos.to_dtype(relative_pos.dtype())?;
1445        let log_pos_mul_sign = log_pos.broadcast_mul(&sign.to_dtype(DType::F32)?)?;
1446        abs_pos_lte_mid.where_cond(&relative_pos.to_dtype(DType::F32)?, &log_pos_mul_sign)?
1447    })
1448}