candle_transformers/models/pixtral/
llava.rs1use candle::{Module, Result, Tensor};
2use candle_nn::{linear, Linear, VarBuilder};
3
4use super::vision_model;
5use crate::models::mistral;
6
7#[derive(serde::Deserialize, Debug, Clone)]
8pub struct Config {
9 pub projector_hidden_act: candle_nn::Activation,
10 pub text_config: mistral::Config,
11 pub vision_config: vision_model::Config,
12 pub image_token_index: usize,
13 pub image_seq_length: usize,
14}
15
16#[derive(Debug, Clone)]
17pub struct MultiModalProjector {
18 linear_1: Linear,
19 act: candle_nn::Activation,
20 linear_2: Linear,
21}
22
23impl MultiModalProjector {
24 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
25 let (hidden_v, hidden_t) = (cfg.vision_config.hidden_size, cfg.text_config.hidden_size);
26 let linear_1 = linear(hidden_v, hidden_t, vb.pp("linear_1"))?;
27 let linear_2 = linear(hidden_t, hidden_t, vb.pp("linear_2"))?;
28 Ok(Self {
29 linear_1,
30 act: cfg.projector_hidden_act,
31 linear_2,
32 })
33 }
34}
35
36impl Module for MultiModalProjector {
37 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
38 xs.apply(&self.linear_1)?
39 .apply(&self.act)?
40 .apply(&self.linear_2)
41 }
42}
43
44#[derive(Debug, Clone)]
45pub struct Model {
46 pub multi_modal_projector: MultiModalProjector,
47 pub language_model: mistral::Model,
48 pub vision_tower: vision_model::Model,
49 pub patch_size: usize,
50 pub dtype: candle::DType,
51 pub pos: usize,
52}
53
54impl Model {
55 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
56 let language_model = mistral::Model::new(&cfg.text_config, vb.pp("language_model"))?;
57 let vision_tower = vision_model::Model::new(
58 &cfg.vision_config,
59 vb.pp("vision_tower").to_dtype(candle::DType::F32),
60 )?;
61 let multi_modal_projector = MultiModalProjector::new(
62 cfg,
63 vb.pp("multi_modal_projector").to_dtype(candle::DType::F32),
64 )?;
65 Ok(Self {
66 multi_modal_projector,
67 language_model,
68 vision_tower,
69 patch_size: cfg.vision_config.patch_size,
70 dtype: vb.dtype(),
71 pos: 0,
72 })
73 }
74
75 pub fn clear_kv_cache(&mut self) {
76 self.language_model.clear_kv_cache();
77 self.pos = 0;
78 }
79
80 pub fn encode_image(&self, image: &Tensor) -> Result<Tensor> {
81 let image_embeds = self.vision_tower.forward(image)?;
82 self.multi_modal_projector.forward(&image_embeds)
83 }
84
85 pub fn lm_forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
86 let (_, seq_len) = input_ids.dims2()?;
87 let logits = self.language_model.forward(input_ids, self.pos)?;
88 self.pos += seq_len;
89 Ok(logits)
90 }
91
92 pub fn lm_forward_embeds(&mut self, xs: &Tensor) -> Result<Tensor> {
93 let (_, seq_len, _) = xs.dims3()?;
94 let logits = self.language_model.forward_embeds(xs, None, self.pos)?;
95 self.pos += seq_len;
96 Ok(logits)
97 }
98}