candle_transformers/models/
colpali.rs1use candle::{Module, Result, Tensor};
7use candle_nn::VarBuilder;
8
9use super::paligemma;
10use candle_nn::{linear, Linear};
11
12pub struct Model {
13 pub model: paligemma::Model,
14 pub custom_text_projection: Linear,
15}
16
17impl Model {
18 pub fn new(config: &paligemma::Config, vb: VarBuilder) -> Result<Self> {
19 let model = paligemma::Model::new(config, vb.pp("model"))?;
20 let custom_text_projection = linear(
21 config.text_config.hidden_size,
22 128,
23 vb.pp("custom_text_proj"),
24 )?;
25
26 Ok(Self {
27 model,
28 custom_text_projection,
29 })
30 }
31
32 pub fn forward_images(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<Tensor> {
33 let outputs = self
34 .model
35 .setup_without_projection(pixel_values, input_ids)?;
36 let outputs = self.custom_text_projection.forward(&outputs)?;
37 let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
38 Ok(outputs)
39 }
40
41 pub fn forward_text(&mut self, input_ids: &Tensor) -> Result<Tensor> {
42 let outputs = self.model.forward_without_projection(input_ids)?;
43 let outputs = self.custom_text_projection.forward(&outputs)?;
44 let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
45 Ok(outputs)
46 }
47}