1use super::blip_text;
12use super::with_tracing::{conv2d, linear, Conv2d, Linear};
13use candle::{Module, Result, Tensor, D};
14use candle_nn::{layer_norm, Conv2dConfig, LayerNorm, VarBuilder};
15use serde::Deserialize;
16
17#[derive(Debug, Clone, Deserialize)]
18pub struct VisionConfig {
19 pub hidden_size: usize,
20 pub intermediate_size: usize,
21 pub projection_dim: usize,
22 pub num_hidden_layers: usize,
23 pub num_attention_heads: usize,
24 pub image_size: usize,
25 pub patch_size: usize,
26 pub hidden_act: candle_nn::Activation,
27 pub layer_norm_eps: f64,
28}
29
30#[derive(Debug, Clone, Deserialize)]
31pub struct Config {
32 pub text_config: blip_text::Config,
33 pub vision_config: VisionConfig,
34 pub projection_dim: usize,
35 pub image_text_hidden_size: usize,
36}
37
38impl Config {
39 pub fn image_captioning_large() -> Self {
40 let text_config = blip_text::Config {
41 vocab_size: 30524,
42 hidden_size: 768,
43 encoder_hidden_size: 1024,
44 intermediate_size: 3072,
45 projection_dim: 768,
46 num_hidden_layers: 12,
47 num_attention_heads: 12,
48 max_position_embeddings: 512,
49 hidden_act: candle_nn::Activation::Gelu,
50 layer_norm_eps: 1e-12,
51 is_decoder: true,
52 };
53 let vision_config = VisionConfig {
54 hidden_size: 1024,
55 intermediate_size: 4096,
56 projection_dim: 512,
57 num_hidden_layers: 24,
58 num_attention_heads: 16,
59 image_size: 384,
60 patch_size: 16,
61 hidden_act: candle_nn::Activation::Gelu,
62 layer_norm_eps: 1e-5,
63 };
64 Self {
65 text_config,
66 vision_config,
67 projection_dim: 512,
68 image_text_hidden_size: 256,
69 }
70 }
71}
72
73#[derive(Debug, Clone)]
74struct VisionEmbeddings {
75 class_embedding: Tensor,
76 patch_embedding: Conv2d,
77 position_embedding: Tensor,
78}
79
80impl VisionEmbeddings {
81 fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
82 let class_embedding = vb.get((1, 1, cfg.hidden_size), "class_embedding")?;
83 let conv_cfg = Conv2dConfig {
84 stride: cfg.patch_size,
85 ..Default::default()
86 };
87 let patch_embedding = conv2d(
88 3,
89 cfg.hidden_size,
90 cfg.patch_size,
91 conv_cfg,
92 vb.pp("patch_embedding"),
93 )?;
94 let num_patches1 = cfg.image_size / cfg.patch_size;
95 let num_patches = num_patches1 * num_patches1;
96 let num_positions = num_patches + 1;
97 let position_embedding =
98 vb.get((1, num_positions, cfg.hidden_size), "position_embedding")?;
99 Ok(Self {
100 class_embedding,
101 patch_embedding,
102 position_embedding,
103 })
104 }
105}
106
107impl Module for VisionEmbeddings {
108 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
109 let target_dtype = xs.dtype();
110 let b_size = xs.dim(0)?;
111 let patch_embeds = xs.apply(&self.patch_embedding)?.flatten_from(2)?.t()?;
112 let d = self.class_embedding.dim(D::Minus1)?;
113 let class_embeds = self
114 .class_embedding
115 .broadcast_as((b_size, 1, d))?
116 .to_dtype(target_dtype)?;
117 let embeddings = Tensor::cat(&[&class_embeds, &patch_embeds], 1)?;
118 let position_embedding = self.position_embedding.narrow(1, 0, embeddings.dim(1)?)?;
119 embeddings.broadcast_add(&position_embedding)
120 }
121}
122
123#[derive(Debug, Clone)]
124struct Attention {
125 qkv: Linear,
126 projection: Linear,
127 scale: f64,
128 num_heads: usize,
129}
130
131impl Attention {
132 fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
133 let embed_dim = cfg.hidden_size;
134 let num_heads = cfg.num_attention_heads;
135 let head_dim = embed_dim / num_heads;
136 let scale = 1f64 / (head_dim as f64).sqrt();
137 let qkv = linear(embed_dim, 3 * embed_dim, vb.pp("qkv"))?;
138 let projection = linear(embed_dim, embed_dim, vb.pp("projection"))?;
139 Ok(Self {
140 qkv,
141 projection,
142 scale,
143 num_heads,
144 })
145 }
146
147 fn forward(&self, xs: &Tensor, attn_mask: Option<&Tensor>) -> Result<Tensor> {
148 let (b_sz, tgt_len, embed_dim) = xs.dims3()?;
149 let mixed_qkv = xs
150 .apply(&self.qkv)?
151 .reshape((b_sz, tgt_len, 3, self.num_heads, embed_dim / self.num_heads))?
152 .permute((2, 0, 3, 1, 4))?;
153 let query = mixed_qkv.get(0)?;
154 let key = mixed_qkv.get(1)?;
155 let value = mixed_qkv.get(2)?;
156 let attention_scores = query.matmul(&key.t()?)?;
157 let attention_scores = (attention_scores * self.scale)?;
158 let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
159 let attention_probs = match attn_mask {
160 None => attention_probs,
161 Some(attn_mask) => (attention_probs * attn_mask)?,
162 };
163 attention_probs
164 .matmul(&value)?
165 .permute((0, 2, 1, 3))?
166 .flatten_from(D::Minus2)?
167 .apply(&self.projection)
168 }
169}
170
171#[derive(Debug, Clone)]
172#[allow(clippy::upper_case_acronyms)]
173struct MLP {
174 activation_fn: candle_nn::Activation,
175 fc1: Linear,
176 fc2: Linear,
177}
178
179impl MLP {
180 fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
181 let fc1 = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("fc1"))?;
182 let fc2 = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?;
183 Ok(Self {
184 activation_fn: cfg.hidden_act,
185 fc1,
186 fc2,
187 })
188 }
189}
190
191impl Module for MLP {
192 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
193 xs.apply(&self.fc1)?
194 .apply(&self.activation_fn)?
195 .apply(&self.fc2)
196 }
197}
198
199#[derive(Debug, Clone)]
200struct EncoderLayer {
201 self_attn: Attention,
202 layer_norm1: LayerNorm,
203 mlp: MLP,
204 layer_norm2: LayerNorm,
205}
206
207impl EncoderLayer {
208 fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
209 let embed_dim = cfg.hidden_size;
210 let self_attn = Attention::new(cfg, vb.pp("self_attn"))?;
211 let layer_norm1 = layer_norm(embed_dim, cfg.layer_norm_eps, vb.pp("layer_norm1"))?;
212 let layer_norm2 = layer_norm(embed_dim, cfg.layer_norm_eps, vb.pp("layer_norm2"))?;
213 let mlp = MLP::new(cfg, vb.pp("mlp"))?;
214 Ok(Self {
215 self_attn,
216 layer_norm1,
217 mlp,
218 layer_norm2,
219 })
220 }
221
222 fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
223 let residual = xs;
224 let xs = xs.apply(&self.layer_norm1)?;
225 let xs = self.self_attn.forward(&xs, attention_mask)?;
226 let xs = (xs + residual)?;
227
228 let residual = &xs;
229 let xs = xs.apply(&self.layer_norm2)?.apply(&self.mlp)?;
230 xs + residual
231 }
232}
233
234#[derive(Debug, Clone)]
235struct Encoder {
236 layers: Vec<EncoderLayer>,
237}
238
239impl Encoder {
240 fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
241 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
242 let vb = vb.pp("layers");
243 for i in 0..cfg.num_hidden_layers {
244 let layer = EncoderLayer::new(cfg, vb.pp(i))?;
245 layers.push(layer)
246 }
247 Ok(Self { layers })
248 }
249
250 fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
251 let mut xs = xs.clone();
252 for layer in self.layers.iter() {
253 xs = layer.forward(&xs, attention_mask)?
254 }
255 Ok(xs)
256 }
257}
258
259#[derive(Debug, Clone)]
260pub struct VisionModel {
261 embeddings: VisionEmbeddings,
262 encoder: Encoder,
263 post_layernorm: LayerNorm,
264}
265
266impl VisionModel {
267 fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
268 let embeddings = VisionEmbeddings::new(cfg, vb.pp("embeddings"))?;
269 let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
270 let post_layernorm =
271 layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("post_layernorm"))?;
272 Ok(Self {
273 embeddings,
274 encoder,
275 post_layernorm,
276 })
277 }
278}
279
280impl Module for VisionModel {
281 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
282 let xs = xs.apply(&self.embeddings)?;
283 let encoder_outputs = self.encoder.forward(&xs, None)?;
284 encoder_outputs.apply(&self.post_layernorm)
286 }
287}
288
289#[derive(Debug, Clone)]
290pub struct BlipForConditionalGeneration {
291 vision_model: VisionModel,
292 text_decoder: blip_text::TextLMHeadModel,
293}
294
295impl BlipForConditionalGeneration {
296 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
297 let vision_model = VisionModel::new(&cfg.vision_config, vb.pp("vision_model"))?;
298 let text_decoder =
299 blip_text::TextLMHeadModel::new(&cfg.text_config, vb.pp("text_decoder"))?;
300 Ok(Self {
301 vision_model,
302 text_decoder,
303 })
304 }
305
306 pub fn vision_model(&self) -> &VisionModel {
307 &self.vision_model
308 }
309
310 pub fn text_decoder(&mut self) -> &mut blip_text::TextLMHeadModel {
311 &mut self.text_decoder
312 }
313
314 pub fn reset_kv_cache(&mut self) {
315 self.text_decoder.reset_kv_cache();
316 }
317}