candle_transformers/models/
quantized_blip_text.rs

1//! Quantized BLIP text module implementation.
2//!
3//! Provides the text decoder portion of the BLIP model with 8-bit quantization.
4//! Uses a BERT-style transformer architecture for text processing.
5//!
6//! Key components:
7//! - Text embeddings layer with position embeddings
8//! - Multi-head self attention layers
9//! - Cross-attention for vision-text fusion
10//! - Layer normalization and feed-forward layers
11//! - Quantized linear transformations
12//!
13//! References:
14//! - [BLIP Paper](https://arxiv.org/abs/2201.12086)
15//! - [Hugging Face Implementation](https://huggingface.co/docs/transformers/model_doc/blip)
16//!
17
18use crate::models::with_tracing::QMatMul;
19use crate::quantized_nn::{layer_norm, linear, Embedding, Linear};
20pub use crate::quantized_var_builder::VarBuilder;
21use candle::{Module, Result, Tensor, D};
22use candle_nn::LayerNorm;
23
24pub type Config = super::blip_text::Config;
25
26#[derive(Debug, Clone)]
27struct TextEmbeddings {
28    word_embedddings: Embedding,
29    position_embeddings: Embedding,
30    layer_norm: LayerNorm,
31    position_ids: Tensor,
32}
33
34impl TextEmbeddings {
35    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
36        let word_embedddings =
37            Embedding::new(cfg.vocab_size, cfg.hidden_size, vb.pp("word_embeddings"))?;
38        let position_embeddings = Embedding::new(
39            cfg.max_position_embeddings,
40            cfg.hidden_size,
41            vb.pp("position_embeddings"),
42        )?;
43        let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
44        let position_ids =
45            Tensor::arange(0, cfg.max_position_embeddings as u32, vb.device())?.unsqueeze(0)?;
46        Ok(Self {
47            word_embedddings,
48            position_embeddings,
49            layer_norm,
50            position_ids,
51        })
52    }
53
54    fn forward(&self, xs: &Tensor, past_kv_len: usize) -> Result<Tensor> {
55        let seq_len = xs.dim(1)?;
56        let position_ids = self.position_ids.narrow(1, past_kv_len, seq_len)?;
57        let embeddings = self.word_embedddings.forward(xs)?;
58        let position_embeddings = self.position_embeddings.forward(&position_ids)?;
59        (embeddings + position_embeddings)?.apply(&self.layer_norm)
60    }
61}
62
63#[derive(Debug, Clone)]
64struct TextSelfAttention {
65    query: Linear,
66    key: Linear,
67    value: Linear,
68    attention_head_size: usize,
69    num_attention_heads: usize,
70    attention_scale: f64,
71    kv_cache: Option<(Tensor, Tensor)>,
72}
73
74impl TextSelfAttention {
75    fn new(cfg: &Config, is_cross_attention: bool, vb: VarBuilder) -> Result<Self> {
76        let num_attention_heads = cfg.num_attention_heads;
77        let attention_head_size = cfg.hidden_size / num_attention_heads;
78        let all_head_size = cfg.num_attention_heads * attention_head_size;
79        let query = linear(cfg.hidden_size, all_head_size, vb.pp("query"))?;
80        let in_size = if is_cross_attention {
81            cfg.encoder_hidden_size
82        } else {
83            cfg.hidden_size
84        };
85        let key = linear(in_size, all_head_size, vb.pp("key"))?;
86        let value = linear(in_size, all_head_size, vb.pp("value"))?;
87        let attention_scale = 1f64 / (attention_head_size as f64).sqrt();
88        Ok(Self {
89            query,
90            key,
91            value,
92            attention_head_size,
93            num_attention_heads,
94            attention_scale,
95            kv_cache: None,
96        })
97    }
98
99    fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
100        let (b_size, seq_len, _) = xs.dims3()?;
101        xs.reshape((
102            b_size,
103            seq_len,
104            self.num_attention_heads,
105            self.attention_head_size,
106        ))?
107        .permute((0, 2, 1, 3))
108    }
109
110    fn reset_kv_cache(&mut self) {
111        self.kv_cache = None
112    }
113
114    fn forward(
115        &mut self,
116        xs: &Tensor,
117        encoder_hidden_states: Option<&Tensor>,
118        attention_mask: Option<&Tensor>,
119    ) -> Result<Tensor> {
120        let query = self
121            .transpose_for_scores(&self.query.forward(xs)?)?
122            .contiguous()?;
123        let (key, value) = match encoder_hidden_states {
124            None => {
125                let key = self.transpose_for_scores(&self.key.forward(xs)?)?;
126                let value = self.transpose_for_scores(&self.value.forward(xs)?)?;
127                let (key, value) = match &self.kv_cache {
128                    None => (key, value),
129                    Some((prev_key, prev_value)) => {
130                        let key = Tensor::cat(&[prev_key, &key], 2)?;
131                        let value = Tensor::cat(&[prev_value, &value], 2)?;
132                        (key, value)
133                    }
134                };
135                self.kv_cache = Some((key.clone(), value.clone()));
136                (key, value)
137            }
138            Some(xs) => {
139                let key = self.transpose_for_scores(&self.key.forward(xs)?)?;
140                let value = self.transpose_for_scores(&self.value.forward(xs)?)?;
141                // no kv-cache in this case, but the results could probably be memoized.
142                (key, value)
143            }
144        };
145        let key = key.contiguous()?;
146        let value = value.contiguous()?;
147        let attention_scores = query.matmul(&key.t()?)?;
148        let attention_scores = (attention_scores * self.attention_scale)?;
149        let attention_scores = match attention_mask {
150            Some(mask) => attention_scores.broadcast_add(mask)?,
151            None => attention_scores,
152        };
153        let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
154        attention_probs
155            .matmul(&value)?
156            .permute((0, 2, 1, 3))?
157            .flatten_from(D::Minus2)
158    }
159}
160
161#[derive(Debug, Clone)]
162struct TextSelfOutput {
163    dense: Linear,
164    layer_norm: LayerNorm,
165}
166
167impl TextSelfOutput {
168    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
169        let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
170        let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
171        Ok(Self { dense, layer_norm })
172    }
173
174    fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
175        (xs.apply(&self.dense) + input_tensor)?.apply(&self.layer_norm)
176    }
177}
178
179#[derive(Debug, Clone)]
180struct TextAttention {
181    self_: TextSelfAttention,
182    output: TextSelfOutput,
183}
184
185impl TextAttention {
186    fn new(cfg: &Config, is_cross_attention: bool, vb: VarBuilder) -> Result<Self> {
187        let self_ = TextSelfAttention::new(cfg, is_cross_attention, vb.pp("self"))?;
188        let output = TextSelfOutput::new(cfg, vb.pp("output"))?;
189        Ok(Self { self_, output })
190    }
191
192    fn reset_kv_cache(&mut self) {
193        self.self_.reset_kv_cache()
194    }
195
196    fn forward(
197        &mut self,
198        xs: &Tensor,
199        encoder_hidden_states: Option<&Tensor>,
200        attention_mask: Option<&Tensor>,
201    ) -> Result<Tensor> {
202        let self_outputs = self
203            .self_
204            .forward(xs, encoder_hidden_states, attention_mask)?;
205        self.output.forward(&self_outputs, xs)
206    }
207}
208
209#[derive(Debug, Clone)]
210struct TextIntermediate {
211    dense: Linear,
212    intermediate_act_fn: candle_nn::Activation,
213}
214
215impl TextIntermediate {
216    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
217        let dense = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("dense"))?;
218        Ok(Self {
219            dense,
220            intermediate_act_fn: cfg.hidden_act,
221        })
222    }
223}
224
225impl Module for TextIntermediate {
226    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
227        xs.apply(&self.dense)?.apply(&self.intermediate_act_fn)
228    }
229}
230
231#[derive(Debug, Clone)]
232struct TextOutput {
233    dense: Linear,
234    layer_norm: LayerNorm,
235}
236
237impl TextOutput {
238    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
239        let dense = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("dense"))?;
240        let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
241        Ok(Self { dense, layer_norm })
242    }
243
244    fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
245        (xs.apply(&self.dense)? + input_tensor)?.apply(&self.layer_norm)
246    }
247}
248
249#[derive(Debug, Clone)]
250struct TextLayer {
251    attention: TextAttention,
252    cross_attention: Option<TextAttention>,
253    intermediate: TextIntermediate,
254    output: TextOutput,
255}
256
257impl TextLayer {
258    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
259        let attention = TextAttention::new(cfg, false, vb.pp("attention"))?;
260        let cross_attention = if cfg.is_decoder {
261            Some(TextAttention::new(cfg, true, vb.pp("crossattention"))?)
262        } else {
263            None
264        };
265        let intermediate = TextIntermediate::new(cfg, vb.pp("intermediate"))?;
266        let output = TextOutput::new(cfg, vb.pp("output"))?;
267        Ok(Self {
268            attention,
269            cross_attention,
270            intermediate,
271            output,
272        })
273    }
274
275    fn reset_kv_cache(&mut self) {
276        self.attention.reset_kv_cache();
277        if let Some(ca) = &mut self.cross_attention {
278            ca.reset_kv_cache()
279        }
280    }
281
282    fn forward(
283        &mut self,
284        xs: &Tensor,
285        encoder_hidden_states: &Tensor,
286        attention_mask: &Tensor,
287    ) -> Result<Tensor> {
288        let attention_output = self.attention.forward(xs, None, Some(attention_mask))?;
289        let attention_output = match &mut self.cross_attention {
290            Some(ca) => ca.forward(&attention_output, Some(encoder_hidden_states), None)?,
291            None => candle::bail!("expected some cross-attn"),
292        };
293        let intermediate_output = self.intermediate.forward(&attention_output)?;
294        self.output.forward(&intermediate_output, &attention_output)
295    }
296}
297
298#[derive(Debug, Clone)]
299struct TextEncoder {
300    layers: Vec<TextLayer>,
301}
302
303impl TextEncoder {
304    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
305        let vb = vb.pp("layer");
306        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
307        for i in 0..cfg.num_hidden_layers {
308            let layer = TextLayer::new(cfg, vb.pp(i))?;
309            layers.push(layer)
310        }
311        Ok(Self { layers })
312    }
313
314    fn reset_kv_cache(&mut self) {
315        self.layers.iter_mut().for_each(|l| l.reset_kv_cache())
316    }
317
318    fn forward(
319        &mut self,
320        xs: &Tensor,
321        encoder_hidden_states: &Tensor,
322        attention_mask: &Tensor,
323    ) -> Result<Tensor> {
324        let mut xs = xs.clone();
325        for layer in self.layers.iter_mut() {
326            xs = layer.forward(&xs, encoder_hidden_states, attention_mask)?
327        }
328        Ok(xs)
329    }
330}
331
332#[derive(Debug, Clone)]
333pub struct TextPooler {
334    dense: Linear,
335}
336
337impl TextPooler {
338    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
339        let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
340        Ok(Self { dense })
341    }
342}
343
344impl Module for TextPooler {
345    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
346        xs.narrow(D::Minus1, 0, 1)?
347            .squeeze(D::Minus1)?
348            .apply(&self.dense)?
349            .tanh()
350    }
351}
352
353#[derive(Debug, Clone)]
354struct TextPredictionHeadTransform {
355    dense: Linear,
356    transform_act_fn: candle_nn::Activation,
357    layer_norm: LayerNorm,
358}
359
360impl TextPredictionHeadTransform {
361    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
362        let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
363        let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
364        Ok(Self {
365            dense,
366            transform_act_fn: cfg.hidden_act,
367            layer_norm,
368        })
369    }
370}
371
372impl Module for TextPredictionHeadTransform {
373    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
374        xs.apply(&self.dense)?
375            .apply(&self.transform_act_fn)?
376            .apply(&self.layer_norm)
377    }
378}
379
380#[derive(Debug, Clone)]
381struct TextLMPredictionHead {
382    transform: TextPredictionHeadTransform,
383    decoder: Linear,
384}
385
386impl TextLMPredictionHead {
387    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
388        let transform = TextPredictionHeadTransform::new(cfg, vb.pp("transform"))?;
389        let weight = QMatMul::new(cfg.hidden_size, cfg.vocab_size, vb.pp("decoder"))?;
390        let bias = vb.get(cfg.vocab_size, "bias")?.dequantize(vb.device())?;
391        let decoder = Linear::from_weights(weight, Some(bias));
392        Ok(Self { transform, decoder })
393    }
394}
395
396impl Module for TextLMPredictionHead {
397    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
398        xs.apply(&self.transform)?.apply(&self.decoder)
399    }
400}
401
402#[derive(Debug, Clone)]
403struct TextOnlyMLMHead {
404    predictions: TextLMPredictionHead,
405}
406
407impl TextOnlyMLMHead {
408    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
409        let predictions = TextLMPredictionHead::new(cfg, vb.pp("predictions"))?;
410        Ok(Self { predictions })
411    }
412}
413
414impl Module for TextOnlyMLMHead {
415    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
416        self.predictions.forward(xs)
417    }
418}
419
420#[derive(Debug, Clone)]
421struct TextModel {
422    embeddings: TextEmbeddings,
423    encoder: TextEncoder,
424    past_kv_len: usize,
425    // We do not need the pooler for caption generation
426}
427
428impl TextModel {
429    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
430        let embeddings = TextEmbeddings::new(cfg, vb.pp("embeddings"))?;
431        let encoder = TextEncoder::new(cfg, vb.pp("encoder"))?;
432        Ok(Self {
433            embeddings,
434            encoder,
435            past_kv_len: 0,
436        })
437    }
438
439    fn forward(
440        &mut self,
441        input_ids: &Tensor,
442        encoder_hidden_states: &Tensor,
443        attention_mask: &Tensor,
444    ) -> Result<Tensor> {
445        let (_b_sz, seq_len) = input_ids.dims2()?;
446        let embedding_output = self.embeddings.forward(input_ids, self.past_kv_len)?;
447        let sequence_output =
448            self.encoder
449                .forward(&embedding_output, encoder_hidden_states, attention_mask)?;
450        self.past_kv_len += seq_len;
451        // We're interested in the sequence-output rather than the pooled-output.
452        Ok(sequence_output)
453    }
454
455    fn reset_kv_cache(&mut self) {
456        self.past_kv_len = 0;
457        self.encoder.reset_kv_cache();
458    }
459}
460
461#[derive(Debug, Clone)]
462pub struct TextLMHeadModel {
463    bert: TextModel,
464    cls: TextOnlyMLMHead,
465}
466
467impl TextLMHeadModel {
468    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
469        let bert = TextModel::new(cfg, vb.pp("bert"))?;
470        let cls = TextOnlyMLMHead::new(cfg, vb.pp("cls"))?;
471        Ok(Self { bert, cls })
472    }
473
474    pub fn forward(
475        &mut self,
476        input_ids: &Tensor,
477        encoder_hidden_states: &Tensor,
478    ) -> Result<Tensor> {
479        let seq_len = input_ids.dim(1)?;
480        let mask: Vec<_> = (0..seq_len)
481            .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
482            .collect();
483        let mask = Tensor::from_vec(mask, (seq_len, seq_len), input_ids.device())?;
484        let sequence_output = self.bert.forward(input_ids, encoder_hidden_states, &mask)?;
485        let prediction_scores = self.cls.forward(&sequence_output)?;
486        // return_logits is false so we don't discard the last sequence element.
487        Ok(prediction_scores)
488    }
489
490    pub fn reset_kv_cache(&mut self) {
491        self.bert.reset_kv_cache()
492    }
493}