1use std::collections::HashMap;
2
3use crate::models::{
4 clip::{text_model::Activation, vision_model::ClipVisionConfig},
5 llama::{Config, LlamaEosToks},
6};
7use serde::{Deserialize, Serialize};
8
9#[derive(Serialize, Deserialize, Debug, Clone)]
11pub struct LLaVAConfig {
12 pub architectures: Vec<String>,
13 pub bos_token_id: usize,
14 pub eos_token_id: usize,
15 pub hidden_size: usize,
16 #[serde(default = "default_image_aspect_ratio")]
17 pub image_aspect_ratio: String,
18 pub image_crop_resolution: usize,
19 pub image_grid_pinpoints: Vec<(u32, u32)>,
20 pub image_split_resolution: usize,
21 pub intermediate_size: usize,
22 pub max_position_embeddings: usize,
23 pub mm_hidden_size: usize,
24 #[serde(default = "default_mm_patch_merge_type")]
25 pub mm_patch_merge_type: String,
26 pub mm_projector_type: String,
27 pub mm_use_im_start_end: bool,
28 pub mm_vision_select_feature: String,
29 pub mm_vision_select_layer: isize,
30 pub mm_vision_tower: Option<String>,
31 pub model_type: String,
32 pub num_attention_heads: usize,
33 pub num_hidden_layers: usize,
34 pub num_key_value_heads: usize,
35 pub pad_token_id: usize,
36 pub rms_norm_eps: f32,
37 pub rope_theta: f32,
38 pub tokenizer_model_max_length: Option<usize>,
39 pub torch_dtype: String,
40 pub use_cache: bool,
41 pub vocab_size: usize,
42 #[serde(default = "default_image_token_index")]
43 pub image_token_index: isize,
44 #[serde(default = "default_hf")]
45 pub hf: bool,
46 pub tie_word_embeddings: Option<bool>,
47}
48
49fn default_hf() -> bool {
50 false
51}
52
53fn default_image_token_index() -> isize {
54 -200
55}
56
57fn default_mm_patch_merge_type() -> String {
58 "flat".to_string()
59}
60
61fn default_image_aspect_ratio() -> String {
62 "square".to_string()
63}
64
65impl LLaVAConfig {
66 pub fn to_llama_config(&self) -> Config {
67 Config {
68 hidden_size: self.hidden_size,
69 intermediate_size: self.intermediate_size,
70 vocab_size: self.vocab_size,
71 num_hidden_layers: self.num_hidden_layers,
72 num_attention_heads: self.num_attention_heads,
73 num_key_value_heads: self.num_key_value_heads,
74 rms_norm_eps: self.rms_norm_eps as f64,
75 rope_theta: self.rope_theta,
76 bos_token_id: Some(self.bos_token_id as u32),
77 eos_token_id: Some(LlamaEosToks::Single(self.eos_token_id as u32)),
78 use_flash_attn: false,
79 rope_scaling: None, max_position_embeddings: self.max_position_embeddings,
81 tie_word_embeddings: self.tie_word_embeddings.unwrap_or(false),
82 }
83 }
84}
85
86#[derive(Serialize, Deserialize, Debug, Clone)]
87pub struct HFLLaVATextConfig {
88 pub architectures: Vec<String>,
89 #[serde(default = "default_hidden_size")]
90 pub hidden_size: usize,
91 #[serde(default = "default_intermediate_size")]
92 pub intermediate_size: usize,
93 #[serde(default = "default_max_length")]
94 pub max_length: usize,
95 pub max_position_embeddings: usize,
96 pub model_type: String,
97 #[serde(default = "default_num_attention_heads")]
98 pub num_attention_heads: usize,
99 #[serde(default = "default_num_hidden_layers")]
100 pub num_hidden_layers: usize,
101 #[serde(default = "default_num_key_value_heads")]
102 pub num_key_value_heads: usize,
103 pub pad_token_id: usize,
104 pub rms_norm_eps: f32,
105 #[serde(default = "default_rope_theta")]
106 pub rope_theta: f32,
107 pub torch_dtype: String,
108 #[serde(default = "default_use_cache")]
109 pub use_cache: bool,
110 pub vocab_size: usize,
111}
112
113fn default_num_hidden_layers() -> usize {
114 32
115}
116
117fn default_use_cache() -> bool {
118 true
119}
120
121fn default_hidden_size() -> usize {
122 4096
123}
124
125fn default_intermediate_size() -> usize {
126 11008
127}
128
129fn default_max_length() -> usize {
130 4096
131}
132
133fn default_num_attention_heads() -> usize {
134 32
135}
136
137fn default_num_key_value_heads() -> usize {
138 32
139}
140
141fn default_rope_theta() -> f32 {
142 10000.0
143}
144
145#[derive(Serialize, Deserialize, Debug, Clone)]
146pub struct HFLLaVAVisionConfig {
147 pub hidden_size: usize,
148 pub image_size: usize,
149 pub intermediate_size: usize,
150 pub model_type: String,
151 pub num_attention_heads: usize,
152 pub num_hidden_layers: usize,
153 pub patch_size: usize,
154 pub projection_dim: usize,
155 pub vocab_size: usize,
156}
157
158#[derive(Serialize, Deserialize, Debug, Clone)]
160pub struct HFLLaVAConfig {
161 pub architectures: Vec<String>,
162 pub ignore_index: isize,
163 pub image_grid_pinpoints: Vec<(u32, u32)>,
164 pub image_token_index: isize,
165 pub model_type: String,
166 pub projector_hidden_act: String,
167 pub text_config: HFLLaVATextConfig,
168 pub torch_dtype: String,
169 pub use_image_newline_parameter: bool,
170 pub vision_config: HFLLaVAVisionConfig,
171 pub vision_feature_layer: isize,
172 pub vision_feature_select_strategy: String,
173 pub vocab_size: usize,
174}
175
176#[derive(Serialize, Deserialize, Debug, Clone)]
177pub struct HFGenerationConfig {
178 pub bos_token_id: usize,
179 pub eos_token_id: usize,
180 #[serde(default = "default_max_length")]
181 pub max_length: usize,
182 pub pad_token_id: usize,
183}
184
185#[derive(Serialize, Deserialize, Debug, Clone)]
186pub struct HFPreProcessorConfig {
187 pub aspect_ratio_setting: String,
188 pub crop_size: HashMap<String, usize>,
189 pub do_center_crop: bool,
190 pub do_convert_rgb: bool,
191 pub do_normalize: bool,
192 pub do_rescale: bool,
193 pub do_resize: bool,
194 pub image_mean: Vec<f32>,
195 pub image_std: Vec<f32>,
196 pub resample: u32,
197 pub rescale_factor: f32,
198 pub size: HashMap<String, f32>,
199}
200
201impl HFLLaVAConfig {
202 pub fn to_clip_vision_config(&self) -> ClipVisionConfig {
203 ClipVisionConfig {
204 embed_dim: self.vision_config.hidden_size,
205 activation: Activation::QuickGelu,
206 intermediate_size: self.vision_config.intermediate_size,
207 num_hidden_layers: self.vision_config.num_hidden_layers,
208 num_attention_heads: self.vision_config.num_attention_heads,
209 projection_dim: self.vision_config.projection_dim,
210 num_channels: 3,
211 image_size: self.vision_config.image_size,
212 patch_size: self.vision_config.patch_size,
213 }
214 }
215 fn map_projector_type(s: &str) -> String {
216 if s == "gelu" {
217 "mlp2x_gelu".to_string()
218 } else {
219 s.to_string()
220 }
221 }
222
223 fn map_select_feature(s: &str) -> String {
224 if s == "default" {
225 "patch".to_string()
226 } else {
227 "cls_patch".to_string()
228 }
229 }
230
231 pub fn to_llava_config(
232 &self,
233 generation_config: &HFGenerationConfig,
234 preprocessor_config: &HFPreProcessorConfig,
235 ) -> LLaVAConfig {
236 LLaVAConfig {
237 hf: true,
238 architectures: self.architectures.clone(),
239 bos_token_id: generation_config.bos_token_id,
240 eos_token_id: generation_config.eos_token_id,
241 hidden_size: self.text_config.hidden_size,
242 image_aspect_ratio: preprocessor_config.aspect_ratio_setting.clone(),
243 image_crop_resolution: 224,
244 image_grid_pinpoints: self.image_grid_pinpoints.clone(),
245 image_split_resolution: 224,
246 intermediate_size: self.text_config.intermediate_size,
247 max_position_embeddings: self.text_config.max_position_embeddings,
248 mm_hidden_size: 1024,
249 mm_patch_merge_type: "spatial_unpad".to_string(),
250 mm_projector_type: Self::map_projector_type(&self.projector_hidden_act),
251 mm_use_im_start_end: false,
252 mm_vision_select_feature: Self::map_select_feature(
253 &self.vision_feature_select_strategy,
254 ),
255 mm_vision_select_layer: self.vision_feature_layer,
256 mm_vision_tower: None,
257 model_type: self.model_type.clone(),
258 num_attention_heads: self.text_config.num_attention_heads,
259 num_hidden_layers: self.text_config.num_hidden_layers,
260 num_key_value_heads: self.text_config.num_key_value_heads,
261 pad_token_id: self.text_config.pad_token_id,
262 rms_norm_eps: self.text_config.rms_norm_eps,
263 rope_theta: self.text_config.rope_theta,
264 tokenizer_model_max_length: Some(4096),
265 torch_dtype: self.torch_dtype.clone(),
266 use_cache: self.text_config.use_cache,
267 vocab_size: self.vocab_size,
268 image_token_index: self.image_token_index,
269 tie_word_embeddings: None,
270 }
271 }
272}