1use candle::{DType, Device, IndexOp, Module, Result, Tensor};
10use candle_nn as nn;
11
12use super::Activation;
13
14#[derive(Clone, Debug)]
20pub enum PositionEmbeddingType {
21 Absolute,
22 RelativeKey,
23 RelativeKeyQuery,
24}
25
26#[derive(Clone, Debug)]
27pub struct ChineseClipTextConfig {
28 pub vocab_size: usize,
29 pub hidden_size: usize,
30 pub num_hidden_layers: usize,
31 pub num_attention_heads: usize,
32 pub intermediate_size: usize,
33 pub hidden_act: Activation,
34 pub hidden_dropout_prob: f32,
35 pub attention_probs_dropout_prob: f64,
36 pub max_position_embeddings: usize,
37 pub type_vocab_size: usize,
38 pub initializer_range: f64,
39 pub initializer_factor: f64,
40 pub layer_norm_eps: f64,
41 pub pad_token_id: usize,
42 pub position_embedding_type: PositionEmbeddingType,
43 pub use_cache: bool,
44}
45
46impl Default for ChineseClipTextConfig {
47 fn default() -> Self {
48 Self {
49 vocab_size: 30522,
50 hidden_size: 768,
51 num_hidden_layers: 12,
52 num_attention_heads: 12,
53 intermediate_size: 3072,
54 hidden_act: Activation::Gelu,
55 hidden_dropout_prob: 0.1,
56 attention_probs_dropout_prob: 0.1,
57 max_position_embeddings: 512,
58 type_vocab_size: 2,
59 initializer_range: 0.02,
60 initializer_factor: 1.0,
61 layer_norm_eps: 1e-12,
62 pad_token_id: 0,
63 position_embedding_type: PositionEmbeddingType::Absolute,
64 use_cache: true,
65 }
66 }
67}
68
69impl ChineseClipTextConfig {
70 pub fn clip_vit_base_patch16() -> Self {
72 Self {
73 vocab_size: 21128,
74 hidden_size: 768,
75 num_hidden_layers: 12,
76 num_attention_heads: 12,
77 intermediate_size: 3072,
78 hidden_act: Activation::Gelu,
79 hidden_dropout_prob: 0.1,
80 attention_probs_dropout_prob: 0.1,
81 max_position_embeddings: 512,
82 type_vocab_size: 2,
83 initializer_range: 0.02,
84 initializer_factor: 1.0,
85 layer_norm_eps: 1e-12,
86 pad_token_id: 0,
87 position_embedding_type: PositionEmbeddingType::Absolute,
88 use_cache: true,
89 }
90 }
91}
92
93#[derive(Clone, Debug)]
94pub struct ChineseClipTextEmbeddings {
95 word_embeddings: nn::Embedding,
96 position_embeddings: nn::Embedding,
97 token_type_embeddings: nn::Embedding,
98 layer_norm: nn::LayerNorm,
99 dropout: nn::Dropout,
100 position_embedding_type: PositionEmbeddingType,
101 position_ids: Tensor,
102 token_type_ids: Tensor,
103}
104
105impl ChineseClipTextEmbeddings {
106 pub fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
107 let word_embeddings = nn::embedding(
108 config.vocab_size,
109 config.hidden_size,
110 var.pp("word_embeddings"),
111 )?;
112 let position_embeddings = nn::embedding(
113 config.max_position_embeddings,
114 config.hidden_size,
115 var.pp("position_embeddings"),
116 )?;
117 let token_type_embeddings = nn::embedding(
118 config.type_vocab_size,
119 config.hidden_size,
120 var.pp("token_type_embeddings"),
121 )?;
122 let layer_norm = nn::layer_norm::<f64>(
123 config.hidden_size,
124 config.layer_norm_eps,
125 var.pp("LayerNorm"),
126 )?;
127 let dropout = nn::Dropout::new(config.hidden_dropout_prob);
128 let position_ids =
129 Tensor::arange(0u32, config.max_position_embeddings as u32, var.device())?
130 .unsqueeze(0)?;
131 let token_type_ids = Tensor::zeros(position_ids.shape(), DType::I64, var.device())?;
132
133 Ok(Self {
134 word_embeddings,
135 position_embeddings,
136 token_type_embeddings,
137 layer_norm,
138 dropout,
139 position_embedding_type: config.position_embedding_type.clone(),
140 position_ids,
141 token_type_ids,
142 })
143 }
144
145 fn forward(&self, xs: &Tensor, token_type_ids: Option<&Tensor>) -> Result<Tensor> {
146 let (_batch_size, seq_length) = xs.dims2()?;
147 let position_ids = (0..seq_length as u32).collect::<Vec<_>>();
148 let position_ids = self.position_ids.index_select(
149 &Tensor::new(&position_ids[..], self.position_ids.device())?,
150 1,
151 )?;
152
153 let word_embeddings = self.word_embeddings.forward(xs)?;
154
155 let token_type_ids = match token_type_ids {
156 Some(token_type_ids) => token_type_ids,
157 None => &self.token_type_ids.i((.., 0..seq_length))?,
158 };
159 let token_type_ids = token_type_ids.expand(xs.shape())?;
160 let token_type_embeddings = self.token_type_embeddings.forward(&token_type_ids)?;
161
162 let embeddings = (&word_embeddings + token_type_embeddings)?;
163 let embeddings = match self.position_embedding_type {
164 PositionEmbeddingType::Absolute => {
165 let position_embeddings = self.position_embeddings.forward(&position_ids)?;
166 let position_embeddings = position_embeddings.expand(embeddings.shape())?;
167 (embeddings + position_embeddings)?
168 }
169 _ => embeddings,
170 };
171 let embeddings = self.layer_norm.forward(&embeddings)?;
172 let embeddings = self.dropout.forward(&embeddings, false)?;
173 Ok(embeddings)
174 }
175}
176
177#[derive(Clone, Debug)]
179struct ChineseClipTextSelfOutput {
180 dense: nn::Linear,
181 layer_norm: nn::LayerNorm,
182 dropout: nn::Dropout,
183 span: tracing::Span,
184}
185
186impl ChineseClipTextSelfOutput {
187 fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
188 let dense = nn::linear(config.hidden_size, config.hidden_size, var.pp("dense"))?;
189 let layer_norm = nn::layer_norm(
190 config.hidden_size,
191 config.layer_norm_eps,
192 var.pp("LayerNorm"),
193 )?;
194 let dropout = nn::Dropout::new(config.hidden_dropout_prob);
195 Ok(Self {
196 dense,
197 layer_norm,
198 dropout,
199 span: tracing::span!(tracing::Level::TRACE, "self-out"),
200 })
201 }
202
203 fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
204 let _enter = self.span.enter();
205 let hidden_states = self.dense.forward(hidden_states)?;
206 let hidden_states = self.dropout.forward(&hidden_states, false)?;
207 self.layer_norm.forward(&(hidden_states + input_tensor)?)
208 }
209}
210
211#[derive(Clone, Debug)]
213struct ChineseClipTextSelfAttention {
214 query: nn::Linear,
215 key: nn::Linear,
216 value: nn::Linear,
217 dropout: nn::Dropout,
218 num_attention_heads: usize,
219 attention_head_size: usize,
220 span: tracing::Span,
221 span_softmax: tracing::Span,
222}
223
224impl ChineseClipTextSelfAttention {
225 fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
226 let attention_head_size = config.hidden_size / config.num_attention_heads;
227 let all_head_size = config.num_attention_heads * attention_head_size;
228 let dropout = nn::Dropout::new(config.hidden_dropout_prob);
229 let hidden_size = config.hidden_size;
230 let query = nn::linear(hidden_size, all_head_size, var.pp("query"))?;
231 let value = nn::linear(hidden_size, all_head_size, var.pp("value"))?;
232 let key = nn::linear(hidden_size, all_head_size, var.pp("key"))?;
233 Ok(Self {
234 query,
235 key,
236 value,
237 dropout,
238 num_attention_heads: config.num_attention_heads,
239 attention_head_size,
240 span: tracing::span!(tracing::Level::TRACE, "self-attn"),
241 span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
242 })
243 }
244
245 fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
246 let mut new_x_shape = xs.dims().to_vec();
247 new_x_shape.pop();
248 new_x_shape.push(self.num_attention_heads);
249 new_x_shape.push(self.attention_head_size);
250 let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
251 xs.contiguous()
252 }
253
254 fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
255 let _enter = self.span.enter();
256 let query_layer = self.query.forward(hidden_states)?;
257 let key_layer = self.key.forward(hidden_states)?;
258 let value_layer = self.value.forward(hidden_states)?;
259
260 let query_layer = self.transpose_for_scores(&query_layer)?;
261 let key_layer = self.transpose_for_scores(&key_layer)?;
262 let value_layer = self.transpose_for_scores(&value_layer)?;
263
264 let attention_scores = query_layer.matmul(&key_layer.t()?)?;
265 let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
266 let attention_scores = attention_scores.broadcast_add(attention_mask)?;
267 let attention_probs = {
268 let _enter_sm = self.span_softmax.enter();
269 nn::ops::softmax(&attention_scores, candle::D::Minus1)?
270 };
271 let attention_probs = self.dropout.forward(&attention_probs, false)?;
272
273 let context_layer = attention_probs.matmul(&value_layer)?;
274 let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
275 let context_layer = context_layer.flatten_from(candle::D::Minus2)?;
276 Ok(context_layer)
277 }
278}
279
280#[derive(Clone, Debug)]
282struct ChineseClipTextAttention {
283 self_attention: ChineseClipTextSelfAttention,
284 self_output: ChineseClipTextSelfOutput,
285 span: tracing::Span,
286}
287
288impl ChineseClipTextAttention {
289 fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
290 let self_attention = ChineseClipTextSelfAttention::new(var.pp("self"), config)?;
291 let self_output = ChineseClipTextSelfOutput::new(var.pp("output"), config)?;
292 Ok(Self {
293 self_attention,
294 self_output,
295 span: tracing::span!(tracing::Level::TRACE, "attn"),
296 })
297 }
298
299 fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
300 let _enter = self.span.enter();
301 let self_outputs = self.self_attention.forward(hidden_states, attention_mask)?;
302 let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
303 Ok(attention_output)
304 }
305}
306
307type HiddenActLayer = Activation;
308
309#[derive(Clone, Debug)]
311struct ChineseClipTextIntermediate {
312 dense: nn::Linear,
313 intermediate_act: HiddenActLayer,
314 span: tracing::Span,
315}
316
317impl ChineseClipTextIntermediate {
318 fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
319 let dense = nn::linear(
320 config.hidden_size,
321 config.intermediate_size,
322 var.pp("dense"),
323 )?;
324 Ok(Self {
325 dense,
326 intermediate_act: config.hidden_act,
327 span: tracing::span!(tracing::Level::TRACE, "inter"),
328 })
329 }
330}
331
332impl Module for ChineseClipTextIntermediate {
333 fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
334 let _enter = self.span.enter();
335 let hidden_states = self.dense.forward(hidden_states)?;
336 let ys = self.intermediate_act.forward(&hidden_states)?;
337 Ok(ys)
338 }
339}
340
341#[derive(Clone, Debug)]
343struct ChineseClipTextOutput {
344 dense: nn::Linear,
345 layer_norm: nn::LayerNorm,
346 dropout: nn::Dropout,
347 span: tracing::Span,
348}
349
350impl ChineseClipTextOutput {
351 fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
352 let dense = nn::linear(
353 config.intermediate_size,
354 config.hidden_size,
355 var.pp("dense"),
356 )?;
357 let layer_norm = nn::layer_norm(
358 config.hidden_size,
359 config.layer_norm_eps,
360 var.pp("LayerNorm"),
361 )?;
362 let dropout = nn::Dropout::new(config.hidden_dropout_prob);
363 Ok(Self {
364 dense,
365 layer_norm,
366 dropout,
367 span: tracing::span!(tracing::Level::TRACE, "out"),
368 })
369 }
370
371 fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
372 let _enter = self.span.enter();
373 let hidden_states = self.dense.forward(hidden_states)?;
374 let hidden_states = self.dropout.forward(&hidden_states, false)?;
375 self.layer_norm.forward(&(hidden_states + input_tensor)?)
376 }
377}
378
379#[derive(Clone, Debug)]
381struct ChineseClipTextLayer {
382 attention: ChineseClipTextAttention,
383 intermediate: ChineseClipTextIntermediate,
384 output: ChineseClipTextOutput,
385 span: tracing::Span,
386}
387
388impl ChineseClipTextLayer {
389 fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
390 let attention = ChineseClipTextAttention::new(var.pp("attention"), config)?;
391 let intermediate = ChineseClipTextIntermediate::new(var.pp("intermediate"), config)?;
392 let output = ChineseClipTextOutput::new(var.pp("output"), config)?;
393 Ok(Self {
394 attention,
395 intermediate,
396 output,
397 span: tracing::span!(tracing::Level::TRACE, "layer"),
398 })
399 }
400
401 fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
402 let _enter = self.span.enter();
403 let attention_output = self.attention.forward(hidden_states, attention_mask)?;
404 let intermediate_output = self.intermediate.forward(&attention_output)?;
406 let layer_output = self
407 .output
408 .forward(&intermediate_output, &attention_output)?;
409 Ok(layer_output)
410 }
411}
412
413#[derive(Clone, Debug)]
414struct Tanh;
415
416impl Tanh {
417 pub fn new() -> Self {
418 Self {}
419 }
420}
421impl Module for Tanh {
422 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
423 xs.tanh()
424 }
425}
426
427#[derive(Clone, Debug)]
428struct ChineseClipTextPooler {
429 dense: nn::Linear,
430 activation: Tanh,
431}
432
433impl ChineseClipTextPooler {
434 pub fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
435 let dense = nn::linear(config.hidden_size, config.hidden_size, var.pp("dense"))?;
436 let activation = Tanh::new();
437 Ok(Self { dense, activation })
438 }
439}
440
441impl Module for ChineseClipTextPooler {
442 fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
443 let first_token_tensor = hidden_states.i((.., 0))?;
444 let pooled_output = self.dense.forward(&first_token_tensor)?;
445 let pooled_output = self.activation.forward(&pooled_output)?;
446 Ok(pooled_output)
447 }
448}
449
450#[derive(Clone, Debug)]
451struct ChineseClipTextEncoder {
452 layers: Vec<ChineseClipTextLayer>,
453 span: tracing::Span,
454}
455
456impl ChineseClipTextEncoder {
457 fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
458 let layers = (0..config.num_hidden_layers)
459 .map(|index| ChineseClipTextLayer::new(var.pp(format!("layer.{index}")), config))
460 .collect::<Result<Vec<_>>>()?;
461 let span = tracing::span!(tracing::Level::TRACE, "encoder");
462 Ok(ChineseClipTextEncoder { layers, span })
463 }
464
465 fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
466 let _enter = self.span.enter();
467 let mut hidden_states = hidden_states.clone();
468 for layer in self.layers.iter() {
470 hidden_states = layer.forward(&hidden_states, attention_mask)?
471 }
472 Ok(hidden_states)
473 }
474}
475
476#[derive(Clone, Debug)]
477pub struct ChineseClipTextTransformer {
478 embeddings: ChineseClipTextEmbeddings,
479 encoder: ChineseClipTextEncoder,
480 pooler: Option<ChineseClipTextPooler>,
481 pub device: Device,
482 span: tracing::Span,
483}
484
485impl ChineseClipTextTransformer {
486 pub fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
487 let embeddings = ChineseClipTextEmbeddings::new(var.pp("embeddings"), config)?;
488 let encoder = ChineseClipTextEncoder::new(var.pp("encoder"), config)?;
489 let pooler = if var.contains_tensor("pooler") {
492 Some(ChineseClipTextPooler::new(var.pp("pooler"), config)?)
493 } else {
494 None
495 };
496 Ok(Self {
497 embeddings,
498 encoder,
499 pooler,
500 device: var.device().clone(),
501 span: tracing::span!(tracing::Level::TRACE, "model"),
502 })
503 }
504
505 pub fn forward(
506 &self,
507 input_ids: &Tensor,
508 token_type_ids: Option<&Tensor>,
509 attention_mask: Option<&Tensor>,
510 ) -> Result<Tensor> {
511 let _enter = self.span.enter();
512 let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
513 let attention_mask = match attention_mask {
514 Some(attention_mask) => attention_mask.clone(),
515 None => input_ids.ones_like()?,
516 };
517 let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?;
519 let encoder_outputs = self.encoder.forward(&embedding_output, &attention_mask)?;
520 let encoder_output = encoder_outputs.i((.., 0, ..))?;
521 let pooled_output = match &self.pooler {
522 Some(pooler) => pooler.forward(&encoder_output)?,
523 None => encoder_output,
524 };
525
526 Ok(pooled_output)
527 }
528}
529
530fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<Tensor> {
531 let attention_mask = match attention_mask.rank() {
532 3 => attention_mask.unsqueeze(1)?,
533 2 => attention_mask.unsqueeze(1)?.unsqueeze(1)?,
534 _ => candle::bail!("Wrong shape for input_ids or attention_mask"),
535 };
536 let attention_mask = attention_mask.to_dtype(dtype)?;
537 (attention_mask.ones_like()? - &attention_mask)?
539 .broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
540}