1use crate::models::clip::div_l2_norm;
10use candle::{IndexOp, Module, Result, Tensor, D};
11use candle_nn::{layer_norm, linear, LayerNorm, Linear, VarBuilder};
12
13fn default_text_vocab_size() -> usize {
14 32000
15}
16
17fn default_text_hidden_size() -> usize {
18 768
19}
20
21fn default_text_intermediate_size() -> usize {
22 3072
23}
24
25fn default_text_num_hidden_layers() -> usize {
26 12
27}
28
29fn default_text_num_attention_heads() -> usize {
30 12
31}
32
33fn default_text_max_position_embeddings() -> usize {
34 64
35}
36
37fn default_text_layer_norm_eps() -> f64 {
38 1e-6
39}
40
41fn default_text_pad_token_id() -> u32 {
42 1
43}
44
45fn default_text_bos_token_id() -> u32 {
46 49406
47}
48
49fn default_text_eos_token_id() -> u32 {
50 49407
51}
52
53fn default_text_hidden_act() -> candle_nn::Activation {
54 candle_nn::Activation::GeluPytorchTanh
55}
56
57#[derive(serde::Deserialize, Clone, Debug)]
59pub struct TextConfig {
60 #[serde(default = "default_text_vocab_size")]
61 pub vocab_size: usize,
62 #[serde(default = "default_text_hidden_size")]
63 pub hidden_size: usize,
64 #[serde(default = "default_text_intermediate_size")]
65 pub intermediate_size: usize,
66 #[serde(default = "default_text_num_hidden_layers")]
67 pub num_hidden_layers: usize,
68 #[serde(default = "default_text_num_attention_heads")]
69 pub num_attention_heads: usize,
70 #[serde(default = "default_text_max_position_embeddings")]
71 pub max_position_embeddings: usize,
72 #[serde(default = "default_text_hidden_act")]
73 pub hidden_act: candle_nn::Activation,
74 #[serde(default = "default_text_layer_norm_eps")]
75 pub layer_norm_eps: f64,
76 #[serde(default = "default_text_pad_token_id")]
77 pub pad_token_id: u32,
78 #[serde(default = "default_text_bos_token_id")]
79 pub bos_token_id: u32,
80 #[serde(default = "default_text_eos_token_id")]
81 pub eos_token_id: u32,
82}
83
84fn default_vision_hidden_size() -> usize {
85 768
86}
87
88fn default_vision_intermediate_size() -> usize {
89 3072
90}
91
92fn default_vision_num_hidden_layers() -> usize {
93 12
94}
95
96fn default_vision_num_attention_heads() -> usize {
97 12
98}
99
100fn default_vision_num_channels() -> usize {
101 3
102}
103
104fn default_vision_image_size() -> usize {
105 224
106}
107
108fn default_vision_batch_size() -> usize {
109 16
110}
111
112fn default_vision_layer_norm_eps() -> f64 {
113 1e-6
114}
115
116fn default_vision_hidden_act() -> candle_nn::Activation {
117 candle_nn::Activation::GeluPytorchTanh
118}
119
120#[derive(serde::Deserialize, Clone, Debug)]
122pub struct VisionConfig {
123 #[serde(default = "default_vision_hidden_size")]
124 pub hidden_size: usize,
125 #[serde(default = "default_vision_intermediate_size")]
126 pub intermediate_size: usize,
127 #[serde(default = "default_vision_num_hidden_layers")]
128 pub num_hidden_layers: usize,
129 #[serde(default = "default_vision_num_attention_heads")]
130 pub num_attention_heads: usize,
131 #[serde(default = "default_vision_num_channels")]
132 pub num_channels: usize,
133 #[serde(default = "default_vision_image_size")]
134 pub image_size: usize,
135 #[serde(default = "default_vision_batch_size")]
136 pub patch_size: usize,
137 #[serde(default = "default_vision_hidden_act")]
138 pub hidden_act: candle_nn::Activation,
139 #[serde(default = "default_vision_layer_norm_eps")]
140 pub layer_norm_eps: f64,
141}
142
143trait TransformerConfig {
144 fn hidden_size(&self) -> usize;
145 fn intermediate_size(&self) -> usize;
146 fn num_attention_heads(&self) -> usize;
147 fn num_hidden_layers(&self) -> usize;
148 fn layer_norm_eps(&self) -> f64;
149 fn hidden_act(&self) -> candle_nn::Activation;
150}
151
152impl TransformerConfig for TextConfig {
153 fn hidden_size(&self) -> usize {
154 self.hidden_size
155 }
156 fn intermediate_size(&self) -> usize {
157 self.intermediate_size
158 }
159 fn num_attention_heads(&self) -> usize {
160 self.num_attention_heads
161 }
162 fn num_hidden_layers(&self) -> usize {
163 self.num_hidden_layers
164 }
165 fn layer_norm_eps(&self) -> f64 {
166 self.layer_norm_eps
167 }
168 fn hidden_act(&self) -> candle_nn::Activation {
169 self.hidden_act
170 }
171}
172
173impl TransformerConfig for VisionConfig {
174 fn hidden_size(&self) -> usize {
175 self.hidden_size
176 }
177 fn intermediate_size(&self) -> usize {
178 self.intermediate_size
179 }
180 fn num_attention_heads(&self) -> usize {
181 self.num_attention_heads
182 }
183 fn num_hidden_layers(&self) -> usize {
184 self.num_hidden_layers
185 }
186 fn layer_norm_eps(&self) -> f64 {
187 self.layer_norm_eps
188 }
189 fn hidden_act(&self) -> candle_nn::Activation {
190 self.hidden_act
191 }
192}
193
194impl VisionConfig {
195 pub fn paligemma_3b_224() -> Self {
196 Self {
197 patch_size: 14,
199 num_attention_heads: 16,
200 num_hidden_layers: 27,
201 hidden_size: 1152,
202 intermediate_size: 4304,
203 image_size: 224, num_channels: 3,
206 hidden_act: candle_nn::Activation::GeluPytorchTanh,
207 layer_norm_eps: 1e-6,
208 }
209 }
210
211 pub fn paligemma_3b_448() -> Self {
212 Self {
213 patch_size: 14,
215 num_attention_heads: 16,
216 num_hidden_layers: 27,
217 hidden_size: 1152,
218 intermediate_size: 4304,
219 image_size: 448, num_channels: 3,
222 hidden_act: candle_nn::Activation::GeluPytorchTanh,
223 layer_norm_eps: 1e-6,
224 }
225 }
226
227 pub fn paligemma_3b_896() -> Self {
228 Self {
229 patch_size: 14,
231 num_attention_heads: 16,
232 num_hidden_layers: 27,
233 hidden_size: 1152,
234 intermediate_size: 4304,
235 image_size: 896, num_channels: 3,
238 hidden_act: candle_nn::Activation::GeluPytorchTanh,
239 layer_norm_eps: 1e-6,
240 }
241 }
242
243 pub fn num_patches(&self) -> usize {
244 (self.image_size / self.patch_size).pow(2)
245 }
246}
247
248#[derive(serde::Deserialize, Clone, Debug)]
250pub struct Config {
251 pub text_config: TextConfig,
252 pub vision_config: VisionConfig,
253}
254
255impl Config {
256 pub fn base_patch16_224() -> Self {
257 let text_config = TextConfig {
258 hidden_size: 768,
260 intermediate_size: 3072,
261 num_attention_heads: 12,
262 vocab_size: 32000,
263 pad_token_id: 1,
265 bos_token_id: 49406,
266 eos_token_id: 49407,
267 layer_norm_eps: 1e-6,
268 hidden_act: candle_nn::Activation::GeluPytorchTanh,
269 max_position_embeddings: 64,
270 num_hidden_layers: 12,
271 };
272 let vision_config = VisionConfig {
273 patch_size: 16,
274 hidden_size: 768,
276 intermediate_size: 3072,
277 num_hidden_layers: 12,
278 num_attention_heads: 12,
279 num_channels: 3,
280 image_size: 224,
281 hidden_act: candle_nn::Activation::GeluPytorchTanh,
282 layer_norm_eps: 1e-6,
283 };
284 Self {
285 text_config,
286 vision_config,
287 }
288 }
289}
290
291#[derive(Clone, Debug)]
292struct MultiheadAttention {
293 q_proj: Linear,
294 k_proj: Linear,
295 v_proj: Linear,
296 out_proj: Linear,
297 num_heads: usize,
298}
299
300impl MultiheadAttention {
301 fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
302 let h = cfg.hidden_size;
303 let num_heads = cfg.num_attention_heads;
304 let w_in_proj = vb.get((3 * h, h), "in_proj_weight")?.chunk(3, 0)?;
305 let b_in_proj = vb.get(3 * h, "in_proj_bias")?.chunk(3, 0)?;
306 let q_proj = Linear::new(w_in_proj[0].clone(), Some(b_in_proj[0].clone()));
307 let k_proj = Linear::new(w_in_proj[1].clone(), Some(b_in_proj[1].clone()));
308 let v_proj = Linear::new(w_in_proj[2].clone(), Some(b_in_proj[2].clone()));
309 let out_proj = linear(h, h, vb.pp("out_proj"))?;
310 Ok(Self {
311 q_proj,
312 k_proj,
313 v_proj,
314 out_proj,
315 num_heads,
316 })
317 }
318
319 fn separate_heads(&self, x: &Tensor) -> Result<Tensor> {
320 let (b, n, c) = x.dims3()?;
321 x.reshape((b, n, self.num_heads, c / self.num_heads))?
322 .transpose(1, 2)?
323 .contiguous()
324 }
325
326 fn recombine_heads(&self, x: &Tensor) -> Result<Tensor> {
327 let (b, n_heads, n_tokens, c_per_head) = x.dims4()?;
328 x.transpose(1, 2)?
329 .reshape((b, n_tokens, n_heads * c_per_head))
330 }
331
332 fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
333 let q = self.q_proj.forward(&q.contiguous()?)?;
334 let k = self.k_proj.forward(&k.contiguous()?)?;
335 let v = self.v_proj.forward(&v.contiguous()?)?;
336
337 let q = self.separate_heads(&q)?;
338 let k = self.separate_heads(&k)?;
339 let v = self.separate_heads(&v)?;
340
341 let (_, _, _, c_per_head) = q.dims4()?;
342 let attn = (q.matmul(&k.t()?)? / (c_per_head as f64).sqrt())?;
343 let attn = candle_nn::ops::softmax_last_dim(&attn)?;
344
345 let out = attn.matmul(&v)?;
346 self.recombine_heads(&out)?.apply(&self.out_proj)
347 }
348}
349
350#[derive(Debug, Clone)]
351struct MultiheadAttentionPoolingHead {
352 probe: Tensor,
353 attention: MultiheadAttention,
354 layernorm: LayerNorm,
355 mlp: Mlp,
356}
357
358impl MultiheadAttentionPoolingHead {
359 fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
360 let mlp = Mlp::new(cfg, vb.pp("mlp"))?;
361 let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("layernorm"))?;
362 let probe = vb.get((1, 1, cfg.hidden_size), "probe")?;
363 let attention = MultiheadAttention::new(cfg, vb.pp("attention"))?;
364 Ok(Self {
365 probe,
366 attention,
367 layernorm,
368 mlp,
369 })
370 }
371}
372
373impl Module for MultiheadAttentionPoolingHead {
374 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
375 let batch_size = xs.dim(0)?;
376 let probe = self.probe.repeat((batch_size, 1, 1))?;
377 let xs = self.attention.forward(&probe, xs, xs)?;
378 let residual = &xs;
379 let xs = xs.apply(&self.layernorm)?.apply(&self.mlp)?;
380 (xs + residual)?.i((.., 0))
381 }
382}
383
384#[derive(Debug, Clone)]
385struct Attention {
386 q_proj: Linear,
387 k_proj: Linear,
388 v_proj: Linear,
389 out_proj: Linear,
390 num_heads: usize,
391 head_dim: usize,
392 scale: f64,
393}
394
395impl Attention {
396 fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
397 let embed_dim = cfg.hidden_size();
398 let q_proj = linear(embed_dim, embed_dim, vb.pp("q_proj"))?;
399 let k_proj = linear(embed_dim, embed_dim, vb.pp("k_proj"))?;
400 let v_proj = linear(embed_dim, embed_dim, vb.pp("v_proj"))?;
401 let out_proj = linear(embed_dim, embed_dim, vb.pp("out_proj"))?;
402 let num_heads = cfg.num_attention_heads();
403 let head_dim = embed_dim / num_heads;
404 Ok(Self {
405 q_proj,
406 k_proj,
407 v_proj,
408 out_proj,
409 num_heads,
410 head_dim,
411 scale: (head_dim as f64).powf(-0.5),
412 })
413 }
414
415 fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
416 let (batch_size, q_len, _) = xs.dims3()?;
417 let query_states = xs.apply(&self.q_proj)?;
418 let key_states = xs.apply(&self.k_proj)?;
419 let value_states = xs.apply(&self.v_proj)?;
420
421 let shape = (batch_size, q_len, self.num_heads, self.head_dim);
422 let query_states = query_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
423 let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
424 let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
425
426 let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?;
427 let attn_weights = match attention_mask {
428 None => attn_weights,
429 Some(mask) => attn_weights.broadcast_add(mask)?,
430 };
431 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
433 let attn_outputs = attn_weights
434 .matmul(&value_states)?
435 .transpose(1, 2)?
436 .reshape((batch_size, q_len, ()))?
437 .apply(&self.out_proj)?;
438 Ok(attn_outputs)
439 }
440}
441
442#[derive(Debug, Clone)]
444struct Mlp {
445 fc1: Linear,
446 fc2: Linear,
447 activation_fn: candle_nn::Activation,
448}
449
450impl Mlp {
451 fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
452 let hidden_size = cfg.hidden_size();
453 let intermediate_size = cfg.intermediate_size();
454 let fc1 = candle_nn::linear(hidden_size, intermediate_size, vb.pp("fc1"))?;
455 let fc2 = candle_nn::linear(intermediate_size, hidden_size, vb.pp("fc2"))?;
456 Ok(Self {
457 fc1,
458 fc2,
459 activation_fn: cfg.hidden_act(),
460 })
461 }
462}
463
464impl Module for Mlp {
465 fn forward(&self, xs: &candle::Tensor) -> Result<candle::Tensor> {
466 xs.apply(&self.fc1)?
467 .apply(&self.activation_fn)?
468 .apply(&self.fc2)
469 }
470}
471
472#[derive(Debug, Clone)]
474struct EncoderLayer {
475 self_attn: Attention,
476 layer_norm1: LayerNorm,
477 mlp: Mlp,
478 layer_norm2: LayerNorm,
479}
480
481impl EncoderLayer {
482 fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
483 let hidden_size = cfg.hidden_size();
484 let layer_norm_eps = cfg.layer_norm_eps();
485 let self_attn = Attention::new(cfg, vb.pp("self_attn"))?;
486 let layer_norm1 = layer_norm(hidden_size, layer_norm_eps, vb.pp("layer_norm1"))?;
487 let mlp = Mlp::new(cfg, vb.pp("mlp"))?;
488 let layer_norm2 = layer_norm(hidden_size, layer_norm_eps, vb.pp("layer_norm2"))?;
489 Ok(Self {
490 self_attn,
491 layer_norm1,
492 mlp,
493 layer_norm2,
494 })
495 }
496
497 fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
498 let residual = xs;
499 let xs = xs.apply(&self.layer_norm1)?;
500 let xs = self.self_attn.forward(&xs, attention_mask)?;
501 let xs = (residual + xs)?;
502 let residual = &xs;
503 let xs = xs.apply(&self.layer_norm2)?.apply(&self.mlp)?;
504 let xs = (xs + residual)?;
505 Ok(xs)
506 }
507}
508
509#[derive(Debug, Clone)]
510struct Encoder {
511 layers: Vec<EncoderLayer>,
512}
513
514impl Encoder {
515 fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
516 let mut layers = vec![];
517 let vb = vb.pp("layers");
518 for layer_idx in 0..cfg.num_hidden_layers() {
519 let layer = EncoderLayer::new(cfg, vb.pp(layer_idx))?;
520 layers.push(layer)
521 }
522 Ok(Self { layers })
523 }
524
525 fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
526 let mut xs = xs.clone();
527 for layer in self.layers.iter() {
528 xs = layer.forward(&xs, attention_mask)?
529 }
530 Ok(xs)
531 }
532}
533
534#[derive(Debug, Clone)]
535struct VisionEmbeddings {
536 patch_embedding: candle_nn::Conv2d,
537 position_embedding: Tensor,
538 patch_size: usize,
539 base_num_patches_per_side: usize,
540}
541
542impl VisionEmbeddings {
543 fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
544 let conv2d_cfg = candle_nn::Conv2dConfig {
545 stride: cfg.patch_size,
546 ..Default::default()
547 };
548 let patch_embedding = candle_nn::conv2d(
549 cfg.num_channels,
550 cfg.hidden_size,
551 cfg.patch_size,
552 conv2d_cfg,
553 vb.pp("patch_embedding"),
554 )?;
555 let num_patches_per_side = cfg.image_size / cfg.patch_size;
556 let embedder = candle_nn::embedding(
557 num_patches_per_side.pow(2),
558 cfg.hidden_size(),
559 vb.pp("position_embedding"),
560 )?;
561 let position_embedding = embedder.embeddings();
562 let position_embedding = position_embedding
563 .reshape((
564 1,
565 num_patches_per_side,
566 num_patches_per_side,
567 cfg.hidden_size(),
568 ))?
569 .permute((0, 3, 1, 2))?;
570 Ok(Self {
571 patch_embedding,
572 position_embedding,
573 patch_size: cfg.patch_size,
574 base_num_patches_per_side: num_patches_per_side,
575 })
576 }
577}
578
579impl Module for VisionEmbeddings {
580 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
581 let (_batch, _channels, _height, _width) = xs.dims4()?;
583 let embeddings = xs.apply(&self.patch_embedding)?;
584 let num_patches_h = _height / self.patch_size;
586 let num_patches_w = _width / self.patch_size;
587 let resized_position_embedding = if num_patches_w == self.base_num_patches_per_side
588 && num_patches_h == self.base_num_patches_per_side
589 {
590 self.position_embedding.clone()
591 } else {
592 self.position_embedding
593 .interpolate2d(num_patches_h, num_patches_w)?
594 };
595 let embeddings = embeddings
597 .broadcast_add(&resized_position_embedding)?
598 .flatten_from(2)?
599 .transpose(1, 2)?;
600 Ok(embeddings)
601 }
602}
603
604#[derive(Debug, Clone)]
605struct VisionTransformer {
606 embeddings: VisionEmbeddings,
607 encoder: Encoder,
608 post_layernorm: LayerNorm,
609 head: Option<MultiheadAttentionPoolingHead>,
610}
611
612impl VisionTransformer {
613 fn new(cfg: &VisionConfig, use_head: bool, vb: VarBuilder) -> Result<Self> {
614 let embeddings = VisionEmbeddings::new(cfg, vb.pp("embeddings"))?;
615 let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
616 let post_layernorm =
617 layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("post_layernorm"))?;
618 let head = if use_head {
619 Some(MultiheadAttentionPoolingHead::new(cfg, vb.pp("head"))?)
620 } else {
621 None
622 };
623 Ok(Self {
624 embeddings,
625 encoder,
626 post_layernorm,
627 head,
628 })
629 }
630}
631
632impl Module for VisionTransformer {
633 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
634 let xs = xs.apply(&self.embeddings)?;
635 let xs = self.encoder.forward(&xs, None)?;
636 let xs = xs.apply(&self.post_layernorm)?;
637 match self.head.as_ref() {
638 None => Ok(xs),
639 Some(h) => xs.apply(h),
640 }
641 }
642}
643
644#[derive(Debug, Clone)]
645pub struct VisionModel {
646 vision_model: VisionTransformer,
647}
648
649impl VisionModel {
650 pub fn new(cfg: &VisionConfig, use_head: bool, vb: VarBuilder) -> Result<Self> {
651 let vision_model = VisionTransformer::new(cfg, use_head, vb)?;
652 Ok(Self { vision_model })
653 }
654}
655
656impl Module for VisionModel {
657 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
658 xs.apply(&self.vision_model)
659 }
660}
661
662#[derive(Debug, Clone)]
663struct TextEmbeddings {
664 token_embedding: candle_nn::Embedding,
665 position_embedding: candle_nn::Embedding,
666 position_ids: Tensor,
667}
668
669impl TextEmbeddings {
670 fn new(cfg: &TextConfig, vb: VarBuilder) -> Result<Self> {
671 let token_embedding =
672 candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("token_embedding"))?;
673 let position_embedding = candle_nn::embedding(
674 cfg.max_position_embeddings,
675 cfg.hidden_size,
676 vb.pp("position_embedding"),
677 )?;
678 let position_ids =
679 Tensor::arange(0u32, cfg.max_position_embeddings as u32, vb.device())?.unsqueeze(0)?;
680 Ok(Self {
681 token_embedding,
682 position_embedding,
683 position_ids,
684 })
685 }
686}
687
688impl Module for TextEmbeddings {
689 fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
690 let seq_length = input_ids.dim(D::Minus1)?;
691 let inputs_embeds = self.token_embedding.forward(input_ids)?;
692 let position_ids = self.position_ids.narrow(1, 0, seq_length)?;
693 let position_embedding = self.position_embedding.forward(&position_ids)?;
694 inputs_embeds.broadcast_add(&position_embedding)
695 }
696}
697
698#[derive(Debug, Clone)]
699pub struct TextTransformer {
700 embeddings: TextEmbeddings,
701 encoder: Encoder,
702 final_layer_norm: LayerNorm,
703 pub head: Linear,
704}
705
706impl TextTransformer {
707 fn new(cfg: &TextConfig, vb: VarBuilder) -> Result<Self> {
708 let embeddings = TextEmbeddings::new(cfg, vb.pp("embeddings"))?;
709 let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
710 let final_layer_norm = layer_norm(
711 cfg.hidden_size,
712 cfg.layer_norm_eps,
713 vb.pp("final_layer_norm"),
714 )?;
715 let head = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("head"))?;
716 Ok(Self {
717 embeddings,
718 encoder,
719 final_layer_norm,
720 head,
721 })
722 }
723}
724impl Module for TextTransformer {
725 fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
726 let (_bsz, seq_len) = input_ids.dims2()?;
727 let input_ids = self.embeddings.forward(input_ids)?;
728 let input_ids = self.encoder.forward(&input_ids, None)?;
729 let last_hidden_state = self.final_layer_norm.forward(&input_ids)?;
730 last_hidden_state
731 .i((.., seq_len - 1, ..))?
732 .contiguous()?
733 .apply(&self.head)
734 }
735}
736
737#[derive(Debug, Clone)]
738pub struct TextModel {
739 pub text_model: TextTransformer,
740}
741
742impl TextModel {
743 pub fn new(cfg: &TextConfig, vb: VarBuilder) -> Result<Self> {
744 let text_model = TextTransformer::new(cfg, vb)?;
745 Ok(Self { text_model })
746 }
747}
748
749impl Module for TextModel {
750 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
751 xs.apply(&self.text_model)
752 }
753}
754
755#[derive(Clone, Debug)]
756pub struct Model {
757 text_model: TextModel,
758 vision_model: VisionModel,
759 logit_bias: Tensor,
760 logit_scale: Tensor,
761}
762
763impl Model {
764 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
765 let text_model = TextModel::new(&cfg.text_config, vb.pp("text_model"))?;
766 let vision_model = VisionModel::new(&cfg.vision_config, true, vb.pp("vision_model"))?;
767 let logit_scale = vb.get(&[1], "logit_scale")?;
768 let logit_bias = vb.get(&[1], "logit_bias")?;
769 Ok(Self {
770 text_model,
771 vision_model,
772 logit_bias,
773 logit_scale,
774 })
775 }
776
777 pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
778 input_ids.apply(&self.text_model)
779 }
780
781 pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
782 pixel_values.apply(&self.vision_model)
783 }
784
785 pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> {
786 let image_features = self.get_image_features(pixel_values)?;
787 let text_features = self.get_text_features(input_ids)?;
788 let image_features_normalized = div_l2_norm(&image_features)?;
789 let text_features_normalized = div_l2_norm(&text_features)?;
790 let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
791 let logit_scale = self.logit_scale.exp()?;
792 let logits_per_text = logits_per_text
793 .broadcast_mul(&logit_scale)?
794 .broadcast_add(&self.logit_bias)?;
795 let logits_per_image = logits_per_text.t()?;
796 Ok((logits_per_text, logits_per_image))
797 }
798}