candle_transformers/models/chinese_clip/
text_model.rs

1//! Chinese contrastive Language-Image Pre-Training
2//!
3//! Chinese contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
4//! pairs of images with related texts.
5//!
6//! - 💻 [Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP)
7//! - 💻 [HF](https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py)
8
9use candle::{DType, Device, IndexOp, Module, Result, Tensor};
10use candle_nn as nn;
11
12use super::Activation;
13
14/// Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
15/// positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
16/// [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
17/// For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
18/// with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
19#[derive(Clone, Debug)]
20pub enum PositionEmbeddingType {
21    Absolute,
22    RelativeKey,
23    RelativeKeyQuery,
24}
25
26#[derive(Clone, Debug)]
27pub struct ChineseClipTextConfig {
28    pub vocab_size: usize,
29    pub hidden_size: usize,
30    pub num_hidden_layers: usize,
31    pub num_attention_heads: usize,
32    pub intermediate_size: usize,
33    pub hidden_act: Activation,
34    pub hidden_dropout_prob: f32,
35    pub attention_probs_dropout_prob: f64,
36    pub max_position_embeddings: usize,
37    pub type_vocab_size: usize,
38    pub initializer_range: f64,
39    pub initializer_factor: f64,
40    pub layer_norm_eps: f64,
41    pub pad_token_id: usize,
42    pub position_embedding_type: PositionEmbeddingType,
43    pub use_cache: bool,
44}
45
46impl Default for ChineseClipTextConfig {
47    fn default() -> Self {
48        Self {
49            vocab_size: 30522,
50            hidden_size: 768,
51            num_hidden_layers: 12,
52            num_attention_heads: 12,
53            intermediate_size: 3072,
54            hidden_act: Activation::Gelu,
55            hidden_dropout_prob: 0.1,
56            attention_probs_dropout_prob: 0.1,
57            max_position_embeddings: 512,
58            type_vocab_size: 2,
59            initializer_range: 0.02,
60            initializer_factor: 1.0,
61            layer_norm_eps: 1e-12,
62            pad_token_id: 0,
63            position_embedding_type: PositionEmbeddingType::Absolute,
64            use_cache: true,
65        }
66    }
67}
68
69impl ChineseClipTextConfig {
70    /// [referer](https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16/blob/main/config.json)
71    pub fn clip_vit_base_patch16() -> Self {
72        Self {
73            vocab_size: 21128,
74            hidden_size: 768,
75            num_hidden_layers: 12,
76            num_attention_heads: 12,
77            intermediate_size: 3072,
78            hidden_act: Activation::Gelu,
79            hidden_dropout_prob: 0.1,
80            attention_probs_dropout_prob: 0.1,
81            max_position_embeddings: 512,
82            type_vocab_size: 2,
83            initializer_range: 0.02,
84            initializer_factor: 1.0,
85            layer_norm_eps: 1e-12,
86            pad_token_id: 0,
87            position_embedding_type: PositionEmbeddingType::Absolute,
88            use_cache: true,
89        }
90    }
91}
92
93#[derive(Clone, Debug)]
94pub struct ChineseClipTextEmbeddings {
95    word_embeddings: nn::Embedding,
96    position_embeddings: nn::Embedding,
97    token_type_embeddings: nn::Embedding,
98    layer_norm: nn::LayerNorm,
99    dropout: nn::Dropout,
100    position_embedding_type: PositionEmbeddingType,
101    position_ids: Tensor,
102    token_type_ids: Tensor,
103}
104
105impl ChineseClipTextEmbeddings {
106    pub fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
107        let word_embeddings = nn::embedding(
108            config.vocab_size,
109            config.hidden_size,
110            var.pp("word_embeddings"),
111        )?;
112        let position_embeddings = nn::embedding(
113            config.max_position_embeddings,
114            config.hidden_size,
115            var.pp("position_embeddings"),
116        )?;
117        let token_type_embeddings = nn::embedding(
118            config.type_vocab_size,
119            config.hidden_size,
120            var.pp("token_type_embeddings"),
121        )?;
122        let layer_norm = nn::layer_norm::<f64>(
123            config.hidden_size,
124            config.layer_norm_eps,
125            var.pp("LayerNorm"),
126        )?;
127        let dropout = nn::Dropout::new(config.hidden_dropout_prob);
128        let position_ids =
129            Tensor::arange(0u32, config.max_position_embeddings as u32, var.device())?
130                .unsqueeze(0)?;
131        let token_type_ids = Tensor::zeros(position_ids.shape(), DType::I64, var.device())?;
132
133        Ok(Self {
134            word_embeddings,
135            position_embeddings,
136            token_type_embeddings,
137            layer_norm,
138            dropout,
139            position_embedding_type: config.position_embedding_type.clone(),
140            position_ids,
141            token_type_ids,
142        })
143    }
144
145    fn forward(&self, xs: &Tensor, token_type_ids: Option<&Tensor>) -> Result<Tensor> {
146        let (_batch_size, seq_length) = xs.dims2()?;
147        let position_ids = (0..seq_length as u32).collect::<Vec<_>>();
148        let position_ids = self.position_ids.index_select(
149            &Tensor::new(&position_ids[..], self.position_ids.device())?,
150            1,
151        )?;
152
153        let word_embeddings = self.word_embeddings.forward(xs)?;
154
155        let token_type_ids = match token_type_ids {
156            Some(token_type_ids) => token_type_ids,
157            None => &self.token_type_ids.i((.., 0..seq_length))?,
158        };
159        let token_type_ids = token_type_ids.expand(xs.shape())?;
160        let token_type_embeddings = self.token_type_embeddings.forward(&token_type_ids)?;
161
162        let embeddings = (&word_embeddings + token_type_embeddings)?;
163        let embeddings = match self.position_embedding_type {
164            PositionEmbeddingType::Absolute => {
165                let position_embeddings = self.position_embeddings.forward(&position_ids)?;
166                let position_embeddings = position_embeddings.expand(embeddings.shape())?;
167                (embeddings + position_embeddings)?
168            }
169            _ => embeddings,
170        };
171        let embeddings = self.layer_norm.forward(&embeddings)?;
172        let embeddings = self.dropout.forward(&embeddings, false)?;
173        Ok(embeddings)
174    }
175}
176
177/// Copied from [`crate::models::bert::BertSelfOutput`] to [`ChineseClipTextSelfOutput`]
178#[derive(Clone, Debug)]
179struct ChineseClipTextSelfOutput {
180    dense: nn::Linear,
181    layer_norm: nn::LayerNorm,
182    dropout: nn::Dropout,
183    span: tracing::Span,
184}
185
186impl ChineseClipTextSelfOutput {
187    fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
188        let dense = nn::linear(config.hidden_size, config.hidden_size, var.pp("dense"))?;
189        let layer_norm = nn::layer_norm(
190            config.hidden_size,
191            config.layer_norm_eps,
192            var.pp("LayerNorm"),
193        )?;
194        let dropout = nn::Dropout::new(config.hidden_dropout_prob);
195        Ok(Self {
196            dense,
197            layer_norm,
198            dropout,
199            span: tracing::span!(tracing::Level::TRACE, "self-out"),
200        })
201    }
202
203    fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
204        let _enter = self.span.enter();
205        let hidden_states = self.dense.forward(hidden_states)?;
206        let hidden_states = self.dropout.forward(&hidden_states, false)?;
207        self.layer_norm.forward(&(hidden_states + input_tensor)?)
208    }
209}
210
211/// Copied from [`crate::models::bert::BertSelfAttention`] to [`ChineseClipTextSelfAttention`]
212#[derive(Clone, Debug)]
213struct ChineseClipTextSelfAttention {
214    query: nn::Linear,
215    key: nn::Linear,
216    value: nn::Linear,
217    dropout: nn::Dropout,
218    num_attention_heads: usize,
219    attention_head_size: usize,
220    span: tracing::Span,
221    span_softmax: tracing::Span,
222}
223
224impl ChineseClipTextSelfAttention {
225    fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
226        let attention_head_size = config.hidden_size / config.num_attention_heads;
227        let all_head_size = config.num_attention_heads * attention_head_size;
228        let dropout = nn::Dropout::new(config.hidden_dropout_prob);
229        let hidden_size = config.hidden_size;
230        let query = nn::linear(hidden_size, all_head_size, var.pp("query"))?;
231        let value = nn::linear(hidden_size, all_head_size, var.pp("value"))?;
232        let key = nn::linear(hidden_size, all_head_size, var.pp("key"))?;
233        Ok(Self {
234            query,
235            key,
236            value,
237            dropout,
238            num_attention_heads: config.num_attention_heads,
239            attention_head_size,
240            span: tracing::span!(tracing::Level::TRACE, "self-attn"),
241            span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
242        })
243    }
244
245    fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
246        let mut new_x_shape = xs.dims().to_vec();
247        new_x_shape.pop();
248        new_x_shape.push(self.num_attention_heads);
249        new_x_shape.push(self.attention_head_size);
250        let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
251        xs.contiguous()
252    }
253
254    fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
255        let _enter = self.span.enter();
256        let query_layer = self.query.forward(hidden_states)?;
257        let key_layer = self.key.forward(hidden_states)?;
258        let value_layer = self.value.forward(hidden_states)?;
259
260        let query_layer = self.transpose_for_scores(&query_layer)?;
261        let key_layer = self.transpose_for_scores(&key_layer)?;
262        let value_layer = self.transpose_for_scores(&value_layer)?;
263
264        let attention_scores = query_layer.matmul(&key_layer.t()?)?;
265        let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
266        let attention_scores = attention_scores.broadcast_add(attention_mask)?;
267        let attention_probs = {
268            let _enter_sm = self.span_softmax.enter();
269            nn::ops::softmax(&attention_scores, candle::D::Minus1)?
270        };
271        let attention_probs = self.dropout.forward(&attention_probs, false)?;
272
273        let context_layer = attention_probs.matmul(&value_layer)?;
274        let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
275        let context_layer = context_layer.flatten_from(candle::D::Minus2)?;
276        Ok(context_layer)
277    }
278}
279
280/// Copied from [`crate::models::bert::BertAttention`] to [`ChineseClipTextAttention`]
281#[derive(Clone, Debug)]
282struct ChineseClipTextAttention {
283    self_attention: ChineseClipTextSelfAttention,
284    self_output: ChineseClipTextSelfOutput,
285    span: tracing::Span,
286}
287
288impl ChineseClipTextAttention {
289    fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
290        let self_attention = ChineseClipTextSelfAttention::new(var.pp("self"), config)?;
291        let self_output = ChineseClipTextSelfOutput::new(var.pp("output"), config)?;
292        Ok(Self {
293            self_attention,
294            self_output,
295            span: tracing::span!(tracing::Level::TRACE, "attn"),
296        })
297    }
298
299    fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
300        let _enter = self.span.enter();
301        let self_outputs = self.self_attention.forward(hidden_states, attention_mask)?;
302        let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
303        Ok(attention_output)
304    }
305}
306
307type HiddenActLayer = Activation;
308
309/// Copied from [`crate::models::bert::BertIntermediate`] to [`ChineseClipTextIntermediate`]
310#[derive(Clone, Debug)]
311struct ChineseClipTextIntermediate {
312    dense: nn::Linear,
313    intermediate_act: HiddenActLayer,
314    span: tracing::Span,
315}
316
317impl ChineseClipTextIntermediate {
318    fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
319        let dense = nn::linear(
320            config.hidden_size,
321            config.intermediate_size,
322            var.pp("dense"),
323        )?;
324        Ok(Self {
325            dense,
326            intermediate_act: config.hidden_act,
327            span: tracing::span!(tracing::Level::TRACE, "inter"),
328        })
329    }
330}
331
332impl Module for ChineseClipTextIntermediate {
333    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
334        let _enter = self.span.enter();
335        let hidden_states = self.dense.forward(hidden_states)?;
336        let ys = self.intermediate_act.forward(&hidden_states)?;
337        Ok(ys)
338    }
339}
340
341/// Copied from [`crate::models::bert::BertOutput`] to [`ChineseClipTextOutput`]
342#[derive(Clone, Debug)]
343struct ChineseClipTextOutput {
344    dense: nn::Linear,
345    layer_norm: nn::LayerNorm,
346    dropout: nn::Dropout,
347    span: tracing::Span,
348}
349
350impl ChineseClipTextOutput {
351    fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
352        let dense = nn::linear(
353            config.intermediate_size,
354            config.hidden_size,
355            var.pp("dense"),
356        )?;
357        let layer_norm = nn::layer_norm(
358            config.hidden_size,
359            config.layer_norm_eps,
360            var.pp("LayerNorm"),
361        )?;
362        let dropout = nn::Dropout::new(config.hidden_dropout_prob);
363        Ok(Self {
364            dense,
365            layer_norm,
366            dropout,
367            span: tracing::span!(tracing::Level::TRACE, "out"),
368        })
369    }
370
371    fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
372        let _enter = self.span.enter();
373        let hidden_states = self.dense.forward(hidden_states)?;
374        let hidden_states = self.dropout.forward(&hidden_states, false)?;
375        self.layer_norm.forward(&(hidden_states + input_tensor)?)
376    }
377}
378
379/// Copied from [`crate::models::bert::BertLayer`] to [`ChineseClipTextLayer`]
380#[derive(Clone, Debug)]
381struct ChineseClipTextLayer {
382    attention: ChineseClipTextAttention,
383    intermediate: ChineseClipTextIntermediate,
384    output: ChineseClipTextOutput,
385    span: tracing::Span,
386}
387
388impl ChineseClipTextLayer {
389    fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
390        let attention = ChineseClipTextAttention::new(var.pp("attention"), config)?;
391        let intermediate = ChineseClipTextIntermediate::new(var.pp("intermediate"), config)?;
392        let output = ChineseClipTextOutput::new(var.pp("output"), config)?;
393        Ok(Self {
394            attention,
395            intermediate,
396            output,
397            span: tracing::span!(tracing::Level::TRACE, "layer"),
398        })
399    }
400
401    fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
402        let _enter = self.span.enter();
403        let attention_output = self.attention.forward(hidden_states, attention_mask)?;
404        // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
405        let intermediate_output = self.intermediate.forward(&attention_output)?;
406        let layer_output = self
407            .output
408            .forward(&intermediate_output, &attention_output)?;
409        Ok(layer_output)
410    }
411}
412
413#[derive(Clone, Debug)]
414struct Tanh;
415
416impl Tanh {
417    pub fn new() -> Self {
418        Self {}
419    }
420}
421impl Module for Tanh {
422    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
423        xs.tanh()
424    }
425}
426
427#[derive(Clone, Debug)]
428struct ChineseClipTextPooler {
429    dense: nn::Linear,
430    activation: Tanh,
431}
432
433impl ChineseClipTextPooler {
434    pub fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
435        let dense = nn::linear(config.hidden_size, config.hidden_size, var.pp("dense"))?;
436        let activation = Tanh::new();
437        Ok(Self { dense, activation })
438    }
439}
440
441impl Module for ChineseClipTextPooler {
442    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
443        let first_token_tensor = hidden_states.i((.., 0))?;
444        let pooled_output = self.dense.forward(&first_token_tensor)?;
445        let pooled_output = self.activation.forward(&pooled_output)?;
446        Ok(pooled_output)
447    }
448}
449
450#[derive(Clone, Debug)]
451struct ChineseClipTextEncoder {
452    layers: Vec<ChineseClipTextLayer>,
453    span: tracing::Span,
454}
455
456impl ChineseClipTextEncoder {
457    fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
458        let layers = (0..config.num_hidden_layers)
459            .map(|index| ChineseClipTextLayer::new(var.pp(format!("layer.{index}")), config))
460            .collect::<Result<Vec<_>>>()?;
461        let span = tracing::span!(tracing::Level::TRACE, "encoder");
462        Ok(ChineseClipTextEncoder { layers, span })
463    }
464
465    fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
466        let _enter = self.span.enter();
467        let mut hidden_states = hidden_states.clone();
468        // Use a loop rather than a fold as it's easier to modify when adding debug/...
469        for layer in self.layers.iter() {
470            hidden_states = layer.forward(&hidden_states, attention_mask)?
471        }
472        Ok(hidden_states)
473    }
474}
475
476#[derive(Clone, Debug)]
477pub struct ChineseClipTextTransformer {
478    embeddings: ChineseClipTextEmbeddings,
479    encoder: ChineseClipTextEncoder,
480    pooler: Option<ChineseClipTextPooler>,
481    pub device: Device,
482    span: tracing::Span,
483}
484
485impl ChineseClipTextTransformer {
486    pub fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
487        let embeddings = ChineseClipTextEmbeddings::new(var.pp("embeddings"), config)?;
488        let encoder = ChineseClipTextEncoder::new(var.pp("encoder"), config)?;
489        // see: https://github.com/huggingface/transformers/blob/e40bb4845e0eefb52ec1e9cac9c2446ab36aef81/src/transformers/models/chinese_clip/modeling_chinese_clip.py#L1362
490        // In the original Python version of the code, the pooler is not used, and there are no parameters for the pooler in the weight file.
491        let pooler = if var.contains_tensor("pooler") {
492            Some(ChineseClipTextPooler::new(var.pp("pooler"), config)?)
493        } else {
494            None
495        };
496        Ok(Self {
497            embeddings,
498            encoder,
499            pooler,
500            device: var.device().clone(),
501            span: tracing::span!(tracing::Level::TRACE, "model"),
502        })
503    }
504
505    pub fn forward(
506        &self,
507        input_ids: &Tensor,
508        token_type_ids: Option<&Tensor>,
509        attention_mask: Option<&Tensor>,
510    ) -> Result<Tensor> {
511        let _enter = self.span.enter();
512        let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
513        let attention_mask = match attention_mask {
514            Some(attention_mask) => attention_mask.clone(),
515            None => input_ids.ones_like()?,
516        };
517        // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
518        let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?;
519        let encoder_outputs = self.encoder.forward(&embedding_output, &attention_mask)?;
520        let encoder_output = encoder_outputs.i((.., 0, ..))?;
521        let pooled_output = match &self.pooler {
522            Some(pooler) => pooler.forward(&encoder_output)?,
523            None => encoder_output,
524        };
525
526        Ok(pooled_output)
527    }
528}
529
530fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<Tensor> {
531    let attention_mask = match attention_mask.rank() {
532        3 => attention_mask.unsqueeze(1)?,
533        2 => attention_mask.unsqueeze(1)?.unsqueeze(1)?,
534        _ => candle::bail!("Wrong shape for input_ids or attention_mask"),
535    };
536    let attention_mask = attention_mask.to_dtype(dtype)?;
537    // torch.finfo(dtype).min
538    (attention_mask.ones_like()? - &attention_mask)?
539        .broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
540}