candle_transformers/models/pixtral/
llava.rs

1use candle::{Module, Result, Tensor};
2use candle_nn::{linear, Linear, VarBuilder};
3
4use super::vision_model;
5use crate::models::mistral;
6
7#[derive(serde::Deserialize, Debug, Clone)]
8pub struct Config {
9    pub projector_hidden_act: candle_nn::Activation,
10    pub text_config: mistral::Config,
11    pub vision_config: vision_model::Config,
12    pub image_token_index: usize,
13    pub image_seq_length: usize,
14}
15
16#[derive(Debug, Clone)]
17pub struct MultiModalProjector {
18    linear_1: Linear,
19    act: candle_nn::Activation,
20    linear_2: Linear,
21}
22
23impl MultiModalProjector {
24    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
25        let (hidden_v, hidden_t) = (cfg.vision_config.hidden_size, cfg.text_config.hidden_size);
26        let linear_1 = linear(hidden_v, hidden_t, vb.pp("linear_1"))?;
27        let linear_2 = linear(hidden_t, hidden_t, vb.pp("linear_2"))?;
28        Ok(Self {
29            linear_1,
30            act: cfg.projector_hidden_act,
31            linear_2,
32        })
33    }
34}
35
36impl Module for MultiModalProjector {
37    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
38        xs.apply(&self.linear_1)?
39            .apply(&self.act)?
40            .apply(&self.linear_2)
41    }
42}
43
44#[derive(Debug, Clone)]
45pub struct Model {
46    pub multi_modal_projector: MultiModalProjector,
47    pub language_model: mistral::Model,
48    pub vision_tower: vision_model::Model,
49    pub patch_size: usize,
50    pub dtype: candle::DType,
51    pub pos: usize,
52}
53
54impl Model {
55    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
56        let language_model = mistral::Model::new(&cfg.text_config, vb.pp("language_model"))?;
57        let vision_tower = vision_model::Model::new(
58            &cfg.vision_config,
59            vb.pp("vision_tower").to_dtype(candle::DType::F32),
60        )?;
61        let multi_modal_projector = MultiModalProjector::new(
62            cfg,
63            vb.pp("multi_modal_projector").to_dtype(candle::DType::F32),
64        )?;
65        Ok(Self {
66            multi_modal_projector,
67            language_model,
68            vision_tower,
69            patch_size: cfg.vision_config.patch_size,
70            dtype: vb.dtype(),
71            pos: 0,
72        })
73    }
74
75    pub fn clear_kv_cache(&mut self) {
76        self.language_model.clear_kv_cache();
77        self.pos = 0;
78    }
79
80    pub fn encode_image(&self, image: &Tensor) -> Result<Tensor> {
81        let image_embeds = self.vision_tower.forward(image)?;
82        self.multi_modal_projector.forward(&image_embeds)
83    }
84
85    pub fn lm_forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
86        let (_, seq_len) = input_ids.dims2()?;
87        let logits = self.language_model.forward(input_ids, self.pos)?;
88        self.pos += seq_len;
89        Ok(logits)
90    }
91
92    pub fn lm_forward_embeds(&mut self, xs: &Tensor) -> Result<Tensor> {
93        let (_, seq_len, _) = xs.dims3()?;
94        let logits = self.language_model.forward_embeds(xs, None, self.pos)?;
95        self.pos += seq_len;
96        Ok(logits)
97    }
98}