candle_transformers/models/whisper/
model.rs

1use super::Config;
2use crate::models::with_tracing::{linear, linear_no_bias, Linear};
3use candle::{Device, IndexOp, Result, Tensor, D};
4use candle_nn::{embedding, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
5
6fn conv1d(
7    in_channels: usize,
8    out_channels: usize,
9    kernel_size: usize,
10    config: Conv1dConfig,
11    vb: VarBuilder,
12) -> Result<Conv1d> {
13    let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
14    let bias = vb.get(out_channels, "bias")?;
15    Ok(Conv1d::new(weight, Some(bias), config))
16}
17
18fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
19    let weight = vb.get(size, "weight")?;
20    let bias = vb.get(size, "bias")?;
21    Ok(LayerNorm::new(weight, bias, 1e-5))
22}
23
24// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
25#[derive(Debug, Clone)]
26struct MultiHeadAttention {
27    query: Linear,
28    key: Linear,
29    value: Linear,
30    out: Linear,
31    n_head: usize,
32    span: tracing::Span,
33    softmax_span: tracing::Span,
34    matmul_span: tracing::Span,
35    kv_cache: Option<(Tensor, Tensor)>,
36}
37
38impl MultiHeadAttention {
39    fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
40        let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn");
41        let softmax_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-softmax");
42        let matmul_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-matmul");
43        let query = linear(n_state, n_state, vb.pp("q_proj"))?;
44        let value = linear(n_state, n_state, vb.pp("v_proj"))?;
45        let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
46        let out = linear(n_state, n_state, vb.pp("out_proj"))?;
47        Ok(Self {
48            query,
49            key,
50            value,
51            out,
52            n_head,
53            span,
54            softmax_span,
55            matmul_span,
56            kv_cache: None,
57        })
58    }
59
60    fn forward(
61        &mut self,
62        x: &Tensor,
63        xa: Option<&Tensor>,
64        mask: Option<&Tensor>,
65        flush_cache: bool,
66    ) -> Result<Tensor> {
67        let _enter = self.span.enter();
68        let q = self.query.forward(x)?;
69        let (k, v) = match xa {
70            None => {
71                let k = self.key.forward(x)?;
72                let v = self.value.forward(x)?;
73                (k, v)
74            }
75            Some(x) => {
76                if flush_cache {
77                    self.kv_cache = None;
78                }
79                if let Some((k, v)) = &self.kv_cache {
80                    (k.clone(), v.clone())
81                } else {
82                    let k = self.key.forward(x)?;
83                    let v = self.value.forward(x)?;
84                    self.kv_cache = Some((k.clone(), v.clone()));
85                    (k, v)
86                }
87            }
88        };
89        let wv = self.qkv_attention(&q, &k, &v, mask)?;
90        let out = self.out.forward(&wv)?;
91        Ok(out)
92    }
93
94    fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
95        let (n_batch, n_ctx, n_state) = x.dims3()?;
96        let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
97        x.reshape(target_dims)?.transpose(1, 2)
98    }
99
100    fn qkv_attention(
101        &self,
102        q: &Tensor,
103        k: &Tensor,
104        v: &Tensor,
105        mask: Option<&Tensor>,
106    ) -> Result<Tensor> {
107        let (_, n_ctx, n_state) = q.dims3()?;
108        let scale = ((n_state / self.n_head) as f64).powf(-0.25);
109        let q = (self.reshape_head(q)? * scale)?;
110        let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
111        let v = self.reshape_head(v)?.contiguous()?;
112        let mut qk = {
113            let _enter = self.matmul_span.enter();
114            q.matmul(&k)?
115        };
116        if let Some(mask) = mask {
117            let mask = mask.i((0..n_ctx, 0..n_ctx))?;
118            qk = qk.broadcast_add(&mask)?
119        }
120        let w = {
121            let _enter = self.softmax_span.enter();
122            candle_nn::ops::softmax_last_dim(&qk)?
123        };
124        let wv = {
125            let _enter = self.matmul_span.enter();
126            w.matmul(&v)?
127        }
128        .transpose(1, 2)?
129        .flatten_from(2)?;
130        Ok(wv)
131    }
132
133    fn reset_kv_cache(&mut self) {
134        self.kv_cache = None;
135    }
136}
137
138// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
139#[derive(Debug, Clone)]
140struct ResidualAttentionBlock {
141    attn: MultiHeadAttention,
142    attn_ln: LayerNorm,
143    cross_attn: Option<(MultiHeadAttention, LayerNorm)>,
144    mlp_linear1: Linear,
145    mlp_linear2: Linear,
146    mlp_ln: LayerNorm,
147    span: tracing::Span,
148}
149
150impl ResidualAttentionBlock {
151    fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
152        let span = tracing::span!(tracing::Level::TRACE, "residual-attn");
153        let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
154        let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
155        let cross_attn = if ca {
156            let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?;
157            let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?;
158            Some((cross_attn, cross_attn_ln))
159        } else {
160            None
161        };
162        let n_mlp = n_state * 4;
163        let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?;
164        let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?;
165        let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?;
166        Ok(Self {
167            attn,
168            attn_ln,
169            cross_attn,
170            mlp_linear1,
171            mlp_linear2,
172            mlp_ln,
173            span,
174        })
175    }
176
177    fn forward(
178        &mut self,
179        x: &Tensor,
180        xa: Option<&Tensor>,
181        mask: Option<&Tensor>,
182        flush_kv_cache: bool,
183    ) -> Result<Tensor> {
184        let _enter = self.span.enter();
185        let attn = self
186            .attn
187            .forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?;
188        let mut x = (x + attn)?;
189        if let Some((attn, ln)) = &mut self.cross_attn {
190            x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
191        }
192        let mlp = self.mlp_linear2.forward(
193            &self
194                .mlp_linear1
195                .forward(&self.mlp_ln.forward(&x)?)?
196                .gelu()?,
197        )?;
198        x + mlp
199    }
200
201    fn reset_kv_cache(&mut self) {
202        self.attn.reset_kv_cache();
203        if let Some((attn, _)) = &mut self.cross_attn {
204            attn.reset_kv_cache();
205        }
206    }
207}
208
209fn sinusoids(length: usize, channels: usize, device: &Device) -> Result<Tensor> {
210    let max_timescale = 10000f32;
211    let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
212    let inv_timescales: Vec<_> = (0..channels / 2)
213        .map(|i| (i as f32 * (-log_timescale_increment)).exp())
214        .collect();
215    let inv_timescales = Tensor::new(inv_timescales.as_slice(), device)?.unsqueeze(0)?;
216    let arange = Tensor::arange(0, length as u32, device)?
217        .to_dtype(candle::DType::F32)?
218        .unsqueeze(1)?;
219    let sh = (length, channels / 2);
220    let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?;
221    let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?;
222    Ok(sincos)
223}
224
225// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
226#[derive(Debug, Clone)]
227pub struct AudioEncoder {
228    conv1: Conv1d,
229    conv2: Conv1d,
230    positional_embedding: Tensor,
231    blocks: Vec<ResidualAttentionBlock>,
232    ln_post: LayerNorm,
233    span: tracing::Span,
234    conv1_span: tracing::Span,
235    conv2_span: tracing::Span,
236}
237
238impl AudioEncoder {
239    fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
240        let span = tracing::span!(tracing::Level::TRACE, "audio-encoder");
241        let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1");
242        let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2");
243        let n_state = cfg.d_model;
244        let n_head = cfg.encoder_attention_heads;
245        let n_ctx = cfg.max_source_positions;
246        let cfg1 = Conv1dConfig {
247            padding: 1,
248            stride: 1,
249            groups: 1,
250            dilation: 1,
251        };
252        let cfg2 = Conv1dConfig {
253            padding: 1,
254            stride: 2,
255            groups: 1,
256            dilation: 1,
257        };
258        let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
259        let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
260        let positional_embedding = sinusoids(n_ctx, n_state, vb.device())?;
261        let blocks = (0..cfg.encoder_layers)
262            .map(|i| {
263                ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(format!("layers.{i}")))
264            })
265            .collect::<Result<Vec<_>>>()?;
266        let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?;
267        Ok(Self {
268            conv1,
269            conv2,
270            positional_embedding,
271            blocks,
272            ln_post,
273            conv1_span,
274            conv2_span,
275            span,
276        })
277    }
278
279    pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
280        let _enter = self.span.enter();
281        let x = {
282            let _enter = self.conv1_span.enter();
283            self.conv1.forward(x)?.gelu()?
284        };
285        let x = {
286            let _enter = self.conv2_span.enter();
287            self.conv2.forward(&x)?.gelu()?
288        };
289        let x = x.transpose(1, 2)?;
290        let (_bsize, seq_len, _hidden) = x.dims3()?;
291        let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
292        let mut x = x.broadcast_add(&positional_embedding)?;
293        for block in self.blocks.iter_mut() {
294            x = block.forward(&x, None, None, flush_kv_cache)?
295        }
296        let x = self.ln_post.forward(&x)?;
297        Ok(x)
298    }
299}
300
301// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
302#[derive(Debug, Clone)]
303pub struct TextDecoder {
304    token_embedding: Embedding,
305    positional_embedding: Tensor,
306    blocks: Vec<ResidualAttentionBlock>,
307    ln: LayerNorm,
308    mask: Tensor,
309    span: tracing::Span,
310    span_final: tracing::Span,
311}
312
313impl TextDecoder {
314    fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
315        let span = tracing::span!(tracing::Level::TRACE, "text-decoder");
316        let span_final = tracing::span!(tracing::Level::TRACE, "text-decoder-final");
317        let n_state = cfg.d_model;
318        let n_head = cfg.decoder_attention_heads;
319        let n_ctx = cfg.max_target_positions;
320        let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?;
321        let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?;
322        let blocks = (0..cfg.decoder_layers)
323            .map(|i| {
324                ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(format!("layers.{i}")))
325            })
326            .collect::<Result<Vec<_>>>()?;
327        let ln = layer_norm(n_state, vb.pp("layer_norm"))?;
328        let mask: Vec<_> = (0..n_ctx)
329            .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
330            .collect();
331        let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?;
332        Ok(Self {
333            token_embedding,
334            positional_embedding,
335            blocks,
336            ln,
337            mask,
338            span,
339            span_final,
340        })
341    }
342
343    pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
344        let _enter = self.span.enter();
345        let last = x.dim(D::Minus1)?;
346        let token_embedding = self.token_embedding.forward(x)?;
347        let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
348        let mut x = token_embedding.broadcast_add(&positional_embedding)?;
349        for block in self.blocks.iter_mut() {
350            x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?;
351        }
352        self.ln.forward(&x)
353    }
354
355    pub fn final_linear(&self, x: &Tensor) -> Result<Tensor> {
356        let b_size = x.dim(0)?;
357        let w = self.token_embedding.embeddings().broadcast_left(b_size)?;
358        let logits = {
359            let _enter = self.span_final.enter();
360            x.matmul(&w.t()?)?
361        };
362        Ok(logits)
363    }
364
365    pub fn reset_kv_cache(&mut self) {
366        for block in self.blocks.iter_mut() {
367            block.reset_kv_cache();
368        }
369    }
370}
371
372// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
373#[derive(Debug, Clone)]
374pub struct Whisper {
375    pub encoder: AudioEncoder,
376    pub decoder: TextDecoder,
377    pub config: Config,
378}
379
380impl Whisper {
381    pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> {
382        let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?;
383        let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?;
384        Ok(Self {
385            encoder,
386            decoder,
387            config,
388        })
389    }
390
391    pub fn reset_kv_cache(&mut self) {
392        self.encoder
393            .blocks
394            .iter_mut()
395            .for_each(|b| b.reset_kv_cache());
396        self.decoder.reset_kv_cache();
397    }
398}