candle_transformers/models/llava/
config.rs

1use std::collections::HashMap;
2
3use crate::models::{
4    clip::{text_model::Activation, vision_model::ClipVisionConfig},
5    llama::{Config, LlamaEosToks},
6};
7use serde::{Deserialize, Serialize};
8
9// original config from liuhaotian/llava
10#[derive(Serialize, Deserialize, Debug, Clone)]
11pub struct LLaVAConfig {
12    pub architectures: Vec<String>,
13    pub bos_token_id: usize,
14    pub eos_token_id: usize,
15    pub hidden_size: usize,
16    #[serde(default = "default_image_aspect_ratio")]
17    pub image_aspect_ratio: String,
18    pub image_crop_resolution: usize,
19    pub image_grid_pinpoints: Vec<(u32, u32)>,
20    pub image_split_resolution: usize,
21    pub intermediate_size: usize,
22    pub max_position_embeddings: usize,
23    pub mm_hidden_size: usize,
24    #[serde(default = "default_mm_patch_merge_type")]
25    pub mm_patch_merge_type: String,
26    pub mm_projector_type: String,
27    pub mm_use_im_start_end: bool,
28    pub mm_vision_select_feature: String,
29    pub mm_vision_select_layer: isize,
30    pub mm_vision_tower: Option<String>,
31    pub model_type: String,
32    pub num_attention_heads: usize,
33    pub num_hidden_layers: usize,
34    pub num_key_value_heads: usize,
35    pub pad_token_id: usize,
36    pub rms_norm_eps: f32,
37    pub rope_theta: f32,
38    pub tokenizer_model_max_length: Option<usize>,
39    pub torch_dtype: String,
40    pub use_cache: bool,
41    pub vocab_size: usize,
42    #[serde(default = "default_image_token_index")]
43    pub image_token_index: isize,
44    #[serde(default = "default_hf")]
45    pub hf: bool,
46    pub tie_word_embeddings: Option<bool>,
47}
48
49fn default_hf() -> bool {
50    false
51}
52
53fn default_image_token_index() -> isize {
54    -200
55}
56
57fn default_mm_patch_merge_type() -> String {
58    "flat".to_string()
59}
60
61fn default_image_aspect_ratio() -> String {
62    "square".to_string()
63}
64
65impl LLaVAConfig {
66    pub fn to_llama_config(&self) -> Config {
67        Config {
68            hidden_size: self.hidden_size,
69            intermediate_size: self.intermediate_size,
70            vocab_size: self.vocab_size,
71            num_hidden_layers: self.num_hidden_layers,
72            num_attention_heads: self.num_attention_heads,
73            num_key_value_heads: self.num_key_value_heads,
74            rms_norm_eps: self.rms_norm_eps as f64,
75            rope_theta: self.rope_theta,
76            bos_token_id: Some(self.bos_token_id as u32),
77            eos_token_id: Some(LlamaEosToks::Single(self.eos_token_id as u32)),
78            use_flash_attn: false,
79            rope_scaling: None, // Assume we don't have LLaVA for Llama 3.1
80            max_position_embeddings: self.max_position_embeddings,
81            tie_word_embeddings: self.tie_word_embeddings.unwrap_or(false),
82        }
83    }
84}
85
86#[derive(Serialize, Deserialize, Debug, Clone)]
87pub struct HFLLaVATextConfig {
88    pub architectures: Vec<String>,
89    #[serde(default = "default_hidden_size")]
90    pub hidden_size: usize,
91    #[serde(default = "default_intermediate_size")]
92    pub intermediate_size: usize,
93    #[serde(default = "default_max_length")]
94    pub max_length: usize,
95    pub max_position_embeddings: usize,
96    pub model_type: String,
97    #[serde(default = "default_num_attention_heads")]
98    pub num_attention_heads: usize,
99    #[serde(default = "default_num_hidden_layers")]
100    pub num_hidden_layers: usize,
101    #[serde(default = "default_num_key_value_heads")]
102    pub num_key_value_heads: usize,
103    pub pad_token_id: usize,
104    pub rms_norm_eps: f32,
105    #[serde(default = "default_rope_theta")]
106    pub rope_theta: f32,
107    pub torch_dtype: String,
108    #[serde(default = "default_use_cache")]
109    pub use_cache: bool,
110    pub vocab_size: usize,
111}
112
113fn default_num_hidden_layers() -> usize {
114    32
115}
116
117fn default_use_cache() -> bool {
118    true
119}
120
121fn default_hidden_size() -> usize {
122    4096
123}
124
125fn default_intermediate_size() -> usize {
126    11008
127}
128
129fn default_max_length() -> usize {
130    4096
131}
132
133fn default_num_attention_heads() -> usize {
134    32
135}
136
137fn default_num_key_value_heads() -> usize {
138    32
139}
140
141fn default_rope_theta() -> f32 {
142    10000.0
143}
144
145#[derive(Serialize, Deserialize, Debug, Clone)]
146pub struct HFLLaVAVisionConfig {
147    pub hidden_size: usize,
148    pub image_size: usize,
149    pub intermediate_size: usize,
150    pub model_type: String,
151    pub num_attention_heads: usize,
152    pub num_hidden_layers: usize,
153    pub patch_size: usize,
154    pub projection_dim: usize,
155    pub vocab_size: usize,
156}
157
158// config from llava-v1.6-vicuna-7b-hf
159#[derive(Serialize, Deserialize, Debug, Clone)]
160pub struct HFLLaVAConfig {
161    pub architectures: Vec<String>,
162    pub ignore_index: isize,
163    pub image_grid_pinpoints: Vec<(u32, u32)>,
164    pub image_token_index: isize,
165    pub model_type: String,
166    pub projector_hidden_act: String,
167    pub text_config: HFLLaVATextConfig,
168    pub torch_dtype: String,
169    pub use_image_newline_parameter: bool,
170    pub vision_config: HFLLaVAVisionConfig,
171    pub vision_feature_layer: isize,
172    pub vision_feature_select_strategy: String,
173    pub vocab_size: usize,
174}
175
176#[derive(Serialize, Deserialize, Debug, Clone)]
177pub struct HFGenerationConfig {
178    pub bos_token_id: usize,
179    pub eos_token_id: usize,
180    #[serde(default = "default_max_length")]
181    pub max_length: usize,
182    pub pad_token_id: usize,
183}
184
185#[derive(Serialize, Deserialize, Debug, Clone)]
186pub struct HFPreProcessorConfig {
187    pub aspect_ratio_setting: String,
188    pub crop_size: HashMap<String, usize>,
189    pub do_center_crop: bool,
190    pub do_convert_rgb: bool,
191    pub do_normalize: bool,
192    pub do_rescale: bool,
193    pub do_resize: bool,
194    pub image_mean: Vec<f32>,
195    pub image_std: Vec<f32>,
196    pub resample: u32,
197    pub rescale_factor: f32,
198    pub size: HashMap<String, f32>,
199}
200
201impl HFLLaVAConfig {
202    pub fn to_clip_vision_config(&self) -> ClipVisionConfig {
203        ClipVisionConfig {
204            embed_dim: self.vision_config.hidden_size,
205            activation: Activation::QuickGelu,
206            intermediate_size: self.vision_config.intermediate_size,
207            num_hidden_layers: self.vision_config.num_hidden_layers,
208            num_attention_heads: self.vision_config.num_attention_heads,
209            projection_dim: self.vision_config.projection_dim,
210            num_channels: 3,
211            image_size: self.vision_config.image_size,
212            patch_size: self.vision_config.patch_size,
213        }
214    }
215    fn map_projector_type(s: &str) -> String {
216        if s == "gelu" {
217            "mlp2x_gelu".to_string()
218        } else {
219            s.to_string()
220        }
221    }
222
223    fn map_select_feature(s: &str) -> String {
224        if s == "default" {
225            "patch".to_string()
226        } else {
227            "cls_patch".to_string()
228        }
229    }
230
231    pub fn to_llava_config(
232        &self,
233        generation_config: &HFGenerationConfig,
234        preprocessor_config: &HFPreProcessorConfig,
235    ) -> LLaVAConfig {
236        LLaVAConfig {
237            hf: true,
238            architectures: self.architectures.clone(),
239            bos_token_id: generation_config.bos_token_id,
240            eos_token_id: generation_config.eos_token_id,
241            hidden_size: self.text_config.hidden_size,
242            image_aspect_ratio: preprocessor_config.aspect_ratio_setting.clone(),
243            image_crop_resolution: 224,
244            image_grid_pinpoints: self.image_grid_pinpoints.clone(),
245            image_split_resolution: 224,
246            intermediate_size: self.text_config.intermediate_size,
247            max_position_embeddings: self.text_config.max_position_embeddings,
248            mm_hidden_size: 1024,
249            mm_patch_merge_type: "spatial_unpad".to_string(),
250            mm_projector_type: Self::map_projector_type(&self.projector_hidden_act),
251            mm_use_im_start_end: false,
252            mm_vision_select_feature: Self::map_select_feature(
253                &self.vision_feature_select_strategy,
254            ),
255            mm_vision_select_layer: self.vision_feature_layer,
256            mm_vision_tower: None,
257            model_type: self.model_type.clone(),
258            num_attention_heads: self.text_config.num_attention_heads,
259            num_hidden_layers: self.text_config.num_hidden_layers,
260            num_key_value_heads: self.text_config.num_key_value_heads,
261            pad_token_id: self.text_config.pad_token_id,
262            rms_norm_eps: self.text_config.rms_norm_eps,
263            rope_theta: self.text_config.rope_theta,
264            tokenizer_model_max_length: Some(4096),
265            torch_dtype: self.torch_dtype.clone(),
266            use_cache: self.text_config.use_cache,
267            vocab_size: self.vocab_size,
268            image_token_index: self.image_token_index,
269            tie_word_embeddings: None,
270        }
271    }
272}