candle_transformers/models/
mobileclip.rs

1//! Mobile CLIP model, combining a lightweight vision encoder with a text encoder
2//!
3//! A mobile-optimized CLIP implementation that uses:
4//! - FastViT as the vision encoder
5//! - OpenCLIP text encoder
6//! - Projection layers to align the feature spaces
7//!
8//! See model details at:
9//! - [FastViT](https://arxiv.org/abs/2303.14189)
10//! - [OpenCLIP](https://github.com/mlfoundations/open_clip)
11//!
12//! References:
13//! - [MobileVLM](https://huggingface.co/mobileVLM)
14//! - [MetaCLIP](https://arxiv.org/abs/2309.16671)
15//!
16
17use super::fastvit;
18use super::openclip::text_model;
19use candle::{Result, Tensor, D};
20use candle_nn::{Func, VarBuilder};
21
22#[derive(Clone, Debug)]
23pub struct MobileClipModel {
24    text_model: text_model::OpenClipTextTransformer,
25    vision_model: Func<'static>,
26    text_projection: Tensor,
27    logit_scale: Tensor,
28}
29
30#[derive(Clone, Debug)]
31pub struct MobileClipConfig {
32    pub text_config: text_model::Config,
33    pub vision_config: fastvit::Config,
34    pub image_size: usize,
35}
36
37impl MobileClipConfig {
38    pub fn s1() -> Self {
39        let text_config = text_model::Config::vit_base_patch32();
40        let vision_config = fastvit::Config::mci1();
41        Self {
42            text_config,
43            vision_config,
44            image_size: 256,
45        }
46    }
47    pub fn s2() -> Self {
48        let text_config = text_model::Config::vit_base_patch32();
49        let vision_config = fastvit::Config::mci2();
50        Self {
51            text_config,
52            vision_config,
53            image_size: 256,
54        }
55    }
56}
57
58impl MobileClipModel {
59    pub fn new(vs: VarBuilder, c: &MobileClipConfig) -> Result<Self> {
60        let vision_model = fastvit::fastvit(&c.vision_config, 512, vs.pp("visual.trunk"))?;
61        let text_model = text_model::OpenClipTextTransformer::new(vs.pp("text"), &c.text_config)?;
62        let text_projection = vs.get(
63            (c.text_config.embed_dim, c.text_config.projection_dim),
64            "text.text_projection",
65        )?;
66        let logit_scale = vs.get(&[], "logit_scale")?;
67        Ok(Self {
68            text_model,
69            vision_model,
70            text_projection,
71            logit_scale,
72        })
73    }
74
75    pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
76        input_ids
77            .apply(&self.text_model)?
78            .matmul(&self.text_projection)
79    }
80
81    pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
82        pixel_values.apply(&self.vision_model)
83    }
84
85    pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> {
86        let image_features = self.get_image_features(pixel_values)?;
87        let text_features = self.get_text_features(input_ids)?;
88        let image_features_normalized = div_l2_norm(&image_features)?;
89        let text_features_normalized = div_l2_norm(&text_features)?;
90        let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
91        let logit_scale = self.logit_scale.exp()?;
92        let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?;
93        let logits_per_image = logits_per_text.t()?;
94        Ok((logits_per_text, logits_per_image))
95    }
96}
97
98pub fn div_l2_norm(v: &Tensor) -> Result<Tensor> {
99    let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?;
100    v.broadcast_div(&l2_norm)
101}