candle_transformers/models/
encodec.rs

1//! EnCodec neural audio codec based on the Encodec implementation.
2//!
3//! See ["High Fidelity Neural Audio Compression"](https://arxiv.org/abs/2210.13438)
4//!
5//! Based on implementation from [huggingface/transformers](https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py)
6
7use candle::{DType, IndexOp, Layout, Module, Result, Shape, Tensor, D};
8use candle_nn::{conv1d, Conv1d, ConvTranspose1d, VarBuilder};
9
10// Encodec Model
11// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
12
13#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Deserialize)]
14pub enum NormType {
15    WeightNorm,
16    TimeGroupNorm,
17    None,
18}
19
20#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Deserialize)]
21pub enum PadMode {
22    Constant,
23    Reflect,
24    Replicate,
25}
26
27#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
28pub struct Config {
29    pub target_bandwidths: Vec<f64>,
30    pub sampling_rate: usize,
31    pub audio_channels: usize,
32    pub normalize: bool,
33    pub chunk_length_s: Option<usize>,
34    pub overlap: Option<usize>,
35    pub hidden_size: usize,
36    pub num_filters: usize,
37    pub num_residual_layers: usize,
38    pub upsampling_ratios: Vec<usize>,
39    pub norm_type: NormType,
40    pub kernel_size: usize,
41    pub last_kernel_size: usize,
42    pub residual_kernel_size: usize,
43    pub dilation_growth_rate: usize,
44    pub use_causal_conv: bool,
45    pub pad_mode: PadMode,
46    pub compress: usize,
47    pub num_lstm_layers: usize,
48    pub trim_right_ratio: f64,
49    pub codebook_size: usize,
50    pub codebook_dim: Option<usize>,
51    pub use_conv_shortcut: bool,
52}
53
54impl Default for Config {
55    fn default() -> Self {
56        Self {
57            target_bandwidths: vec![1.5, 3.0, 6.0, 12.0, 24.0],
58            sampling_rate: 24_000,
59            audio_channels: 1,
60            normalize: false,
61            chunk_length_s: None,
62            overlap: None,
63            hidden_size: 128,
64            num_filters: 32,
65            num_residual_layers: 1,
66            upsampling_ratios: vec![8, 5, 4, 2],
67            norm_type: NormType::WeightNorm,
68            kernel_size: 7,
69            last_kernel_size: 7,
70            residual_kernel_size: 3,
71            dilation_growth_rate: 2,
72            use_causal_conv: true,
73            // This should be PadMode::Reflect which is currently unsupported in candle.
74            pad_mode: PadMode::Replicate,
75            compress: 2,
76            num_lstm_layers: 2,
77            trim_right_ratio: 1.0,
78            codebook_size: 1024,
79            codebook_dim: None,
80            use_conv_shortcut: true,
81        }
82    }
83}
84
85impl Config {
86    fn codebook_dim(&self) -> usize {
87        self.codebook_dim.unwrap_or(self.hidden_size)
88    }
89
90    fn frame_rate(&self) -> usize {
91        let hop_length: usize = self.upsampling_ratios.iter().product();
92        self.sampling_rate.div_ceil(hop_length)
93    }
94
95    fn num_quantizers(&self) -> usize {
96        let num = 1000f64
97            * self
98                .target_bandwidths
99                .last()
100                .expect("empty target_bandwidths");
101        (num as usize) / (self.frame_rate() * 10)
102    }
103}
104
105fn get_extra_padding_for_conv1d(
106    xs: &Tensor,
107    k_size: usize,
108    stride: usize,
109    padding_total: usize,
110) -> Result<usize> {
111    let len = xs.dim(D::Minus1)?;
112    let n_frames = (len + padding_total).saturating_sub(k_size) as f64 / stride as f64 + 1.0;
113    let ideal_len =
114        ((n_frames.ceil() as usize - 1) * stride + k_size).saturating_sub(padding_total);
115    Ok(ideal_len.saturating_sub(len))
116}
117
118fn pad1d(xs: &Tensor, pad_l: usize, pad_r: usize, mode: PadMode) -> Result<Tensor> {
119    match mode {
120        PadMode::Constant => xs.pad_with_zeros(D::Minus1, pad_l, pad_r),
121        PadMode::Reflect => candle::bail!("pad-mode 'reflect' is not supported"),
122        PadMode::Replicate => xs.pad_with_same(D::Minus1, pad_l, pad_r),
123    }
124}
125
126// Applies weight norm for inference by recomputing the weight tensor. This
127// does not apply to training.
128// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
129pub fn conv1d_weight_norm(
130    in_c: usize,
131    out_c: usize,
132    kernel_size: usize,
133    config: candle_nn::Conv1dConfig,
134    vb: VarBuilder,
135) -> Result<Conv1d> {
136    let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
137    let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
138    let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
139    let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
140    let bias = vb.get(out_c, "bias")?;
141    Ok(Conv1d::new(weight, Some(bias), config))
142}
143
144pub fn conv_transpose1d_weight_norm(
145    in_c: usize,
146    out_c: usize,
147    kernel_size: usize,
148    bias: bool,
149    config: candle_nn::ConvTranspose1dConfig,
150    vb: VarBuilder,
151) -> Result<ConvTranspose1d> {
152    let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
153    let weight_v = vb.get((in_c, out_c, kernel_size), "weight_v")?;
154    let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
155    let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
156    let bias = if bias {
157        Some(vb.get(out_c, "bias")?)
158    } else {
159        None
160    };
161    Ok(ConvTranspose1d::new(weight, bias, config))
162}
163
164struct CodebookEncode;
165
166impl candle::CustomOp2 for CodebookEncode {
167    fn name(&self) -> &'static str {
168        "cb"
169    }
170
171    fn cpu_fwd(
172        &self,
173        lhs_storage: &candle::CpuStorage,
174        lhs_layout: &Layout,
175        rhs_storage: &candle::CpuStorage,
176        rhs_layout: &Layout,
177    ) -> Result<(candle::CpuStorage, Shape)> {
178        use rayon::prelude::*;
179
180        let (lhs_dim1, lhs_dim2) = lhs_layout.shape().dims2()?;
181        let (rhs_dim1, rhs_dim2) = rhs_layout.shape().dims2()?;
182        if lhs_dim2 != rhs_dim2 {
183            candle::bail!("CodebookEncode, mismatch on last dim, {lhs_layout:?} {rhs_layout:?}");
184        }
185        if lhs_dim2 == 0 {
186            candle::bail!("CodebookEncode, empty last dim {lhs_layout:?}")
187        }
188        let lhs = match lhs_layout.contiguous_offsets() {
189            None => candle::bail!("CodebookEncode, lhs has to be contiguous, got {lhs_layout:?}"),
190            Some((o1, o2)) => {
191                let slice = lhs_storage.as_slice::<f32>()?;
192                &slice[o1..o2]
193            }
194        };
195        let rhs = match rhs_layout.contiguous_offsets() {
196            None => candle::bail!("CodebookEncode, rhs has to be contiguous, got {rhs_layout:?}"),
197            Some((o1, o2)) => {
198                let slice = rhs_storage.as_slice::<f32>()?;
199                &slice[o1..o2]
200            }
201        };
202        let dst = (0..lhs_dim1)
203            .into_par_iter()
204            .map(|idx1| {
205                let mut where_min = 0;
206                let mut min_dist = f32::INFINITY;
207                let lhs = &lhs[idx1 * lhs_dim2..(idx1 + 1) * lhs_dim2];
208                for idx2 in 0..rhs_dim1 {
209                    let rhs = &rhs[idx2 * rhs_dim2..(idx2 + 1) * rhs_dim2];
210                    let mut dist = 0f32;
211                    for (a, b) in lhs.iter().zip(rhs.iter()) {
212                        dist += (a - b) * (a - b)
213                    }
214                    if dist < min_dist {
215                        min_dist = dist;
216                        where_min = idx2;
217                    }
218                }
219                where_min as u32
220            })
221            .collect();
222        let storage = candle::WithDType::to_cpu_storage_owned(dst);
223        Ok((storage, (lhs_dim1,).into()))
224    }
225}
226
227// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L340
228#[allow(unused)]
229#[derive(Clone, Debug)]
230pub struct EuclideanCodebook {
231    inited: Tensor,
232    cluster_size: Tensor,
233    embed: candle_nn::Embedding,
234    embed_avg: Tensor,
235    c2: Tensor,
236}
237
238impl EuclideanCodebook {
239    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
240        let inited = vb.get(1, "inited")?;
241        let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?;
242        let e_shape = (cfg.codebook_size, cfg.codebook_dim());
243        let embed = vb.get(e_shape, "embed")?;
244        let c2 = ((&embed * &embed)?.sum(D::Minus1)? / 2.0)?;
245        let embed_avg = vb.get(e_shape, "embed_avg")?;
246        Ok(Self {
247            inited,
248            cluster_size,
249            embed: candle_nn::Embedding::new(embed, cfg.codebook_dim()),
250            embed_avg,
251            c2,
252        })
253    }
254
255    pub fn encode_slow(&self, xs: &Tensor) -> Result<Tensor> {
256        let mut target_shape = xs.dims().to_vec();
257        target_shape.pop();
258        let xs = xs.flatten_to(D::Minus2)?;
259        let _ = xs.dims2()?;
260        let dot_prod = xs.matmul(&self.embed.embeddings().t()?)?;
261        let codes = self.c2.broadcast_sub(&dot_prod)?.argmin(D::Minus1)?;
262        codes.reshape(target_shape)
263    }
264
265    pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
266        let mut target_shape = xs.dims().to_vec();
267        target_shape.pop();
268        let xs = xs.flatten_to(D::Minus2)?;
269        let _ = xs.dims2()?;
270        let codes = Tensor::apply_op2(&xs, self.embed.embeddings(), CodebookEncode)?;
271        codes.reshape(target_shape)
272    }
273
274    pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
275        let quantize = self.embed.forward(embed_ind)?;
276        Ok(quantize)
277    }
278}
279
280#[derive(Clone, Debug)]
281pub struct VectorQuantization {
282    codebook: EuclideanCodebook,
283}
284
285impl VectorQuantization {
286    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
287        let codebook = EuclideanCodebook::new(cfg, vb.pp("codebook"))?;
288        Ok(Self { codebook })
289    }
290
291    pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
292        let xs = xs.transpose(1, 2)?;
293        self.codebook.encode_slow(&xs)
294    }
295
296    pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
297        let quantize = self.codebook.decode(embed_ind)?;
298        let quantize = quantize.transpose(1, 2)?;
299        Ok(quantize)
300    }
301}
302
303#[derive(Clone, Debug)]
304pub struct ResidualVectorQuantizer {
305    layers: Vec<VectorQuantization>,
306    dtype: DType,
307}
308
309impl ResidualVectorQuantizer {
310    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
311        let vb = &vb.pp("layers");
312        let layers = (0..cfg.num_quantizers())
313            .map(|i| VectorQuantization::new(cfg, vb.pp(i)))
314            .collect::<Result<Vec<_>>>()?;
315        Ok(Self {
316            layers,
317            dtype: vb.dtype(),
318        })
319    }
320
321    pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
322        let mut codes = Vec::with_capacity(self.layers.len());
323        let mut residual = xs.clone();
324        for layer in self.layers.iter() {
325            let indices = layer.encode(&residual)?;
326            let quantized = layer.decode(&indices)?;
327            residual = (residual - quantized)?;
328            codes.push(indices)
329        }
330        Tensor::stack(&codes, 0)
331    }
332
333    pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
334        let mut quantized_out = Tensor::zeros((), self.dtype, codes.device())?;
335        let ncodes = codes.dim(0)?;
336        if ncodes > self.layers.len() {
337            candle::bail!(
338                "codes shape {:?} does not match the number of quantization layers {}",
339                codes.shape(),
340                self.layers.len()
341            )
342        }
343        for (i, layer) in self.layers.iter().take(ncodes).enumerate() {
344            let quantized = layer.decode(&codes.i(i)?)?;
345            quantized_out = quantized.broadcast_add(&quantized_out)?;
346        }
347        Ok(quantized_out)
348    }
349}
350
351// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
352#[derive(Clone, Debug)]
353pub struct EncodecLSTM {
354    layers: Vec<candle_nn::LSTM>,
355}
356
357impl EncodecLSTM {
358    pub fn new(dim: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
359        let vb = &vb.pp("lstm");
360        let mut layers = vec![];
361        for layer_idx in 0..cfg.num_lstm_layers {
362            let config = candle_nn::LSTMConfig {
363                layer_idx,
364                ..Default::default()
365            };
366            let lstm = candle_nn::lstm(dim, dim, config, vb.clone())?;
367            layers.push(lstm)
368        }
369        Ok(Self { layers })
370    }
371}
372
373impl Module for EncodecLSTM {
374    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
375        use candle_nn::RNN;
376        // This is different from the Python transformers version as candle LSTM is batch first.
377        let xs = xs.t()?;
378        let residual = &xs;
379        let mut xs = xs.clone();
380        for layer in self.layers.iter() {
381            let states = layer.seq(&xs)?;
382            xs = layer.states_to_tensor(&states)?;
383        }
384        let xs = (xs + residual)?.t()?;
385        Ok(xs)
386    }
387}
388
389#[derive(Clone, Debug)]
390pub struct EncodecConvTranspose1d {
391    conv: ConvTranspose1d,
392}
393
394impl EncodecConvTranspose1d {
395    fn new(
396        in_c: usize,
397        out_c: usize,
398        k: usize,
399        stride: usize,
400        _cfg: &Config,
401        vb: VarBuilder,
402    ) -> Result<Self> {
403        let cfg = candle_nn::ConvTranspose1dConfig {
404            stride,
405            ..Default::default()
406        };
407        let conv = conv_transpose1d_weight_norm(in_c, out_c, k, true, cfg, vb.pp("conv"))?;
408        Ok(Self { conv })
409    }
410}
411
412impl Module for EncodecConvTranspose1d {
413    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
414        xs.apply(&self.conv)
415    }
416}
417
418#[derive(Clone, Debug)]
419pub struct EncodecConv1d {
420    causal: bool,
421    conv: Conv1d,
422    norm: Option<candle_nn::GroupNorm>,
423    pad_mode: PadMode,
424}
425
426impl EncodecConv1d {
427    pub fn new(
428        in_c: usize,
429        out_c: usize,
430        kernel_size: usize,
431        stride: usize,
432        dilation: usize,
433        cfg: &Config,
434        vb: VarBuilder,
435    ) -> Result<Self> {
436        let conv = match cfg.norm_type {
437            NormType::WeightNorm => conv1d_weight_norm(
438                in_c,
439                out_c,
440                kernel_size,
441                candle_nn::Conv1dConfig {
442                    stride,
443                    dilation,
444                    ..Default::default()
445                },
446                vb.pp("conv"),
447            )?,
448            NormType::None | NormType::TimeGroupNorm => conv1d(
449                in_c,
450                out_c,
451                kernel_size,
452                candle_nn::Conv1dConfig {
453                    padding: 0,
454                    stride,
455                    groups: 1,
456                    dilation: 1,
457                },
458                vb.pp("conv"),
459            )?,
460        };
461        let norm = match cfg.norm_type {
462            NormType::None | NormType::WeightNorm => None,
463            NormType::TimeGroupNorm => {
464                let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
465                Some(gn)
466            }
467        };
468        Ok(Self {
469            causal: cfg.use_causal_conv,
470            conv,
471            norm,
472            pad_mode: cfg.pad_mode,
473        })
474    }
475}
476
477impl Module for EncodecConv1d {
478    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
479        let (_b, _t, _c) = xs.dims3()?;
480        let k_size = self.conv.weight().dim(D::Minus1)?;
481        let conv_cfg = self.conv.config();
482        // Effective kernel size with dilations.
483        let k_size = (k_size - 1) * conv_cfg.dilation + 1;
484        let padding_total = k_size - conv_cfg.stride;
485        let extra_padding =
486            get_extra_padding_for_conv1d(xs, k_size, conv_cfg.stride, padding_total)?;
487        let xs = if self.causal {
488            pad1d(xs, padding_total, extra_padding, self.pad_mode)?
489        } else {
490            let padding_right = padding_total / 2;
491            let padding_left = padding_total - padding_right;
492            pad1d(
493                xs,
494                padding_left,
495                padding_right + extra_padding,
496                self.pad_mode,
497            )?
498        };
499        let xs = self.conv.forward(&xs)?;
500        match &self.norm {
501            None => Ok(xs),
502            Some(norm) => xs.apply(norm),
503        }
504    }
505}
506
507#[derive(Clone, Debug)]
508pub struct EncodecResnetBlock {
509    block_conv1: EncodecConv1d,
510    block_conv2: EncodecConv1d,
511    shortcut: Option<EncodecConv1d>,
512}
513
514impl EncodecResnetBlock {
515    pub fn new(
516        dim: usize,
517        (dilation1, dilation2): (usize, usize),
518        cfg: &Config,
519        vb: VarBuilder,
520    ) -> Result<Self> {
521        let h = dim / cfg.compress;
522        let mut layer = Layer::new(vb.pp("block"));
523        // TODO: Apply dilations!
524        layer.inc();
525        let block_conv1 = EncodecConv1d::new(
526            dim,
527            h,
528            cfg.residual_kernel_size,
529            1,
530            dilation1,
531            cfg,
532            layer.next(),
533        )?;
534        layer.inc();
535        let block_conv2 = EncodecConv1d::new(h, dim, 1, 1, dilation2, cfg, layer.next())?;
536        let shortcut = if cfg.use_conv_shortcut {
537            let conv = EncodecConv1d::new(dim, dim, 1, 1, 1, cfg, vb.pp("shortcut"))?;
538            Some(conv)
539        } else {
540            None
541        };
542        Ok(Self {
543            block_conv1,
544            block_conv2,
545            shortcut,
546        })
547    }
548}
549
550impl Module for EncodecResnetBlock {
551    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
552        let residual = xs.clone();
553        let xs = xs.elu(1.)?;
554        let xs = self.block_conv1.forward(&xs)?;
555        let xs = xs.elu(1.)?;
556        let xs = self.block_conv2.forward(&xs)?;
557        let xs = match &self.shortcut {
558            None => (xs + residual)?,
559            Some(shortcut) => xs.add(&shortcut.forward(&residual)?)?,
560        };
561        Ok(xs)
562    }
563}
564
565struct Layer<'a> {
566    vb: VarBuilder<'a>,
567    cnt: usize,
568}
569
570impl<'a> Layer<'a> {
571    fn new(vb: VarBuilder<'a>) -> Self {
572        Self { vb, cnt: 0 }
573    }
574
575    fn inc(&mut self) {
576        self.cnt += 1;
577    }
578
579    fn next(&mut self) -> VarBuilder {
580        let vb = self.vb.pp(self.cnt.to_string());
581        self.cnt += 1;
582        vb
583    }
584}
585
586#[derive(Clone, Debug)]
587pub struct Encoder {
588    init_conv: EncodecConv1d,
589    sampling_layers: Vec<(Vec<EncodecResnetBlock>, EncodecConv1d)>,
590    final_lstm: EncodecLSTM,
591    final_conv: EncodecConv1d,
592}
593
594impl Encoder {
595    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
596        let mut layer = Layer::new(vb.pp("layers"));
597        let init_conv = EncodecConv1d::new(
598            cfg.audio_channels,
599            cfg.num_filters,
600            cfg.kernel_size,
601            1,
602            1,
603            cfg,
604            layer.next(),
605        )?;
606        let mut sampling_layers = vec![];
607        let mut scaling = 1;
608        for &ratio in cfg.upsampling_ratios.iter().rev() {
609            let current_scale = scaling * cfg.num_filters;
610            let mut resnets = vec![];
611            for j in 0..(cfg.num_residual_layers as u32) {
612                let resnet = EncodecResnetBlock::new(
613                    current_scale,
614                    (cfg.dilation_growth_rate.pow(j), 1),
615                    cfg,
616                    layer.next(),
617                )?;
618                resnets.push(resnet)
619            }
620            layer.inc(); // ELU
621            let conv1d = EncodecConv1d::new(
622                current_scale,
623                current_scale * 2,
624                ratio * 2,
625                ratio,
626                1,
627                cfg,
628                layer.next(),
629            )?;
630            sampling_layers.push((resnets, conv1d));
631            scaling *= 2;
632        }
633        let final_lstm = EncodecLSTM::new(cfg.num_filters * scaling, cfg, layer.next())?;
634        layer.inc(); // ELU
635        let final_conv = EncodecConv1d::new(
636            cfg.num_filters * scaling,
637            cfg.hidden_size,
638            cfg.last_kernel_size,
639            1,
640            1,
641            cfg,
642            layer.next(),
643        )?;
644        Ok(Self {
645            init_conv,
646            sampling_layers,
647            final_conv,
648            final_lstm,
649        })
650    }
651}
652
653impl Module for Encoder {
654    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
655        let mut xs = xs.apply(&self.init_conv)?;
656        for (resnets, conv) in self.sampling_layers.iter() {
657            for resnet in resnets.iter() {
658                xs = xs.apply(resnet)?;
659            }
660            xs = xs.elu(1.0)?.apply(conv)?;
661        }
662        xs.apply(&self.final_lstm)?
663            .elu(1.0)?
664            .apply(&self.final_conv)
665    }
666}
667
668#[derive(Clone, Debug)]
669pub struct Decoder {
670    init_conv: EncodecConv1d,
671    init_lstm: EncodecLSTM,
672    sampling_layers: Vec<(EncodecConvTranspose1d, Vec<EncodecResnetBlock>)>,
673    final_conv: EncodecConv1d,
674}
675
676impl Decoder {
677    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
678        let mut layer = Layer::new(vb.pp("layers"));
679        let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32);
680        let init_conv = EncodecConv1d::new(
681            cfg.hidden_size,
682            cfg.num_filters * scaling,
683            cfg.last_kernel_size,
684            1,
685            1,
686            cfg,
687            layer.next(),
688        )?;
689        let init_lstm = EncodecLSTM::new(cfg.num_filters * scaling, cfg, layer.next())?;
690        let mut sampling_layers = vec![];
691        for &ratio in cfg.upsampling_ratios.iter() {
692            let current_scale = scaling * cfg.num_filters;
693            layer.inc(); // ELU
694            let conv1d = EncodecConvTranspose1d::new(
695                current_scale,
696                current_scale / 2,
697                ratio * 2,
698                ratio,
699                cfg,
700                layer.next(),
701            )?;
702            let mut resnets = vec![];
703            for j in 0..(cfg.num_residual_layers as u32) {
704                let resnet = EncodecResnetBlock::new(
705                    current_scale / 2,
706                    (cfg.dilation_growth_rate.pow(j), 1),
707                    cfg,
708                    layer.next(),
709                )?;
710                resnets.push(resnet)
711            }
712            sampling_layers.push((conv1d, resnets));
713            scaling /= 2;
714        }
715        layer.inc(); // ELU
716        let final_conv = EncodecConv1d::new(
717            cfg.num_filters,
718            cfg.audio_channels,
719            cfg.last_kernel_size,
720            1,
721            1,
722            cfg,
723            layer.next(),
724        )?;
725        Ok(Self {
726            init_conv,
727            init_lstm,
728            sampling_layers,
729            final_conv,
730        })
731    }
732}
733
734impl Module for Decoder {
735    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
736        let mut xs = xs.apply(&self.init_conv)?.apply(&self.init_lstm)?;
737        for (conv, resnets) in self.sampling_layers.iter() {
738            xs = xs.elu(1.)?.apply(conv)?;
739            for resnet in resnets.iter() {
740                xs = xs.apply(resnet)?
741            }
742        }
743        xs.elu(1.)?.apply(&self.final_conv)
744    }
745}
746
747#[derive(Debug)]
748pub struct Model {
749    encoder: Encoder,
750    decoder: Decoder,
751    quantizer: ResidualVectorQuantizer,
752}
753
754impl Model {
755    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
756        let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
757        let decoder = Decoder::new(cfg, vb.pp("decoder"))?;
758        let quantizer = ResidualVectorQuantizer::new(cfg, vb.pp("quantizer"))?;
759        Ok(Self {
760            encoder,
761            decoder,
762            quantizer,
763        })
764    }
765
766    pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
767        let xs = self.encoder.forward(xs)?;
768        let codes = self.quantizer.encode(&xs)?;
769        codes.transpose(0, 1)
770    }
771
772    pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
773        let (_b_sz, _codebooks, _seqlen) = codes.dims3()?;
774        let codes = codes.transpose(0, 1)?;
775        let embeddings = self.quantizer.decode(&codes)?;
776        let outputs = self.decoder.forward(&embeddings)?;
777        Ok(outputs)
778    }
779}