candle_transformers/models/chinese_clip/
vision_model.rs

1//! Chinese contrastive Language-Image Pre-Training
2//!
3//! Chinese contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
4//! pairs of images with related texts.
5//!
6//! - 💻 [Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP)
7//! - 💻 [GH](https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py_
8
9use candle::{Context, DType, IndexOp, Module, Result, Shape, Tensor, D};
10use candle_nn as nn;
11
12use super::{Activation, EncoderConfig};
13
14#[derive(Clone, Debug)]
15pub struct ChineseClipVisionConfig {
16    pub hidden_size: usize,
17    pub intermediate_size: usize,
18    pub projection_dim: usize,
19    pub num_hidden_layers: usize,
20    pub num_attention_heads: usize,
21    pub num_channels: usize,
22    pub image_size: usize,
23    pub patch_size: usize,
24    pub hidden_act: Activation,
25    pub layer_norm_eps: f64,
26    pub attention_dropout: f32,
27    pub initializer_range: f32,
28    pub initializer_factor: f32,
29}
30
31impl Default for ChineseClipVisionConfig {
32    fn default() -> Self {
33        ChineseClipVisionConfig {
34            hidden_size: 768,
35            intermediate_size: 3072,
36            projection_dim: 512,
37            num_hidden_layers: 12,
38            num_attention_heads: 12,
39            num_channels: 3,
40            image_size: 224,
41            patch_size: 32,
42            hidden_act: Activation::QuickGelu,
43            layer_norm_eps: 1e-5,
44            attention_dropout: 0.0,
45            initializer_range: 0.02,
46            initializer_factor: 1.0,
47        }
48    }
49}
50
51impl ChineseClipVisionConfig {
52    /// [referer](https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16/blob/main/config.json)
53    pub fn clip_vit_base_patch16() -> Self {
54        Self {
55            hidden_size: 768,
56            intermediate_size: 3072,
57            projection_dim: 512,
58            num_hidden_layers: 12,
59            num_attention_heads: 12,
60            num_channels: 3,
61            image_size: 224,
62            patch_size: 16,
63            hidden_act: Activation::QuickGelu,
64            layer_norm_eps: 1e-5,
65            attention_dropout: 0.0,
66            initializer_range: 0.02,
67            initializer_factor: 1.0,
68        }
69    }
70}
71
72#[derive(Clone, Debug)]
73pub struct ChineseClipVisionEmbeddings {
74    patch_embedding: nn::Conv2d,
75    position_ids: Tensor,
76    class_embedding: Tensor,
77    position_embedding: nn::Embedding,
78}
79
80impl ChineseClipVisionEmbeddings {
81    pub fn new(var: nn::VarBuilder, config: &ChineseClipVisionConfig) -> Result<Self> {
82        let embed_dim = config.hidden_size;
83        // originally nn.Parameter
84        let class_embedding = if var.contains_tensor("class_embedding") {
85            var.get(embed_dim, "class_embedding")?
86        } else {
87            Tensor::randn(0f32, 1f32, embed_dim, var.device())?
88        };
89
90        let num_patches = (config.image_size / config.patch_size).pow(2);
91        let num_positions = num_patches + 1;
92        let position_ids = Tensor::arange(0, num_positions as i64, var.device())?;
93
94        let conv2dconfig = nn::Conv2dConfig {
95            stride: config.patch_size,
96            ..Default::default()
97        };
98        let position_embedding =
99            nn::embedding(num_positions, embed_dim, var.pp("position_embedding"))?;
100        let patch_embedding = nn::conv2d_no_bias(
101            config.num_channels,
102            embed_dim,
103            config.patch_size,
104            conv2dconfig,
105            var.pp("patch_embedding"),
106        )?;
107        Ok(Self {
108            patch_embedding,
109            position_ids,
110            class_embedding,
111            position_embedding,
112        })
113    }
114}
115
116impl Module for ChineseClipVisionEmbeddings {
117    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
118        let batch_size = xs.shape().dims();
119        let patch_embeds = self
120            .patch_embedding
121            .forward(xs)?
122            .flatten_from(2)?
123            .transpose(1, 2)?;
124        let shape = Shape::from((batch_size[0], 1, self.class_embedding.dim(D::Minus1)?));
125        let class_embeds = self.class_embedding.expand(shape)?;
126        let embeddings = Tensor::cat(&[class_embeds, patch_embeds], 1)?;
127        let position_embedding = self.position_embedding.forward(&self.position_ids)?;
128        embeddings.broadcast_add(&position_embedding)
129    }
130}
131
132#[derive(Clone, Debug)]
133struct ChineseClipVisionAttention {
134    k_proj: nn::Linear,
135    v_proj: nn::Linear,
136    q_proj: nn::Linear,
137    out_proj: nn::Linear,
138    head_dim: usize,
139    scale: f64,
140    num_attention_heads: usize,
141}
142
143impl ChineseClipVisionAttention {
144    fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result<Self> {
145        let embed_dim = config.embed_dim();
146        let num_attention_heads = config.num_attention_heads();
147        let k_proj = nn::linear(embed_dim, embed_dim, var.pp("k_proj"))?;
148        let v_proj = nn::linear(embed_dim, embed_dim, var.pp("v_proj"))?;
149        let q_proj = nn::linear(embed_dim, embed_dim, var.pp("q_proj"))?;
150        let out_proj = nn::linear(embed_dim, embed_dim, var.pp("out_proj"))?;
151        let head_dim = embed_dim / num_attention_heads;
152        let scale = (head_dim as f64).powf(-0.5);
153
154        Ok(ChineseClipVisionAttention {
155            k_proj,
156            v_proj,
157            q_proj,
158            out_proj,
159            head_dim,
160            scale,
161            num_attention_heads,
162        })
163    }
164
165    fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
166        xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
167            .transpose(1, 2)?
168            .contiguous()
169    }
170
171    fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
172        let in_dtype = xs.dtype();
173        let (bsz, seq_len, embed_dim) = xs.dims3()?;
174
175        let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
176        let query_states = self
177            .shape(&(self.q_proj.forward(xs)? * self.scale)?, seq_len, bsz)?
178            .reshape(proj_shape)?
179            .to_dtype(DType::F32)?;
180        let key_states = self
181            .shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
182            .reshape(proj_shape)?
183            .to_dtype(DType::F32)?;
184        let value_states = self
185            .shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
186            .reshape(proj_shape)?
187            .to_dtype(DType::F32)?;
188
189        let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
190
191        let src_len = key_states.dim(1)?;
192
193        let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask {
194            attn_weights
195                .reshape((bsz, self.num_attention_heads, seq_len, src_len))?
196                .broadcast_add(causal_attention_mask)?
197                .reshape((bsz * self.num_attention_heads, seq_len, src_len))?
198        } else {
199            attn_weights
200        };
201
202        let attn_weights = nn::ops::softmax(&attn_weights, D::Minus1)?;
203
204        let attn_output = attn_weights.matmul(&value_states)?.to_dtype(in_dtype)?;
205        let attn_output = attn_output
206            .reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
207            .transpose(1, 2)?
208            .reshape((bsz, seq_len, embed_dim))?;
209        self.out_proj.forward(&attn_output)
210    }
211}
212
213#[derive(Clone, Debug)]
214struct ChineseClipVisionMlp {
215    fc1: nn::Linear,
216    fc2: nn::Linear,
217    activation: Activation,
218}
219
220impl ChineseClipVisionMlp {
221    fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result<Self> {
222        let fc1 = nn::linear(
223            config.embed_dim(),
224            config.intermediate_size(),
225            var.pp("fc1"),
226        )?;
227        let fc2 = nn::linear(
228            config.intermediate_size(),
229            config.embed_dim(),
230            var.pp("fc2"),
231        )?;
232
233        Ok(ChineseClipVisionMlp {
234            fc1,
235            fc2,
236            activation: config.activation(),
237        })
238    }
239}
240
241impl ChineseClipVisionMlp {
242    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
243        let xs = self.fc1.forward(xs)?;
244        self.fc2.forward(&self.activation.forward(&xs)?)
245    }
246}
247
248#[derive(Clone, Debug)]
249struct ChineseClipVisionEncoderLayer {
250    self_attn: ChineseClipVisionAttention,
251    layer_norm1: nn::LayerNorm,
252    mlp: ChineseClipVisionMlp,
253    layer_norm2: nn::LayerNorm,
254}
255
256impl ChineseClipVisionEncoderLayer {
257    fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result<Self> {
258        let self_attn = ChineseClipVisionAttention::new(var.pp("self_attn"), config)?;
259        let layer_norm1 = nn::layer_norm(
260            config.embed_dim(),
261            config.layer_norm_eps(),
262            var.pp("layer_norm1"),
263        )?;
264        let mlp = ChineseClipVisionMlp::new(var.pp("mlp"), config)?;
265        let layer_norm2 = nn::layer_norm(
266            config.embed_dim(),
267            config.layer_norm_eps(),
268            var.pp("layer_norm2"),
269        )?;
270
271        Ok(ChineseClipVisionEncoderLayer {
272            self_attn,
273            layer_norm1,
274            mlp,
275            layer_norm2,
276        })
277    }
278
279    fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
280        let residual = xs;
281        let xs = self.layer_norm1.forward(xs)?;
282        let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
283        let xs = (xs + residual)?;
284
285        let residual = &xs;
286        let xs = self.layer_norm2.forward(&xs)?;
287        let xs = self.mlp.forward(&xs)?;
288        xs + residual
289    }
290}
291
292#[derive(Clone, Debug)]
293pub struct ChineseClipVisionEncoder {
294    layers: Vec<ChineseClipVisionEncoderLayer>,
295}
296
297impl ChineseClipVisionEncoder {
298    pub fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result<Self> {
299        let vs = var.pp("layers");
300        let mut layers: Vec<ChineseClipVisionEncoderLayer> = Vec::new();
301        for index in 0..config.num_hidden_layers() {
302            let layer = ChineseClipVisionEncoderLayer::new(vs.pp(index.to_string()), config)?;
303            layers.push(layer)
304        }
305        Ok(ChineseClipVisionEncoder { layers })
306    }
307
308    pub fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
309        let mut xs = xs.clone();
310        for layer in self.layers.iter() {
311            xs = layer.forward(&xs, causal_attention_mask)?;
312        }
313        Ok(xs)
314    }
315
316    // required by LLaVA
317    pub fn output_hidden_states(
318        &self,
319        xs: &Tensor,
320        causal_attention_mask: Option<&Tensor>,
321    ) -> Result<Vec<Tensor>> {
322        let mut xs = xs.clone();
323        let mut hidden_states = Vec::new();
324        for layer in self.layers.iter() {
325            xs = layer.forward(&xs, causal_attention_mask)?;
326            hidden_states.push(xs.clone());
327        }
328        Ok(hidden_states)
329    }
330}
331
332#[derive(Clone, Debug)]
333pub struct ChineseClipVisionTransformer {
334    embeddings: ChineseClipVisionEmbeddings,
335    encoder: ChineseClipVisionEncoder,
336    pre_layer_norm: nn::LayerNorm,
337    final_layer_norm: nn::LayerNorm,
338}
339
340impl ChineseClipVisionTransformer {
341    pub fn new(var: nn::VarBuilder, config: &ChineseClipVisionConfig) -> Result<Self> {
342        let embed_dim = config.hidden_size;
343        let embeddings = ChineseClipVisionEmbeddings::new(var.pp("embeddings"), config)?;
344        let pre_layer_norm =
345            nn::layer_norm(embed_dim, config.layer_norm_eps, var.pp("pre_layrnorm"))?;
346        let encoder = ChineseClipVisionEncoder::new(
347            var.pp("encoder"),
348            &EncoderConfig::Vision(config.clone()),
349        )?;
350        let final_layer_norm =
351            nn::layer_norm(embed_dim, config.layer_norm_eps, var.pp("post_layernorm"))?;
352        Ok(Self {
353            embeddings,
354            encoder,
355            final_layer_norm,
356            pre_layer_norm,
357        })
358    }
359    // required by LLaVA
360    pub fn output_hidden_states(&self, pixel_values: &Tensor) -> Result<Vec<Tensor>> {
361        let hidden_states = pixel_values
362            .apply(&self.embeddings)?
363            .apply(&self.pre_layer_norm)?;
364
365        let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
366        let encoder_outputs = result.last().context("no last")?;
367        let pooled_output = encoder_outputs.i((.., 0, ..))?;
368        result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
369        Ok(result)
370    }
371}
372
373impl Module for ChineseClipVisionTransformer {
374    fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
375        let hidden_states = pixel_values
376            .apply(&self.embeddings)?
377            .apply(&self.pre_layer_norm)?;
378
379        let encoder_outputs = self.encoder.forward(&hidden_states, None)?;
380
381        // referer: https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L787
382        let pooled_output = encoder_outputs.i((.., 0, ..))?;
383        self.final_layer_norm.forward(&pooled_output)
384    }
385}