candle_transformers/models/clip/
vision_model.rs1use candle::{Context, IndexOp, Result, Shape, Tensor, D};
10use candle_nn as nn;
11use candle_nn::Module;
12use nn::Conv2dConfig;
13
14use super::{
15 text_model::{Activation, ClipEncoder},
16 EncoderConfig,
17};
18
19#[derive(Debug, Clone)]
20pub struct ClipVisionConfig {
21 pub embed_dim: usize,
22 pub activation: Activation,
23 pub intermediate_size: usize,
24 pub num_hidden_layers: usize,
25 pub num_attention_heads: usize,
26 #[allow(dead_code)]
27 pub projection_dim: usize,
28 pub num_channels: usize,
29 pub image_size: usize,
30 pub patch_size: usize,
31}
32
33impl ClipVisionConfig {
34 pub fn vit_base_patch32() -> Self {
37 Self {
38 embed_dim: 768,
39 activation: Activation::QuickGelu,
40 intermediate_size: 3072,
41 num_hidden_layers: 12,
42 num_attention_heads: 12,
43 projection_dim: 512,
44 num_channels: 3,
45 image_size: 224,
46 patch_size: 32,
47 }
48 }
49 pub fn clip_vit_large_patch14_336() -> Self {
50 Self {
51 embed_dim: 1024,
52 activation: Activation::QuickGelu,
53 intermediate_size: 4096,
54 num_hidden_layers: 24,
55 num_attention_heads: 16,
56 projection_dim: 768,
57 num_channels: 3,
58 image_size: 336,
59 patch_size: 14,
60 }
61 }
62}
63
64#[derive(Clone, Debug)]
66struct ClipVisionEmbeddings {
67 patch_embedding: candle_nn::Conv2d,
68 position_ids: Tensor,
69 class_embedding: Tensor,
70 position_embedding: candle_nn::Embedding,
71}
72
73impl ClipVisionEmbeddings {
74 fn new(vs: candle_nn::VarBuilder, c: &ClipVisionConfig) -> Result<Self> {
75 let class_embedding = if vs.contains_tensor("class_embedding") {
77 vs.get(c.embed_dim, "class_embedding")?
78 } else {
79 Tensor::randn(0f32, 1f32, c.embed_dim, vs.device())?
80 };
81
82 let num_patches = (c.image_size / c.patch_size).pow(2);
83 let num_positions = num_patches + 1;
84 let position_ids = Tensor::arange(0, num_positions as i64, vs.device())?;
85
86 let conv2dconfig = Conv2dConfig {
87 stride: c.patch_size,
88 ..Default::default()
89 };
90 let position_embedding =
91 candle_nn::embedding(num_positions, c.embed_dim, vs.pp("position_embedding"))?;
92 let patch_embedding = candle_nn::conv2d_no_bias(
93 c.num_channels,
94 c.embed_dim,
95 c.patch_size,
96 conv2dconfig,
97 vs.pp("patch_embedding"),
98 )?;
99 Ok(Self {
100 patch_embedding,
101 position_ids,
102 class_embedding,
103 position_embedding,
104 })
105 }
106}
107
108impl Module for ClipVisionEmbeddings {
109 fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
110 let batch_size = pixel_values.shape().dims();
111 let patch_embeds = self
112 .patch_embedding
113 .forward(pixel_values)?
114 .flatten_from(2)?
115 .transpose(1, 2)?;
116 let shape = Shape::from((batch_size[0], 1, self.class_embedding.dim(D::Minus1)?));
117 let class_embeds = self.class_embedding.expand(shape)?;
118 let embeddings = Tensor::cat(&[class_embeds, patch_embeds], 1)?;
119 let position_embedding = self.position_embedding.forward(&self.position_ids)?;
120 embeddings.broadcast_add(&position_embedding)
121 }
122}
123
124#[derive(Clone, Debug)]
126pub struct ClipVisionTransformer {
127 embeddings: ClipVisionEmbeddings,
128 encoder: ClipEncoder,
129 pre_layer_norm: candle_nn::LayerNorm,
130 final_layer_norm: candle_nn::LayerNorm,
131}
132
133impl ClipVisionTransformer {
134 pub fn new(vs: candle_nn::VarBuilder, c: &ClipVisionConfig) -> Result<Self> {
135 let embeddings = ClipVisionEmbeddings::new(vs.pp("embeddings"), c)?;
136 let pre_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("pre_layrnorm"))?;
137 let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Vision(c.clone()))?;
138 let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("post_layernorm"))?;
139 Ok(Self {
140 embeddings,
141 encoder,
142 final_layer_norm,
143 pre_layer_norm,
144 })
145 }
146 pub fn output_hidden_states(&self, pixel_values: &Tensor) -> Result<Vec<Tensor>> {
148 let hidden_states = pixel_values
149 .apply(&self.embeddings)?
150 .apply(&self.pre_layer_norm)?;
151 let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
152 let encoder_outputs = result.last().context("no last")?;
153 let pooled_output = encoder_outputs.i((.., 0, ..))?;
154 result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
155 Ok(result)
156 }
157}
158
159impl Module for ClipVisionTransformer {
160 fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
161 let hidden_states = pixel_values
162 .apply(&self.embeddings)?
163 .apply(&self.pre_layer_norm)?;
164
165 let encoder_outputs = self.encoder.forward(&hidden_states, None)?;
166 let pooled_output = encoder_outputs.i((.., 0, ..))?;
169 self.final_layer_norm.forward(&pooled_output)
170 }
171}