candle_transformers/models/whisper/
quantized_model.rs

1use super::Config;
2use crate::quantized_nn::{layer_norm, linear, linear_no_bias, Embedding, Linear};
3pub use crate::quantized_var_builder::VarBuilder;
4use candle::{Device, IndexOp, Result, Tensor, D};
5use candle_nn::{Conv1d, Conv1dConfig, LayerNorm, Module};
6
7fn conv1d(
8    in_channels: usize,
9    out_channels: usize,
10    kernel_size: usize,
11    config: Conv1dConfig,
12    vb: VarBuilder,
13) -> Result<Conv1d> {
14    let weight = vb
15        .get((out_channels, in_channels, kernel_size), "weight")?
16        .dequantize(vb.device())?;
17    let bias = vb.get(out_channels, "bias")?.dequantize(vb.device())?;
18    Ok(Conv1d::new(weight, Some(bias), config))
19}
20
21// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
22#[derive(Debug, Clone)]
23struct MultiHeadAttention {
24    query: Linear,
25    key: Linear,
26    value: Linear,
27    out: Linear,
28    n_head: usize,
29    span: tracing::Span,
30    softmax_span: tracing::Span,
31    matmul_span: tracing::Span,
32    kv_cache: Option<(Tensor, Tensor)>,
33}
34
35impl MultiHeadAttention {
36    fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
37        let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn");
38        let softmax_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-softmax");
39        let matmul_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-matmul");
40        let query = linear(n_state, n_state, vb.pp("q_proj"))?;
41        let value = linear(n_state, n_state, vb.pp("v_proj"))?;
42        let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
43        let out = linear(n_state, n_state, vb.pp("out_proj"))?;
44        Ok(Self {
45            query,
46            key,
47            value,
48            out,
49            n_head,
50            span,
51            softmax_span,
52            matmul_span,
53            kv_cache: None,
54        })
55    }
56
57    fn forward(
58        &mut self,
59        x: &Tensor,
60        xa: Option<&Tensor>,
61        mask: Option<&Tensor>,
62        flush_cache: bool,
63    ) -> Result<Tensor> {
64        let _enter = self.span.enter();
65        let q = self.query.forward(x)?;
66        let (k, v) = match xa {
67            None => {
68                let k = self.key.forward(x)?;
69                let v = self.value.forward(x)?;
70                (k, v)
71            }
72            Some(x) => {
73                if flush_cache {
74                    self.kv_cache = None;
75                }
76                if let Some((k, v)) = &self.kv_cache {
77                    (k.clone(), v.clone())
78                } else {
79                    let k = self.key.forward(x)?;
80                    let v = self.value.forward(x)?;
81                    self.kv_cache = Some((k.clone(), v.clone()));
82                    (k, v)
83                }
84            }
85        };
86        let wv = self.qkv_attention(&q, &k, &v, mask)?;
87        let out = self.out.forward(&wv)?;
88        Ok(out)
89    }
90
91    fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
92        let (n_batch, n_ctx, n_state) = x.dims3()?;
93        let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
94        x.reshape(target_dims)?.transpose(1, 2)
95    }
96
97    fn qkv_attention(
98        &self,
99        q: &Tensor,
100        k: &Tensor,
101        v: &Tensor,
102        mask: Option<&Tensor>,
103    ) -> Result<Tensor> {
104        let (_, n_ctx, n_state) = q.dims3()?;
105        let scale = ((n_state / self.n_head) as f64).powf(-0.25);
106        let q = (self.reshape_head(q)? * scale)?;
107        let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
108        let v = self.reshape_head(v)?.contiguous()?;
109        let mut qk = {
110            let _enter = self.matmul_span.enter();
111            q.matmul(&k)?
112        };
113        if let Some(mask) = mask {
114            let mask = mask.i((0..n_ctx, 0..n_ctx))?;
115            qk = qk.broadcast_add(&mask)?
116        }
117        let w = {
118            let _enter = self.softmax_span.enter();
119            candle_nn::ops::softmax_last_dim(&qk)?
120        };
121        let wv = {
122            let _enter = self.matmul_span.enter();
123            w.matmul(&v)?
124        }
125        .transpose(1, 2)?
126        .flatten_from(2)?;
127        Ok(wv)
128    }
129
130    fn reset_kv_cache(&mut self) {
131        self.kv_cache = None;
132    }
133}
134
135// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
136#[derive(Debug, Clone)]
137struct ResidualAttentionBlock {
138    attn: MultiHeadAttention,
139    attn_ln: LayerNorm,
140    cross_attn: Option<(MultiHeadAttention, LayerNorm)>,
141    mlp_linear1: Linear,
142    mlp_linear2: Linear,
143    mlp_ln: LayerNorm,
144    span: tracing::Span,
145}
146
147impl ResidualAttentionBlock {
148    fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
149        let span = tracing::span!(tracing::Level::TRACE, "residual-attn");
150        let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
151        let attn_ln = layer_norm(n_state, 1e-5, vb.pp("self_attn_layer_norm"))?;
152        let cross_attn = if ca {
153            let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?;
154            let cross_attn_ln = layer_norm(n_state, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
155            Some((cross_attn, cross_attn_ln))
156        } else {
157            None
158        };
159        let n_mlp = n_state * 4;
160        let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?;
161        let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?;
162        let mlp_ln = layer_norm(n_state, 1e-5, vb.pp("final_layer_norm"))?;
163        Ok(Self {
164            attn,
165            attn_ln,
166            cross_attn,
167            mlp_linear1,
168            mlp_linear2,
169            mlp_ln,
170            span,
171        })
172    }
173
174    fn forward(
175        &mut self,
176        x: &Tensor,
177        xa: Option<&Tensor>,
178        mask: Option<&Tensor>,
179        flush_kv_cache: bool,
180    ) -> Result<Tensor> {
181        let _enter = self.span.enter();
182        let attn = self
183            .attn
184            .forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?;
185        let mut x = (x + attn)?;
186        if let Some((attn, ln)) = &mut self.cross_attn {
187            x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
188        }
189        let mlp = x
190            .apply(&self.mlp_ln)?
191            .apply(&self.mlp_linear1)?
192            .gelu()?
193            .apply(&self.mlp_linear2)?;
194        x + mlp
195    }
196
197    fn reset_kv_cache(&mut self) {
198        self.attn.reset_kv_cache();
199        if let Some((attn, _)) = &mut self.cross_attn {
200            attn.reset_kv_cache();
201        }
202    }
203}
204
205fn sinusoids(length: usize, channels: usize, device: &Device) -> Result<Tensor> {
206    let max_timescale = 10000f32;
207    let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
208    let inv_timescales: Vec<_> = (0..channels / 2)
209        .map(|i| (i as f32 * (-log_timescale_increment)).exp())
210        .collect();
211    let inv_timescales = Tensor::new(inv_timescales.as_slice(), device)?.unsqueeze(0)?;
212    let arange = Tensor::arange(0, length as u32, device)?
213        .to_dtype(candle::DType::F32)?
214        .unsqueeze(1)?;
215    let sh = (length, channels / 2);
216    let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?;
217    let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?;
218    Ok(sincos)
219}
220
221// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
222#[derive(Debug, Clone)]
223pub struct AudioEncoder {
224    conv1: Conv1d,
225    conv2: Conv1d,
226    positional_embedding: Tensor,
227    blocks: Vec<ResidualAttentionBlock>,
228    ln_post: LayerNorm,
229    span: tracing::Span,
230    conv1_span: tracing::Span,
231    conv2_span: tracing::Span,
232}
233
234impl AudioEncoder {
235    fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
236        let span = tracing::span!(tracing::Level::TRACE, "audio-encoder");
237        let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1");
238        let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2");
239        let n_state = cfg.d_model;
240        let n_head = cfg.encoder_attention_heads;
241        let n_ctx = cfg.max_source_positions;
242        let cfg1 = Conv1dConfig {
243            padding: 1,
244            stride: 1,
245            groups: 1,
246            dilation: 1,
247        };
248        let cfg2 = Conv1dConfig {
249            padding: 1,
250            stride: 2,
251            groups: 1,
252            dilation: 1,
253        };
254        let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
255        let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
256        let positional_embedding = sinusoids(n_ctx, n_state, vb.device())?;
257        let blocks = (0..cfg.encoder_layers)
258            .map(|i| {
259                ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(format!("layers.{i}")))
260            })
261            .collect::<Result<Vec<_>>>()?;
262        let ln_post = layer_norm(n_state, 1e-5, vb.pp("layer_norm"))?;
263        Ok(Self {
264            conv1,
265            conv2,
266            positional_embedding,
267            blocks,
268            ln_post,
269            conv1_span,
270            conv2_span,
271            span,
272        })
273    }
274
275    pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
276        let _enter = self.span.enter();
277        let x = {
278            let _enter = self.conv1_span.enter();
279            self.conv1.forward(x)?.gelu()?
280        };
281        let x = {
282            let _enter = self.conv2_span.enter();
283            self.conv2.forward(&x)?.gelu()?
284        };
285        let x = x.transpose(1, 2)?;
286        let (_bsize, seq_len, _hidden) = x.dims3()?;
287        let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
288        let mut x = x.broadcast_add(&positional_embedding)?;
289        for block in self.blocks.iter_mut() {
290            x = block.forward(&x, None, None, flush_kv_cache)?
291        }
292        let x = self.ln_post.forward(&x)?;
293        Ok(x)
294    }
295
296    pub fn reset_kv_cache(&mut self) {
297        for block in self.blocks.iter_mut() {
298            block.reset_kv_cache();
299        }
300    }
301}
302
303// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
304#[derive(Debug, Clone)]
305pub struct TextDecoder {
306    token_embedding: Embedding,
307    positional_embedding: Tensor,
308    blocks: Vec<ResidualAttentionBlock>,
309    ln: LayerNorm,
310    mask: Tensor,
311    span: tracing::Span,
312    span_final: tracing::Span,
313}
314
315impl TextDecoder {
316    fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
317        let span = tracing::span!(tracing::Level::TRACE, "text-decoder");
318        let span_final = tracing::span!(tracing::Level::TRACE, "text-decoder-final");
319        let n_state = cfg.d_model;
320        let n_head = cfg.decoder_attention_heads;
321        let n_ctx = cfg.max_target_positions;
322        let token_embedding = Embedding::new(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?;
323        let positional_embedding = vb
324            .get((n_ctx, n_state), "embed_positions.weight")?
325            .dequantize(vb.device())?;
326        let blocks = (0..cfg.decoder_layers)
327            .map(|i| {
328                ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(format!("layers.{i}")))
329            })
330            .collect::<Result<Vec<_>>>()?;
331        let ln = layer_norm(n_state, 1e-5, vb.pp("layer_norm"))?;
332        let mask: Vec<_> = (0..n_ctx)
333            .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
334            .collect();
335        let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?;
336        Ok(Self {
337            token_embedding,
338            positional_embedding,
339            blocks,
340            ln,
341            mask,
342            span,
343            span_final,
344        })
345    }
346
347    pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
348        let _enter = self.span.enter();
349        let last = x.dim(D::Minus1)?;
350        let token_embedding = self.token_embedding.forward(x)?;
351        let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
352        let mut x = token_embedding.broadcast_add(&positional_embedding)?;
353        for block in self.blocks.iter_mut() {
354            x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?;
355        }
356        self.ln.forward(&x)
357    }
358
359    pub fn final_linear(&self, x: &Tensor) -> Result<Tensor> {
360        let b_size = x.dim(0)?;
361        let w = self.token_embedding.embeddings().broadcast_left(b_size)?;
362        let logits = {
363            let _enter = self.span_final.enter();
364            x.matmul(&w.t()?)?
365        };
366        Ok(logits)
367    }
368
369    pub fn reset_kv_cache(&mut self) {
370        for block in self.blocks.iter_mut() {
371            block.reset_kv_cache();
372        }
373    }
374}
375
376// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
377#[derive(Debug, Clone)]
378pub struct Whisper {
379    pub encoder: AudioEncoder,
380    pub decoder: TextDecoder,
381    pub config: Config,
382}
383
384impl Whisper {
385    pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> {
386        let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?;
387        let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?;
388        Ok(Self {
389            encoder,
390            decoder,
391            config,
392        })
393    }
394
395    pub fn reset_kv_cache(&mut self) {
396        self.encoder.reset_kv_cache();
397        self.decoder.reset_kv_cache();
398    }
399}