candle_transformers/models/
siglip.rs

1//! Siglip model implementation.
2//!
3//! Siglip architecture combining vision and language for zero-shot tasks.
4//!
5//! References:
6//! - 🤗 [Model Card](https://huggingface.co/google/siglip-base-patch16-224)
7//!
8
9use crate::models::clip::div_l2_norm;
10use candle::{IndexOp, Module, Result, Tensor, D};
11use candle_nn::{layer_norm, linear, LayerNorm, Linear, VarBuilder};
12
13fn default_text_vocab_size() -> usize {
14    32000
15}
16
17fn default_text_hidden_size() -> usize {
18    768
19}
20
21fn default_text_intermediate_size() -> usize {
22    3072
23}
24
25fn default_text_num_hidden_layers() -> usize {
26    12
27}
28
29fn default_text_num_attention_heads() -> usize {
30    12
31}
32
33fn default_text_max_position_embeddings() -> usize {
34    64
35}
36
37fn default_text_layer_norm_eps() -> f64 {
38    1e-6
39}
40
41fn default_text_pad_token_id() -> u32 {
42    1
43}
44
45fn default_text_bos_token_id() -> u32 {
46    49406
47}
48
49fn default_text_eos_token_id() -> u32 {
50    49407
51}
52
53fn default_text_hidden_act() -> candle_nn::Activation {
54    candle_nn::Activation::GeluPytorchTanh
55}
56
57// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L27
58#[derive(serde::Deserialize, Clone, Debug)]
59pub struct TextConfig {
60    #[serde(default = "default_text_vocab_size")]
61    pub vocab_size: usize,
62    #[serde(default = "default_text_hidden_size")]
63    pub hidden_size: usize,
64    #[serde(default = "default_text_intermediate_size")]
65    pub intermediate_size: usize,
66    #[serde(default = "default_text_num_hidden_layers")]
67    pub num_hidden_layers: usize,
68    #[serde(default = "default_text_num_attention_heads")]
69    pub num_attention_heads: usize,
70    #[serde(default = "default_text_max_position_embeddings")]
71    pub max_position_embeddings: usize,
72    #[serde(default = "default_text_hidden_act")]
73    pub hidden_act: candle_nn::Activation,
74    #[serde(default = "default_text_layer_norm_eps")]
75    pub layer_norm_eps: f64,
76    #[serde(default = "default_text_pad_token_id")]
77    pub pad_token_id: u32,
78    #[serde(default = "default_text_bos_token_id")]
79    pub bos_token_id: u32,
80    #[serde(default = "default_text_eos_token_id")]
81    pub eos_token_id: u32,
82}
83
84fn default_vision_hidden_size() -> usize {
85    768
86}
87
88fn default_vision_intermediate_size() -> usize {
89    3072
90}
91
92fn default_vision_num_hidden_layers() -> usize {
93    12
94}
95
96fn default_vision_num_attention_heads() -> usize {
97    12
98}
99
100fn default_vision_num_channels() -> usize {
101    3
102}
103
104fn default_vision_image_size() -> usize {
105    224
106}
107
108fn default_vision_batch_size() -> usize {
109    16
110}
111
112fn default_vision_layer_norm_eps() -> f64 {
113    1e-6
114}
115
116fn default_vision_hidden_act() -> candle_nn::Activation {
117    candle_nn::Activation::GeluPytorchTanh
118}
119
120// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L132
121#[derive(serde::Deserialize, Clone, Debug)]
122pub struct VisionConfig {
123    #[serde(default = "default_vision_hidden_size")]
124    pub hidden_size: usize,
125    #[serde(default = "default_vision_intermediate_size")]
126    pub intermediate_size: usize,
127    #[serde(default = "default_vision_num_hidden_layers")]
128    pub num_hidden_layers: usize,
129    #[serde(default = "default_vision_num_attention_heads")]
130    pub num_attention_heads: usize,
131    #[serde(default = "default_vision_num_channels")]
132    pub num_channels: usize,
133    #[serde(default = "default_vision_image_size")]
134    pub image_size: usize,
135    #[serde(default = "default_vision_batch_size")]
136    pub patch_size: usize,
137    #[serde(default = "default_vision_hidden_act")]
138    pub hidden_act: candle_nn::Activation,
139    #[serde(default = "default_vision_layer_norm_eps")]
140    pub layer_norm_eps: f64,
141}
142
143trait TransformerConfig {
144    fn hidden_size(&self) -> usize;
145    fn intermediate_size(&self) -> usize;
146    fn num_attention_heads(&self) -> usize;
147    fn num_hidden_layers(&self) -> usize;
148    fn layer_norm_eps(&self) -> f64;
149    fn hidden_act(&self) -> candle_nn::Activation;
150}
151
152impl TransformerConfig for TextConfig {
153    fn hidden_size(&self) -> usize {
154        self.hidden_size
155    }
156    fn intermediate_size(&self) -> usize {
157        self.intermediate_size
158    }
159    fn num_attention_heads(&self) -> usize {
160        self.num_attention_heads
161    }
162    fn num_hidden_layers(&self) -> usize {
163        self.num_hidden_layers
164    }
165    fn layer_norm_eps(&self) -> f64 {
166        self.layer_norm_eps
167    }
168    fn hidden_act(&self) -> candle_nn::Activation {
169        self.hidden_act
170    }
171}
172
173impl TransformerConfig for VisionConfig {
174    fn hidden_size(&self) -> usize {
175        self.hidden_size
176    }
177    fn intermediate_size(&self) -> usize {
178        self.intermediate_size
179    }
180    fn num_attention_heads(&self) -> usize {
181        self.num_attention_heads
182    }
183    fn num_hidden_layers(&self) -> usize {
184        self.num_hidden_layers
185    }
186    fn layer_norm_eps(&self) -> f64 {
187        self.layer_norm_eps
188    }
189    fn hidden_act(&self) -> candle_nn::Activation {
190        self.hidden_act
191    }
192}
193
194impl VisionConfig {
195    pub fn paligemma_3b_224() -> Self {
196        Self {
197            // https://huggingface.co/google/paligemma-3b-pt-224/blob/main/config.json
198            patch_size: 14,
199            num_attention_heads: 16,
200            num_hidden_layers: 27,
201            hidden_size: 1152,
202            intermediate_size: 4304,
203            image_size: 224, // num_image_tokens: (224 / 14)^2 = 256
204            // Default values.
205            num_channels: 3,
206            hidden_act: candle_nn::Activation::GeluPytorchTanh,
207            layer_norm_eps: 1e-6,
208        }
209    }
210
211    pub fn paligemma_3b_448() -> Self {
212        Self {
213            // https://huggingface.co/google/paligemma-3b-pt-448/blob/main/config.json
214            patch_size: 14,
215            num_attention_heads: 16,
216            num_hidden_layers: 27,
217            hidden_size: 1152,
218            intermediate_size: 4304,
219            image_size: 448, // num_image_tokens: (448 / 14)^2 = 1024
220            // Default values.
221            num_channels: 3,
222            hidden_act: candle_nn::Activation::GeluPytorchTanh,
223            layer_norm_eps: 1e-6,
224        }
225    }
226
227    pub fn paligemma_3b_896() -> Self {
228        Self {
229            // https://huggingface.co/google/paligemma-3b-pt-448/blob/main/config.json
230            patch_size: 14,
231            num_attention_heads: 16,
232            num_hidden_layers: 27,
233            hidden_size: 1152,
234            intermediate_size: 4304,
235            image_size: 896, // num_image_tokens: (896 / 14)^2 = 4096
236            // Default values.
237            num_channels: 3,
238            hidden_act: candle_nn::Activation::GeluPytorchTanh,
239            layer_norm_eps: 1e-6,
240        }
241    }
242
243    pub fn num_patches(&self) -> usize {
244        (self.image_size / self.patch_size).pow(2)
245    }
246}
247
248// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L228
249#[derive(serde::Deserialize, Clone, Debug)]
250pub struct Config {
251    pub text_config: TextConfig,
252    pub vision_config: VisionConfig,
253}
254
255impl Config {
256    pub fn base_patch16_224() -> Self {
257        let text_config = TextConfig {
258            // https://huggingface.co/google/siglip-base-patch16-224/blob/main/config.json
259            hidden_size: 768,
260            intermediate_size: 3072,
261            num_attention_heads: 12,
262            vocab_size: 32000,
263            // Default values.
264            pad_token_id: 1,
265            bos_token_id: 49406,
266            eos_token_id: 49407,
267            layer_norm_eps: 1e-6,
268            hidden_act: candle_nn::Activation::GeluPytorchTanh,
269            max_position_embeddings: 64,
270            num_hidden_layers: 12,
271        };
272        let vision_config = VisionConfig {
273            patch_size: 16,
274            // Default values.
275            hidden_size: 768,
276            intermediate_size: 3072,
277            num_hidden_layers: 12,
278            num_attention_heads: 12,
279            num_channels: 3,
280            image_size: 224,
281            hidden_act: candle_nn::Activation::GeluPytorchTanh,
282            layer_norm_eps: 1e-6,
283        };
284        Self {
285            text_config,
286            vision_config,
287        }
288    }
289}
290
291#[derive(Clone, Debug)]
292struct MultiheadAttention {
293    q_proj: Linear,
294    k_proj: Linear,
295    v_proj: Linear,
296    out_proj: Linear,
297    num_heads: usize,
298}
299
300impl MultiheadAttention {
301    fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
302        let h = cfg.hidden_size;
303        let num_heads = cfg.num_attention_heads;
304        let w_in_proj = vb.get((3 * h, h), "in_proj_weight")?.chunk(3, 0)?;
305        let b_in_proj = vb.get(3 * h, "in_proj_bias")?.chunk(3, 0)?;
306        let q_proj = Linear::new(w_in_proj[0].clone(), Some(b_in_proj[0].clone()));
307        let k_proj = Linear::new(w_in_proj[1].clone(), Some(b_in_proj[1].clone()));
308        let v_proj = Linear::new(w_in_proj[2].clone(), Some(b_in_proj[2].clone()));
309        let out_proj = linear(h, h, vb.pp("out_proj"))?;
310        Ok(Self {
311            q_proj,
312            k_proj,
313            v_proj,
314            out_proj,
315            num_heads,
316        })
317    }
318
319    fn separate_heads(&self, x: &Tensor) -> Result<Tensor> {
320        let (b, n, c) = x.dims3()?;
321        x.reshape((b, n, self.num_heads, c / self.num_heads))?
322            .transpose(1, 2)?
323            .contiguous()
324    }
325
326    fn recombine_heads(&self, x: &Tensor) -> Result<Tensor> {
327        let (b, n_heads, n_tokens, c_per_head) = x.dims4()?;
328        x.transpose(1, 2)?
329            .reshape((b, n_tokens, n_heads * c_per_head))
330    }
331
332    fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
333        let q = self.q_proj.forward(&q.contiguous()?)?;
334        let k = self.k_proj.forward(&k.contiguous()?)?;
335        let v = self.v_proj.forward(&v.contiguous()?)?;
336
337        let q = self.separate_heads(&q)?;
338        let k = self.separate_heads(&k)?;
339        let v = self.separate_heads(&v)?;
340
341        let (_, _, _, c_per_head) = q.dims4()?;
342        let attn = (q.matmul(&k.t()?)? / (c_per_head as f64).sqrt())?;
343        let attn = candle_nn::ops::softmax_last_dim(&attn)?;
344
345        let out = attn.matmul(&v)?;
346        self.recombine_heads(&out)?.apply(&self.out_proj)
347    }
348}
349
350#[derive(Debug, Clone)]
351struct MultiheadAttentionPoolingHead {
352    probe: Tensor,
353    attention: MultiheadAttention,
354    layernorm: LayerNorm,
355    mlp: Mlp,
356}
357
358impl MultiheadAttentionPoolingHead {
359    fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
360        let mlp = Mlp::new(cfg, vb.pp("mlp"))?;
361        let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("layernorm"))?;
362        let probe = vb.get((1, 1, cfg.hidden_size), "probe")?;
363        let attention = MultiheadAttention::new(cfg, vb.pp("attention"))?;
364        Ok(Self {
365            probe,
366            attention,
367            layernorm,
368            mlp,
369        })
370    }
371}
372
373impl Module for MultiheadAttentionPoolingHead {
374    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
375        let batch_size = xs.dim(0)?;
376        let probe = self.probe.repeat((batch_size, 1, 1))?;
377        let xs = self.attention.forward(&probe, xs, xs)?;
378        let residual = &xs;
379        let xs = xs.apply(&self.layernorm)?.apply(&self.mlp)?;
380        (xs + residual)?.i((.., 0))
381    }
382}
383
384#[derive(Debug, Clone)]
385struct Attention {
386    q_proj: Linear,
387    k_proj: Linear,
388    v_proj: Linear,
389    out_proj: Linear,
390    num_heads: usize,
391    head_dim: usize,
392    scale: f64,
393}
394
395impl Attention {
396    fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
397        let embed_dim = cfg.hidden_size();
398        let q_proj = linear(embed_dim, embed_dim, vb.pp("q_proj"))?;
399        let k_proj = linear(embed_dim, embed_dim, vb.pp("k_proj"))?;
400        let v_proj = linear(embed_dim, embed_dim, vb.pp("v_proj"))?;
401        let out_proj = linear(embed_dim, embed_dim, vb.pp("out_proj"))?;
402        let num_heads = cfg.num_attention_heads();
403        let head_dim = embed_dim / num_heads;
404        Ok(Self {
405            q_proj,
406            k_proj,
407            v_proj,
408            out_proj,
409            num_heads,
410            head_dim,
411            scale: (head_dim as f64).powf(-0.5),
412        })
413    }
414
415    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
416        let (batch_size, q_len, _) = xs.dims3()?;
417        let query_states = xs.apply(&self.q_proj)?;
418        let key_states = xs.apply(&self.k_proj)?;
419        let value_states = xs.apply(&self.v_proj)?;
420
421        let shape = (batch_size, q_len, self.num_heads, self.head_dim);
422        let query_states = query_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
423        let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
424        let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
425
426        let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?;
427        let attn_weights = match attention_mask {
428            None => attn_weights,
429            Some(mask) => attn_weights.broadcast_add(mask)?,
430        };
431        // The original implementation upcasts to f32 but candle_nn::ops::softmax should handle this properly.
432        let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
433        let attn_outputs = attn_weights
434            .matmul(&value_states)?
435            .transpose(1, 2)?
436            .reshape((batch_size, q_len, ()))?
437            .apply(&self.out_proj)?;
438        Ok(attn_outputs)
439    }
440}
441
442// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/modeling_siglip.py#L599
443#[derive(Debug, Clone)]
444struct Mlp {
445    fc1: Linear,
446    fc2: Linear,
447    activation_fn: candle_nn::Activation,
448}
449
450impl Mlp {
451    fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
452        let hidden_size = cfg.hidden_size();
453        let intermediate_size = cfg.intermediate_size();
454        let fc1 = candle_nn::linear(hidden_size, intermediate_size, vb.pp("fc1"))?;
455        let fc2 = candle_nn::linear(intermediate_size, hidden_size, vb.pp("fc2"))?;
456        Ok(Self {
457            fc1,
458            fc2,
459            activation_fn: cfg.hidden_act(),
460        })
461    }
462}
463
464impl Module for Mlp {
465    fn forward(&self, xs: &candle::Tensor) -> Result<candle::Tensor> {
466        xs.apply(&self.fc1)?
467            .apply(&self.activation_fn)?
468            .apply(&self.fc2)
469    }
470}
471
472// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/modeling_siglip.py#L614
473#[derive(Debug, Clone)]
474struct EncoderLayer {
475    self_attn: Attention,
476    layer_norm1: LayerNorm,
477    mlp: Mlp,
478    layer_norm2: LayerNorm,
479}
480
481impl EncoderLayer {
482    fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
483        let hidden_size = cfg.hidden_size();
484        let layer_norm_eps = cfg.layer_norm_eps();
485        let self_attn = Attention::new(cfg, vb.pp("self_attn"))?;
486        let layer_norm1 = layer_norm(hidden_size, layer_norm_eps, vb.pp("layer_norm1"))?;
487        let mlp = Mlp::new(cfg, vb.pp("mlp"))?;
488        let layer_norm2 = layer_norm(hidden_size, layer_norm_eps, vb.pp("layer_norm2"))?;
489        Ok(Self {
490            self_attn,
491            layer_norm1,
492            mlp,
493            layer_norm2,
494        })
495    }
496
497    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
498        let residual = xs;
499        let xs = xs.apply(&self.layer_norm1)?;
500        let xs = self.self_attn.forward(&xs, attention_mask)?;
501        let xs = (residual + xs)?;
502        let residual = &xs;
503        let xs = xs.apply(&self.layer_norm2)?.apply(&self.mlp)?;
504        let xs = (xs + residual)?;
505        Ok(xs)
506    }
507}
508
509#[derive(Debug, Clone)]
510struct Encoder {
511    layers: Vec<EncoderLayer>,
512}
513
514impl Encoder {
515    fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
516        let mut layers = vec![];
517        let vb = vb.pp("layers");
518        for layer_idx in 0..cfg.num_hidden_layers() {
519            let layer = EncoderLayer::new(cfg, vb.pp(layer_idx))?;
520            layers.push(layer)
521        }
522        Ok(Self { layers })
523    }
524
525    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
526        let mut xs = xs.clone();
527        for layer in self.layers.iter() {
528            xs = layer.forward(&xs, attention_mask)?
529        }
530        Ok(xs)
531    }
532}
533
534#[derive(Debug, Clone)]
535struct VisionEmbeddings {
536    patch_embedding: candle_nn::Conv2d,
537    position_embedding: Tensor,
538    patch_size: usize,
539    base_num_patches_per_side: usize,
540}
541
542impl VisionEmbeddings {
543    fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
544        let conv2d_cfg = candle_nn::Conv2dConfig {
545            stride: cfg.patch_size,
546            ..Default::default()
547        };
548        let patch_embedding = candle_nn::conv2d(
549            cfg.num_channels,
550            cfg.hidden_size,
551            cfg.patch_size,
552            conv2d_cfg,
553            vb.pp("patch_embedding"),
554        )?;
555        let num_patches_per_side = cfg.image_size / cfg.patch_size;
556        let embedder = candle_nn::embedding(
557            num_patches_per_side.pow(2),
558            cfg.hidden_size(),
559            vb.pp("position_embedding"),
560        )?;
561        let position_embedding = embedder.embeddings();
562        let position_embedding = position_embedding
563            .reshape((
564                1,
565                num_patches_per_side,
566                num_patches_per_side,
567                cfg.hidden_size(),
568            ))?
569            .permute((0, 3, 1, 2))?;
570        Ok(Self {
571            patch_embedding,
572            position_embedding,
573            patch_size: cfg.patch_size,
574            base_num_patches_per_side: num_patches_per_side,
575        })
576    }
577}
578
579impl Module for VisionEmbeddings {
580    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
581        //embed tokens
582        let (_batch, _channels, _height, _width) = xs.dims4()?;
583        let embeddings = xs.apply(&self.patch_embedding)?;
584        // interpolate position embeddings for the current image size (if needed)
585        let num_patches_h = _height / self.patch_size;
586        let num_patches_w = _width / self.patch_size;
587        let resized_position_embedding = if num_patches_w == self.base_num_patches_per_side
588            && num_patches_h == self.base_num_patches_per_side
589        {
590            self.position_embedding.clone()
591        } else {
592            self.position_embedding
593                .interpolate2d(num_patches_h, num_patches_w)?
594        };
595        // Add position embeddings to tokens and flatten from 2D patches to 1D sequence
596        let embeddings = embeddings
597            .broadcast_add(&resized_position_embedding)?
598            .flatten_from(2)?
599            .transpose(1, 2)?;
600        Ok(embeddings)
601    }
602}
603
604#[derive(Debug, Clone)]
605struct VisionTransformer {
606    embeddings: VisionEmbeddings,
607    encoder: Encoder,
608    post_layernorm: LayerNorm,
609    head: Option<MultiheadAttentionPoolingHead>,
610}
611
612impl VisionTransformer {
613    fn new(cfg: &VisionConfig, use_head: bool, vb: VarBuilder) -> Result<Self> {
614        let embeddings = VisionEmbeddings::new(cfg, vb.pp("embeddings"))?;
615        let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
616        let post_layernorm =
617            layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("post_layernorm"))?;
618        let head = if use_head {
619            Some(MultiheadAttentionPoolingHead::new(cfg, vb.pp("head"))?)
620        } else {
621            None
622        };
623        Ok(Self {
624            embeddings,
625            encoder,
626            post_layernorm,
627            head,
628        })
629    }
630}
631
632impl Module for VisionTransformer {
633    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
634        let xs = xs.apply(&self.embeddings)?;
635        let xs = self.encoder.forward(&xs, None)?;
636        let xs = xs.apply(&self.post_layernorm)?;
637        match self.head.as_ref() {
638            None => Ok(xs),
639            Some(h) => xs.apply(h),
640        }
641    }
642}
643
644#[derive(Debug, Clone)]
645pub struct VisionModel {
646    vision_model: VisionTransformer,
647}
648
649impl VisionModel {
650    pub fn new(cfg: &VisionConfig, use_head: bool, vb: VarBuilder) -> Result<Self> {
651        let vision_model = VisionTransformer::new(cfg, use_head, vb)?;
652        Ok(Self { vision_model })
653    }
654}
655
656impl Module for VisionModel {
657    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
658        xs.apply(&self.vision_model)
659    }
660}
661
662#[derive(Debug, Clone)]
663struct TextEmbeddings {
664    token_embedding: candle_nn::Embedding,
665    position_embedding: candle_nn::Embedding,
666    position_ids: Tensor,
667}
668
669impl TextEmbeddings {
670    fn new(cfg: &TextConfig, vb: VarBuilder) -> Result<Self> {
671        let token_embedding =
672            candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("token_embedding"))?;
673        let position_embedding = candle_nn::embedding(
674            cfg.max_position_embeddings,
675            cfg.hidden_size,
676            vb.pp("position_embedding"),
677        )?;
678        let position_ids =
679            Tensor::arange(0u32, cfg.max_position_embeddings as u32, vb.device())?.unsqueeze(0)?;
680        Ok(Self {
681            token_embedding,
682            position_embedding,
683            position_ids,
684        })
685    }
686}
687
688impl Module for TextEmbeddings {
689    fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
690        let seq_length = input_ids.dim(D::Minus1)?;
691        let inputs_embeds = self.token_embedding.forward(input_ids)?;
692        let position_ids = self.position_ids.narrow(1, 0, seq_length)?;
693        let position_embedding = self.position_embedding.forward(&position_ids)?;
694        inputs_embeds.broadcast_add(&position_embedding)
695    }
696}
697
698#[derive(Debug, Clone)]
699pub struct TextTransformer {
700    embeddings: TextEmbeddings,
701    encoder: Encoder,
702    final_layer_norm: LayerNorm,
703    pub head: Linear,
704}
705
706impl TextTransformer {
707    fn new(cfg: &TextConfig, vb: VarBuilder) -> Result<Self> {
708        let embeddings = TextEmbeddings::new(cfg, vb.pp("embeddings"))?;
709        let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
710        let final_layer_norm = layer_norm(
711            cfg.hidden_size,
712            cfg.layer_norm_eps,
713            vb.pp("final_layer_norm"),
714        )?;
715        let head = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("head"))?;
716        Ok(Self {
717            embeddings,
718            encoder,
719            final_layer_norm,
720            head,
721        })
722    }
723}
724impl Module for TextTransformer {
725    fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
726        let (_bsz, seq_len) = input_ids.dims2()?;
727        let input_ids = self.embeddings.forward(input_ids)?;
728        let input_ids = self.encoder.forward(&input_ids, None)?;
729        let last_hidden_state = self.final_layer_norm.forward(&input_ids)?;
730        last_hidden_state
731            .i((.., seq_len - 1, ..))?
732            .contiguous()?
733            .apply(&self.head)
734    }
735}
736
737#[derive(Debug, Clone)]
738pub struct TextModel {
739    pub text_model: TextTransformer,
740}
741
742impl TextModel {
743    pub fn new(cfg: &TextConfig, vb: VarBuilder) -> Result<Self> {
744        let text_model = TextTransformer::new(cfg, vb)?;
745        Ok(Self { text_model })
746    }
747}
748
749impl Module for TextModel {
750    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
751        xs.apply(&self.text_model)
752    }
753}
754
755#[derive(Clone, Debug)]
756pub struct Model {
757    text_model: TextModel,
758    vision_model: VisionModel,
759    logit_bias: Tensor,
760    logit_scale: Tensor,
761}
762
763impl Model {
764    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
765        let text_model = TextModel::new(&cfg.text_config, vb.pp("text_model"))?;
766        let vision_model = VisionModel::new(&cfg.vision_config, true, vb.pp("vision_model"))?;
767        let logit_scale = vb.get(&[1], "logit_scale")?;
768        let logit_bias = vb.get(&[1], "logit_bias")?;
769        Ok(Self {
770            text_model,
771            vision_model,
772            logit_bias,
773            logit_scale,
774        })
775    }
776
777    pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
778        input_ids.apply(&self.text_model)
779    }
780
781    pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
782        pixel_values.apply(&self.vision_model)
783    }
784
785    pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> {
786        let image_features = self.get_image_features(pixel_values)?;
787        let text_features = self.get_text_features(input_ids)?;
788        let image_features_normalized = div_l2_norm(&image_features)?;
789        let text_features_normalized = div_l2_norm(&text_features)?;
790        let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
791        let logit_scale = self.logit_scale.exp()?;
792        let logits_per_text = logits_per_text
793            .broadcast_mul(&logit_scale)?
794            .broadcast_add(&self.logit_bias)?;
795        let logits_per_image = logits_per_text.t()?;
796        Ok((logits_per_text, logits_per_image))
797    }
798}