candle_transformers/models/clip/
vision_model.rs

1//! Contrastive Language-Image Pre-Training
2//!
3//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
4//! pairs of images with related texts.
5//!
6//! https://github.com/openai/CLIP
7//! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
8
9use candle::{Context, IndexOp, Result, Shape, Tensor, D};
10use candle_nn as nn;
11use candle_nn::Module;
12use nn::Conv2dConfig;
13
14use super::{
15    text_model::{Activation, ClipEncoder},
16    EncoderConfig,
17};
18
19#[derive(Debug, Clone)]
20pub struct ClipVisionConfig {
21    pub embed_dim: usize,
22    pub activation: Activation,
23    pub intermediate_size: usize,
24    pub num_hidden_layers: usize,
25    pub num_attention_heads: usize,
26    #[allow(dead_code)]
27    pub projection_dim: usize,
28    pub num_channels: usize,
29    pub image_size: usize,
30    pub patch_size: usize,
31}
32
33impl ClipVisionConfig {
34    // The config details can be found in the "vision_config" section of this json file:
35    // https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
36    pub fn vit_base_patch32() -> Self {
37        Self {
38            embed_dim: 768,
39            activation: Activation::QuickGelu,
40            intermediate_size: 3072,
41            num_hidden_layers: 12,
42            num_attention_heads: 12,
43            projection_dim: 512,
44            num_channels: 3,
45            image_size: 224,
46            patch_size: 32,
47        }
48    }
49    pub fn clip_vit_large_patch14_336() -> Self {
50        Self {
51            embed_dim: 1024,
52            activation: Activation::QuickGelu,
53            intermediate_size: 4096,
54            num_hidden_layers: 24,
55            num_attention_heads: 16,
56            projection_dim: 768,
57            num_channels: 3,
58            image_size: 336,
59            patch_size: 14,
60        }
61    }
62}
63
64// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L112
65#[derive(Clone, Debug)]
66struct ClipVisionEmbeddings {
67    patch_embedding: candle_nn::Conv2d,
68    position_ids: Tensor,
69    class_embedding: Tensor,
70    position_embedding: candle_nn::Embedding,
71}
72
73impl ClipVisionEmbeddings {
74    fn new(vs: candle_nn::VarBuilder, c: &ClipVisionConfig) -> Result<Self> {
75        // originally nn.Parameter
76        let class_embedding = if vs.contains_tensor("class_embedding") {
77            vs.get(c.embed_dim, "class_embedding")?
78        } else {
79            Tensor::randn(0f32, 1f32, c.embed_dim, vs.device())?
80        };
81
82        let num_patches = (c.image_size / c.patch_size).pow(2);
83        let num_positions = num_patches + 1;
84        let position_ids = Tensor::arange(0, num_positions as i64, vs.device())?;
85
86        let conv2dconfig = Conv2dConfig {
87            stride: c.patch_size,
88            ..Default::default()
89        };
90        let position_embedding =
91            candle_nn::embedding(num_positions, c.embed_dim, vs.pp("position_embedding"))?;
92        let patch_embedding = candle_nn::conv2d_no_bias(
93            c.num_channels,
94            c.embed_dim,
95            c.patch_size,
96            conv2dconfig,
97            vs.pp("patch_embedding"),
98        )?;
99        Ok(Self {
100            patch_embedding,
101            position_ids,
102            class_embedding,
103            position_embedding,
104        })
105    }
106}
107
108impl Module for ClipVisionEmbeddings {
109    fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
110        let batch_size = pixel_values.shape().dims();
111        let patch_embeds = self
112            .patch_embedding
113            .forward(pixel_values)?
114            .flatten_from(2)?
115            .transpose(1, 2)?;
116        let shape = Shape::from((batch_size[0], 1, self.class_embedding.dim(D::Minus1)?));
117        let class_embeds = self.class_embedding.expand(shape)?;
118        let embeddings = Tensor::cat(&[class_embeds, patch_embeds], 1)?;
119        let position_embedding = self.position_embedding.forward(&self.position_ids)?;
120        embeddings.broadcast_add(&position_embedding)
121    }
122}
123
124// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L743
125#[derive(Clone, Debug)]
126pub struct ClipVisionTransformer {
127    embeddings: ClipVisionEmbeddings,
128    encoder: ClipEncoder,
129    pre_layer_norm: candle_nn::LayerNorm,
130    final_layer_norm: candle_nn::LayerNorm,
131}
132
133impl ClipVisionTransformer {
134    pub fn new(vs: candle_nn::VarBuilder, c: &ClipVisionConfig) -> Result<Self> {
135        let embeddings = ClipVisionEmbeddings::new(vs.pp("embeddings"), c)?;
136        let pre_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("pre_layrnorm"))?;
137        let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Vision(c.clone()))?;
138        let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("post_layernorm"))?;
139        Ok(Self {
140            embeddings,
141            encoder,
142            final_layer_norm,
143            pre_layer_norm,
144        })
145    }
146    // required by LLaVA
147    pub fn output_hidden_states(&self, pixel_values: &Tensor) -> Result<Vec<Tensor>> {
148        let hidden_states = pixel_values
149            .apply(&self.embeddings)?
150            .apply(&self.pre_layer_norm)?;
151        let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
152        let encoder_outputs = result.last().context("no last")?;
153        let pooled_output = encoder_outputs.i((.., 0, ..))?;
154        result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
155        Ok(result)
156    }
157}
158
159impl Module for ClipVisionTransformer {
160    fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
161        let hidden_states = pixel_values
162            .apply(&self.embeddings)?
163            .apply(&self.pre_layer_norm)?;
164
165        let encoder_outputs = self.encoder.forward(&hidden_states, None)?;
166        // https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L787
167        // pooled_output = encoder_outputs[:, 0, :]
168        let pooled_output = encoder_outputs.i((.., 0, ..))?;
169        self.final_layer_norm.forward(&pooled_output)
170    }
171}