candle_transformers/models/clip/
text_model.rs

1//! Contrastive Language-Image Pre-Training
2//!
3//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
4//! pairs of images with related texts.
5//!
6//! - [GH](https://github.com/openai/CLIP)
7//! - [Code](https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip)
8
9use candle::{DType, Device, IndexOp, Result, Tensor, D};
10use candle_nn as nn;
11use candle_nn::Module;
12
13use super::EncoderConfig;
14
15#[derive(Debug, Clone, Copy)]
16pub enum Activation {
17    QuickGelu,
18}
19
20impl Module for Activation {
21    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
22        match self {
23            Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
24        }
25    }
26}
27
28#[derive(Debug, Clone)]
29pub struct ClipTextConfig {
30    pub vocab_size: usize,
31    pub embed_dim: usize,
32    pub activation: Activation,
33    pub intermediate_size: usize,
34    pub max_position_embeddings: usize,
35    pub pad_with: Option<String>,
36    pub num_hidden_layers: usize,
37    pub num_attention_heads: usize,
38    #[allow(dead_code)]
39    pub projection_dim: usize,
40}
41
42impl ClipTextConfig {
43    // The config details can be found in the "text_config" section of this json file:
44    // https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
45    pub fn vit_base_patch32() -> Self {
46        Self {
47            vocab_size: 49408,
48            embed_dim: 512,
49            intermediate_size: 2048,
50            max_position_embeddings: 77,
51            pad_with: None,
52            num_hidden_layers: 12,
53            num_attention_heads: 8,
54            projection_dim: 512,
55            activation: Activation::QuickGelu,
56        }
57    }
58}
59
60// ClipTextEmbeddings mostly based on the existing implementation in the stable diffision model.
61// TODO rewrite to be more similar to https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L142
62#[derive(Clone, Debug)]
63struct ClipTextEmbeddings {
64    token_embedding: candle_nn::Embedding,
65    position_embedding: candle_nn::Embedding,
66    position_ids: Tensor,
67}
68
69impl ClipTextEmbeddings {
70    fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result<Self> {
71        let token_embedding =
72            candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
73        let position_embedding: nn::Embedding = candle_nn::embedding(
74            c.max_position_embeddings,
75            c.embed_dim,
76            vs.pp("position_embedding"),
77        )?;
78        let position_ids =
79            Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
80        Ok(Self {
81            token_embedding,
82            position_embedding,
83            position_ids,
84        })
85    }
86}
87
88impl Module for ClipTextEmbeddings {
89    fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
90        let seq_length = input_ids.dim(D::Minus1)?;
91        let inputs_embeds = self.token_embedding.forward(input_ids)?;
92        let position_ids = self.position_ids.narrow(1, 0, seq_length)?;
93        let position_embedding = self.position_embedding.forward(&position_ids)?;
94        inputs_embeds.broadcast_add(&position_embedding)
95    }
96}
97
98#[derive(Clone, Debug)]
99struct ClipAttention {
100    k_proj: candle_nn::Linear,
101    v_proj: candle_nn::Linear,
102    q_proj: candle_nn::Linear,
103    out_proj: candle_nn::Linear,
104    head_dim: usize,
105    scale: f64,
106    num_attention_heads: usize,
107}
108
109impl ClipAttention {
110    fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result<Self> {
111        let embed_dim = c.embed_dim();
112        let num_attention_heads = c.num_attention_heads();
113        let k_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("k_proj"))?;
114        let v_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("v_proj"))?;
115        let q_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("q_proj"))?;
116        let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?;
117        let head_dim = embed_dim / num_attention_heads;
118        let scale = (head_dim as f64).powf(-0.5);
119
120        Ok(ClipAttention {
121            k_proj,
122            v_proj,
123            q_proj,
124            out_proj,
125            head_dim,
126            scale,
127            num_attention_heads,
128        })
129    }
130
131    fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
132        xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
133            .transpose(1, 2)?
134            .contiguous()
135    }
136
137    fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
138        let in_dtype = xs.dtype();
139        let (bsz, seq_len, embed_dim) = xs.dims3()?;
140
141        let query_states = (self.q_proj.forward(xs)? * self.scale)?;
142        let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
143        let query_states = self
144            .shape(&query_states, seq_len, bsz)?
145            .reshape(proj_shape)?
146            .to_dtype(DType::F32)?;
147        let key_states = self
148            .shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
149            .reshape(proj_shape)?
150            .to_dtype(DType::F32)?;
151        let value_states = self
152            .shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
153            .reshape(proj_shape)?
154            .to_dtype(DType::F32)?;
155        let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
156
157        let src_len = key_states.dim(1)?;
158
159        let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask {
160            attn_weights
161                .reshape((bsz, self.num_attention_heads, seq_len, src_len))?
162                .broadcast_add(causal_attention_mask)?
163                .reshape((bsz * self.num_attention_heads, seq_len, src_len))?
164        } else {
165            attn_weights
166        };
167
168        let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
169
170        let attn_output = attn_weights.matmul(&value_states)?.to_dtype(in_dtype)?;
171        let attn_output = attn_output
172            .reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
173            .transpose(1, 2)?
174            .reshape((bsz, seq_len, embed_dim))?;
175        self.out_proj.forward(&attn_output)
176    }
177}
178
179#[derive(Clone, Debug)]
180struct ClipMlp {
181    fc1: candle_nn::Linear,
182    fc2: candle_nn::Linear,
183    activation: Activation,
184}
185
186impl ClipMlp {
187    fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result<Self> {
188        let fc1 = candle_nn::linear(c.embed_dim(), c.intermediate_size(), vs.pp("fc1"))?;
189        let fc2 = candle_nn::linear(c.intermediate_size(), c.embed_dim(), vs.pp("fc2"))?;
190
191        Ok(ClipMlp {
192            fc1,
193            fc2,
194            activation: c.activation(),
195        })
196    }
197}
198
199impl ClipMlp {
200    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
201        let xs = self.fc1.forward(xs)?;
202        self.fc2.forward(&self.activation.forward(&xs)?)
203    }
204}
205
206#[derive(Clone, Debug)]
207struct ClipEncoderLayer {
208    self_attn: ClipAttention,
209    layer_norm1: candle_nn::LayerNorm,
210    mlp: ClipMlp,
211    layer_norm2: candle_nn::LayerNorm,
212}
213
214impl ClipEncoderLayer {
215    fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result<Self> {
216        let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?;
217        let layer_norm1 = candle_nn::layer_norm(c.embed_dim(), 1e-5, vs.pp("layer_norm1"))?;
218        let mlp = ClipMlp::new(vs.pp("mlp"), c)?;
219        let layer_norm2 = candle_nn::layer_norm(c.embed_dim(), 1e-5, vs.pp("layer_norm2"))?;
220
221        Ok(ClipEncoderLayer {
222            self_attn,
223            layer_norm1,
224            mlp,
225            layer_norm2,
226        })
227    }
228
229    fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
230        let residual = xs;
231        let xs = self.layer_norm1.forward(xs)?;
232        let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
233        let xs = (xs + residual)?;
234
235        let residual = &xs;
236        let xs = self.layer_norm2.forward(&xs)?;
237        let xs = self.mlp.forward(&xs)?;
238        xs + residual
239    }
240}
241
242#[derive(Clone, Debug)]
243pub struct ClipEncoder {
244    layers: Vec<ClipEncoderLayer>,
245}
246
247impl ClipEncoder {
248    pub fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result<Self> {
249        let vs = vs.pp("layers");
250        let mut layers: Vec<ClipEncoderLayer> = Vec::new();
251        for index in 0..c.num_hidden_layers() {
252            let layer = ClipEncoderLayer::new(vs.pp(index.to_string()), c)?;
253            layers.push(layer)
254        }
255        Ok(ClipEncoder { layers })
256    }
257
258    pub fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
259        let mut xs = xs.clone();
260        for layer in self.layers.iter() {
261            xs = layer.forward(&xs, causal_attention_mask)?;
262        }
263        Ok(xs)
264    }
265    // required by LLaVA
266    pub fn output_hidden_states(
267        &self,
268        xs: &Tensor,
269        causal_attention_mask: Option<&Tensor>,
270    ) -> Result<Vec<Tensor>> {
271        let mut xs = xs.clone();
272        let mut hidden_states = Vec::new();
273        for layer in self.layers.iter() {
274            xs = layer.forward(&xs, causal_attention_mask)?;
275            hidden_states.push(xs.clone());
276        }
277        Ok(hidden_states)
278    }
279}
280
281/// A CLIP transformer based model.
282#[derive(Clone, Debug)]
283pub struct ClipTextTransformer {
284    embeddings: ClipTextEmbeddings,
285    encoder: ClipEncoder,
286    final_layer_norm: candle_nn::LayerNorm,
287}
288
289impl ClipTextTransformer {
290    pub fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result<Self> {
291        let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
292        let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Text(c.clone()))?;
293        let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?;
294        Ok(ClipTextTransformer {
295            embeddings,
296            encoder,
297            final_layer_norm,
298        })
299    }
300
301    // TODO: rewrite to newer version
302    fn build_causal_attention_mask(
303        bsz: usize,
304        seq_len: usize,
305        mask_after: usize,
306        device: &Device,
307    ) -> Result<Tensor> {
308        let mask: Vec<_> = (0..seq_len)
309            .flat_map(|i| {
310                (0..seq_len).map(move |j| {
311                    if j > i || j > mask_after {
312                        f32::MIN
313                    } else {
314                        0.
315                    }
316                })
317            })
318            .collect();
319        let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
320        mask.broadcast_as((bsz, 1, seq_len, seq_len))
321    }
322
323    pub fn forward_with_mask(&self, input_ids: &Tensor, mask_after: usize) -> Result<Tensor> {
324        let (bsz, seq_len) = input_ids.dims2()?;
325        let input_ids = self.embeddings.forward(input_ids)?;
326        let causal_attention_mask =
327            Self::build_causal_attention_mask(bsz, seq_len, mask_after, input_ids.device())?;
328        let input_ids = self
329            .encoder
330            .forward(&input_ids, Some(&causal_attention_mask))?;
331        self.final_layer_norm.forward(&input_ids)
332    }
333}
334
335impl Module for ClipTextTransformer {
336    fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
337        let output = self.forward_with_mask(input_ids, usize::MAX)?;
338        let sequence_max_indices = input_ids.argmax(D::Minus1)?.to_dtype(DType::I64)?;
339
340        let mut indices = Vec::new();
341        for (batch_idx, &seq_idx) in sequence_max_indices.to_vec1::<i64>()?.iter().enumerate() {
342            let index = output.i((batch_idx, seq_idx as usize))?.unsqueeze(0)?;
343            indices.push(index);
344        }
345        Tensor::cat(&indices, 0)
346    }
347}