candle_transformers/models/
stella_en_v5.rs

1//! Stella v5 model implementation.
2//!
3//! Stella is a dense text embedding model optimized for retrieval and similarity tasks.
4//! This implementation provides support for multiple embedding dimensions.
5//!
6//! Key characteristics:
7//! - Dense text embeddings optimized for similarity search
8//! - Multiple output dimension support (256 to 8192)
9//! - Grouped query attention (GQA)
10//! - RMSNorm for layer normalization
11//! - Rotary positional embeddings (RoPE)
12//!
13//! References:
14//! - [MRL Framework](https://arxiv.org/abs/2205.13147)
15//! - [Model Card](https://huggingface.co/dunzhang/stella_en_1.5B_v5)
16//!
17
18use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm};
19use candle::{DType, Device, Error, IndexOp, Module, Result, Tensor, D};
20use candle_nn::{layer_norm, Activation, LayerNorm, VarBuilder};
21use std::sync::Arc;
22
23// internal representation for identifying which model is being used
24#[derive(Debug, Copy, Clone, PartialEq, serde::Deserialize)]
25pub enum ModelVariant {
26    Large, // 1.5B
27    Small, // 400M
28}
29
30impl Default for ModelVariant {
31    fn default() -> Self {
32        Self::Large
33    }
34}
35
36// Same as `qwen2` family of models with the exception being the `embed_head`
37// The final `output` causal modelling head is swapped with a learned `dense` layer, `embed_head`
38#[derive(Debug, Default, Clone, PartialEq, serde::Deserialize)]
39pub struct Config {
40    pub variant: ModelVariant,
41    pub vocab_size: usize,
42    pub hidden_size: usize,
43    pub intermediate_size: usize,
44    pub num_hidden_layers: usize,
45    pub num_attention_heads: usize,
46    pub max_position_embeddings: usize,
47    pub rope_theta: f64,
48    pub embed_head: EmbedHead,
49    pub norm_eps: f64,             // RMSNorm for 1.5B || LayerNorm for 400M
50    pub activation_fn: Activation, // Silu for 1.5B || Gelu for 400M
51    // Unique to 1.5B
52    pub num_key_value_heads: usize,
53    // Unique to 400M
54    pub type_vocab_size: usize,
55    pub scaling_factor: f64,
56}
57
58// Excerpt from `stella` model card:
59// `Stella_en_1.5B_v5` models have been trained on [MRL](https://arxiv.org/abs/2205.13147) enabling multiple output dimensions
60// Embed head represents the config for various embedding dims supported
61#[derive(Debug, Default, Clone, PartialEq, serde::Deserialize)]
62pub struct EmbedHead {
63    pub in_features: usize,
64    pub out_features: usize,
65}
66
67/// An enum variant representing the Embedding head dimensions `stella` is trained on
68/// As the [model-card](https://huggingface.co/dunzhang/stella_en_1.5B_v5#introduction) suggests, D1024 is good enough for most cases
69#[derive(Debug, Clone, Copy)]
70pub enum EmbedDim {
71    Dim256,
72    Dim768,
73    Dim1024,
74    Dim2048,
75    Dim4096,
76    Dim6144,
77    Dim8192,
78}
79
80impl Default for EmbedDim {
81    fn default() -> Self {
82        Self::Dim1024
83    }
84}
85
86impl EmbedDim {
87    pub fn config(&self, in_features: usize) -> EmbedHead {
88        EmbedHead {
89            in_features,
90            out_features: match &self {
91                Self::Dim256 => 256,
92                Self::Dim768 => 768,
93                Self::Dim1024 => 1024,
94                Self::Dim2048 => 2048,
95                Self::Dim4096 => 4096,
96                Self::Dim6144 => 6144,
97                Self::Dim8192 => 8192,
98            },
99        }
100    }
101}
102
103// Initialize a new `stella_en` model - with 400M variant or 1.5B variant
104impl Config {
105    /// Initialize a new `stella_en_1.5B_v5`` model with given embedding dim
106    pub fn new_1_5_b_v5(embed_dim: EmbedDim) -> Self {
107        // Representing config.json at https://huggingface.co/dunzhang/stella_en_1.5B_v5/blob/main/config.json
108        // Removed `sliding_window` related config which is basically being carried forward from `qwen2` but not used here
109        Self {
110            variant: ModelVariant::Large,
111            activation_fn: candle_nn::Activation::Silu,
112            vocab_size: 151646,
113            hidden_size: 1536,
114            intermediate_size: 8960,
115            num_hidden_layers: 28,
116            num_attention_heads: 12,
117            num_key_value_heads: 2,
118            max_position_embeddings: 131072,
119            rope_theta: 1000000.,
120            norm_eps: 1e-06,
121            embed_head: embed_dim.config(1536),
122            ..Default::default()
123        }
124    }
125
126    /// Initialize new `stella_en_400M_v5`
127    pub fn new_400_m_v5(embed_dim: EmbedDim) -> Self {
128        Self {
129            variant: ModelVariant::Small,
130            vocab_size: 30528,
131            hidden_size: 1024,
132            intermediate_size: 4096,
133            num_hidden_layers: 24,
134            num_attention_heads: 16,
135            max_position_embeddings: 8192,
136            type_vocab_size: 2,
137            norm_eps: 1e-12,
138            scaling_factor: 2.0,
139            rope_theta: 160000.0,
140            activation_fn: Activation::Gelu,
141            embed_head: embed_dim.config(1024),
142            ..Default::default()
143        }
144    }
145}
146
147#[derive(Debug, Clone)]
148struct RotaryEmbedding {
149    sin: Tensor,
150    cos: Tensor,
151}
152
153impl RotaryEmbedding {
154    fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
155        let dim = cfg.hidden_size / cfg.num_attention_heads;
156        // Factoring in `scaling factor` for `400M` variant
157        let max_seq_len = if cfg.scaling_factor == 0. {
158            cfg.max_position_embeddings
159        } else {
160            ((cfg.max_position_embeddings as f64) * cfg.scaling_factor) as usize
161        };
162
163        // let rot_dim = if cfg.variant == ModelVariant::Small { dim / 2 } else { dim };
164        let inv_freq: Vec<_> = (0..dim)
165            .step_by(2)
166            .map(|i| {
167                // Scaled rope_theta for 400M variant
168                let rope_theta = if cfg.scaling_factor == 0. {
169                    cfg.rope_theta
170                } else {
171                    cfg.rope_theta * cfg.scaling_factor
172                };
173                let mut freq = 1. / rope_theta.powf(i as f64 / dim as f64);
174
175                if cfg.scaling_factor != 0. {
176                    freq /= cfg.scaling_factor.powf(2.0 / (dim as f64))
177                }
178
179                freq as f32
180            })
181            .collect();
182
183        let inv_freq_len = inv_freq.len();
184        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
185
186        // Calculate position embeddings with scaled sequence length
187        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
188            .to_dtype(dtype)?
189            .reshape((max_seq_len, 1))?;
190        let freqs = t.matmul(&inv_freq)?;
191        // if cfg.variant == ModelVariant::Small {
192        //     freqs = Tensor::cat(&[&freqs, &freqs], 1)?
193        // }
194
195        Ok(Self {
196            sin: freqs.sin()?,
197            cos: freqs.cos()?,
198        })
199    }
200
201    // TODO: re-visit this
202    fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
203        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
204        let cos = self.cos.narrow(0, 0, seq_len)?;
205        let sin = self.sin.narrow(0, 0, seq_len)?;
206
207        let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
208        let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
209        Ok((q_embed, k_embed))
210    }
211}
212
213#[derive(Debug, Clone)]
214#[allow(clippy::upper_case_acronyms)]
215struct MLP {
216    variant: ModelVariant,
217    gate_proj: Linear,
218    up_proj: Option<Linear>, // `up_proj` only for 1.5B variant
219    down_proj: Linear,
220    act_fn: Activation,
221}
222
223impl MLP {
224    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
225        let hidden_sz = cfg.hidden_size;
226        let intermediate_sz = cfg.intermediate_size;
227
228        let (gate_proj, up_proj, down_proj) = match cfg.variant {
229            ModelVariant::Large => (
230                linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?,
231                Some(linear_no_bias(
232                    hidden_sz,
233                    intermediate_sz,
234                    vb.pp("up_proj"),
235                )?),
236                linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?,
237            ),
238            ModelVariant::Small => (
239                linear_no_bias(hidden_sz, intermediate_sz * 2, vb.pp("up_gate_proj"))?,
240                None,
241                linear(intermediate_sz, hidden_sz, vb.pp("down_proj"))?,
242            ),
243        };
244
245        Ok(Self {
246            variant: cfg.variant,
247            gate_proj,
248            up_proj,
249            down_proj,
250            act_fn: cfg.activation_fn,
251        })
252    }
253}
254
255impl Module for MLP {
256    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
257        let up = self.gate_proj.forward(xs)?;
258
259        let (lhs, rhs) = match self.variant {
260            ModelVariant::Large => {
261                let lhs = up.apply(&self.act_fn)?;
262                let rhs = xs.apply(self.up_proj.as_ref().unwrap())?;
263
264                (lhs, rhs)
265            }
266            ModelVariant::Small => {
267                // Get the dimensions
268                let (_batch_size, _seq_len, hidden_dim) = up.dims3()?;
269                let split_size = hidden_dim / 2;
270
271                // Split along the last dimension (hidden_dim)
272                let up_states = up.narrow(2, 0, split_size)?;
273                let gate = up.narrow(2, split_size, split_size)?.apply(&self.act_fn)?;
274
275                (up_states, gate)
276            }
277        };
278
279        (lhs * rhs)?.apply(&self.down_proj)
280    }
281}
282
283#[derive(Debug, Clone)]
284struct Attention {
285    qkv_proj: Linear,
286    o_proj: Linear,
287    num_heads: usize,
288    num_kv_heads: usize,
289    num_kv_groups: usize,
290    head_dim: usize,
291    hidden_size: usize,
292    rotary_emb: Arc<RotaryEmbedding>,
293    variant: ModelVariant,
294}
295
296impl Attention {
297    fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
298        let hidden_sz = cfg.hidden_size;
299        let num_heads = cfg.num_attention_heads;
300        let num_kv_heads = cfg.num_key_value_heads;
301        let num_kv_groups = if num_kv_heads > 0 {
302            num_heads / num_kv_heads
303        } else {
304            0
305        };
306        let head_dim = hidden_sz / num_heads;
307
308        let (qkv_proj, o_proj) = match cfg.variant {
309            ModelVariant::Large => {
310                // The 1.5B variant comes with separate `q, k, v` layers, let's merge it and standardize
311                // Weights
312                let q_w = vb
313                    .pp("q_proj")
314                    .get((num_heads * head_dim, hidden_sz), "weight")?;
315                let k_w = vb
316                    .pp("k_proj")
317                    .get((num_kv_heads * head_dim, hidden_sz), "weight")?;
318                let v_w = vb
319                    .pp("v_proj")
320                    .get((num_kv_heads * head_dim, hidden_sz), "weight")?;
321                // Biases
322                let q_b = vb.pp("q_proj").get(num_heads * head_dim, "bias")?;
323                let k_b = vb.pp("k_proj").get(num_kv_heads * head_dim, "bias")?;
324                let v_b = vb.pp("v_proj").get(num_kv_heads * head_dim, "bias")?;
325
326                let qkv_w = Tensor::cat(&[&q_w, &k_w, &v_w], 0)?;
327                let qkv_b = Tensor::cat(&[&q_b, &k_b, &v_b], 0)?;
328
329                (
330                    Linear::from_weights(qkv_w, Some(qkv_b)),
331                    linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?,
332                )
333            }
334            ModelVariant::Small => (
335                linear(hidden_sz, 3 * num_heads * head_dim, vb.pp("qkv_proj"))?,
336                linear(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?,
337            ),
338        };
339
340        Ok(Self {
341            qkv_proj,
342            o_proj,
343            num_heads,
344            num_kv_heads,
345            num_kv_groups,
346            head_dim,
347            hidden_size: hidden_sz,
348            rotary_emb,
349            variant: cfg.variant,
350        })
351    }
352
353    fn forward(&mut self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
354        let (b_sz, q_len, _) = xs.dims3()?;
355
356        let qkv = self.qkv_proj.forward(xs)?;
357
358        let n_kv_heads = match self.variant {
359            ModelVariant::Large => self.num_kv_heads,
360            ModelVariant::Small => self.num_heads,
361        };
362
363        let (query_states, key_states, value_states) = match self.variant {
364            ModelVariant::Large => {
365                let q_sz = self.num_heads * self.head_dim;
366                let kv_sz = n_kv_heads * self.head_dim;
367
368                let q = qkv.narrow(D::Minus1, 0, q_sz)?.reshape((
369                    b_sz,
370                    q_len,
371                    self.num_heads,
372                    self.head_dim,
373                ))?;
374                let k = qkv.narrow(D::Minus1, q_sz, kv_sz)?.reshape((
375                    b_sz,
376                    q_len,
377                    n_kv_heads,
378                    self.head_dim,
379                ))?;
380                let v = qkv.narrow(D::Minus1, q_sz + kv_sz, kv_sz)?.reshape((
381                    b_sz,
382                    q_len,
383                    n_kv_heads,
384                    self.head_dim,
385                ))?;
386
387                (q, k, v)
388            }
389            ModelVariant::Small => {
390                // Split into Q, K, V and reshape to match PyTorch shapes
391                let qkv = qkv.reshape((b_sz, q_len, 3, self.num_heads, self.head_dim))?;
392
393                (
394                    qkv.i((.., .., 0, .., ..))?,
395                    qkv.i((.., .., 1, .., ..))?,
396                    qkv.i((.., .., 2, .., ..))?,
397                )
398            }
399        };
400
401        let query_states = query_states.transpose(1, 2)?.contiguous()?;
402        let key_states = key_states.transpose(1, 2)?.contiguous()?;
403        let value_states = value_states.transpose(1, 2)?.contiguous()?;
404
405        let (query_states, key_states) = self
406            .rotary_emb
407            .apply_rotary_emb_qkv(&query_states, &key_states)?;
408
409        // The 1.5B is expected to have grouped query attention
410        let (key_states, value_states) = if self.variant == ModelVariant::Large {
411            (
412                crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?,
413                crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?,
414            )
415        } else {
416            (key_states, value_states)
417        };
418
419        let attn_output = {
420            let scale = 1f64 / f64::sqrt(self.head_dim as f64);
421            let attn_weights = query_states.matmul(&key_states.transpose(2, 3)?)?;
422            let attn_weights = (attn_weights * scale)?;
423
424            let attn_weights = match attention_mask {
425                None => attn_weights,
426                Some(mask) => attn_weights.broadcast_add(mask)?,
427            };
428            let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
429
430            attn_weights.matmul(&value_states)?
431        };
432
433        attn_output
434            .transpose(1, 2)?
435            .reshape((b_sz, q_len, self.hidden_size))?
436            .apply(&self.o_proj)
437    }
438}
439
440#[derive(Debug, Clone)]
441enum NormType {
442    Layer(LayerNorm),
443    Rms(RmsNorm),
444}
445
446#[derive(Debug, Clone)]
447struct Layer {
448    variant: ModelVariant,
449    attention: Attention,
450    mlp: MLP,
451    // For 1.5B: this is `input_layernorm`
452    // For 400M: this is `output_layernorm`
453    layernorm: NormType,
454    post_attention_layernorm: NormType,
455}
456
457impl Layer {
458    fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
459        let attention = Attention::new(
460            rotary_emb,
461            cfg,
462            vb.pp(if cfg.variant == ModelVariant::Large {
463                "self_attn"
464            } else {
465                "attention"
466            }),
467        )?;
468        let mlp = MLP::new(cfg, vb.pp("mlp"))?;
469        let (layernorm, post_attention_layernorm) = match cfg.variant {
470            ModelVariant::Large => (
471                NormType::Rms(RmsNorm::new(
472                    cfg.hidden_size,
473                    cfg.norm_eps,
474                    vb.pp("input_layernorm"),
475                )?),
476                NormType::Rms(RmsNorm::new(
477                    cfg.hidden_size,
478                    cfg.norm_eps,
479                    vb.pp("post_attention_layernorm"),
480                )?),
481            ),
482            ModelVariant::Small => (
483                NormType::Layer(layer_norm(
484                    cfg.hidden_size,
485                    candle_nn::LayerNormConfig {
486                        eps: cfg.norm_eps,
487                        ..Default::default()
488                    },
489                    vb.pp("mlp_ln"),
490                )?),
491                NormType::Layer(layer_norm(
492                    cfg.hidden_size,
493                    candle_nn::LayerNormConfig {
494                        eps: cfg.norm_eps,
495                        ..Default::default()
496                    },
497                    vb.pp("attn_ln"),
498                )?),
499            ),
500        };
501
502        Ok(Self {
503            variant: cfg.variant,
504            attention,
505            mlp,
506            layernorm,
507            post_attention_layernorm,
508        })
509    }
510
511    fn forward(&mut self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
512        // Here, the application of normalizations and activation calculations differ
513        // For Large [1.5B]:
514        //  residual = x
515        //  state = other_layernorm(xs)
516        //  state = attention(state)
517        //  state += residual
518        //  residual = state
519        //  state = mlp(attention_layernorm(state))
520        //  -> residual + state
521        // For Small [400M]:
522        //  residual = x;
523        //  state = attention(x)
524        //  state += residual
525        //  state = attention_layernorm(state)
526        //  residual = state
527        //  state = mlp(state)
528        //  state += residual
529        //  -> other_layernorm(state)
530        let residual = xs;
531
532        match self.variant {
533            ModelVariant::Large => {
534                let (attn_ln, input_ln) = if let (NormType::Rms(attn_ln), NormType::Rms(input_ln)) =
535                    (&self.post_attention_layernorm, &self.layernorm)
536                {
537                    (attn_ln, input_ln)
538                } else {
539                    return Err(candle::error::Error::Msg(
540                        "Stella 1.5B expects RMSNorm".to_string(),
541                    ));
542                };
543
544                let xs = input_ln.forward(xs)?;
545                let xs = (self.attention.forward(&xs, attention_mask)? + residual)?;
546
547                let residual = &xs;
548                let xs = xs.apply(attn_ln)?.apply(&self.mlp)?;
549
550                residual + xs
551            }
552            ModelVariant::Small => {
553                let (attn_ln, output_ln) =
554                    if let (NormType::Layer(attn_ln), NormType::Layer(input_ln)) =
555                        (&self.post_attention_layernorm, &self.layernorm)
556                    {
557                        (attn_ln, input_ln)
558                    } else {
559                        return Err(candle::error::Error::Msg(
560                            "Stella 400M expects RMSNorm".to_string(),
561                        ));
562                    };
563
564                let xs = (self.attention.forward(xs, attention_mask)? + residual)?;
565                let xs = attn_ln.forward(&xs)?;
566
567                let residual = &xs;
568                let xs = (self.mlp.forward(&xs)? + residual)?;
569
570                output_ln.forward(&xs)
571            }
572        }
573    }
574}
575
576#[derive(Debug, Clone)]
577pub struct Embeddings {
578    variant: ModelVariant,
579    // For 1.5B: this is the `embed_tokens`
580    // For 400M: this is the `word_embeddings`
581    embeddings: candle_nn::Embedding,
582    // folloing are specifically for 400M
583    token_type_embeddings: Option<candle_nn::Embedding>,
584    layer_norm: Option<LayerNorm>,
585    position_ids: Option<Tensor>,
586}
587
588impl Embeddings {
589    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
590        let (embeddings, token_type_embeddings, layer_norm, position_ids) = match cfg.variant {
591            ModelVariant::Large => (
592                candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?,
593                None,
594                None,
595                None,
596            ),
597            ModelVariant::Small => {
598                let vb = vb.pp("embeddings");
599                let weight = vb.pp("LayerNorm").get_with_hints(
600                    cfg.hidden_size,
601                    "weight",
602                    candle_nn::Init::Const(1.0),
603                )?;
604                let bias = vb.pp("LayerNorm").get_with_hints(
605                    cfg.hidden_size,
606                    "bias",
607                    candle_nn::Init::Const(0.0),
608                )?;
609                let dev = bias.device().clone();
610
611                let layer_norm = candle_nn::LayerNorm::new(weight, bias, cfg.norm_eps);
612
613                (
614                    candle_nn::embedding(
615                        cfg.vocab_size,
616                        cfg.hidden_size,
617                        vb.pp("word_embeddings"),
618                    )?,
619                    Some(candle_nn::embedding(
620                        cfg.type_vocab_size,
621                        cfg.hidden_size,
622                        vb.pp("token_type_embeddings"),
623                    )?),
624                    Some(layer_norm),
625                    Some(Tensor::arange(
626                        0u32,
627                        cfg.max_position_embeddings as u32,
628                        &dev,
629                    )?),
630                )
631            }
632        };
633
634        Ok(Self {
635            variant: cfg.variant,
636            embeddings,
637            token_type_embeddings,
638            layer_norm,
639            position_ids,
640        })
641    }
642}
643
644impl Module for Embeddings {
645    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
646        let embd = self.embeddings.forward(xs)?;
647        // For 1.5B just forward the embeddings
648        if self.variant == ModelVariant::Large {
649            return Ok(embd);
650        }
651
652        let (token_type_embed, layer_norm, pos_ids) =
653            if let (Some(token_type_embd), Some(layer_norm), Some(position_ids)) = (
654                &self.token_type_embeddings,
655                &self.layer_norm,
656                &self.position_ids,
657            ) {
658                (token_type_embd, layer_norm, position_ids)
659            } else {
660                return Err(Error::Msg(
661                    "Stella 400M requires `token_type_embeddings`, `layer_norm` and `position_ids`"
662                        .to_string(),
663                ));
664            };
665
666        let (batch_size, seq_length) = xs.dims2()?;
667
668        let pos_ids = pos_ids
669            .as_ref()
670            .narrow(0, 0, seq_length)?
671            .expand((batch_size, seq_length))?;
672
673        layer_norm.forward(&embd.add(&token_type_embed.forward(&pos_ids.zeros_like()?)?)?)
674    }
675}
676
677#[derive(Debug, Clone)]
678pub struct Model {
679    embeddings: Embeddings,
680    layers: Vec<Layer>,
681    norm: Option<RmsNorm>,
682    device: Device,
683    dtype: DType,
684}
685
686impl Model {
687    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
688        let vb_m = match cfg.variant {
689            ModelVariant::Large => vb.pp("model"),
690            ModelVariant::Small => vb.pp("new"),
691        };
692        // let embed_tokens =
693        //     candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
694        let embeddings = Embeddings::new(cfg, vb_m.clone())?;
695        let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
696        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
697        let vb_l = match cfg.variant {
698            ModelVariant::Large => vb_m.pp("layers"),
699            ModelVariant::Small => vb_m.pp("encoder").pp("layer"),
700        };
701        for layer_idx in 0..cfg.num_hidden_layers {
702            let layer = Layer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
703            layers.push(layer)
704        }
705        let norm = match cfg.variant {
706            ModelVariant::Large => Some(RmsNorm::new(
707                cfg.hidden_size,
708                cfg.norm_eps,
709                vb_m.pp("norm"),
710            )?),
711            ModelVariant::Small => None,
712        };
713        Ok(Self {
714            embeddings,
715            layers,
716            norm,
717            device: vb.device().clone(),
718            dtype: vb.dtype(),
719        })
720    }
721
722    fn prepare_attention_mask(&self, attn_mask: &Tensor) -> Result<Tensor> {
723        let (b_sz, sql_len) = attn_mask.dims2()?;
724        let mut mask: Vec<Tensor> = vec![];
725        for b in 0..b_sz {
726            mask.push(attn_mask.i((b, ..))?.expand((1, 1, sql_len, sql_len))?);
727        }
728        let mask = Tensor::cat(&mask, 0)?;
729        let on_true = mask.zeros_like()?.to_dtype(self.dtype)?;
730        let on_false = Tensor::new(f32::NEG_INFINITY, &self.device)?
731            .broadcast_as(mask.shape())?
732            .to_dtype(self.dtype)?;
733        mask.where_cond(&on_true, &on_false)
734    }
735
736    pub fn forward(&mut self, input_ids: &Tensor, mask: &Tensor) -> Result<Tensor> {
737        let (_, seq_len) = input_ids.dims2()?;
738        let attention_mask = if seq_len <= 1 {
739            None
740        } else {
741            // This is not a `causal language modelling` task, we'll need to prepare a `non-causal` attention
742            Some(self.prepare_attention_mask(mask)?)
743        };
744
745        let mut xs = self.embeddings.forward(input_ids)?;
746        for layer in self.layers.iter_mut() {
747            xs = layer.forward(&xs, attention_mask.as_ref())?
748        }
749
750        if let Some(n) = &self.norm {
751            xs.apply(n)
752        } else {
753            Ok(xs)
754        }
755    }
756}
757
758#[derive(Debug)]
759pub struct EmbeddingModel {
760    base_model: Model,
761    lm_head: Linear,
762}
763
764impl EmbeddingModel {
765    pub fn new(cfg: &Config, base_vb: VarBuilder, embed_vb: VarBuilder) -> Result<Self> {
766        let base_model = Model::new(cfg, base_vb.clone())?;
767        let lm_head = linear(
768            cfg.embed_head.in_features,
769            cfg.embed_head.out_features,
770            embed_vb.pp("linear"),
771        )?;
772
773        Ok(Self {
774            base_model,
775            lm_head,
776        })
777    }
778
779    pub fn forward(&mut self, input_ids: &Tensor, mask: &Tensor) -> Result<Tensor> {
780        let x = self.base_model.forward(input_ids, mask)?;
781        let x = self.pool(&x, mask)?;
782
783        // No matter what keeping the final activations as F32 helps with the accuracy
784        self.lm_head.forward(&x.to_dtype(DType::F32)?) // [B_sz, dim_size]
785    }
786
787    /// Same as forward pass but normalizes the output
788    pub fn forward_norm(&mut self, input_ids: &Tensor, mask: &Tensor) -> Result<Tensor> {
789        let x = self.forward(input_ids, mask)?;
790        // Normalize
791        x.broadcast_div(&x.sqr()?.sum_keepdim(1)?.sqrt()?)
792    }
793
794    fn pool(&self, x: &Tensor, mask: &Tensor) -> Result<Tensor> {
795        let mask = mask.to_dtype(x.dtype())?; // [B_Sz, Seq_len]
796        let (batch_size, seq_len, hidden_dim) = x.dims3()?;
797        // expanding the shape of the mask from [B_Sz, Seq_len] -> [B_Sz, Seq_len, Hidden_size]
798        let mask_expanded = mask
799            .unsqueeze(2)?
800            .broadcast_as((batch_size, seq_len, hidden_dim))?; // [B_Sz, Seq_len, Hidden_dim]
801
802        let x = (x * &mask_expanded)?;
803
804        // Sum
805        let sum_mask = mask
806            .sum(1)?
807            .unsqueeze(1)?
808            .expand((batch_size, hidden_dim))?;
809        x.sum(1)? / sum_mask
810    }
811}