candle_transformers/models/mimi/
conv.rs

1// Copyright (c) Kyutai, all rights reserved.
2// This source code is licensed under the license found in the
3// LICENSE file in the root directory of this source tree.
4
5use candle::{Module, Result, StreamTensor, StreamingModule, Tensor, D};
6use candle_nn::{Conv1d, VarBuilder};
7
8#[allow(clippy::enum_variant_names)]
9#[derive(Debug, Copy, Clone, PartialEq, Eq)]
10pub enum Norm {
11    WeightNorm,
12    SpectralNorm,
13    TimeGroupNorm,
14}
15
16#[derive(Debug, Copy, Clone, PartialEq, Eq)]
17pub enum PadMode {
18    Constant,
19    Reflect,
20    Replicate,
21}
22
23// Applies weight norm for inference by recomputing the weight tensor. This
24// does not apply to training.
25// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
26fn conv1d_weight_norm(
27    in_c: usize,
28    out_c: usize,
29    kernel_size: usize,
30    bias: bool,
31    config: candle_nn::Conv1dConfig,
32    vb: VarBuilder,
33) -> Result<Conv1d> {
34    let weight = if vb.contains_tensor("weight") {
35        vb.get((out_c, in_c, kernel_size), "weight")?
36    } else {
37        let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
38        let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
39        let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
40        weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?
41    };
42    let bias = if bias {
43        Some(vb.get(out_c, "bias")?)
44    } else {
45        None
46    };
47    Ok(Conv1d::new(weight, bias, config))
48}
49
50#[derive(Debug, Clone)]
51pub struct NormConv1d {
52    conv: Conv1d,
53    norm: Option<candle_nn::GroupNorm>,
54    span: tracing::Span,
55}
56
57impl NormConv1d {
58    #[allow(clippy::too_many_arguments)]
59    pub fn new(
60        in_c: usize,
61        out_c: usize,
62        k_size: usize,
63        causal: bool,
64        norm: Option<Norm>,
65        bias: bool,
66        cfg: candle_nn::Conv1dConfig,
67        vb: VarBuilder,
68    ) -> Result<Self> {
69        let conv = match norm {
70            None | Some(Norm::TimeGroupNorm) => {
71                if bias {
72                    candle_nn::conv1d(in_c, out_c, k_size, cfg, vb.pp("conv"))?
73                } else {
74                    candle_nn::conv1d_no_bias(in_c, out_c, k_size, cfg, vb.pp("conv"))?
75                }
76            }
77            Some(Norm::WeightNorm) => {
78                conv1d_weight_norm(in_c, out_c, k_size, bias, cfg, vb.pp("conv"))?
79            }
80            Some(Norm::SpectralNorm) => candle::bail!("SpectralNorm is not supported yet."),
81        };
82        let norm = match norm {
83            None | Some(Norm::WeightNorm) | Some(Norm::SpectralNorm) => None,
84            Some(Norm::TimeGroupNorm) => {
85                if causal {
86                    candle::bail!("GroupNorm doesn't support causal evaluation.")
87                }
88                let norm = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
89                Some(norm)
90            }
91        };
92        Ok(Self {
93            conv,
94            norm,
95            span: tracing::span!(tracing::Level::TRACE, "norm-conv1d"),
96        })
97    }
98}
99
100impl Module for NormConv1d {
101    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
102        let _enter = self.span.enter();
103        let xs = xs.apply(&self.conv)?;
104        match self.norm.as_ref() {
105            None => Ok(xs),
106            Some(norm) => xs.apply(norm),
107        }
108    }
109}
110
111#[derive(Debug, Clone)]
112pub struct NormConvTranspose1d {
113    ws: Tensor,
114    bs: Option<Tensor>,
115    k_size: usize,
116    stride: usize,
117    groups: usize,
118    norm: Option<candle_nn::GroupNorm>,
119    span: tracing::Span,
120}
121
122impl NormConvTranspose1d {
123    #[allow(clippy::too_many_arguments)]
124    pub fn new(
125        in_c: usize,
126        out_c: usize,
127        k_size: usize,
128        causal: bool,
129        norm: Option<Norm>,
130        bias: bool,
131        stride: usize,
132        groups: usize,
133        vb: VarBuilder,
134    ) -> Result<Self> {
135        let vb = vb.pp("conv");
136        let bs = if bias {
137            Some(vb.get(out_c, "bias")?)
138        } else {
139            None
140        };
141        let ws = match norm {
142            None | Some(Norm::TimeGroupNorm) => vb.get((in_c, out_c / groups, k_size), "weight")?,
143            Some(Norm::WeightNorm) => {
144                if vb.contains_tensor("weight") {
145                    vb.get((in_c, out_c, k_size), "weight")?
146                } else {
147                    let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
148                    let weight_v = vb.get((in_c, out_c, k_size), "weight_v")?;
149                    let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
150                    weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?
151                }
152            }
153            Some(Norm::SpectralNorm) => candle::bail!("SpectralNorm is not supported yet."),
154        };
155        let (ws, groups) = if groups == out_c && in_c == out_c {
156            let eye = Tensor::eye(out_c, ws.dtype(), ws.device())?;
157            let ws = ws
158                .repeat((1, out_c, 1))?
159                .mul(&eye.unsqueeze(2)?.repeat((1, 1, k_size))?)?;
160            (ws, 1)
161        } else {
162            (ws, groups)
163        };
164        let norm = match norm {
165            None | Some(Norm::WeightNorm) | Some(Norm::SpectralNorm) => None,
166            Some(Norm::TimeGroupNorm) => {
167                if causal {
168                    candle::bail!("GroupNorm doesn't support causal evaluation.")
169                }
170                let norm = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
171                Some(norm)
172            }
173        };
174        Ok(Self {
175            ws,
176            bs,
177            k_size,
178            stride,
179            groups,
180            norm,
181            span: tracing::span!(tracing::Level::TRACE, "norm-conv-tr1d"),
182        })
183    }
184}
185
186impl Module for NormConvTranspose1d {
187    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
188        let _enter = self.span.enter();
189        // conv-transpose1d seems to be broken on metal after enough iterations. Causing
190        // the following error:
191        // _status < MTLCommandBufferStatusCommitted >
192        // -[IOGPUMetalCommandBuffer setCurrentCommandEncoder:]
193        // This is now fixed in candle.
194        let xs = Tensor::conv_transpose1d(xs, &self.ws, 0, 0, self.stride, 1, self.groups)?;
195        let xs = match &self.bs {
196            None => xs,
197            Some(bias) => {
198                let b = bias.dims1()?;
199                let bias = bias.reshape((1, b, 1))?;
200                xs.broadcast_add(&bias)?
201            }
202        };
203        match self.norm.as_ref() {
204            None => Ok(xs),
205            Some(norm) => xs.apply(norm),
206        }
207    }
208}
209
210fn get_extra_padding_for_conv1d(
211    xs: &Tensor,
212    k_size: usize,
213    stride: usize,
214    padding_total: usize,
215) -> Result<usize> {
216    let len = xs.dim(D::Minus1)?;
217    let n_frames = (len + padding_total).saturating_sub(k_size) as f64 / stride as f64 + 1.0;
218    let ideal_len =
219        ((n_frames.ceil() as usize - 1) * stride + k_size).saturating_sub(padding_total);
220    Ok(ideal_len.saturating_sub(len))
221}
222
223fn pad1d(xs: &Tensor, pad_l: usize, pad_r: usize, mode: PadMode) -> Result<Tensor> {
224    match mode {
225        PadMode::Constant => xs.pad_with_zeros(D::Minus1, pad_l, pad_r),
226        PadMode::Reflect => candle::bail!("pad-mode 'reflect' is not supported"),
227        PadMode::Replicate => xs.pad_with_same(D::Minus1, pad_l, pad_r),
228    }
229}
230
231fn unpad1d(xs: &Tensor, unpad_l: usize, unpad_r: usize) -> Result<Tensor> {
232    let len = xs.dim(D::Minus1)?;
233    if len < unpad_l + unpad_r {
234        candle::bail!("unpad1d: tensor len {len} is too low, {unpad_l} + {unpad_r}")
235    }
236    xs.narrow(D::Minus1, unpad_l, len - (unpad_l + unpad_r))
237}
238
239#[derive(Debug, Clone)]
240pub struct StreamableConv1d {
241    conv: NormConv1d,
242    causal: bool,
243    pad_mode: PadMode,
244    state_prev_xs: StreamTensor,
245    left_pad_applied: bool,
246    kernel_size: usize,
247    span: tracing::Span,
248}
249
250impl StreamableConv1d {
251    #[allow(clippy::too_many_arguments)]
252    pub fn new(
253        in_c: usize,
254        out_c: usize,
255        k_size: usize,
256        stride: usize,
257        dilation: usize,
258        groups: usize,
259        bias: bool,
260        causal: bool,
261        norm: Option<Norm>,
262        pad_mode: PadMode,
263        vb: VarBuilder,
264    ) -> Result<Self> {
265        let cfg = candle_nn::Conv1dConfig {
266            padding: 0,
267            stride,
268            dilation,
269            groups,
270        };
271        let conv = NormConv1d::new(in_c, out_c, k_size, causal, norm, bias, cfg, vb)?;
272        if k_size < stride {
273            candle::bail!("kernel-size {k_size} is smaller than stride {stride}")
274        }
275        Ok(Self {
276            conv,
277            causal,
278            pad_mode,
279            state_prev_xs: StreamTensor::empty(),
280            left_pad_applied: false,
281            kernel_size: k_size,
282            span: tracing::span!(tracing::Level::TRACE, "streamable-conv1d"),
283        })
284    }
285}
286
287impl Module for StreamableConv1d {
288    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
289        let _enter = self.span.enter();
290        let (_b, _t, _c) = xs.dims3()?;
291        let k_size = self.conv.conv.weight().dim(D::Minus1)?;
292        let conv_cfg = self.conv.conv.config();
293        // Effective kernel size with dilations.
294        let k_size = (k_size - 1) * conv_cfg.dilation + 1;
295        let padding_total = k_size - conv_cfg.stride;
296        let extra_padding =
297            get_extra_padding_for_conv1d(xs, k_size, conv_cfg.stride, padding_total)?;
298        let xs = if self.causal {
299            pad1d(xs, padding_total, extra_padding, self.pad_mode)?
300        } else {
301            let padding_right = padding_total / 2;
302            let padding_left = padding_total - padding_right;
303            pad1d(
304                xs,
305                padding_left,
306                padding_right + extra_padding,
307                self.pad_mode,
308            )?
309        };
310        xs.apply(&self.conv)
311    }
312}
313
314impl StreamingModule for StreamableConv1d {
315    fn reset_state(&mut self) {
316        self.state_prev_xs.reset();
317        self.left_pad_applied = false;
318    }
319
320    fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
321        let _enter = self.span.enter();
322        let xs = match xs.as_option() {
323            None => return Ok(().into()),
324            Some(xs) => xs.clone(),
325        };
326        let xs = if self.left_pad_applied {
327            xs
328        } else {
329            self.left_pad_applied = true;
330            let k_size = self.conv.conv.weight().dim(D::Minus1)?;
331            let conv_cfg = self.conv.conv.config();
332            let k_size = (k_size - 1) * conv_cfg.dilation + 1;
333            let padding_total = k_size - conv_cfg.stride;
334            pad1d(&xs, padding_total, 0, self.pad_mode)?
335        };
336        let cfg = self.conv.conv.config();
337        let stride = cfg.stride;
338        let dilation = cfg.dilation;
339        let kernel = (self.kernel_size - 1) * dilation + 1;
340        let xs = StreamTensor::cat2(&self.state_prev_xs, &xs.into(), D::Minus1)?;
341        let seq_len = xs.seq_len(D::Minus1)?;
342        let num_frames = (seq_len + stride).saturating_sub(kernel) / stride;
343        if num_frames > 0 {
344            let offset = num_frames * stride;
345            self.state_prev_xs = xs.narrow(D::Minus1, offset, seq_len - offset)?;
346            let in_l = (num_frames - 1) * stride + kernel;
347            let xs = xs.narrow(D::Minus1, 0, in_l)?;
348            // We apply the underlying convtr directly rather than through forward so as
349            // not to apply any padding here.
350            xs.apply(&self.conv.conv)
351        } else {
352            self.state_prev_xs = xs;
353            Ok(StreamTensor::empty())
354        }
355    }
356}
357
358#[derive(Debug, Clone)]
359pub struct StreamableConvTranspose1d {
360    convtr: NormConvTranspose1d,
361    causal: bool,
362    state_prev_ys: StreamTensor,
363    kernel_size: usize,
364    span: tracing::Span,
365}
366
367impl StreamableConvTranspose1d {
368    #[allow(clippy::too_many_arguments)]
369    pub fn new(
370        in_c: usize,
371        out_c: usize,
372        k_size: usize,
373        stride: usize,
374        groups: usize,
375        bias: bool,
376        causal: bool,
377        norm: Option<Norm>,
378        vb: VarBuilder,
379    ) -> Result<Self> {
380        let convtr =
381            NormConvTranspose1d::new(in_c, out_c, k_size, causal, norm, bias, stride, groups, vb)?;
382        Ok(Self {
383            convtr,
384            causal,
385            kernel_size: k_size,
386            state_prev_ys: StreamTensor::empty(),
387            span: tracing::span!(tracing::Level::TRACE, "streamable-conv-tr1d"),
388        })
389    }
390}
391
392impl Module for StreamableConvTranspose1d {
393    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
394        let _enter = self.span.enter();
395        let k_size = self.convtr.k_size;
396        let stride = self.convtr.stride;
397        let padding_total = k_size.saturating_sub(stride);
398        let xs = xs.apply(&self.convtr)?;
399        if self.causal {
400            // This corresponds to trim_right_ratio = 1.
401            unpad1d(&xs, 0, padding_total)
402        } else {
403            let padding_right = padding_total / 2;
404            let padding_left = padding_total - padding_right;
405            unpad1d(&xs, padding_left, padding_right)
406        }
407    }
408}
409
410impl StreamingModule for StreamableConvTranspose1d {
411    fn reset_state(&mut self) {
412        self.state_prev_ys.reset()
413    }
414
415    fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
416        let _enter = self.span.enter();
417        let xs = match xs.as_option() {
418            Some(xs) => xs,
419            None => return Ok(StreamTensor::empty()),
420        };
421        let stride = self.convtr.stride;
422        // We apply the underlying convtr directly rather than through forward so as
423        // not to apply any padding here.
424        let ys = self.convtr.forward(xs)?;
425        let ot = ys.dim(D::Minus1)?;
426        let ys = match self.state_prev_ys.as_option() {
427            None => ys,
428            Some(prev_ys) => {
429                let pt = prev_ys.dim(D::Minus1)?;
430                // Remove the bias as it will be applied multiple times.
431                let prev_ys = match &self.convtr.bs {
432                    None => prev_ys.clone(),
433                    Some(bias) => {
434                        let bias = bias.reshape((1, (), 1))?;
435                        prev_ys.broadcast_sub(&bias)?
436                    }
437                };
438                let ys1 = (ys.narrow(D::Minus1, 0, pt)? + prev_ys)?;
439                let ys2 = ys.narrow(D::Minus1, pt, ot - pt)?;
440                Tensor::cat(&[ys1, ys2], D::Minus1)?
441            }
442        };
443        let invalid_steps = self.kernel_size - stride;
444        let (ys, prev_ys) = StreamTensor::from(ys).split(D::Minus1, ot - invalid_steps)?;
445        self.state_prev_ys = prev_ys;
446        Ok(ys)
447    }
448}
449
450#[derive(Debug, Clone)]
451pub struct ConvDownsample1d {
452    conv: StreamableConv1d,
453}
454
455impl ConvDownsample1d {
456    pub fn new(
457        stride: usize,
458        dim: usize,
459        causal: bool,
460        learnt: bool,
461        vb: VarBuilder,
462    ) -> Result<Self> {
463        if !learnt {
464            candle::bail!("only learnt=true is supported")
465        }
466        let conv = StreamableConv1d::new(
467            /* in_c */ dim,
468            /* out_c */ dim,
469            /* k_size_c */ 2 * stride,
470            /* stride */ stride,
471            /* dilation */ 1,
472            /* groups */ 1, // channel_wise = false
473            /* bias */ false,
474            /* causal */ causal,
475            /* norm */ None,
476            /* pad_mode */ PadMode::Replicate,
477            vb,
478        )?;
479        Ok(Self { conv })
480    }
481}
482
483impl Module for ConvDownsample1d {
484    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
485        xs.apply(&self.conv)
486    }
487}
488
489impl StreamingModule for ConvDownsample1d {
490    fn reset_state(&mut self) {
491        self.conv.reset_state()
492    }
493
494    fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
495        self.conv.step(xs)
496    }
497}
498
499#[derive(Debug, Clone)]
500pub struct ConvTrUpsample1d {
501    convtr: StreamableConvTranspose1d,
502}
503
504impl ConvTrUpsample1d {
505    pub fn new(
506        stride: usize,
507        dim: usize,
508        causal: bool,
509        learnt: bool,
510        vb: VarBuilder,
511    ) -> Result<Self> {
512        if !learnt {
513            candle::bail!("only learnt=true is supported")
514        }
515        let convtr = StreamableConvTranspose1d::new(
516            dim,
517            dim,
518            /* k_size */ 2 * stride,
519            /* stride */ stride,
520            /* groups */ dim,
521            /* bias */ false,
522            /* causal */ causal,
523            /* norm */ None,
524            vb,
525        )?;
526        Ok(Self { convtr })
527    }
528}
529
530impl Module for ConvTrUpsample1d {
531    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
532        xs.apply(&self.convtr)
533    }
534}
535
536impl StreamingModule for ConvTrUpsample1d {
537    fn reset_state(&mut self) {
538        self.convtr.reset_state()
539    }
540
541    fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
542        self.convtr.step(xs)
543    }
544}
545
546#[cfg(test)]
547mod tests {
548    use super::*;
549    use candle::IndexOp;
550
551    fn run_conv1d(
552        k_size: usize,
553        stride: usize,
554        dilation: usize,
555        step_size: usize,
556        len: usize,
557        bias: bool,
558    ) -> Result<()> {
559        // TODO: We should ensure for the seed to be constant when running these tests.
560        let dev = &candle::Device::Cpu;
561        let vm = candle_nn::VarMap::new();
562        let vb = VarBuilder::from_varmap(&vm, candle::DType::F32, dev);
563        let conv1d = StreamableConv1d::new(
564            /* in_c */ 2,
565            /* out_c */ 3,
566            /* k_size */ k_size,
567            /* stride */ stride,
568            /* dilation */ dilation,
569            /* groups */ 1,
570            /* bias */ bias,
571            /* causal */ true,
572            /* norm */ None,
573            /* pad_mode */ PadMode::Constant,
574            vb,
575        )?;
576        let xs = Tensor::randn(0f32, 1., (1, 2, step_size * len), dev)?;
577        let ys = conv1d.forward(&xs)?;
578        let mut conv1d = conv1d;
579        let mut ys_steps = vec![];
580        for idx in 0..len {
581            let xs = xs.i((.., .., step_size * idx..step_size * (idx + 1)))?;
582            let ys = conv1d.step(&xs.into())?;
583            if let Some(ys) = ys.as_option() {
584                ys_steps.push(ys.clone())
585            }
586        }
587        let ys_steps = Tensor::cat(&ys_steps, D::Minus1)?;
588        let diff = (&ys - &ys_steps)?
589            .abs()?
590            .flatten_all()?
591            .max(0)?
592            .to_vec0::<f32>()?;
593        if diff > 1e-5 {
594            println!("{xs}");
595            println!("{ys}");
596            println!("{ys_steps}");
597            candle::bail!("larger diff than expected {diff}")
598        }
599        Ok(())
600    }
601
602    fn run_conv_tr1d(
603        k_size: usize,
604        stride: usize,
605        step_size: usize,
606        len: usize,
607        bias: bool,
608    ) -> Result<()> {
609        // TODO: We should ensure for the seed to be constant when running these tests.
610        let dev = &candle::Device::Cpu;
611        let vm = candle_nn::VarMap::new();
612        let vb = VarBuilder::from_varmap(&vm, candle::DType::F32, dev);
613        let conv1d = StreamableConvTranspose1d::new(
614            /* in_c */ 2, /* out_c */ 3, /* k_size */ k_size,
615            /* stride */ stride, /* groups */ 1, /* bias */ bias,
616            /* causal */ true, /* norm */ None, vb,
617        )?;
618        let xs = Tensor::randn(0f32, 1., (1, 2, step_size * len), dev)?;
619        let ys = conv1d.forward(&xs)?;
620        let mut conv1d = conv1d;
621        let mut ys_steps = vec![];
622        for idx in 0..len {
623            let xs = xs.i((.., .., step_size * idx..step_size * (idx + 1)))?;
624            let ys = conv1d.step(&xs.into())?;
625            if let Some(ys) = ys.as_option() {
626                ys_steps.push(ys.clone())
627            }
628        }
629        let ys_steps = Tensor::cat(&ys_steps, D::Minus1)?;
630        let diff = (&ys - &ys_steps)?
631            .abs()?
632            .flatten_all()?
633            .max(0)?
634            .to_vec0::<f32>()?;
635        if diff > 1e-5 {
636            println!("{xs}");
637            println!("{ys}");
638            println!("{ys_steps}");
639            candle::bail!("larger diff than expected {diff}")
640        }
641        Ok(())
642    }
643
644    #[test]
645    fn conv1d() -> Result<()> {
646        for step_size in [1, 2, 3] {
647            for bias in [false, true] {
648                run_conv1d(1, 1, 1, step_size, 5, bias)?;
649                run_conv1d(2, 1, 1, step_size, 5, bias)?;
650                run_conv1d(2, 2, 1, step_size, 6, bias)?;
651                run_conv1d(3, 2, 1, step_size, 8, bias)?;
652                run_conv1d(3, 2, 2, step_size, 8, bias)?;
653            }
654        }
655        Ok(())
656    }
657
658    #[test]
659    fn conv_tr1d() -> Result<()> {
660        for step_size in [1, 2, 3] {
661            for bias in [false, true] {
662                run_conv_tr1d(1, 1, step_size, 5, bias)?;
663                run_conv_tr1d(2, 1, step_size, 5, bias)?;
664                run_conv_tr1d(3, 1, step_size, 5, bias)?;
665                run_conv_tr1d(3, 2, step_size, 5, bias)?;
666            }
667        }
668        Ok(())
669    }
670}