candle_transformers/models/llava/
mod.rs1pub mod config;
11pub mod utils;
12
13use crate::models::clip::vision_model::{ClipVisionConfig, ClipVisionTransformer};
14use crate::models::llama::{Cache, Llama};
15use crate::models::with_tracing::linear;
16
17use candle::{bail, Context, Device, IndexOp, Result, Tensor};
18use candle_nn::{seq, Activation, Module, Sequential, VarBuilder};
19use fancy_regex::Regex;
20use utils::get_anyres_image_grid_shape;
21
22use config::LLaVAConfig;
23
24fn mlp_gelu_match(mm_projector_type: &str) -> Option<usize> {
25 let mlp_gelu_regex = Regex::new(r"^mlp(\d+)x_gelu$").unwrap();
26
27 if let Ok(Some(captures)) = mlp_gelu_regex.captures(mm_projector_type) {
28 if let Some(match_str) = captures.get(1) {
29 let match_str = match_str.as_str();
30 match_str.parse::<usize>().ok()
31 } else {
32 None
33 }
34 } else {
35 None
36 }
37}
38
39fn unpad_image(tensor: &Tensor, original_size: &(u32, u32)) -> Result<Tensor> {
40 assert_eq!(tensor.dims().len(), 3);
41 let (original_width, original_height) = *original_size;
42 let tensor_dims = tensor.dims();
43 let current_height = tensor_dims[1];
44 let current_width = tensor_dims[2];
45 let original_aspect_ratio = (original_width as f32) / (original_height as f32);
46 let current_aspect_ratio = (current_width as f32) / (current_height as f32);
47 if original_aspect_ratio > current_aspect_ratio {
48 let scale_factor = (current_width as f32) / (original_width as f32);
49 let new_height = (original_height as f32 * scale_factor).floor() as usize;
50 let padding = (current_height - new_height) / 2;
51 tensor.i((.., padding..current_width - padding, ..))
52 } else {
53 let scale_factor = (current_height as f32) / (original_height as f32);
54 let new_width = (original_width as f32 * scale_factor).floor() as usize;
55 let padding = (current_width - new_width) / 2;
56 tensor.i((.., .., padding..current_width - padding))
57 }
58}
59
60pub struct IdentityMap {}
61
62impl Module for IdentityMap {
63 fn forward(&self, x: &Tensor) -> Result<Tensor> {
64 Ok(x.clone())
65 }
66}
67
68pub struct MMProjector {
69 pub modules: Sequential,
70}
71
72impl MMProjector {
73 pub fn load(vb: &VarBuilder, config: &LLaVAConfig) -> Result<Self> {
74 if config.mm_projector_type == "linear" {
75 let vb_prefix = if config.hf {
76 "multi_modal_projector.linear_1"
77 } else {
78 "model.mm_projector.0"
79 };
80 let linear = linear(config.mm_hidden_size, config.hidden_size, vb.pp(vb_prefix))?;
81 let modules = seq().add(linear);
82 Ok(Self { modules })
83 } else if let Some(mlp_depth) = mlp_gelu_match(&config.mm_projector_type) {
84 let modules = if config.hf {
85 let mut modules = seq().add(linear(
86 config.mm_hidden_size,
87 config.hidden_size,
88 vb.pp("multi_modal_projector.linear_1"),
89 )?);
90 for i in 1..mlp_depth {
91 modules = modules.add(Activation::Gelu).add(linear(
92 config.hidden_size,
93 config.hidden_size,
94 vb.pp(format!("multi_modal_projector.linear_{}", i + 1)),
95 )?);
96 }
97 modules
98 } else {
99 let mut modules = seq().add(linear(
100 config.mm_hidden_size,
101 config.hidden_size,
102 vb.pp("model.mm_projector.0"),
103 )?);
104 for i in 1..mlp_depth {
105 modules = modules.add(Activation::Gelu).add(linear(
106 config.hidden_size,
107 config.hidden_size,
108 vb.pp(format!("model.mm_projector.{}", i * 2)),
109 )?);
110 }
111 modules
112 };
113 Ok(Self { modules })
114 } else if config.mm_projector_type == "identity" {
115 Ok(Self {
116 modules: seq().add(IdentityMap {}),
117 })
118 } else {
119 bail!(
120 "Unsupported MM projector type: {}",
121 config.mm_projector_type
122 )
123 }
124 }
125
126 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
127 self.modules.forward(x)
128 }
129}
130
131pub struct ClipVisionTower {
132 model: ClipVisionTransformer,
133 select_layer: isize,
134 select_feature_method: String,
135 pub config: ClipVisionConfig,
136}
137
138impl ClipVisionTower {
139 pub fn new(
140 vb: VarBuilder,
141 select_layer: isize,
142 select_feature_method: &str,
143 config: &Option<ClipVisionConfig>,
144 ) -> Result<Self> {
145 let config = if config.is_none() {
146 ClipVisionConfig::clip_vit_large_patch14_336()
147 } else {
148 config.clone().context("no config")?
149 };
150 let select_layer = match select_layer {
151 -1 | -2 => select_layer,
152 _ => bail!("Unsupported select layer: {}", select_layer),
153 };
154 let model = ClipVisionTransformer::new(vb, &config)?;
155 Ok(Self {
156 model,
157 select_layer,
158 select_feature_method: select_feature_method.to_string(),
159 config,
160 })
161 }
162
163 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
164 let result = self.model.output_hidden_states(x)?;
165 let index = result.len() as isize + self.select_layer;
166 let result = result[index as usize].clone();
167 if self.select_feature_method == "cls_patch" {
168 Ok(result)
169 } else {
170 result.i((.., 1..))
171 }
172 }
173
174 pub fn num_patches_per_side(&self) -> usize {
175 self.config.image_size / self.config.patch_size
176 }
177}
178
179pub struct LLaVA {
180 pub clip_vision_tower: ClipVisionTower,
181 pub image_newline: Tensor,
182 pub mm_projector: MMProjector,
183 pub llama: Llama,
184 config: LLaVAConfig,
185 device: Device,
186}
187
188impl LLaVA {
189 pub fn load(
190 vb: VarBuilder,
191 config: &LLaVAConfig,
192 clip_vision_config: Option<ClipVisionConfig>,
193 ) -> Result<Self> {
194 let device = vb.device().clone();
195 let llama_config = config.to_llama_config();
196 let mm_projector = MMProjector::load(&vb, config)?;
197 let (clip_vision_tower, image_newline, llama) = if config.hf {
198 (
199 ClipVisionTower::new(
200 vb.pp("vision_tower.vision_model"),
201 config.mm_vision_select_layer,
202 &config.mm_vision_select_feature,
203 &clip_vision_config,
204 )?,
205 vb.get(&[config.hidden_size], "image_newline")?
206 .to_device(&device)?,
207 Llama::load(vb.pp("language_model"), &llama_config)?,
208 )
209 } else {
210 (
211 ClipVisionTower::new(
212 vb.pp("model.vision_tower.vision_tower.vision_model"),
213 config.mm_vision_select_layer,
214 &config.mm_vision_select_feature,
215 &clip_vision_config,
216 )?,
217 vb.get(&[config.hidden_size], "model.image_newline")?
218 .to_device(&device)?,
219 Llama::load(vb, &llama_config)?,
220 )
221 };
222 Ok(Self {
223 clip_vision_tower,
224 image_newline,
225 mm_projector,
226 llama,
227 config: (*config).clone(),
228 device,
229 })
230 }
231
232 pub fn encode_images(&self, x: &Tensor) -> Result<Tensor> {
233 let image_features = self.clip_vision_tower.forward(x)?;
234 let image_features = self.mm_projector.forward(&image_features)?;
235 Ok(image_features)
236 }
237 pub fn prepare_inputs_labels_for_multimodal(
239 &self,
240 input_ids: &Tensor,
241 images: &[Tensor],
242 image_sizes: &[(u32, u32)],
243 ) -> Result<Tensor> {
244 let concat_images = Tensor::cat(images, 0)?;
247 let image_features_together = self.encode_images(&concat_images)?;
248 let split_sizes = images
249 .iter()
250 .map(|x| x.shape().dims()[0])
251 .collect::<Vec<usize>>();
252 let mut index_pos = 0;
254 let mut image_features = Vec::new();
255 for split_size in split_sizes.iter() {
256 image_features.push(image_features_together.i(index_pos..index_pos + (*split_size))?);
257 index_pos += *split_size;
258 }
259 let mm_patch_merge_type = &self.config.mm_patch_merge_type;
260 let image_aspect_ratio = &self.config.image_aspect_ratio;
261
262 let image_features = if mm_patch_merge_type == "flat" {
263 image_features
264 .iter()
265 .map(|x| x.flatten(0, 1))
266 .collect::<Result<Vec<Tensor>>>()?
267 } else if mm_patch_merge_type.starts_with("spatial") {
268 let mut new_image_features = Vec::new();
269 for (image_idx, image_feature) in image_features.iter().enumerate() {
270 let new_image_feature = if image_feature.dims()[0] > 1 {
271 let base_image_feature = image_feature.get(0)?;
272 let patch_image_feature = image_feature.i(1..)?;
273 let height = self.clip_vision_tower.num_patches_per_side();
274 let width = height;
275 assert_eq!(height * width, base_image_feature.dims()[0]);
276 let image_size = image_sizes[image_idx];
277 let new_image_feature = if image_aspect_ratio == "anyres" {
278 let (num_patch_width, num_patch_height) = get_anyres_image_grid_shape(
279 image_size,
280 &self.config.image_grid_pinpoints,
281 self.clip_vision_tower.config.image_size as u32,
282 );
283 patch_image_feature.reshape((
284 num_patch_height as usize,
285 num_patch_width as usize,
286 height,
287 width,
288 (),
289 ))?
290 } else {
291 bail!("not implemented in original python LLaVA yet")
292 };
293 let new_image_feature = if mm_patch_merge_type.contains("unpad") {
294 let new_image_feature = new_image_feature
295 .permute((4, 0, 2, 1, 3))?
296 .flatten(1, 2)?
297 .flatten(2, 3)?;
298 let new_image_feature = unpad_image(&new_image_feature, &image_size)?;
299 let new_image_feature_dims = new_image_feature.dims();
300 let image_new_line = self
301 .image_newline
302 .reshape((self.config.hidden_size, 1, 1))?
303 .broadcast_as((
304 new_image_feature_dims[0],
305 new_image_feature_dims[1],
306 1,
307 ))?;
308 let new_image_feature =
309 Tensor::cat(&[new_image_feature, image_new_line], 2)?;
310 new_image_feature.flatten(1, 2)?.transpose(0, 1)?
311 } else {
312 new_image_feature.permute((0, 2, 1, 3, 4))?.flatten(0, 3)?
313 };
314 Tensor::cat(&[base_image_feature, new_image_feature], 0)?
315 } else {
316 let new_image_feature = image_feature.get(0)?;
317 if mm_patch_merge_type.contains("unpad") {
318 Tensor::cat(
319 &[new_image_feature, self.image_newline.clone().unsqueeze(0)?],
320 0,
321 )?
322 } else {
323 new_image_feature
324 }
325 };
326 new_image_features.push(new_image_feature);
327 }
328 new_image_features
329 } else {
330 bail!("Unexpected mm_patch_merge_type: {mm_patch_merge_type}")
331 };
332 let input_ids_vec = input_ids.squeeze(0)?.to_vec1::<i64>()?;
334 let mut image_indices = {
335 let mut image_indices = vec![0_i64];
336 image_indices.extend(
337 input_ids_vec
338 .iter()
339 .enumerate()
340 .filter_map(|(i, x)| {
341 if *x == self.config.image_token_index as i64 {
342 Some(i as i64)
343 } else {
344 None
345 }
346 })
347 .collect::<Vec<i64>>(),
348 );
349 image_indices
350 };
351 if image_indices.len() == 1 {
352 return self.llama.embed(input_ids);
354 }
355
356 let input_ids_noim = input_ids_vec
357 .iter()
358 .filter_map(|x| {
359 if *x != self.config.image_token_index as i64 {
360 Some(*x)
361 } else {
362 None
363 }
364 })
365 .collect::<Vec<i64>>();
366 let input_ids_noim_len = input_ids_noim.len();
367 image_indices.push((input_ids_noim_len) as i64);
368 let input_ids_noim = Tensor::from_vec(input_ids_noim, input_ids_noim_len, &self.device)?;
369 let cur_input_embeds = self.llama.embed(&input_ids_noim)?;
370 let input_embed_no_ims = {
372 let mut input_embeds = Vec::new();
373 for i in 0..image_indices.len() - 1 {
374 let start = (image_indices[i]) as usize;
375 let end = image_indices[i + 1] as usize;
376 input_embeds.push(cur_input_embeds.i((start..end, ..))?)
377 }
378 input_embeds
379 };
380
381 let mut cur_new_input_embeds = Vec::new();
382 for (i, image_feature) in image_features.iter().enumerate() {
383 cur_new_input_embeds.push(input_embed_no_ims[i].clone());
384 cur_new_input_embeds.push(image_feature.clone());
385 }
386 cur_new_input_embeds.push(input_embed_no_ims[image_features.len()].clone());
387 let new_input_embeds = Tensor::cat(&cur_new_input_embeds, 0)?;
388 let new_input_embeds =
390 if let Some(tokenizer_model_max_length) = self.config.tokenizer_model_max_length {
391 let (new_input_embeds_length, _) = new_input_embeds.shape().dims2()?;
392 if new_input_embeds_length > tokenizer_model_max_length {
393 new_input_embeds.i((..tokenizer_model_max_length, ..))?
394 } else {
395 new_input_embeds
396 }
397 } else {
398 new_input_embeds
399 };
400 new_input_embeds.unsqueeze(0)
401 }
402
403 pub fn forward(
404 &self,
405 input_embeds: &Tensor,
406 position_id: usize,
407 cache: &mut Cache,
408 ) -> Result<Tensor> {
409 self.llama
410 .forward_input_embed(input_embeds, position_id, cache)
411 }
412}