candle_transformers/models/
paligemma.rs1use crate::models::{gemma, siglip};
18use candle::{Module, Result, Tensor};
19use candle_nn::{linear, Linear, VarBuilder};
20
21#[derive(serde::Deserialize, Clone, Debug)]
22pub struct Config {
23 pub vision_config: siglip::VisionConfig,
24 pub text_config: gemma::Config,
25 pub projection_dim: usize,
26}
27
28impl Config {
29 pub fn paligemma_3b_224() -> Self {
30 Self {
32 vision_config: siglip::VisionConfig::paligemma_3b_224(),
33 text_config: gemma::Config {
34 hidden_size: 2048,
35 intermediate_size: 16384,
36 num_attention_heads: 8,
37 num_hidden_layers: 18,
38 num_key_value_heads: 1,
39 vocab_size: 257216,
40 rope_theta: 10000.,
42 head_dim: 256,
43 hidden_act: Some(candle_nn::Activation::GeluPytorchTanh),
44 hidden_activation: None,
45 attention_bias: false,
46 max_position_embeddings: 8192,
47 rms_norm_eps: 1e-6,
48 },
49 projection_dim: 2048,
50 }
51 }
52
53 pub fn paligemma_3b_448() -> Self {
54 Self {
55 vision_config: siglip::VisionConfig::paligemma_3b_448(),
56 text_config: gemma::Config {
57 hidden_size: 2048,
58 intermediate_size: 16384,
59 num_attention_heads: 8,
60 num_hidden_layers: 18,
61 num_key_value_heads: 1,
62 rope_theta: 10000.,
64 head_dim: 256,
65 hidden_act: Some(candle_nn::Activation::GeluPytorchTanh),
66 hidden_activation: None,
67 attention_bias: false,
68 max_position_embeddings: 8192,
69 rms_norm_eps: 1e-6,
70 vocab_size: 257216,
71 },
72 projection_dim: 2048,
73 }
74 }
75}
76
77#[derive(Clone, Debug)]
78pub struct MultiModalProjector {
79 linear: Linear,
80}
81
82impl MultiModalProjector {
83 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
84 let linear = linear(
85 cfg.vision_config.hidden_size,
86 cfg.projection_dim,
87 vb.pp("linear"),
88 )?;
89 Ok(Self { linear })
90 }
91}
92
93impl Module for MultiModalProjector {
94 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
95 xs.apply(&self.linear)
96 }
97}
98
99#[derive(Clone, Debug)]
100pub struct Model {
101 pos: usize,
102 vision_tower: siglip::VisionModel,
103 multi_modal_projector: MultiModalProjector,
104 language_model: gemma::Model,
105}
106
107impl Model {
108 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
109 let vision_tower = siglip::VisionModel::new(
110 &cfg.vision_config,
111 false,
112 vb.pp("vision_tower.vision_model"),
113 )?;
114 let multi_modal_projector = MultiModalProjector::new(cfg, vb.pp("multi_modal_projector"))?;
115 let language_model = gemma::Model::new(false, &cfg.text_config, vb.pp("language_model"))?;
116 Ok(Self {
117 pos: 0,
118 language_model,
119 vision_tower,
120 multi_modal_projector,
121 })
122 }
123
124 pub fn setup(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<Tensor> {
125 self.clear_kv_cache();
126 let image_features = self
127 .vision_tower
128 .forward(pixel_values)?
129 .apply(&self.multi_modal_projector)?;
130 let image_features = crate::models::clip::div_l2_norm(&image_features)?;
131 let text_features = self.language_model.embed_tokens().forward(input_ids)?;
132 let input_embeds = Tensor::cat(&[image_features, text_features], 1)?;
133 self.pos = input_embeds.dim(1)?;
134 self.language_model.forward_embeds(&input_embeds, None, 0)
135 }
136
137 pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
138 let pos = self.pos;
139 let seq_len = input_ids.dim(1)?;
140 self.pos = pos + seq_len;
141 self.language_model.forward(input_ids, pos)
142 }
143
144 pub fn forward_without_projection(&mut self, input_ids: &Tensor) -> Result<Tensor> {
145 self.clear_kv_cache();
146 let input_embeds = self.language_model.embed_tokens().forward(input_ids)?;
147 self.language_model
148 .forward_embeds_without_projection(&input_embeds, None, 0)
149 }
150 pub fn setup_without_projection(
151 &mut self,
152 pixel_values: &Tensor,
153 input_ids: &Tensor,
154 ) -> Result<Tensor> {
155 self.clear_kv_cache();
156 let image_features = self
157 .vision_tower
158 .forward(pixel_values)?
159 .apply(&self.multi_modal_projector)?;
160 let image_features = crate::models::clip::div_l2_norm(&image_features)?;
161 let text_features = self.language_model.embed_tokens().forward(input_ids)?;
162 let input_embeds = Tensor::cat(&[image_features, text_features], 1)?;
163 self.language_model
164 .forward_embeds_without_projection(&input_embeds, None, 0)
165 }
166 pub fn clear_kv_cache(&mut self) {
167 self.pos = 0;
168 self.language_model.clear_kv_cache()
169 }
170}