candle_transformers/models/
blip.rs

1//! Based on the BLIP paper from Salesforce Research.
2//!
3//! The blip-image-captioning model can generate captions for an input image.
4//!
5//! - ⚡ [Interactive Wasm Example](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning)
6//! - 💻 [GH Link](https://github.com/salesforce/BLIP)
7//! - 🤗 [HF Link](https://huggingface.co/Salesforce/blip-image-captioning-base)
8//! - 📝 [Paper](https://arxiv.org/abs/2201.12086)
9//!
10
11use super::blip_text;
12use super::with_tracing::{conv2d, linear, Conv2d, Linear};
13use candle::{Module, Result, Tensor, D};
14use candle_nn::{layer_norm, Conv2dConfig, LayerNorm, VarBuilder};
15use serde::Deserialize;
16
17#[derive(Debug, Clone, Deserialize)]
18pub struct VisionConfig {
19    pub hidden_size: usize,
20    pub intermediate_size: usize,
21    pub projection_dim: usize,
22    pub num_hidden_layers: usize,
23    pub num_attention_heads: usize,
24    pub image_size: usize,
25    pub patch_size: usize,
26    pub hidden_act: candle_nn::Activation,
27    pub layer_norm_eps: f64,
28}
29
30#[derive(Debug, Clone, Deserialize)]
31pub struct Config {
32    pub text_config: blip_text::Config,
33    pub vision_config: VisionConfig,
34    pub projection_dim: usize,
35    pub image_text_hidden_size: usize,
36}
37
38impl Config {
39    pub fn image_captioning_large() -> Self {
40        let text_config = blip_text::Config {
41            vocab_size: 30524,
42            hidden_size: 768,
43            encoder_hidden_size: 1024,
44            intermediate_size: 3072,
45            projection_dim: 768,
46            num_hidden_layers: 12,
47            num_attention_heads: 12,
48            max_position_embeddings: 512,
49            hidden_act: candle_nn::Activation::Gelu,
50            layer_norm_eps: 1e-12,
51            is_decoder: true,
52        };
53        let vision_config = VisionConfig {
54            hidden_size: 1024,
55            intermediate_size: 4096,
56            projection_dim: 512,
57            num_hidden_layers: 24,
58            num_attention_heads: 16,
59            image_size: 384,
60            patch_size: 16,
61            hidden_act: candle_nn::Activation::Gelu,
62            layer_norm_eps: 1e-5,
63        };
64        Self {
65            text_config,
66            vision_config,
67            projection_dim: 512,
68            image_text_hidden_size: 256,
69        }
70    }
71}
72
73#[derive(Debug, Clone)]
74struct VisionEmbeddings {
75    class_embedding: Tensor,
76    patch_embedding: Conv2d,
77    position_embedding: Tensor,
78}
79
80impl VisionEmbeddings {
81    fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
82        let class_embedding = vb.get((1, 1, cfg.hidden_size), "class_embedding")?;
83        let conv_cfg = Conv2dConfig {
84            stride: cfg.patch_size,
85            ..Default::default()
86        };
87        let patch_embedding = conv2d(
88            3,
89            cfg.hidden_size,
90            cfg.patch_size,
91            conv_cfg,
92            vb.pp("patch_embedding"),
93        )?;
94        let num_patches1 = cfg.image_size / cfg.patch_size;
95        let num_patches = num_patches1 * num_patches1;
96        let num_positions = num_patches + 1;
97        let position_embedding =
98            vb.get((1, num_positions, cfg.hidden_size), "position_embedding")?;
99        Ok(Self {
100            class_embedding,
101            patch_embedding,
102            position_embedding,
103        })
104    }
105}
106
107impl Module for VisionEmbeddings {
108    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
109        let target_dtype = xs.dtype();
110        let b_size = xs.dim(0)?;
111        let patch_embeds = xs.apply(&self.patch_embedding)?.flatten_from(2)?.t()?;
112        let d = self.class_embedding.dim(D::Minus1)?;
113        let class_embeds = self
114            .class_embedding
115            .broadcast_as((b_size, 1, d))?
116            .to_dtype(target_dtype)?;
117        let embeddings = Tensor::cat(&[&class_embeds, &patch_embeds], 1)?;
118        let position_embedding = self.position_embedding.narrow(1, 0, embeddings.dim(1)?)?;
119        embeddings.broadcast_add(&position_embedding)
120    }
121}
122
123#[derive(Debug, Clone)]
124struct Attention {
125    qkv: Linear,
126    projection: Linear,
127    scale: f64,
128    num_heads: usize,
129}
130
131impl Attention {
132    fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
133        let embed_dim = cfg.hidden_size;
134        let num_heads = cfg.num_attention_heads;
135        let head_dim = embed_dim / num_heads;
136        let scale = 1f64 / (head_dim as f64).sqrt();
137        let qkv = linear(embed_dim, 3 * embed_dim, vb.pp("qkv"))?;
138        let projection = linear(embed_dim, embed_dim, vb.pp("projection"))?;
139        Ok(Self {
140            qkv,
141            projection,
142            scale,
143            num_heads,
144        })
145    }
146
147    fn forward(&self, xs: &Tensor, attn_mask: Option<&Tensor>) -> Result<Tensor> {
148        let (b_sz, tgt_len, embed_dim) = xs.dims3()?;
149        let mixed_qkv = xs
150            .apply(&self.qkv)?
151            .reshape((b_sz, tgt_len, 3, self.num_heads, embed_dim / self.num_heads))?
152            .permute((2, 0, 3, 1, 4))?;
153        let query = mixed_qkv.get(0)?;
154        let key = mixed_qkv.get(1)?;
155        let value = mixed_qkv.get(2)?;
156        let attention_scores = query.matmul(&key.t()?)?;
157        let attention_scores = (attention_scores * self.scale)?;
158        let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
159        let attention_probs = match attn_mask {
160            None => attention_probs,
161            Some(attn_mask) => (attention_probs * attn_mask)?,
162        };
163        attention_probs
164            .matmul(&value)?
165            .permute((0, 2, 1, 3))?
166            .flatten_from(D::Minus2)?
167            .apply(&self.projection)
168    }
169}
170
171#[derive(Debug, Clone)]
172#[allow(clippy::upper_case_acronyms)]
173struct MLP {
174    activation_fn: candle_nn::Activation,
175    fc1: Linear,
176    fc2: Linear,
177}
178
179impl MLP {
180    fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
181        let fc1 = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("fc1"))?;
182        let fc2 = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?;
183        Ok(Self {
184            activation_fn: cfg.hidden_act,
185            fc1,
186            fc2,
187        })
188    }
189}
190
191impl Module for MLP {
192    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
193        xs.apply(&self.fc1)?
194            .apply(&self.activation_fn)?
195            .apply(&self.fc2)
196    }
197}
198
199#[derive(Debug, Clone)]
200struct EncoderLayer {
201    self_attn: Attention,
202    layer_norm1: LayerNorm,
203    mlp: MLP,
204    layer_norm2: LayerNorm,
205}
206
207impl EncoderLayer {
208    fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
209        let embed_dim = cfg.hidden_size;
210        let self_attn = Attention::new(cfg, vb.pp("self_attn"))?;
211        let layer_norm1 = layer_norm(embed_dim, cfg.layer_norm_eps, vb.pp("layer_norm1"))?;
212        let layer_norm2 = layer_norm(embed_dim, cfg.layer_norm_eps, vb.pp("layer_norm2"))?;
213        let mlp = MLP::new(cfg, vb.pp("mlp"))?;
214        Ok(Self {
215            self_attn,
216            layer_norm1,
217            mlp,
218            layer_norm2,
219        })
220    }
221
222    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
223        let residual = xs;
224        let xs = xs.apply(&self.layer_norm1)?;
225        let xs = self.self_attn.forward(&xs, attention_mask)?;
226        let xs = (xs + residual)?;
227
228        let residual = &xs;
229        let xs = xs.apply(&self.layer_norm2)?.apply(&self.mlp)?;
230        xs + residual
231    }
232}
233
234#[derive(Debug, Clone)]
235struct Encoder {
236    layers: Vec<EncoderLayer>,
237}
238
239impl Encoder {
240    fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
241        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
242        let vb = vb.pp("layers");
243        for i in 0..cfg.num_hidden_layers {
244            let layer = EncoderLayer::new(cfg, vb.pp(i))?;
245            layers.push(layer)
246        }
247        Ok(Self { layers })
248    }
249
250    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
251        let mut xs = xs.clone();
252        for layer in self.layers.iter() {
253            xs = layer.forward(&xs, attention_mask)?
254        }
255        Ok(xs)
256    }
257}
258
259#[derive(Debug, Clone)]
260pub struct VisionModel {
261    embeddings: VisionEmbeddings,
262    encoder: Encoder,
263    post_layernorm: LayerNorm,
264}
265
266impl VisionModel {
267    fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
268        let embeddings = VisionEmbeddings::new(cfg, vb.pp("embeddings"))?;
269        let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
270        let post_layernorm =
271            layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("post_layernorm"))?;
272        Ok(Self {
273            embeddings,
274            encoder,
275            post_layernorm,
276        })
277    }
278}
279
280impl Module for VisionModel {
281    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
282        let xs = xs.apply(&self.embeddings)?;
283        let encoder_outputs = self.encoder.forward(&xs, None)?;
284        // Return the last hidden state rather than pooled outputs.
285        encoder_outputs.apply(&self.post_layernorm)
286    }
287}
288
289#[derive(Debug, Clone)]
290pub struct BlipForConditionalGeneration {
291    vision_model: VisionModel,
292    text_decoder: blip_text::TextLMHeadModel,
293}
294
295impl BlipForConditionalGeneration {
296    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
297        let vision_model = VisionModel::new(&cfg.vision_config, vb.pp("vision_model"))?;
298        let text_decoder =
299            blip_text::TextLMHeadModel::new(&cfg.text_config, vb.pp("text_decoder"))?;
300        Ok(Self {
301            vision_model,
302            text_decoder,
303        })
304    }
305
306    pub fn vision_model(&self) -> &VisionModel {
307        &self.vision_model
308    }
309
310    pub fn text_decoder(&mut self) -> &mut blip_text::TextLMHeadModel {
311        &mut self.text_decoder
312    }
313
314    pub fn reset_kv_cache(&mut self) {
315        self.text_decoder.reset_kv_cache();
316    }
317}