1use super::Config;
2use crate::models::with_tracing::{linear, linear_no_bias, Linear};
3use candle::{Device, IndexOp, Result, Tensor, D};
4use candle_nn::{embedding, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
5
6fn conv1d(
7 in_channels: usize,
8 out_channels: usize,
9 kernel_size: usize,
10 config: Conv1dConfig,
11 vb: VarBuilder,
12) -> Result<Conv1d> {
13 let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
14 let bias = vb.get(out_channels, "bias")?;
15 Ok(Conv1d::new(weight, Some(bias), config))
16}
17
18fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
19 let weight = vb.get(size, "weight")?;
20 let bias = vb.get(size, "bias")?;
21 Ok(LayerNorm::new(weight, bias, 1e-5))
22}
23
24#[derive(Debug, Clone)]
26struct MultiHeadAttention {
27 query: Linear,
28 key: Linear,
29 value: Linear,
30 out: Linear,
31 n_head: usize,
32 span: tracing::Span,
33 softmax_span: tracing::Span,
34 matmul_span: tracing::Span,
35 kv_cache: Option<(Tensor, Tensor)>,
36}
37
38impl MultiHeadAttention {
39 fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
40 let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn");
41 let softmax_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-softmax");
42 let matmul_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-matmul");
43 let query = linear(n_state, n_state, vb.pp("q_proj"))?;
44 let value = linear(n_state, n_state, vb.pp("v_proj"))?;
45 let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
46 let out = linear(n_state, n_state, vb.pp("out_proj"))?;
47 Ok(Self {
48 query,
49 key,
50 value,
51 out,
52 n_head,
53 span,
54 softmax_span,
55 matmul_span,
56 kv_cache: None,
57 })
58 }
59
60 fn forward(
61 &mut self,
62 x: &Tensor,
63 xa: Option<&Tensor>,
64 mask: Option<&Tensor>,
65 flush_cache: bool,
66 ) -> Result<Tensor> {
67 let _enter = self.span.enter();
68 let q = self.query.forward(x)?;
69 let (k, v) = match xa {
70 None => {
71 let k = self.key.forward(x)?;
72 let v = self.value.forward(x)?;
73 (k, v)
74 }
75 Some(x) => {
76 if flush_cache {
77 self.kv_cache = None;
78 }
79 if let Some((k, v)) = &self.kv_cache {
80 (k.clone(), v.clone())
81 } else {
82 let k = self.key.forward(x)?;
83 let v = self.value.forward(x)?;
84 self.kv_cache = Some((k.clone(), v.clone()));
85 (k, v)
86 }
87 }
88 };
89 let wv = self.qkv_attention(&q, &k, &v, mask)?;
90 let out = self.out.forward(&wv)?;
91 Ok(out)
92 }
93
94 fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
95 let (n_batch, n_ctx, n_state) = x.dims3()?;
96 let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
97 x.reshape(target_dims)?.transpose(1, 2)
98 }
99
100 fn qkv_attention(
101 &self,
102 q: &Tensor,
103 k: &Tensor,
104 v: &Tensor,
105 mask: Option<&Tensor>,
106 ) -> Result<Tensor> {
107 let (_, n_ctx, n_state) = q.dims3()?;
108 let scale = ((n_state / self.n_head) as f64).powf(-0.25);
109 let q = (self.reshape_head(q)? * scale)?;
110 let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
111 let v = self.reshape_head(v)?.contiguous()?;
112 let mut qk = {
113 let _enter = self.matmul_span.enter();
114 q.matmul(&k)?
115 };
116 if let Some(mask) = mask {
117 let mask = mask.i((0..n_ctx, 0..n_ctx))?;
118 qk = qk.broadcast_add(&mask)?
119 }
120 let w = {
121 let _enter = self.softmax_span.enter();
122 candle_nn::ops::softmax_last_dim(&qk)?
123 };
124 let wv = {
125 let _enter = self.matmul_span.enter();
126 w.matmul(&v)?
127 }
128 .transpose(1, 2)?
129 .flatten_from(2)?;
130 Ok(wv)
131 }
132
133 fn reset_kv_cache(&mut self) {
134 self.kv_cache = None;
135 }
136}
137
138#[derive(Debug, Clone)]
140struct ResidualAttentionBlock {
141 attn: MultiHeadAttention,
142 attn_ln: LayerNorm,
143 cross_attn: Option<(MultiHeadAttention, LayerNorm)>,
144 mlp_linear1: Linear,
145 mlp_linear2: Linear,
146 mlp_ln: LayerNorm,
147 span: tracing::Span,
148}
149
150impl ResidualAttentionBlock {
151 fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
152 let span = tracing::span!(tracing::Level::TRACE, "residual-attn");
153 let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
154 let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
155 let cross_attn = if ca {
156 let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?;
157 let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?;
158 Some((cross_attn, cross_attn_ln))
159 } else {
160 None
161 };
162 let n_mlp = n_state * 4;
163 let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?;
164 let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?;
165 let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?;
166 Ok(Self {
167 attn,
168 attn_ln,
169 cross_attn,
170 mlp_linear1,
171 mlp_linear2,
172 mlp_ln,
173 span,
174 })
175 }
176
177 fn forward(
178 &mut self,
179 x: &Tensor,
180 xa: Option<&Tensor>,
181 mask: Option<&Tensor>,
182 flush_kv_cache: bool,
183 ) -> Result<Tensor> {
184 let _enter = self.span.enter();
185 let attn = self
186 .attn
187 .forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?;
188 let mut x = (x + attn)?;
189 if let Some((attn, ln)) = &mut self.cross_attn {
190 x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
191 }
192 let mlp = self.mlp_linear2.forward(
193 &self
194 .mlp_linear1
195 .forward(&self.mlp_ln.forward(&x)?)?
196 .gelu()?,
197 )?;
198 x + mlp
199 }
200
201 fn reset_kv_cache(&mut self) {
202 self.attn.reset_kv_cache();
203 if let Some((attn, _)) = &mut self.cross_attn {
204 attn.reset_kv_cache();
205 }
206 }
207}
208
209fn sinusoids(length: usize, channels: usize, device: &Device) -> Result<Tensor> {
210 let max_timescale = 10000f32;
211 let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
212 let inv_timescales: Vec<_> = (0..channels / 2)
213 .map(|i| (i as f32 * (-log_timescale_increment)).exp())
214 .collect();
215 let inv_timescales = Tensor::new(inv_timescales.as_slice(), device)?.unsqueeze(0)?;
216 let arange = Tensor::arange(0, length as u32, device)?
217 .to_dtype(candle::DType::F32)?
218 .unsqueeze(1)?;
219 let sh = (length, channels / 2);
220 let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?;
221 let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?;
222 Ok(sincos)
223}
224
225#[derive(Debug, Clone)]
227pub struct AudioEncoder {
228 conv1: Conv1d,
229 conv2: Conv1d,
230 positional_embedding: Tensor,
231 blocks: Vec<ResidualAttentionBlock>,
232 ln_post: LayerNorm,
233 span: tracing::Span,
234 conv1_span: tracing::Span,
235 conv2_span: tracing::Span,
236}
237
238impl AudioEncoder {
239 fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
240 let span = tracing::span!(tracing::Level::TRACE, "audio-encoder");
241 let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1");
242 let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2");
243 let n_state = cfg.d_model;
244 let n_head = cfg.encoder_attention_heads;
245 let n_ctx = cfg.max_source_positions;
246 let cfg1 = Conv1dConfig {
247 padding: 1,
248 stride: 1,
249 groups: 1,
250 dilation: 1,
251 };
252 let cfg2 = Conv1dConfig {
253 padding: 1,
254 stride: 2,
255 groups: 1,
256 dilation: 1,
257 };
258 let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
259 let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
260 let positional_embedding = sinusoids(n_ctx, n_state, vb.device())?;
261 let blocks = (0..cfg.encoder_layers)
262 .map(|i| {
263 ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(format!("layers.{i}")))
264 })
265 .collect::<Result<Vec<_>>>()?;
266 let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?;
267 Ok(Self {
268 conv1,
269 conv2,
270 positional_embedding,
271 blocks,
272 ln_post,
273 conv1_span,
274 conv2_span,
275 span,
276 })
277 }
278
279 pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
280 let _enter = self.span.enter();
281 let x = {
282 let _enter = self.conv1_span.enter();
283 self.conv1.forward(x)?.gelu()?
284 };
285 let x = {
286 let _enter = self.conv2_span.enter();
287 self.conv2.forward(&x)?.gelu()?
288 };
289 let x = x.transpose(1, 2)?;
290 let (_bsize, seq_len, _hidden) = x.dims3()?;
291 let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
292 let mut x = x.broadcast_add(&positional_embedding)?;
293 for block in self.blocks.iter_mut() {
294 x = block.forward(&x, None, None, flush_kv_cache)?
295 }
296 let x = self.ln_post.forward(&x)?;
297 Ok(x)
298 }
299}
300
301#[derive(Debug, Clone)]
303pub struct TextDecoder {
304 token_embedding: Embedding,
305 positional_embedding: Tensor,
306 blocks: Vec<ResidualAttentionBlock>,
307 ln: LayerNorm,
308 mask: Tensor,
309 span: tracing::Span,
310 span_final: tracing::Span,
311}
312
313impl TextDecoder {
314 fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
315 let span = tracing::span!(tracing::Level::TRACE, "text-decoder");
316 let span_final = tracing::span!(tracing::Level::TRACE, "text-decoder-final");
317 let n_state = cfg.d_model;
318 let n_head = cfg.decoder_attention_heads;
319 let n_ctx = cfg.max_target_positions;
320 let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?;
321 let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?;
322 let blocks = (0..cfg.decoder_layers)
323 .map(|i| {
324 ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(format!("layers.{i}")))
325 })
326 .collect::<Result<Vec<_>>>()?;
327 let ln = layer_norm(n_state, vb.pp("layer_norm"))?;
328 let mask: Vec<_> = (0..n_ctx)
329 .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
330 .collect();
331 let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?;
332 Ok(Self {
333 token_embedding,
334 positional_embedding,
335 blocks,
336 ln,
337 mask,
338 span,
339 span_final,
340 })
341 }
342
343 pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
344 let _enter = self.span.enter();
345 let last = x.dim(D::Minus1)?;
346 let token_embedding = self.token_embedding.forward(x)?;
347 let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
348 let mut x = token_embedding.broadcast_add(&positional_embedding)?;
349 for block in self.blocks.iter_mut() {
350 x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?;
351 }
352 self.ln.forward(&x)
353 }
354
355 pub fn final_linear(&self, x: &Tensor) -> Result<Tensor> {
356 let b_size = x.dim(0)?;
357 let w = self.token_embedding.embeddings().broadcast_left(b_size)?;
358 let logits = {
359 let _enter = self.span_final.enter();
360 x.matmul(&w.t()?)?
361 };
362 Ok(logits)
363 }
364
365 pub fn reset_kv_cache(&mut self) {
366 for block in self.blocks.iter_mut() {
367 block.reset_kv_cache();
368 }
369 }
370}
371
372#[derive(Debug, Clone)]
374pub struct Whisper {
375 pub encoder: AudioEncoder,
376 pub decoder: TextDecoder,
377 pub config: Config,
378}
379
380impl Whisper {
381 pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> {
382 let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?;
383 let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?;
384 Ok(Self {
385 encoder,
386 decoder,
387 config,
388 })
389 }
390
391 pub fn reset_kv_cache(&mut self) {
392 self.encoder
393 .blocks
394 .iter_mut()
395 .for_each(|b| b.reset_kv_cache());
396 self.decoder.reset_kv_cache();
397 }
398}