1use crate::models::with_tracing::QMatMul;
19use crate::quantized_nn::{layer_norm, linear, Embedding, Linear};
20pub use crate::quantized_var_builder::VarBuilder;
21use candle::{Module, Result, Tensor, D};
22use candle_nn::LayerNorm;
23
24pub type Config = super::blip_text::Config;
25
26#[derive(Debug, Clone)]
27struct TextEmbeddings {
28 word_embedddings: Embedding,
29 position_embeddings: Embedding,
30 layer_norm: LayerNorm,
31 position_ids: Tensor,
32}
33
34impl TextEmbeddings {
35 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
36 let word_embedddings =
37 Embedding::new(cfg.vocab_size, cfg.hidden_size, vb.pp("word_embeddings"))?;
38 let position_embeddings = Embedding::new(
39 cfg.max_position_embeddings,
40 cfg.hidden_size,
41 vb.pp("position_embeddings"),
42 )?;
43 let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
44 let position_ids =
45 Tensor::arange(0, cfg.max_position_embeddings as u32, vb.device())?.unsqueeze(0)?;
46 Ok(Self {
47 word_embedddings,
48 position_embeddings,
49 layer_norm,
50 position_ids,
51 })
52 }
53
54 fn forward(&self, xs: &Tensor, past_kv_len: usize) -> Result<Tensor> {
55 let seq_len = xs.dim(1)?;
56 let position_ids = self.position_ids.narrow(1, past_kv_len, seq_len)?;
57 let embeddings = self.word_embedddings.forward(xs)?;
58 let position_embeddings = self.position_embeddings.forward(&position_ids)?;
59 (embeddings + position_embeddings)?.apply(&self.layer_norm)
60 }
61}
62
63#[derive(Debug, Clone)]
64struct TextSelfAttention {
65 query: Linear,
66 key: Linear,
67 value: Linear,
68 attention_head_size: usize,
69 num_attention_heads: usize,
70 attention_scale: f64,
71 kv_cache: Option<(Tensor, Tensor)>,
72}
73
74impl TextSelfAttention {
75 fn new(cfg: &Config, is_cross_attention: bool, vb: VarBuilder) -> Result<Self> {
76 let num_attention_heads = cfg.num_attention_heads;
77 let attention_head_size = cfg.hidden_size / num_attention_heads;
78 let all_head_size = cfg.num_attention_heads * attention_head_size;
79 let query = linear(cfg.hidden_size, all_head_size, vb.pp("query"))?;
80 let in_size = if is_cross_attention {
81 cfg.encoder_hidden_size
82 } else {
83 cfg.hidden_size
84 };
85 let key = linear(in_size, all_head_size, vb.pp("key"))?;
86 let value = linear(in_size, all_head_size, vb.pp("value"))?;
87 let attention_scale = 1f64 / (attention_head_size as f64).sqrt();
88 Ok(Self {
89 query,
90 key,
91 value,
92 attention_head_size,
93 num_attention_heads,
94 attention_scale,
95 kv_cache: None,
96 })
97 }
98
99 fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
100 let (b_size, seq_len, _) = xs.dims3()?;
101 xs.reshape((
102 b_size,
103 seq_len,
104 self.num_attention_heads,
105 self.attention_head_size,
106 ))?
107 .permute((0, 2, 1, 3))
108 }
109
110 fn reset_kv_cache(&mut self) {
111 self.kv_cache = None
112 }
113
114 fn forward(
115 &mut self,
116 xs: &Tensor,
117 encoder_hidden_states: Option<&Tensor>,
118 attention_mask: Option<&Tensor>,
119 ) -> Result<Tensor> {
120 let query = self
121 .transpose_for_scores(&self.query.forward(xs)?)?
122 .contiguous()?;
123 let (key, value) = match encoder_hidden_states {
124 None => {
125 let key = self.transpose_for_scores(&self.key.forward(xs)?)?;
126 let value = self.transpose_for_scores(&self.value.forward(xs)?)?;
127 let (key, value) = match &self.kv_cache {
128 None => (key, value),
129 Some((prev_key, prev_value)) => {
130 let key = Tensor::cat(&[prev_key, &key], 2)?;
131 let value = Tensor::cat(&[prev_value, &value], 2)?;
132 (key, value)
133 }
134 };
135 self.kv_cache = Some((key.clone(), value.clone()));
136 (key, value)
137 }
138 Some(xs) => {
139 let key = self.transpose_for_scores(&self.key.forward(xs)?)?;
140 let value = self.transpose_for_scores(&self.value.forward(xs)?)?;
141 (key, value)
143 }
144 };
145 let key = key.contiguous()?;
146 let value = value.contiguous()?;
147 let attention_scores = query.matmul(&key.t()?)?;
148 let attention_scores = (attention_scores * self.attention_scale)?;
149 let attention_scores = match attention_mask {
150 Some(mask) => attention_scores.broadcast_add(mask)?,
151 None => attention_scores,
152 };
153 let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
154 attention_probs
155 .matmul(&value)?
156 .permute((0, 2, 1, 3))?
157 .flatten_from(D::Minus2)
158 }
159}
160
161#[derive(Debug, Clone)]
162struct TextSelfOutput {
163 dense: Linear,
164 layer_norm: LayerNorm,
165}
166
167impl TextSelfOutput {
168 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
169 let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
170 let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
171 Ok(Self { dense, layer_norm })
172 }
173
174 fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
175 (xs.apply(&self.dense) + input_tensor)?.apply(&self.layer_norm)
176 }
177}
178
179#[derive(Debug, Clone)]
180struct TextAttention {
181 self_: TextSelfAttention,
182 output: TextSelfOutput,
183}
184
185impl TextAttention {
186 fn new(cfg: &Config, is_cross_attention: bool, vb: VarBuilder) -> Result<Self> {
187 let self_ = TextSelfAttention::new(cfg, is_cross_attention, vb.pp("self"))?;
188 let output = TextSelfOutput::new(cfg, vb.pp("output"))?;
189 Ok(Self { self_, output })
190 }
191
192 fn reset_kv_cache(&mut self) {
193 self.self_.reset_kv_cache()
194 }
195
196 fn forward(
197 &mut self,
198 xs: &Tensor,
199 encoder_hidden_states: Option<&Tensor>,
200 attention_mask: Option<&Tensor>,
201 ) -> Result<Tensor> {
202 let self_outputs = self
203 .self_
204 .forward(xs, encoder_hidden_states, attention_mask)?;
205 self.output.forward(&self_outputs, xs)
206 }
207}
208
209#[derive(Debug, Clone)]
210struct TextIntermediate {
211 dense: Linear,
212 intermediate_act_fn: candle_nn::Activation,
213}
214
215impl TextIntermediate {
216 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
217 let dense = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("dense"))?;
218 Ok(Self {
219 dense,
220 intermediate_act_fn: cfg.hidden_act,
221 })
222 }
223}
224
225impl Module for TextIntermediate {
226 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
227 xs.apply(&self.dense)?.apply(&self.intermediate_act_fn)
228 }
229}
230
231#[derive(Debug, Clone)]
232struct TextOutput {
233 dense: Linear,
234 layer_norm: LayerNorm,
235}
236
237impl TextOutput {
238 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
239 let dense = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("dense"))?;
240 let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
241 Ok(Self { dense, layer_norm })
242 }
243
244 fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
245 (xs.apply(&self.dense)? + input_tensor)?.apply(&self.layer_norm)
246 }
247}
248
249#[derive(Debug, Clone)]
250struct TextLayer {
251 attention: TextAttention,
252 cross_attention: Option<TextAttention>,
253 intermediate: TextIntermediate,
254 output: TextOutput,
255}
256
257impl TextLayer {
258 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
259 let attention = TextAttention::new(cfg, false, vb.pp("attention"))?;
260 let cross_attention = if cfg.is_decoder {
261 Some(TextAttention::new(cfg, true, vb.pp("crossattention"))?)
262 } else {
263 None
264 };
265 let intermediate = TextIntermediate::new(cfg, vb.pp("intermediate"))?;
266 let output = TextOutput::new(cfg, vb.pp("output"))?;
267 Ok(Self {
268 attention,
269 cross_attention,
270 intermediate,
271 output,
272 })
273 }
274
275 fn reset_kv_cache(&mut self) {
276 self.attention.reset_kv_cache();
277 if let Some(ca) = &mut self.cross_attention {
278 ca.reset_kv_cache()
279 }
280 }
281
282 fn forward(
283 &mut self,
284 xs: &Tensor,
285 encoder_hidden_states: &Tensor,
286 attention_mask: &Tensor,
287 ) -> Result<Tensor> {
288 let attention_output = self.attention.forward(xs, None, Some(attention_mask))?;
289 let attention_output = match &mut self.cross_attention {
290 Some(ca) => ca.forward(&attention_output, Some(encoder_hidden_states), None)?,
291 None => candle::bail!("expected some cross-attn"),
292 };
293 let intermediate_output = self.intermediate.forward(&attention_output)?;
294 self.output.forward(&intermediate_output, &attention_output)
295 }
296}
297
298#[derive(Debug, Clone)]
299struct TextEncoder {
300 layers: Vec<TextLayer>,
301}
302
303impl TextEncoder {
304 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
305 let vb = vb.pp("layer");
306 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
307 for i in 0..cfg.num_hidden_layers {
308 let layer = TextLayer::new(cfg, vb.pp(i))?;
309 layers.push(layer)
310 }
311 Ok(Self { layers })
312 }
313
314 fn reset_kv_cache(&mut self) {
315 self.layers.iter_mut().for_each(|l| l.reset_kv_cache())
316 }
317
318 fn forward(
319 &mut self,
320 xs: &Tensor,
321 encoder_hidden_states: &Tensor,
322 attention_mask: &Tensor,
323 ) -> Result<Tensor> {
324 let mut xs = xs.clone();
325 for layer in self.layers.iter_mut() {
326 xs = layer.forward(&xs, encoder_hidden_states, attention_mask)?
327 }
328 Ok(xs)
329 }
330}
331
332#[derive(Debug, Clone)]
333pub struct TextPooler {
334 dense: Linear,
335}
336
337impl TextPooler {
338 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
339 let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
340 Ok(Self { dense })
341 }
342}
343
344impl Module for TextPooler {
345 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
346 xs.narrow(D::Minus1, 0, 1)?
347 .squeeze(D::Minus1)?
348 .apply(&self.dense)?
349 .tanh()
350 }
351}
352
353#[derive(Debug, Clone)]
354struct TextPredictionHeadTransform {
355 dense: Linear,
356 transform_act_fn: candle_nn::Activation,
357 layer_norm: LayerNorm,
358}
359
360impl TextPredictionHeadTransform {
361 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
362 let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
363 let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
364 Ok(Self {
365 dense,
366 transform_act_fn: cfg.hidden_act,
367 layer_norm,
368 })
369 }
370}
371
372impl Module for TextPredictionHeadTransform {
373 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
374 xs.apply(&self.dense)?
375 .apply(&self.transform_act_fn)?
376 .apply(&self.layer_norm)
377 }
378}
379
380#[derive(Debug, Clone)]
381struct TextLMPredictionHead {
382 transform: TextPredictionHeadTransform,
383 decoder: Linear,
384}
385
386impl TextLMPredictionHead {
387 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
388 let transform = TextPredictionHeadTransform::new(cfg, vb.pp("transform"))?;
389 let weight = QMatMul::new(cfg.hidden_size, cfg.vocab_size, vb.pp("decoder"))?;
390 let bias = vb.get(cfg.vocab_size, "bias")?.dequantize(vb.device())?;
391 let decoder = Linear::from_weights(weight, Some(bias));
392 Ok(Self { transform, decoder })
393 }
394}
395
396impl Module for TextLMPredictionHead {
397 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
398 xs.apply(&self.transform)?.apply(&self.decoder)
399 }
400}
401
402#[derive(Debug, Clone)]
403struct TextOnlyMLMHead {
404 predictions: TextLMPredictionHead,
405}
406
407impl TextOnlyMLMHead {
408 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
409 let predictions = TextLMPredictionHead::new(cfg, vb.pp("predictions"))?;
410 Ok(Self { predictions })
411 }
412}
413
414impl Module for TextOnlyMLMHead {
415 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
416 self.predictions.forward(xs)
417 }
418}
419
420#[derive(Debug, Clone)]
421struct TextModel {
422 embeddings: TextEmbeddings,
423 encoder: TextEncoder,
424 past_kv_len: usize,
425 }
427
428impl TextModel {
429 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
430 let embeddings = TextEmbeddings::new(cfg, vb.pp("embeddings"))?;
431 let encoder = TextEncoder::new(cfg, vb.pp("encoder"))?;
432 Ok(Self {
433 embeddings,
434 encoder,
435 past_kv_len: 0,
436 })
437 }
438
439 fn forward(
440 &mut self,
441 input_ids: &Tensor,
442 encoder_hidden_states: &Tensor,
443 attention_mask: &Tensor,
444 ) -> Result<Tensor> {
445 let (_b_sz, seq_len) = input_ids.dims2()?;
446 let embedding_output = self.embeddings.forward(input_ids, self.past_kv_len)?;
447 let sequence_output =
448 self.encoder
449 .forward(&embedding_output, encoder_hidden_states, attention_mask)?;
450 self.past_kv_len += seq_len;
451 Ok(sequence_output)
453 }
454
455 fn reset_kv_cache(&mut self) {
456 self.past_kv_len = 0;
457 self.encoder.reset_kv_cache();
458 }
459}
460
461#[derive(Debug, Clone)]
462pub struct TextLMHeadModel {
463 bert: TextModel,
464 cls: TextOnlyMLMHead,
465}
466
467impl TextLMHeadModel {
468 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
469 let bert = TextModel::new(cfg, vb.pp("bert"))?;
470 let cls = TextOnlyMLMHead::new(cfg, vb.pp("cls"))?;
471 Ok(Self { bert, cls })
472 }
473
474 pub fn forward(
475 &mut self,
476 input_ids: &Tensor,
477 encoder_hidden_states: &Tensor,
478 ) -> Result<Tensor> {
479 let seq_len = input_ids.dim(1)?;
480 let mask: Vec<_> = (0..seq_len)
481 .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
482 .collect();
483 let mask = Tensor::from_vec(mask, (seq_len, seq_len), input_ids.device())?;
484 let sequence_output = self.bert.forward(input_ids, encoder_hidden_states, &mask)?;
485 let prediction_scores = self.cls.forward(&sequence_output)?;
486 Ok(prediction_scores)
488 }
489
490 pub fn reset_kv_cache(&mut self) {
491 self.bert.reset_kv_cache()
492 }
493}