candle_transformers/models/
colpali.rs

1//! Colpali Model for text/image similarity scoring.
2//!
3//! Colpali combines a vision encoder with an efficient LM for retrieving content.
4//!
5
6use candle::{Module, Result, Tensor};
7use candle_nn::VarBuilder;
8
9use super::paligemma;
10use candle_nn::{linear, Linear};
11
12pub struct Model {
13    pub model: paligemma::Model,
14    pub custom_text_projection: Linear,
15}
16
17impl Model {
18    pub fn new(config: &paligemma::Config, vb: VarBuilder) -> Result<Self> {
19        let model = paligemma::Model::new(config, vb.pp("model"))?;
20        let custom_text_projection = linear(
21            config.text_config.hidden_size,
22            128,
23            vb.pp("custom_text_proj"),
24        )?;
25
26        Ok(Self {
27            model,
28            custom_text_projection,
29        })
30    }
31
32    pub fn forward_images(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<Tensor> {
33        let outputs = self
34            .model
35            .setup_without_projection(pixel_values, input_ids)?;
36        let outputs = self.custom_text_projection.forward(&outputs)?;
37        let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
38        Ok(outputs)
39    }
40
41    pub fn forward_text(&mut self, input_ids: &Tensor) -> Result<Tensor> {
42        let outputs = self.model.forward_without_projection(input_ids)?;
43        let outputs = self.custom_text_projection.forward(&outputs)?;
44        let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
45        Ok(outputs)
46    }
47}