candle_transformers/models/
jina_bert.rs

1//! # JinaBERT inference implementation
2//!
3//! Based on implementation from huggingface for Jina BERT and its variants
4//!
5//! See: [Jina Embeddings on HuggingFace](https://huggingface.co/jinaai/jina-embeddings-v2-base-en)
6
7use super::with_tracing::{linear, linear_no_bias, Embedding, Linear};
8use candle::{DType, Device, IndexOp, Result, Tensor, D};
9use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder};
10use serde::Deserialize;
11
12pub const DTYPE: DType = DType::F32;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
15#[serde(rename_all = "lowercase")]
16pub enum PositionEmbeddingType {
17    Absolute,
18    Alibi,
19}
20
21// https://huggingface.co/jinaai/jina-bert-implementation/blob/main/configuration_bert.py
22#[derive(Debug, Clone, PartialEq, Deserialize)]
23pub struct Config {
24    pub vocab_size: usize,
25    pub hidden_size: usize,
26    pub num_hidden_layers: usize,
27    pub num_attention_heads: usize,
28    pub intermediate_size: usize,
29    pub hidden_act: candle_nn::Activation,
30    pub max_position_embeddings: usize,
31    pub type_vocab_size: usize,
32    pub initializer_range: f64,
33    pub layer_norm_eps: f64,
34    pub pad_token_id: usize,
35    pub position_embedding_type: PositionEmbeddingType,
36}
37
38impl Config {
39    pub fn v2_base() -> Self {
40        // https://huggingface.co/jinaai/jina-embeddings-v2-base-en/blob/main/config.json
41        Self {
42            vocab_size: 30528,
43            hidden_size: 768,
44            num_hidden_layers: 12,
45            num_attention_heads: 12,
46            intermediate_size: 3072,
47            hidden_act: candle_nn::Activation::Gelu,
48            max_position_embeddings: 8192,
49            type_vocab_size: 2,
50            initializer_range: 0.02,
51            layer_norm_eps: 1e-12,
52            pad_token_id: 0,
53            position_embedding_type: PositionEmbeddingType::Alibi,
54        }
55    }
56
57    #[allow(clippy::too_many_arguments)]
58    pub fn new(
59        vocab_size: usize,
60        hidden_size: usize,
61        num_hidden_layers: usize,
62        num_attention_heads: usize,
63        intermediate_size: usize,
64        hidden_act: candle_nn::Activation,
65        max_position_embeddings: usize,
66        type_vocab_size: usize,
67        initializer_range: f64,
68        layer_norm_eps: f64,
69        pad_token_id: usize,
70        position_embedding_type: PositionEmbeddingType,
71    ) -> Self {
72        Config {
73            vocab_size,
74            hidden_size,
75            num_hidden_layers,
76            num_attention_heads,
77            intermediate_size,
78            hidden_act,
79            max_position_embeddings,
80            type_vocab_size,
81            initializer_range,
82            layer_norm_eps,
83            pad_token_id,
84            position_embedding_type,
85        }
86    }
87}
88
89#[derive(Clone, Debug)]
90struct BertEmbeddings {
91    word_embeddings: Embedding,
92    // no position_embeddings as we only support alibi.
93    token_type_embeddings: Embedding,
94    layer_norm: LayerNorm,
95    span: tracing::Span,
96}
97
98impl BertEmbeddings {
99    fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
100        let word_embeddings =
101            Embedding::new(cfg.vocab_size, cfg.hidden_size, vb.pp("word_embeddings"))?;
102        let token_type_embeddings = Embedding::new(
103            cfg.type_vocab_size,
104            cfg.hidden_size,
105            vb.pp("token_type_embeddings"),
106        )?;
107        let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
108        Ok(Self {
109            word_embeddings,
110            token_type_embeddings,
111            layer_norm,
112            span: tracing::span!(tracing::Level::TRACE, "embeddings"),
113        })
114    }
115}
116
117impl Module for BertEmbeddings {
118    fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
119        let _enter = self.span.enter();
120        let (b_size, seq_len) = input_ids.dims2()?;
121        let input_embeddings = self.word_embeddings.forward(input_ids)?;
122        let token_type_embeddings = Tensor::zeros(seq_len, DType::U32, input_ids.device())?
123            .broadcast_left(b_size)?
124            .apply(&self.token_type_embeddings)?;
125        let embeddings = (&input_embeddings + token_type_embeddings)?;
126        let embeddings = self.layer_norm.forward(&embeddings)?;
127        Ok(embeddings)
128    }
129}
130
131#[derive(Clone, Debug)]
132struct BertSelfAttention {
133    query: Linear,
134    key: Linear,
135    value: Linear,
136    num_attention_heads: usize,
137    attention_head_size: usize,
138    span: tracing::Span,
139    span_softmax: tracing::Span,
140}
141
142impl BertSelfAttention {
143    fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
144        let attention_head_size = cfg.hidden_size / cfg.num_attention_heads;
145        let all_head_size = cfg.num_attention_heads * attention_head_size;
146        let hidden_size = cfg.hidden_size;
147        let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
148        let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
149        let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
150        Ok(Self {
151            query,
152            key,
153            value,
154            num_attention_heads: cfg.num_attention_heads,
155            attention_head_size,
156            span: tracing::span!(tracing::Level::TRACE, "self-attn"),
157            span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
158        })
159    }
160
161    fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
162        let mut x_shape = xs.dims().to_vec();
163        x_shape.pop();
164        x_shape.push(self.num_attention_heads);
165        x_shape.push(self.attention_head_size);
166        xs.reshape(x_shape)?.transpose(1, 2)?.contiguous()
167    }
168
169    fn forward(&self, xs: &Tensor, bias: &Tensor) -> Result<Tensor> {
170        let _enter = self.span.enter();
171        let query_layer = self.query.forward(xs)?;
172        let key_layer = self.key.forward(xs)?;
173        let value_layer = self.value.forward(xs)?;
174
175        let query_layer = self.transpose_for_scores(&query_layer)?;
176        let key_layer = self.transpose_for_scores(&key_layer)?;
177        let value_layer = self.transpose_for_scores(&value_layer)?;
178
179        let attention_scores = query_layer.matmul(&key_layer.t()?)?;
180        let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
181        let attention_scores = attention_scores.broadcast_add(bias)?;
182        let attention_probs = {
183            let _enter_sm = self.span_softmax.enter();
184            candle_nn::ops::softmax_last_dim(&attention_scores)?
185        };
186        let context_layer = attention_probs.matmul(&value_layer)?;
187        let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
188        let context_layer = context_layer.flatten_from(D::Minus2)?;
189        Ok(context_layer)
190    }
191}
192
193#[derive(Clone, Debug)]
194struct BertSelfOutput {
195    dense: Linear,
196    layer_norm: LayerNorm,
197    span: tracing::Span,
198}
199
200impl BertSelfOutput {
201    fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
202        let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
203        let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
204        Ok(Self {
205            dense,
206            layer_norm,
207            span: tracing::span!(tracing::Level::TRACE, "self-out"),
208        })
209    }
210
211    fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
212        let _enter = self.span.enter();
213        let xs = self.dense.forward(xs)?;
214        self.layer_norm.forward(&(xs + input_tensor)?)
215    }
216}
217
218#[derive(Clone, Debug)]
219struct BertAttention {
220    self_attention: BertSelfAttention,
221    self_output: BertSelfOutput,
222    span: tracing::Span,
223}
224
225impl BertAttention {
226    fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
227        let self_attention = BertSelfAttention::new(vb.pp("self"), cfg)?;
228        let self_output = BertSelfOutput::new(vb.pp("output"), cfg)?;
229        Ok(Self {
230            self_attention,
231            self_output,
232            span: tracing::span!(tracing::Level::TRACE, "attn"),
233        })
234    }
235
236    fn forward(&self, xs: &Tensor, bias: &Tensor) -> Result<Tensor> {
237        let _enter = self.span.enter();
238        let self_outputs = self.self_attention.forward(xs, bias)?;
239        let attention_output = self.self_output.forward(&self_outputs, xs)?;
240        Ok(attention_output)
241    }
242}
243
244#[derive(Clone, Debug)]
245struct BertGLUMLP {
246    gated_layers: Linear,
247    act: candle_nn::Activation,
248    wo: Linear,
249    layernorm: LayerNorm,
250    intermediate_size: usize,
251}
252
253impl BertGLUMLP {
254    fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
255        let gated_layers = linear_no_bias(
256            cfg.hidden_size,
257            cfg.intermediate_size * 2,
258            vb.pp("gated_layers"),
259        )?;
260        let act = candle_nn::Activation::Gelu; // geglu
261        let wo = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("wo"))?;
262        let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("layernorm"))?;
263        Ok(Self {
264            gated_layers,
265            act,
266            wo,
267            layernorm,
268            intermediate_size: cfg.intermediate_size,
269        })
270    }
271}
272
273impl Module for BertGLUMLP {
274    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
275        let residual = xs;
276        let xs = xs.apply(&self.gated_layers)?;
277        let gated = xs.narrow(D::Minus1, 0, self.intermediate_size)?;
278        let non_gated = xs.narrow(D::Minus1, self.intermediate_size, self.intermediate_size)?;
279        let xs = (gated.apply(&self.act) * non_gated)?.apply(&self.wo);
280        (xs + residual)?.apply(&self.layernorm)
281    }
282}
283
284#[derive(Clone, Debug)]
285struct BertLayer {
286    attention: BertAttention,
287    mlp: BertGLUMLP,
288    span: tracing::Span,
289}
290
291impl BertLayer {
292    fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
293        let attention = BertAttention::new(vb.pp("attention"), cfg)?;
294        let mlp = BertGLUMLP::new(vb.pp("mlp"), cfg)?;
295        Ok(Self {
296            attention,
297            mlp,
298            span: tracing::span!(tracing::Level::TRACE, "layer"),
299        })
300    }
301
302    fn forward(&self, xs: &Tensor, bias: &Tensor) -> Result<Tensor> {
303        let _enter = self.span.enter();
304        self.attention.forward(xs, bias)?.apply(&self.mlp)
305    }
306}
307
308fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
309    let n_heads = cfg.num_attention_heads;
310    let seq_len = cfg.max_position_embeddings;
311    let alibi_bias = Tensor::arange(0, seq_len as i64, &Device::Cpu)?.to_dtype(DType::F32)?;
312    let alibi_bias = {
313        let a1 = alibi_bias.reshape((1, seq_len))?;
314        let a2 = alibi_bias.reshape((seq_len, 1))?;
315        a1.broadcast_sub(&a2)?.abs()?.broadcast_left(n_heads)?
316    };
317    let mut n_heads2 = 1;
318    while n_heads2 < n_heads {
319        n_heads2 *= 2
320    }
321    let slopes = (1..=n_heads2)
322        .map(|v| -1f32 / 2f32.powf((v * 8) as f32 / n_heads2 as f32))
323        .collect::<Vec<_>>();
324    let slopes = if n_heads2 == n_heads {
325        slopes
326    } else {
327        slopes
328            .iter()
329            .skip(1)
330            .step_by(2)
331            .chain(slopes.iter().step_by(2))
332            .take(n_heads)
333            .cloned()
334            .collect::<Vec<f32>>()
335    };
336    let slopes = Tensor::new(slopes, &Device::Cpu)?.reshape((1, (), 1, 1))?;
337    alibi_bias.to_dtype(DType::F32)?.broadcast_mul(&slopes)
338}
339
340#[derive(Clone, Debug)]
341struct BertEncoder {
342    alibi: Tensor,
343    layers: Vec<BertLayer>,
344    span: tracing::Span,
345}
346
347impl BertEncoder {
348    fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
349        if cfg.position_embedding_type != PositionEmbeddingType::Alibi {
350            candle::bail!("only alibi is supported as a position-embedding-type")
351        }
352        let layers = (0..cfg.num_hidden_layers)
353            .map(|index| BertLayer::new(vb.pp(format!("layer.{index}")), cfg))
354            .collect::<Result<Vec<_>>>()?;
355        let span = tracing::span!(tracing::Level::TRACE, "encoder");
356        let alibi = build_alibi_bias(cfg)?.to_device(vb.device())?;
357        Ok(Self {
358            alibi,
359            layers,
360            span,
361        })
362    }
363}
364
365impl Module for BertEncoder {
366    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
367        let _enter = self.span.enter();
368        let seq_len = xs.dim(1)?;
369        let alibi_bias = self.alibi.i((.., .., ..seq_len, ..seq_len))?;
370        let mut xs = xs.clone();
371        for layer in self.layers.iter() {
372            xs = layer.forward(&xs, &alibi_bias)?
373        }
374        Ok(xs)
375    }
376}
377
378#[derive(Clone, Debug)]
379pub struct BertModel {
380    embeddings: BertEmbeddings,
381    encoder: BertEncoder,
382    pub device: Device,
383    span: tracing::Span,
384}
385
386impl BertModel {
387    pub fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
388        let embeddings = BertEmbeddings::new(vb.pp("embeddings"), cfg)?;
389        let encoder = BertEncoder::new(vb.pp("encoder"), cfg)?;
390        Ok(Self {
391            embeddings,
392            encoder,
393            device: vb.device().clone(),
394            span: tracing::span!(tracing::Level::TRACE, "model"),
395        })
396    }
397}
398
399impl Module for BertModel {
400    fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
401        let _enter = self.span.enter();
402        let embedding_output = self.embeddings.forward(input_ids)?;
403        let sequence_output = self.encoder.forward(&embedding_output)?;
404        Ok(sequence_output)
405    }
406}