1use candle::{DType, Device, IndexOp, Result, Tensor, D};
10use candle_nn as nn;
11use candle_nn::Module;
12
13use super::EncoderConfig;
14
15#[derive(Debug, Clone, Copy)]
16pub enum Activation {
17 QuickGelu,
18}
19
20impl Module for Activation {
21 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
22 match self {
23 Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
24 }
25 }
26}
27
28#[derive(Debug, Clone)]
29pub struct ClipTextConfig {
30 pub vocab_size: usize,
31 pub embed_dim: usize,
32 pub activation: Activation,
33 pub intermediate_size: usize,
34 pub max_position_embeddings: usize,
35 pub pad_with: Option<String>,
36 pub num_hidden_layers: usize,
37 pub num_attention_heads: usize,
38 #[allow(dead_code)]
39 pub projection_dim: usize,
40}
41
42impl ClipTextConfig {
43 pub fn vit_base_patch32() -> Self {
46 Self {
47 vocab_size: 49408,
48 embed_dim: 512,
49 intermediate_size: 2048,
50 max_position_embeddings: 77,
51 pad_with: None,
52 num_hidden_layers: 12,
53 num_attention_heads: 8,
54 projection_dim: 512,
55 activation: Activation::QuickGelu,
56 }
57 }
58}
59
60#[derive(Clone, Debug)]
63struct ClipTextEmbeddings {
64 token_embedding: candle_nn::Embedding,
65 position_embedding: candle_nn::Embedding,
66 position_ids: Tensor,
67}
68
69impl ClipTextEmbeddings {
70 fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result<Self> {
71 let token_embedding =
72 candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
73 let position_embedding: nn::Embedding = candle_nn::embedding(
74 c.max_position_embeddings,
75 c.embed_dim,
76 vs.pp("position_embedding"),
77 )?;
78 let position_ids =
79 Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
80 Ok(Self {
81 token_embedding,
82 position_embedding,
83 position_ids,
84 })
85 }
86}
87
88impl Module for ClipTextEmbeddings {
89 fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
90 let seq_length = input_ids.dim(D::Minus1)?;
91 let inputs_embeds = self.token_embedding.forward(input_ids)?;
92 let position_ids = self.position_ids.narrow(1, 0, seq_length)?;
93 let position_embedding = self.position_embedding.forward(&position_ids)?;
94 inputs_embeds.broadcast_add(&position_embedding)
95 }
96}
97
98#[derive(Clone, Debug)]
99struct ClipAttention {
100 k_proj: candle_nn::Linear,
101 v_proj: candle_nn::Linear,
102 q_proj: candle_nn::Linear,
103 out_proj: candle_nn::Linear,
104 head_dim: usize,
105 scale: f64,
106 num_attention_heads: usize,
107}
108
109impl ClipAttention {
110 fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result<Self> {
111 let embed_dim = c.embed_dim();
112 let num_attention_heads = c.num_attention_heads();
113 let k_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("k_proj"))?;
114 let v_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("v_proj"))?;
115 let q_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("q_proj"))?;
116 let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?;
117 let head_dim = embed_dim / num_attention_heads;
118 let scale = (head_dim as f64).powf(-0.5);
119
120 Ok(ClipAttention {
121 k_proj,
122 v_proj,
123 q_proj,
124 out_proj,
125 head_dim,
126 scale,
127 num_attention_heads,
128 })
129 }
130
131 fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
132 xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
133 .transpose(1, 2)?
134 .contiguous()
135 }
136
137 fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
138 let in_dtype = xs.dtype();
139 let (bsz, seq_len, embed_dim) = xs.dims3()?;
140
141 let query_states = (self.q_proj.forward(xs)? * self.scale)?;
142 let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
143 let query_states = self
144 .shape(&query_states, seq_len, bsz)?
145 .reshape(proj_shape)?
146 .to_dtype(DType::F32)?;
147 let key_states = self
148 .shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
149 .reshape(proj_shape)?
150 .to_dtype(DType::F32)?;
151 let value_states = self
152 .shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
153 .reshape(proj_shape)?
154 .to_dtype(DType::F32)?;
155 let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
156
157 let src_len = key_states.dim(1)?;
158
159 let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask {
160 attn_weights
161 .reshape((bsz, self.num_attention_heads, seq_len, src_len))?
162 .broadcast_add(causal_attention_mask)?
163 .reshape((bsz * self.num_attention_heads, seq_len, src_len))?
164 } else {
165 attn_weights
166 };
167
168 let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
169
170 let attn_output = attn_weights.matmul(&value_states)?.to_dtype(in_dtype)?;
171 let attn_output = attn_output
172 .reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
173 .transpose(1, 2)?
174 .reshape((bsz, seq_len, embed_dim))?;
175 self.out_proj.forward(&attn_output)
176 }
177}
178
179#[derive(Clone, Debug)]
180struct ClipMlp {
181 fc1: candle_nn::Linear,
182 fc2: candle_nn::Linear,
183 activation: Activation,
184}
185
186impl ClipMlp {
187 fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result<Self> {
188 let fc1 = candle_nn::linear(c.embed_dim(), c.intermediate_size(), vs.pp("fc1"))?;
189 let fc2 = candle_nn::linear(c.intermediate_size(), c.embed_dim(), vs.pp("fc2"))?;
190
191 Ok(ClipMlp {
192 fc1,
193 fc2,
194 activation: c.activation(),
195 })
196 }
197}
198
199impl ClipMlp {
200 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
201 let xs = self.fc1.forward(xs)?;
202 self.fc2.forward(&self.activation.forward(&xs)?)
203 }
204}
205
206#[derive(Clone, Debug)]
207struct ClipEncoderLayer {
208 self_attn: ClipAttention,
209 layer_norm1: candle_nn::LayerNorm,
210 mlp: ClipMlp,
211 layer_norm2: candle_nn::LayerNorm,
212}
213
214impl ClipEncoderLayer {
215 fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result<Self> {
216 let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?;
217 let layer_norm1 = candle_nn::layer_norm(c.embed_dim(), 1e-5, vs.pp("layer_norm1"))?;
218 let mlp = ClipMlp::new(vs.pp("mlp"), c)?;
219 let layer_norm2 = candle_nn::layer_norm(c.embed_dim(), 1e-5, vs.pp("layer_norm2"))?;
220
221 Ok(ClipEncoderLayer {
222 self_attn,
223 layer_norm1,
224 mlp,
225 layer_norm2,
226 })
227 }
228
229 fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
230 let residual = xs;
231 let xs = self.layer_norm1.forward(xs)?;
232 let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
233 let xs = (xs + residual)?;
234
235 let residual = &xs;
236 let xs = self.layer_norm2.forward(&xs)?;
237 let xs = self.mlp.forward(&xs)?;
238 xs + residual
239 }
240}
241
242#[derive(Clone, Debug)]
243pub struct ClipEncoder {
244 layers: Vec<ClipEncoderLayer>,
245}
246
247impl ClipEncoder {
248 pub fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result<Self> {
249 let vs = vs.pp("layers");
250 let mut layers: Vec<ClipEncoderLayer> = Vec::new();
251 for index in 0..c.num_hidden_layers() {
252 let layer = ClipEncoderLayer::new(vs.pp(index.to_string()), c)?;
253 layers.push(layer)
254 }
255 Ok(ClipEncoder { layers })
256 }
257
258 pub fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
259 let mut xs = xs.clone();
260 for layer in self.layers.iter() {
261 xs = layer.forward(&xs, causal_attention_mask)?;
262 }
263 Ok(xs)
264 }
265 pub fn output_hidden_states(
267 &self,
268 xs: &Tensor,
269 causal_attention_mask: Option<&Tensor>,
270 ) -> Result<Vec<Tensor>> {
271 let mut xs = xs.clone();
272 let mut hidden_states = Vec::new();
273 for layer in self.layers.iter() {
274 xs = layer.forward(&xs, causal_attention_mask)?;
275 hidden_states.push(xs.clone());
276 }
277 Ok(hidden_states)
278 }
279}
280
281#[derive(Clone, Debug)]
283pub struct ClipTextTransformer {
284 embeddings: ClipTextEmbeddings,
285 encoder: ClipEncoder,
286 final_layer_norm: candle_nn::LayerNorm,
287}
288
289impl ClipTextTransformer {
290 pub fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result<Self> {
291 let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
292 let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Text(c.clone()))?;
293 let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?;
294 Ok(ClipTextTransformer {
295 embeddings,
296 encoder,
297 final_layer_norm,
298 })
299 }
300
301 fn build_causal_attention_mask(
303 bsz: usize,
304 seq_len: usize,
305 mask_after: usize,
306 device: &Device,
307 ) -> Result<Tensor> {
308 let mask: Vec<_> = (0..seq_len)
309 .flat_map(|i| {
310 (0..seq_len).map(move |j| {
311 if j > i || j > mask_after {
312 f32::MIN
313 } else {
314 0.
315 }
316 })
317 })
318 .collect();
319 let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
320 mask.broadcast_as((bsz, 1, seq_len, seq_len))
321 }
322
323 pub fn forward_with_mask(&self, input_ids: &Tensor, mask_after: usize) -> Result<Tensor> {
324 let (bsz, seq_len) = input_ids.dims2()?;
325 let input_ids = self.embeddings.forward(input_ids)?;
326 let causal_attention_mask =
327 Self::build_causal_attention_mask(bsz, seq_len, mask_after, input_ids.device())?;
328 let input_ids = self
329 .encoder
330 .forward(&input_ids, Some(&causal_attention_mask))?;
331 self.final_layer_norm.forward(&input_ids)
332 }
333}
334
335impl Module for ClipTextTransformer {
336 fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
337 let output = self.forward_with_mask(input_ids, usize::MAX)?;
338 let sequence_max_indices = input_ids.argmax(D::Minus1)?.to_dtype(DType::I64)?;
339
340 let mut indices = Vec::new();
341 for (batch_idx, &seq_idx) in sequence_max_indices.to_vec1::<i64>()?.iter().enumerate() {
342 let index = output.i((batch_idx, seq_idx as usize))?.unsqueeze(0)?;
343 indices.push(index);
344 }
345 Tensor::cat(&indices, 0)
346 }
347}