candle_transformers/models/
mobileclip.rs1use super::fastvit;
18use super::openclip::text_model;
19use candle::{Result, Tensor, D};
20use candle_nn::{Func, VarBuilder};
21
22#[derive(Clone, Debug)]
23pub struct MobileClipModel {
24 text_model: text_model::OpenClipTextTransformer,
25 vision_model: Func<'static>,
26 text_projection: Tensor,
27 logit_scale: Tensor,
28}
29
30#[derive(Clone, Debug)]
31pub struct MobileClipConfig {
32 pub text_config: text_model::Config,
33 pub vision_config: fastvit::Config,
34 pub image_size: usize,
35}
36
37impl MobileClipConfig {
38 pub fn s1() -> Self {
39 let text_config = text_model::Config::vit_base_patch32();
40 let vision_config = fastvit::Config::mci1();
41 Self {
42 text_config,
43 vision_config,
44 image_size: 256,
45 }
46 }
47 pub fn s2() -> Self {
48 let text_config = text_model::Config::vit_base_patch32();
49 let vision_config = fastvit::Config::mci2();
50 Self {
51 text_config,
52 vision_config,
53 image_size: 256,
54 }
55 }
56}
57
58impl MobileClipModel {
59 pub fn new(vs: VarBuilder, c: &MobileClipConfig) -> Result<Self> {
60 let vision_model = fastvit::fastvit(&c.vision_config, 512, vs.pp("visual.trunk"))?;
61 let text_model = text_model::OpenClipTextTransformer::new(vs.pp("text"), &c.text_config)?;
62 let text_projection = vs.get(
63 (c.text_config.embed_dim, c.text_config.projection_dim),
64 "text.text_projection",
65 )?;
66 let logit_scale = vs.get(&[], "logit_scale")?;
67 Ok(Self {
68 text_model,
69 vision_model,
70 text_projection,
71 logit_scale,
72 })
73 }
74
75 pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
76 input_ids
77 .apply(&self.text_model)?
78 .matmul(&self.text_projection)
79 }
80
81 pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
82 pixel_values.apply(&self.vision_model)
83 }
84
85 pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> {
86 let image_features = self.get_image_features(pixel_values)?;
87 let text_features = self.get_text_features(input_ids)?;
88 let image_features_normalized = div_l2_norm(&image_features)?;
89 let text_features_normalized = div_l2_norm(&text_features)?;
90 let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
91 let logit_scale = self.logit_scale.exp()?;
92 let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?;
93 let logits_per_image = logits_per_text.t()?;
94 Ok((logits_per_text, logits_per_image))
95 }
96}
97
98pub fn div_l2_norm(v: &Tensor) -> Result<Tensor> {
99 let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?;
100 v.broadcast_div(&l2_norm)
101}