1use candle::{Context, DType, IndexOp, Module, Result, Shape, Tensor, D};
10use candle_nn as nn;
11
12use super::{Activation, EncoderConfig};
13
14#[derive(Clone, Debug)]
15pub struct ChineseClipVisionConfig {
16 pub hidden_size: usize,
17 pub intermediate_size: usize,
18 pub projection_dim: usize,
19 pub num_hidden_layers: usize,
20 pub num_attention_heads: usize,
21 pub num_channels: usize,
22 pub image_size: usize,
23 pub patch_size: usize,
24 pub hidden_act: Activation,
25 pub layer_norm_eps: f64,
26 pub attention_dropout: f32,
27 pub initializer_range: f32,
28 pub initializer_factor: f32,
29}
30
31impl Default for ChineseClipVisionConfig {
32 fn default() -> Self {
33 ChineseClipVisionConfig {
34 hidden_size: 768,
35 intermediate_size: 3072,
36 projection_dim: 512,
37 num_hidden_layers: 12,
38 num_attention_heads: 12,
39 num_channels: 3,
40 image_size: 224,
41 patch_size: 32,
42 hidden_act: Activation::QuickGelu,
43 layer_norm_eps: 1e-5,
44 attention_dropout: 0.0,
45 initializer_range: 0.02,
46 initializer_factor: 1.0,
47 }
48 }
49}
50
51impl ChineseClipVisionConfig {
52 pub fn clip_vit_base_patch16() -> Self {
54 Self {
55 hidden_size: 768,
56 intermediate_size: 3072,
57 projection_dim: 512,
58 num_hidden_layers: 12,
59 num_attention_heads: 12,
60 num_channels: 3,
61 image_size: 224,
62 patch_size: 16,
63 hidden_act: Activation::QuickGelu,
64 layer_norm_eps: 1e-5,
65 attention_dropout: 0.0,
66 initializer_range: 0.02,
67 initializer_factor: 1.0,
68 }
69 }
70}
71
72#[derive(Clone, Debug)]
73pub struct ChineseClipVisionEmbeddings {
74 patch_embedding: nn::Conv2d,
75 position_ids: Tensor,
76 class_embedding: Tensor,
77 position_embedding: nn::Embedding,
78}
79
80impl ChineseClipVisionEmbeddings {
81 pub fn new(var: nn::VarBuilder, config: &ChineseClipVisionConfig) -> Result<Self> {
82 let embed_dim = config.hidden_size;
83 let class_embedding = if var.contains_tensor("class_embedding") {
85 var.get(embed_dim, "class_embedding")?
86 } else {
87 Tensor::randn(0f32, 1f32, embed_dim, var.device())?
88 };
89
90 let num_patches = (config.image_size / config.patch_size).pow(2);
91 let num_positions = num_patches + 1;
92 let position_ids = Tensor::arange(0, num_positions as i64, var.device())?;
93
94 let conv2dconfig = nn::Conv2dConfig {
95 stride: config.patch_size,
96 ..Default::default()
97 };
98 let position_embedding =
99 nn::embedding(num_positions, embed_dim, var.pp("position_embedding"))?;
100 let patch_embedding = nn::conv2d_no_bias(
101 config.num_channels,
102 embed_dim,
103 config.patch_size,
104 conv2dconfig,
105 var.pp("patch_embedding"),
106 )?;
107 Ok(Self {
108 patch_embedding,
109 position_ids,
110 class_embedding,
111 position_embedding,
112 })
113 }
114}
115
116impl Module for ChineseClipVisionEmbeddings {
117 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
118 let batch_size = xs.shape().dims();
119 let patch_embeds = self
120 .patch_embedding
121 .forward(xs)?
122 .flatten_from(2)?
123 .transpose(1, 2)?;
124 let shape = Shape::from((batch_size[0], 1, self.class_embedding.dim(D::Minus1)?));
125 let class_embeds = self.class_embedding.expand(shape)?;
126 let embeddings = Tensor::cat(&[class_embeds, patch_embeds], 1)?;
127 let position_embedding = self.position_embedding.forward(&self.position_ids)?;
128 embeddings.broadcast_add(&position_embedding)
129 }
130}
131
132#[derive(Clone, Debug)]
133struct ChineseClipVisionAttention {
134 k_proj: nn::Linear,
135 v_proj: nn::Linear,
136 q_proj: nn::Linear,
137 out_proj: nn::Linear,
138 head_dim: usize,
139 scale: f64,
140 num_attention_heads: usize,
141}
142
143impl ChineseClipVisionAttention {
144 fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result<Self> {
145 let embed_dim = config.embed_dim();
146 let num_attention_heads = config.num_attention_heads();
147 let k_proj = nn::linear(embed_dim, embed_dim, var.pp("k_proj"))?;
148 let v_proj = nn::linear(embed_dim, embed_dim, var.pp("v_proj"))?;
149 let q_proj = nn::linear(embed_dim, embed_dim, var.pp("q_proj"))?;
150 let out_proj = nn::linear(embed_dim, embed_dim, var.pp("out_proj"))?;
151 let head_dim = embed_dim / num_attention_heads;
152 let scale = (head_dim as f64).powf(-0.5);
153
154 Ok(ChineseClipVisionAttention {
155 k_proj,
156 v_proj,
157 q_proj,
158 out_proj,
159 head_dim,
160 scale,
161 num_attention_heads,
162 })
163 }
164
165 fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
166 xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
167 .transpose(1, 2)?
168 .contiguous()
169 }
170
171 fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
172 let in_dtype = xs.dtype();
173 let (bsz, seq_len, embed_dim) = xs.dims3()?;
174
175 let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
176 let query_states = self
177 .shape(&(self.q_proj.forward(xs)? * self.scale)?, seq_len, bsz)?
178 .reshape(proj_shape)?
179 .to_dtype(DType::F32)?;
180 let key_states = self
181 .shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
182 .reshape(proj_shape)?
183 .to_dtype(DType::F32)?;
184 let value_states = self
185 .shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
186 .reshape(proj_shape)?
187 .to_dtype(DType::F32)?;
188
189 let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
190
191 let src_len = key_states.dim(1)?;
192
193 let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask {
194 attn_weights
195 .reshape((bsz, self.num_attention_heads, seq_len, src_len))?
196 .broadcast_add(causal_attention_mask)?
197 .reshape((bsz * self.num_attention_heads, seq_len, src_len))?
198 } else {
199 attn_weights
200 };
201
202 let attn_weights = nn::ops::softmax(&attn_weights, D::Minus1)?;
203
204 let attn_output = attn_weights.matmul(&value_states)?.to_dtype(in_dtype)?;
205 let attn_output = attn_output
206 .reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
207 .transpose(1, 2)?
208 .reshape((bsz, seq_len, embed_dim))?;
209 self.out_proj.forward(&attn_output)
210 }
211}
212
213#[derive(Clone, Debug)]
214struct ChineseClipVisionMlp {
215 fc1: nn::Linear,
216 fc2: nn::Linear,
217 activation: Activation,
218}
219
220impl ChineseClipVisionMlp {
221 fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result<Self> {
222 let fc1 = nn::linear(
223 config.embed_dim(),
224 config.intermediate_size(),
225 var.pp("fc1"),
226 )?;
227 let fc2 = nn::linear(
228 config.intermediate_size(),
229 config.embed_dim(),
230 var.pp("fc2"),
231 )?;
232
233 Ok(ChineseClipVisionMlp {
234 fc1,
235 fc2,
236 activation: config.activation(),
237 })
238 }
239}
240
241impl ChineseClipVisionMlp {
242 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
243 let xs = self.fc1.forward(xs)?;
244 self.fc2.forward(&self.activation.forward(&xs)?)
245 }
246}
247
248#[derive(Clone, Debug)]
249struct ChineseClipVisionEncoderLayer {
250 self_attn: ChineseClipVisionAttention,
251 layer_norm1: nn::LayerNorm,
252 mlp: ChineseClipVisionMlp,
253 layer_norm2: nn::LayerNorm,
254}
255
256impl ChineseClipVisionEncoderLayer {
257 fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result<Self> {
258 let self_attn = ChineseClipVisionAttention::new(var.pp("self_attn"), config)?;
259 let layer_norm1 = nn::layer_norm(
260 config.embed_dim(),
261 config.layer_norm_eps(),
262 var.pp("layer_norm1"),
263 )?;
264 let mlp = ChineseClipVisionMlp::new(var.pp("mlp"), config)?;
265 let layer_norm2 = nn::layer_norm(
266 config.embed_dim(),
267 config.layer_norm_eps(),
268 var.pp("layer_norm2"),
269 )?;
270
271 Ok(ChineseClipVisionEncoderLayer {
272 self_attn,
273 layer_norm1,
274 mlp,
275 layer_norm2,
276 })
277 }
278
279 fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
280 let residual = xs;
281 let xs = self.layer_norm1.forward(xs)?;
282 let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
283 let xs = (xs + residual)?;
284
285 let residual = &xs;
286 let xs = self.layer_norm2.forward(&xs)?;
287 let xs = self.mlp.forward(&xs)?;
288 xs + residual
289 }
290}
291
292#[derive(Clone, Debug)]
293pub struct ChineseClipVisionEncoder {
294 layers: Vec<ChineseClipVisionEncoderLayer>,
295}
296
297impl ChineseClipVisionEncoder {
298 pub fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result<Self> {
299 let vs = var.pp("layers");
300 let mut layers: Vec<ChineseClipVisionEncoderLayer> = Vec::new();
301 for index in 0..config.num_hidden_layers() {
302 let layer = ChineseClipVisionEncoderLayer::new(vs.pp(index.to_string()), config)?;
303 layers.push(layer)
304 }
305 Ok(ChineseClipVisionEncoder { layers })
306 }
307
308 pub fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
309 let mut xs = xs.clone();
310 for layer in self.layers.iter() {
311 xs = layer.forward(&xs, causal_attention_mask)?;
312 }
313 Ok(xs)
314 }
315
316 pub fn output_hidden_states(
318 &self,
319 xs: &Tensor,
320 causal_attention_mask: Option<&Tensor>,
321 ) -> Result<Vec<Tensor>> {
322 let mut xs = xs.clone();
323 let mut hidden_states = Vec::new();
324 for layer in self.layers.iter() {
325 xs = layer.forward(&xs, causal_attention_mask)?;
326 hidden_states.push(xs.clone());
327 }
328 Ok(hidden_states)
329 }
330}
331
332#[derive(Clone, Debug)]
333pub struct ChineseClipVisionTransformer {
334 embeddings: ChineseClipVisionEmbeddings,
335 encoder: ChineseClipVisionEncoder,
336 pre_layer_norm: nn::LayerNorm,
337 final_layer_norm: nn::LayerNorm,
338}
339
340impl ChineseClipVisionTransformer {
341 pub fn new(var: nn::VarBuilder, config: &ChineseClipVisionConfig) -> Result<Self> {
342 let embed_dim = config.hidden_size;
343 let embeddings = ChineseClipVisionEmbeddings::new(var.pp("embeddings"), config)?;
344 let pre_layer_norm =
345 nn::layer_norm(embed_dim, config.layer_norm_eps, var.pp("pre_layrnorm"))?;
346 let encoder = ChineseClipVisionEncoder::new(
347 var.pp("encoder"),
348 &EncoderConfig::Vision(config.clone()),
349 )?;
350 let final_layer_norm =
351 nn::layer_norm(embed_dim, config.layer_norm_eps, var.pp("post_layernorm"))?;
352 Ok(Self {
353 embeddings,
354 encoder,
355 final_layer_norm,
356 pre_layer_norm,
357 })
358 }
359 pub fn output_hidden_states(&self, pixel_values: &Tensor) -> Result<Vec<Tensor>> {
361 let hidden_states = pixel_values
362 .apply(&self.embeddings)?
363 .apply(&self.pre_layer_norm)?;
364
365 let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
366 let encoder_outputs = result.last().context("no last")?;
367 let pooled_output = encoder_outputs.i((.., 0, ..))?;
368 result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
369 Ok(result)
370 }
371}
372
373impl Module for ChineseClipVisionTransformer {
374 fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
375 let hidden_states = pixel_values
376 .apply(&self.embeddings)?
377 .apply(&self.pre_layer_norm)?;
378
379 let encoder_outputs = self.encoder.forward(&hidden_states, None)?;
380
381 let pooled_output = encoder_outputs.i((.., 0, ..))?;
383 self.final_layer_norm.forward(&pooled_output)
384 }
385}