candle_transformers/models/
paligemma.rs

1//! Multimodal multi-purpose model combining Gemma-based language model with SigLIP image understanding
2//!
3//! See PaLiGemma details at:
4//! - [Paper](https://arxiv.org/abs/2402.05257)
5//! - [Google Blog Post](https://blog.research.google/2024/02/paligemma-scaling-language-image.html)
6//!
7//! The model is a multimodal combination of:
8//! - SigLIP vision encoder
9//! - Gemma language model
10//! - Cross-projection layers
11//!
12//! References:
13//! - [HuggingFace Implementation](https://huggingface.co/google/paligemma-3b)
14//! - [Paper: PaLI-3 and Beyond: Scaling Language-Image Learning](https://arxiv.org/abs/2402.05257)
15//!
16
17use crate::models::{gemma, siglip};
18use candle::{Module, Result, Tensor};
19use candle_nn::{linear, Linear, VarBuilder};
20
21#[derive(serde::Deserialize, Clone, Debug)]
22pub struct Config {
23    pub vision_config: siglip::VisionConfig,
24    pub text_config: gemma::Config,
25    pub projection_dim: usize,
26}
27
28impl Config {
29    pub fn paligemma_3b_224() -> Self {
30        // https://huggingface.co/google/paligemma-3b-pt-224/blob/main/config.json
31        Self {
32            vision_config: siglip::VisionConfig::paligemma_3b_224(),
33            text_config: gemma::Config {
34                hidden_size: 2048,
35                intermediate_size: 16384,
36                num_attention_heads: 8,
37                num_hidden_layers: 18,
38                num_key_value_heads: 1,
39                vocab_size: 257216,
40                // Default values.
41                rope_theta: 10000.,
42                head_dim: 256,
43                hidden_act: Some(candle_nn::Activation::GeluPytorchTanh),
44                hidden_activation: None,
45                attention_bias: false,
46                max_position_embeddings: 8192,
47                rms_norm_eps: 1e-6,
48            },
49            projection_dim: 2048,
50        }
51    }
52
53    pub fn paligemma_3b_448() -> Self {
54        Self {
55            vision_config: siglip::VisionConfig::paligemma_3b_448(),
56            text_config: gemma::Config {
57                hidden_size: 2048,
58                intermediate_size: 16384,
59                num_attention_heads: 8,
60                num_hidden_layers: 18,
61                num_key_value_heads: 1,
62                // Default values.
63                rope_theta: 10000.,
64                head_dim: 256,
65                hidden_act: Some(candle_nn::Activation::GeluPytorchTanh),
66                hidden_activation: None,
67                attention_bias: false,
68                max_position_embeddings: 8192,
69                rms_norm_eps: 1e-6,
70                vocab_size: 257216,
71            },
72            projection_dim: 2048,
73        }
74    }
75}
76
77#[derive(Clone, Debug)]
78pub struct MultiModalProjector {
79    linear: Linear,
80}
81
82impl MultiModalProjector {
83    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
84        let linear = linear(
85            cfg.vision_config.hidden_size,
86            cfg.projection_dim,
87            vb.pp("linear"),
88        )?;
89        Ok(Self { linear })
90    }
91}
92
93impl Module for MultiModalProjector {
94    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
95        xs.apply(&self.linear)
96    }
97}
98
99#[derive(Clone, Debug)]
100pub struct Model {
101    pos: usize,
102    vision_tower: siglip::VisionModel,
103    multi_modal_projector: MultiModalProjector,
104    language_model: gemma::Model,
105}
106
107impl Model {
108    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
109        let vision_tower = siglip::VisionModel::new(
110            &cfg.vision_config,
111            false,
112            vb.pp("vision_tower.vision_model"),
113        )?;
114        let multi_modal_projector = MultiModalProjector::new(cfg, vb.pp("multi_modal_projector"))?;
115        let language_model = gemma::Model::new(false, &cfg.text_config, vb.pp("language_model"))?;
116        Ok(Self {
117            pos: 0,
118            language_model,
119            vision_tower,
120            multi_modal_projector,
121        })
122    }
123
124    pub fn setup(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<Tensor> {
125        self.clear_kv_cache();
126        let image_features = self
127            .vision_tower
128            .forward(pixel_values)?
129            .apply(&self.multi_modal_projector)?;
130        let image_features = crate::models::clip::div_l2_norm(&image_features)?;
131        let text_features = self.language_model.embed_tokens().forward(input_ids)?;
132        let input_embeds = Tensor::cat(&[image_features, text_features], 1)?;
133        self.pos = input_embeds.dim(1)?;
134        self.language_model.forward_embeds(&input_embeds, None, 0)
135    }
136
137    pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
138        let pos = self.pos;
139        let seq_len = input_ids.dim(1)?;
140        self.pos = pos + seq_len;
141        self.language_model.forward(input_ids, pos)
142    }
143
144    pub fn forward_without_projection(&mut self, input_ids: &Tensor) -> Result<Tensor> {
145        self.clear_kv_cache();
146        let input_embeds = self.language_model.embed_tokens().forward(input_ids)?;
147        self.language_model
148            .forward_embeds_without_projection(&input_embeds, None, 0)
149    }
150    pub fn setup_without_projection(
151        &mut self,
152        pixel_values: &Tensor,
153        input_ids: &Tensor,
154    ) -> Result<Tensor> {
155        self.clear_kv_cache();
156        let image_features = self
157            .vision_tower
158            .forward(pixel_values)?
159            .apply(&self.multi_modal_projector)?;
160        let image_features = crate::models::clip::div_l2_norm(&image_features)?;
161        let text_features = self.language_model.embed_tokens().forward(input_ids)?;
162        let input_embeds = Tensor::cat(&[image_features, text_features], 1)?;
163        self.language_model
164            .forward_embeds_without_projection(&input_embeds, None, 0)
165    }
166    pub fn clear_kv_cache(&mut self) {
167        self.pos = 0;
168        self.language_model.clear_kv_cache()
169    }
170}