candle_transformers/models/llava/
mod.rs

1//! The LLaVA (Large Language and Vision Assistant) model.
2//!
3//! This provides the main model implementation combining a vision tower (CLIP) with
4//! language model (Llama) for multimodal capabilities. The architecture implements the training-free projection technique.
5//!
6//! - 💻[GH Link](https://github.com/haotian-liu/LLaVA/tree/main)
7//! - 📝 [Paper](https://arxiv.org/abs/2304.08485)/ Visual Instruction Tuning
8//!
9
10pub mod config;
11pub mod utils;
12
13use crate::models::clip::vision_model::{ClipVisionConfig, ClipVisionTransformer};
14use crate::models::llama::{Cache, Llama};
15use crate::models::with_tracing::linear;
16
17use candle::{bail, Context, Device, IndexOp, Result, Tensor};
18use candle_nn::{seq, Activation, Module, Sequential, VarBuilder};
19use fancy_regex::Regex;
20use utils::get_anyres_image_grid_shape;
21
22use config::LLaVAConfig;
23
24fn mlp_gelu_match(mm_projector_type: &str) -> Option<usize> {
25    let mlp_gelu_regex = Regex::new(r"^mlp(\d+)x_gelu$").unwrap();
26
27    if let Ok(Some(captures)) = mlp_gelu_regex.captures(mm_projector_type) {
28        if let Some(match_str) = captures.get(1) {
29            let match_str = match_str.as_str();
30            match_str.parse::<usize>().ok()
31        } else {
32            None
33        }
34    } else {
35        None
36    }
37}
38
39fn unpad_image(tensor: &Tensor, original_size: &(u32, u32)) -> Result<Tensor> {
40    assert_eq!(tensor.dims().len(), 3);
41    let (original_width, original_height) = *original_size;
42    let tensor_dims = tensor.dims();
43    let current_height = tensor_dims[1];
44    let current_width = tensor_dims[2];
45    let original_aspect_ratio = (original_width as f32) / (original_height as f32);
46    let current_aspect_ratio = (current_width as f32) / (current_height as f32);
47    if original_aspect_ratio > current_aspect_ratio {
48        let scale_factor = (current_width as f32) / (original_width as f32);
49        let new_height = (original_height as f32 * scale_factor).floor() as usize;
50        let padding = (current_height - new_height) / 2;
51        tensor.i((.., padding..current_width - padding, ..))
52    } else {
53        let scale_factor = (current_height as f32) / (original_height as f32);
54        let new_width = (original_width as f32 * scale_factor).floor() as usize;
55        let padding = (current_width - new_width) / 2;
56        tensor.i((.., .., padding..current_width - padding))
57    }
58}
59
60pub struct IdentityMap {}
61
62impl Module for IdentityMap {
63    fn forward(&self, x: &Tensor) -> Result<Tensor> {
64        Ok(x.clone())
65    }
66}
67
68pub struct MMProjector {
69    pub modules: Sequential,
70}
71
72impl MMProjector {
73    pub fn load(vb: &VarBuilder, config: &LLaVAConfig) -> Result<Self> {
74        if config.mm_projector_type == "linear" {
75            let vb_prefix = if config.hf {
76                "multi_modal_projector.linear_1"
77            } else {
78                "model.mm_projector.0"
79            };
80            let linear = linear(config.mm_hidden_size, config.hidden_size, vb.pp(vb_prefix))?;
81            let modules = seq().add(linear);
82            Ok(Self { modules })
83        } else if let Some(mlp_depth) = mlp_gelu_match(&config.mm_projector_type) {
84            let modules = if config.hf {
85                let mut modules = seq().add(linear(
86                    config.mm_hidden_size,
87                    config.hidden_size,
88                    vb.pp("multi_modal_projector.linear_1"),
89                )?);
90                for i in 1..mlp_depth {
91                    modules = modules.add(Activation::Gelu).add(linear(
92                        config.hidden_size,
93                        config.hidden_size,
94                        vb.pp(format!("multi_modal_projector.linear_{}", i + 1)),
95                    )?);
96                }
97                modules
98            } else {
99                let mut modules = seq().add(linear(
100                    config.mm_hidden_size,
101                    config.hidden_size,
102                    vb.pp("model.mm_projector.0"),
103                )?);
104                for i in 1..mlp_depth {
105                    modules = modules.add(Activation::Gelu).add(linear(
106                        config.hidden_size,
107                        config.hidden_size,
108                        vb.pp(format!("model.mm_projector.{}", i * 2)),
109                    )?);
110                }
111                modules
112            };
113            Ok(Self { modules })
114        } else if config.mm_projector_type == "identity" {
115            Ok(Self {
116                modules: seq().add(IdentityMap {}),
117            })
118        } else {
119            bail!(
120                "Unsupported MM projector type: {}",
121                config.mm_projector_type
122            )
123        }
124    }
125
126    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
127        self.modules.forward(x)
128    }
129}
130
131pub struct ClipVisionTower {
132    model: ClipVisionTransformer,
133    select_layer: isize,
134    select_feature_method: String,
135    pub config: ClipVisionConfig,
136}
137
138impl ClipVisionTower {
139    pub fn new(
140        vb: VarBuilder,
141        select_layer: isize,
142        select_feature_method: &str,
143        config: &Option<ClipVisionConfig>,
144    ) -> Result<Self> {
145        let config = if config.is_none() {
146            ClipVisionConfig::clip_vit_large_patch14_336()
147        } else {
148            config.clone().context("no config")?
149        };
150        let select_layer = match select_layer {
151            -1 | -2 => select_layer,
152            _ => bail!("Unsupported select layer: {}", select_layer),
153        };
154        let model = ClipVisionTransformer::new(vb, &config)?;
155        Ok(Self {
156            model,
157            select_layer,
158            select_feature_method: select_feature_method.to_string(),
159            config,
160        })
161    }
162
163    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
164        let result = self.model.output_hidden_states(x)?;
165        let index = result.len() as isize + self.select_layer;
166        let result = result[index as usize].clone();
167        if self.select_feature_method == "cls_patch" {
168            Ok(result)
169        } else {
170            result.i((.., 1..))
171        }
172    }
173
174    pub fn num_patches_per_side(&self) -> usize {
175        self.config.image_size / self.config.patch_size
176    }
177}
178
179pub struct LLaVA {
180    pub clip_vision_tower: ClipVisionTower,
181    pub image_newline: Tensor,
182    pub mm_projector: MMProjector,
183    pub llama: Llama,
184    config: LLaVAConfig,
185    device: Device,
186}
187
188impl LLaVA {
189    pub fn load(
190        vb: VarBuilder,
191        config: &LLaVAConfig,
192        clip_vision_config: Option<ClipVisionConfig>,
193    ) -> Result<Self> {
194        let device = vb.device().clone();
195        let llama_config = config.to_llama_config();
196        let mm_projector = MMProjector::load(&vb, config)?;
197        let (clip_vision_tower, image_newline, llama) = if config.hf {
198            (
199                ClipVisionTower::new(
200                    vb.pp("vision_tower.vision_model"),
201                    config.mm_vision_select_layer,
202                    &config.mm_vision_select_feature,
203                    &clip_vision_config,
204                )?,
205                vb.get(&[config.hidden_size], "image_newline")?
206                    .to_device(&device)?,
207                Llama::load(vb.pp("language_model"), &llama_config)?,
208            )
209        } else {
210            (
211                ClipVisionTower::new(
212                    vb.pp("model.vision_tower.vision_tower.vision_model"),
213                    config.mm_vision_select_layer,
214                    &config.mm_vision_select_feature,
215                    &clip_vision_config,
216                )?,
217                vb.get(&[config.hidden_size], "model.image_newline")?
218                    .to_device(&device)?,
219                Llama::load(vb, &llama_config)?,
220            )
221        };
222        Ok(Self {
223            clip_vision_tower,
224            image_newline,
225            mm_projector,
226            llama,
227            config: (*config).clone(),
228            device,
229        })
230    }
231
232    pub fn encode_images(&self, x: &Tensor) -> Result<Tensor> {
233        let image_features = self.clip_vision_tower.forward(x)?;
234        let image_features = self.mm_projector.forward(&image_features)?;
235        Ok(image_features)
236    }
237    // currently only for single image, 4 dim tensor
238    pub fn prepare_inputs_labels_for_multimodal(
239        &self,
240        input_ids: &Tensor,
241        images: &[Tensor],
242        image_sizes: &[(u32, u32)],
243    ) -> Result<Tensor> {
244        //TODO: process of multiple images/ new line
245        // 576: 336(input size)/14(patch size)=24 24*24+1(class)=577 577-1=576
246        let concat_images = Tensor::cat(images, 0)?;
247        let image_features_together = self.encode_images(&concat_images)?;
248        let split_sizes = images
249            .iter()
250            .map(|x| x.shape().dims()[0])
251            .collect::<Vec<usize>>();
252        // can be replaced by split
253        let mut index_pos = 0;
254        let mut image_features = Vec::new();
255        for split_size in split_sizes.iter() {
256            image_features.push(image_features_together.i(index_pos..index_pos + (*split_size))?);
257            index_pos += *split_size;
258        }
259        let mm_patch_merge_type = &self.config.mm_patch_merge_type;
260        let image_aspect_ratio = &self.config.image_aspect_ratio;
261
262        let image_features = if mm_patch_merge_type == "flat" {
263            image_features
264                .iter()
265                .map(|x| x.flatten(0, 1))
266                .collect::<Result<Vec<Tensor>>>()?
267        } else if mm_patch_merge_type.starts_with("spatial") {
268            let mut new_image_features = Vec::new();
269            for (image_idx, image_feature) in image_features.iter().enumerate() {
270                let new_image_feature = if image_feature.dims()[0] > 1 {
271                    let base_image_feature = image_feature.get(0)?;
272                    let patch_image_feature = image_feature.i(1..)?;
273                    let height = self.clip_vision_tower.num_patches_per_side();
274                    let width = height;
275                    assert_eq!(height * width, base_image_feature.dims()[0]);
276                    let image_size = image_sizes[image_idx];
277                    let new_image_feature = if image_aspect_ratio == "anyres" {
278                        let (num_patch_width, num_patch_height) = get_anyres_image_grid_shape(
279                            image_size,
280                            &self.config.image_grid_pinpoints,
281                            self.clip_vision_tower.config.image_size as u32,
282                        );
283                        patch_image_feature.reshape((
284                            num_patch_height as usize,
285                            num_patch_width as usize,
286                            height,
287                            width,
288                            (),
289                        ))?
290                    } else {
291                        bail!("not implemented in original python LLaVA yet")
292                    };
293                    let new_image_feature = if mm_patch_merge_type.contains("unpad") {
294                        let new_image_feature = new_image_feature
295                            .permute((4, 0, 2, 1, 3))?
296                            .flatten(1, 2)?
297                            .flatten(2, 3)?;
298                        let new_image_feature = unpad_image(&new_image_feature, &image_size)?;
299                        let new_image_feature_dims = new_image_feature.dims();
300                        let image_new_line = self
301                            .image_newline
302                            .reshape((self.config.hidden_size, 1, 1))?
303                            .broadcast_as((
304                                new_image_feature_dims[0],
305                                new_image_feature_dims[1],
306                                1,
307                            ))?;
308                        let new_image_feature =
309                            Tensor::cat(&[new_image_feature, image_new_line], 2)?;
310                        new_image_feature.flatten(1, 2)?.transpose(0, 1)?
311                    } else {
312                        new_image_feature.permute((0, 2, 1, 3, 4))?.flatten(0, 3)?
313                    };
314                    Tensor::cat(&[base_image_feature, new_image_feature], 0)?
315                } else {
316                    let new_image_feature = image_feature.get(0)?;
317                    if mm_patch_merge_type.contains("unpad") {
318                        Tensor::cat(
319                            &[new_image_feature, self.image_newline.clone().unsqueeze(0)?],
320                            0,
321                        )?
322                    } else {
323                        new_image_feature
324                    }
325                };
326                new_image_features.push(new_image_feature);
327            }
328            new_image_features
329        } else {
330            bail!("Unexpected mm_patch_merge_type: {mm_patch_merge_type}")
331        };
332        // can easily be replaced by nonzero if it is implemented in candle
333        let input_ids_vec = input_ids.squeeze(0)?.to_vec1::<i64>()?;
334        let mut image_indices = {
335            let mut image_indices = vec![0_i64];
336            image_indices.extend(
337                input_ids_vec
338                    .iter()
339                    .enumerate()
340                    .filter_map(|(i, x)| {
341                        if *x == self.config.image_token_index as i64 {
342                            Some(i as i64)
343                        } else {
344                            None
345                        }
346                    })
347                    .collect::<Vec<i64>>(),
348            );
349            image_indices
350        };
351        if image_indices.len() == 1 {
352            //no image, only [0],
353            return self.llama.embed(input_ids);
354        }
355
356        let input_ids_noim = input_ids_vec
357            .iter()
358            .filter_map(|x| {
359                if *x != self.config.image_token_index as i64 {
360                    Some(*x)
361                } else {
362                    None
363                }
364            })
365            .collect::<Vec<i64>>();
366        let input_ids_noim_len = input_ids_noim.len();
367        image_indices.push((input_ids_noim_len) as i64);
368        let input_ids_noim = Tensor::from_vec(input_ids_noim, input_ids_noim_len, &self.device)?;
369        let cur_input_embeds = self.llama.embed(&input_ids_noim)?;
370        // can be replace by split if it is implemented in candle
371        let input_embed_no_ims = {
372            let mut input_embeds = Vec::new();
373            for i in 0..image_indices.len() - 1 {
374                let start = (image_indices[i]) as usize;
375                let end = image_indices[i + 1] as usize;
376                input_embeds.push(cur_input_embeds.i((start..end, ..))?)
377            }
378            input_embeds
379        };
380
381        let mut cur_new_input_embeds = Vec::new();
382        for (i, image_feature) in image_features.iter().enumerate() {
383            cur_new_input_embeds.push(input_embed_no_ims[i].clone());
384            cur_new_input_embeds.push(image_feature.clone());
385        }
386        cur_new_input_embeds.push(input_embed_no_ims[image_features.len()].clone());
387        let new_input_embeds = Tensor::cat(&cur_new_input_embeds, 0)?;
388        //trancate
389        let new_input_embeds =
390            if let Some(tokenizer_model_max_length) = self.config.tokenizer_model_max_length {
391                let (new_input_embeds_length, _) = new_input_embeds.shape().dims2()?;
392                if new_input_embeds_length > tokenizer_model_max_length {
393                    new_input_embeds.i((..tokenizer_model_max_length, ..))?
394                } else {
395                    new_input_embeds
396                }
397            } else {
398                new_input_embeds
399            };
400        new_input_embeds.unsqueeze(0)
401    }
402
403    pub fn forward(
404        &self,
405        input_embeds: &Tensor,
406        position_id: usize,
407        cache: &mut Cache,
408    ) -> Result<Tensor> {
409        self.llama
410            .forward_input_embed(input_embeds, position_id, cache)
411    }
412}