1use super::Config;
2use crate::quantized_nn::{layer_norm, linear, linear_no_bias, Embedding, Linear};
3pub use crate::quantized_var_builder::VarBuilder;
4use candle::{Device, IndexOp, Result, Tensor, D};
5use candle_nn::{Conv1d, Conv1dConfig, LayerNorm, Module};
6
7fn conv1d(
8 in_channels: usize,
9 out_channels: usize,
10 kernel_size: usize,
11 config: Conv1dConfig,
12 vb: VarBuilder,
13) -> Result<Conv1d> {
14 let weight = vb
15 .get((out_channels, in_channels, kernel_size), "weight")?
16 .dequantize(vb.device())?;
17 let bias = vb.get(out_channels, "bias")?.dequantize(vb.device())?;
18 Ok(Conv1d::new(weight, Some(bias), config))
19}
20
21#[derive(Debug, Clone)]
23struct MultiHeadAttention {
24 query: Linear,
25 key: Linear,
26 value: Linear,
27 out: Linear,
28 n_head: usize,
29 span: tracing::Span,
30 softmax_span: tracing::Span,
31 matmul_span: tracing::Span,
32 kv_cache: Option<(Tensor, Tensor)>,
33}
34
35impl MultiHeadAttention {
36 fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
37 let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn");
38 let softmax_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-softmax");
39 let matmul_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-matmul");
40 let query = linear(n_state, n_state, vb.pp("q_proj"))?;
41 let value = linear(n_state, n_state, vb.pp("v_proj"))?;
42 let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
43 let out = linear(n_state, n_state, vb.pp("out_proj"))?;
44 Ok(Self {
45 query,
46 key,
47 value,
48 out,
49 n_head,
50 span,
51 softmax_span,
52 matmul_span,
53 kv_cache: None,
54 })
55 }
56
57 fn forward(
58 &mut self,
59 x: &Tensor,
60 xa: Option<&Tensor>,
61 mask: Option<&Tensor>,
62 flush_cache: bool,
63 ) -> Result<Tensor> {
64 let _enter = self.span.enter();
65 let q = self.query.forward(x)?;
66 let (k, v) = match xa {
67 None => {
68 let k = self.key.forward(x)?;
69 let v = self.value.forward(x)?;
70 (k, v)
71 }
72 Some(x) => {
73 if flush_cache {
74 self.kv_cache = None;
75 }
76 if let Some((k, v)) = &self.kv_cache {
77 (k.clone(), v.clone())
78 } else {
79 let k = self.key.forward(x)?;
80 let v = self.value.forward(x)?;
81 self.kv_cache = Some((k.clone(), v.clone()));
82 (k, v)
83 }
84 }
85 };
86 let wv = self.qkv_attention(&q, &k, &v, mask)?;
87 let out = self.out.forward(&wv)?;
88 Ok(out)
89 }
90
91 fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
92 let (n_batch, n_ctx, n_state) = x.dims3()?;
93 let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
94 x.reshape(target_dims)?.transpose(1, 2)
95 }
96
97 fn qkv_attention(
98 &self,
99 q: &Tensor,
100 k: &Tensor,
101 v: &Tensor,
102 mask: Option<&Tensor>,
103 ) -> Result<Tensor> {
104 let (_, n_ctx, n_state) = q.dims3()?;
105 let scale = ((n_state / self.n_head) as f64).powf(-0.25);
106 let q = (self.reshape_head(q)? * scale)?;
107 let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
108 let v = self.reshape_head(v)?.contiguous()?;
109 let mut qk = {
110 let _enter = self.matmul_span.enter();
111 q.matmul(&k)?
112 };
113 if let Some(mask) = mask {
114 let mask = mask.i((0..n_ctx, 0..n_ctx))?;
115 qk = qk.broadcast_add(&mask)?
116 }
117 let w = {
118 let _enter = self.softmax_span.enter();
119 candle_nn::ops::softmax_last_dim(&qk)?
120 };
121 let wv = {
122 let _enter = self.matmul_span.enter();
123 w.matmul(&v)?
124 }
125 .transpose(1, 2)?
126 .flatten_from(2)?;
127 Ok(wv)
128 }
129
130 fn reset_kv_cache(&mut self) {
131 self.kv_cache = None;
132 }
133}
134
135#[derive(Debug, Clone)]
137struct ResidualAttentionBlock {
138 attn: MultiHeadAttention,
139 attn_ln: LayerNorm,
140 cross_attn: Option<(MultiHeadAttention, LayerNorm)>,
141 mlp_linear1: Linear,
142 mlp_linear2: Linear,
143 mlp_ln: LayerNorm,
144 span: tracing::Span,
145}
146
147impl ResidualAttentionBlock {
148 fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
149 let span = tracing::span!(tracing::Level::TRACE, "residual-attn");
150 let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
151 let attn_ln = layer_norm(n_state, 1e-5, vb.pp("self_attn_layer_norm"))?;
152 let cross_attn = if ca {
153 let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?;
154 let cross_attn_ln = layer_norm(n_state, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
155 Some((cross_attn, cross_attn_ln))
156 } else {
157 None
158 };
159 let n_mlp = n_state * 4;
160 let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?;
161 let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?;
162 let mlp_ln = layer_norm(n_state, 1e-5, vb.pp("final_layer_norm"))?;
163 Ok(Self {
164 attn,
165 attn_ln,
166 cross_attn,
167 mlp_linear1,
168 mlp_linear2,
169 mlp_ln,
170 span,
171 })
172 }
173
174 fn forward(
175 &mut self,
176 x: &Tensor,
177 xa: Option<&Tensor>,
178 mask: Option<&Tensor>,
179 flush_kv_cache: bool,
180 ) -> Result<Tensor> {
181 let _enter = self.span.enter();
182 let attn = self
183 .attn
184 .forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?;
185 let mut x = (x + attn)?;
186 if let Some((attn, ln)) = &mut self.cross_attn {
187 x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
188 }
189 let mlp = x
190 .apply(&self.mlp_ln)?
191 .apply(&self.mlp_linear1)?
192 .gelu()?
193 .apply(&self.mlp_linear2)?;
194 x + mlp
195 }
196
197 fn reset_kv_cache(&mut self) {
198 self.attn.reset_kv_cache();
199 if let Some((attn, _)) = &mut self.cross_attn {
200 attn.reset_kv_cache();
201 }
202 }
203}
204
205fn sinusoids(length: usize, channels: usize, device: &Device) -> Result<Tensor> {
206 let max_timescale = 10000f32;
207 let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
208 let inv_timescales: Vec<_> = (0..channels / 2)
209 .map(|i| (i as f32 * (-log_timescale_increment)).exp())
210 .collect();
211 let inv_timescales = Tensor::new(inv_timescales.as_slice(), device)?.unsqueeze(0)?;
212 let arange = Tensor::arange(0, length as u32, device)?
213 .to_dtype(candle::DType::F32)?
214 .unsqueeze(1)?;
215 let sh = (length, channels / 2);
216 let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?;
217 let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?;
218 Ok(sincos)
219}
220
221#[derive(Debug, Clone)]
223pub struct AudioEncoder {
224 conv1: Conv1d,
225 conv2: Conv1d,
226 positional_embedding: Tensor,
227 blocks: Vec<ResidualAttentionBlock>,
228 ln_post: LayerNorm,
229 span: tracing::Span,
230 conv1_span: tracing::Span,
231 conv2_span: tracing::Span,
232}
233
234impl AudioEncoder {
235 fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
236 let span = tracing::span!(tracing::Level::TRACE, "audio-encoder");
237 let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1");
238 let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2");
239 let n_state = cfg.d_model;
240 let n_head = cfg.encoder_attention_heads;
241 let n_ctx = cfg.max_source_positions;
242 let cfg1 = Conv1dConfig {
243 padding: 1,
244 stride: 1,
245 groups: 1,
246 dilation: 1,
247 };
248 let cfg2 = Conv1dConfig {
249 padding: 1,
250 stride: 2,
251 groups: 1,
252 dilation: 1,
253 };
254 let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
255 let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
256 let positional_embedding = sinusoids(n_ctx, n_state, vb.device())?;
257 let blocks = (0..cfg.encoder_layers)
258 .map(|i| {
259 ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(format!("layers.{i}")))
260 })
261 .collect::<Result<Vec<_>>>()?;
262 let ln_post = layer_norm(n_state, 1e-5, vb.pp("layer_norm"))?;
263 Ok(Self {
264 conv1,
265 conv2,
266 positional_embedding,
267 blocks,
268 ln_post,
269 conv1_span,
270 conv2_span,
271 span,
272 })
273 }
274
275 pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
276 let _enter = self.span.enter();
277 let x = {
278 let _enter = self.conv1_span.enter();
279 self.conv1.forward(x)?.gelu()?
280 };
281 let x = {
282 let _enter = self.conv2_span.enter();
283 self.conv2.forward(&x)?.gelu()?
284 };
285 let x = x.transpose(1, 2)?;
286 let (_bsize, seq_len, _hidden) = x.dims3()?;
287 let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
288 let mut x = x.broadcast_add(&positional_embedding)?;
289 for block in self.blocks.iter_mut() {
290 x = block.forward(&x, None, None, flush_kv_cache)?
291 }
292 let x = self.ln_post.forward(&x)?;
293 Ok(x)
294 }
295
296 pub fn reset_kv_cache(&mut self) {
297 for block in self.blocks.iter_mut() {
298 block.reset_kv_cache();
299 }
300 }
301}
302
303#[derive(Debug, Clone)]
305pub struct TextDecoder {
306 token_embedding: Embedding,
307 positional_embedding: Tensor,
308 blocks: Vec<ResidualAttentionBlock>,
309 ln: LayerNorm,
310 mask: Tensor,
311 span: tracing::Span,
312 span_final: tracing::Span,
313}
314
315impl TextDecoder {
316 fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
317 let span = tracing::span!(tracing::Level::TRACE, "text-decoder");
318 let span_final = tracing::span!(tracing::Level::TRACE, "text-decoder-final");
319 let n_state = cfg.d_model;
320 let n_head = cfg.decoder_attention_heads;
321 let n_ctx = cfg.max_target_positions;
322 let token_embedding = Embedding::new(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?;
323 let positional_embedding = vb
324 .get((n_ctx, n_state), "embed_positions.weight")?
325 .dequantize(vb.device())?;
326 let blocks = (0..cfg.decoder_layers)
327 .map(|i| {
328 ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(format!("layers.{i}")))
329 })
330 .collect::<Result<Vec<_>>>()?;
331 let ln = layer_norm(n_state, 1e-5, vb.pp("layer_norm"))?;
332 let mask: Vec<_> = (0..n_ctx)
333 .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
334 .collect();
335 let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?;
336 Ok(Self {
337 token_embedding,
338 positional_embedding,
339 blocks,
340 ln,
341 mask,
342 span,
343 span_final,
344 })
345 }
346
347 pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
348 let _enter = self.span.enter();
349 let last = x.dim(D::Minus1)?;
350 let token_embedding = self.token_embedding.forward(x)?;
351 let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
352 let mut x = token_embedding.broadcast_add(&positional_embedding)?;
353 for block in self.blocks.iter_mut() {
354 x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?;
355 }
356 self.ln.forward(&x)
357 }
358
359 pub fn final_linear(&self, x: &Tensor) -> Result<Tensor> {
360 let b_size = x.dim(0)?;
361 let w = self.token_embedding.embeddings().broadcast_left(b_size)?;
362 let logits = {
363 let _enter = self.span_final.enter();
364 x.matmul(&w.t()?)?
365 };
366 Ok(logits)
367 }
368
369 pub fn reset_kv_cache(&mut self) {
370 for block in self.blocks.iter_mut() {
371 block.reset_kv_cache();
372 }
373 }
374}
375
376#[derive(Debug, Clone)]
378pub struct Whisper {
379 pub encoder: AudioEncoder,
380 pub decoder: TextDecoder,
381 pub config: Config,
382}
383
384impl Whisper {
385 pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> {
386 let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?;
387 let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?;
388 Ok(Self {
389 encoder,
390 decoder,
391 config,
392 })
393 }
394
395 pub fn reset_kv_cache(&mut self) {
396 self.encoder.reset_kv_cache();
397 self.decoder.reset_kv_cache();
398 }
399}