candle_transformers/models/
dac.rs

1//! Implementation of the Descript Audio Codec (DAC) model
2//!
3//! See: [Descript Audio Codec](https://github.com/descriptinc/descript-audio-codec)
4//!
5/// An efficient neural codec for compressing/decompressing audio
6///
7use crate::models::encodec;
8use candle::{IndexOp, Result, Tensor, D};
9use candle_nn::{Conv1d, Conv1dConfig, ConvTranspose1d, ConvTranspose1dConfig, VarBuilder};
10
11#[derive(serde::Deserialize, Debug, Clone)]
12pub struct Config {
13    pub num_codebooks: usize,
14    pub model_bitrate: u32,
15    pub codebook_size: usize,
16    pub latent_dim: usize,
17    pub frame_rate: u32,
18    pub sampling_rate: u32,
19}
20
21#[derive(Debug, Clone)]
22pub struct Snake1d {
23    alpha: Tensor,
24}
25
26impl Snake1d {
27    pub fn new(channels: usize, vb: VarBuilder) -> Result<Self> {
28        let alpha = vb.get((1, channels, 1), "alpha")?;
29        Ok(Self { alpha })
30    }
31}
32
33impl candle::Module for Snake1d {
34    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
35        let xs_shape = xs.shape();
36        let xs = xs.flatten_from(2)?;
37        let sin = self.alpha.broadcast_mul(&xs)?.sin()?;
38        let sin = (&sin * &sin)?;
39        (xs + (&self.alpha + 1e-9)?.recip()?.broadcast_mul(&sin)?)?.reshape(xs_shape)
40    }
41}
42
43#[derive(Debug, Clone)]
44pub struct ResidualUnit {
45    snake1: Snake1d,
46    conv1: Conv1d,
47    snake2: Snake1d,
48    conv2: Conv1d,
49}
50
51impl ResidualUnit {
52    pub fn new(dim: usize, dilation: usize, vb: VarBuilder) -> Result<Self> {
53        let pad = ((7 - 1) * dilation) / 2;
54        let vb = vb.pp("block");
55        let snake1 = Snake1d::new(dim, vb.pp(0))?;
56        let cfg1 = Conv1dConfig {
57            dilation,
58            padding: pad,
59            ..Default::default()
60        };
61        let conv1 = encodec::conv1d_weight_norm(dim, dim, 7, cfg1, vb.pp(1))?;
62        let snake2 = Snake1d::new(dim, vb.pp(2))?;
63        let conv2 = encodec::conv1d_weight_norm(dim, dim, 1, Default::default(), vb.pp(3))?;
64        Ok(Self {
65            snake1,
66            conv1,
67            snake2,
68            conv2,
69        })
70    }
71}
72
73impl candle::Module for ResidualUnit {
74    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
75        let ys = xs
76            .apply(&self.snake1)?
77            .apply(&self.conv1)?
78            .apply(&self.snake2)?
79            .apply(&self.conv2)?;
80        let pad = (xs.dim(D::Minus1)? - ys.dim(D::Minus1)?) / 2;
81        if pad > 0 {
82            &ys + xs.narrow(D::Minus1, pad, ys.dim(D::Minus1)?)
83        } else {
84            ys + xs
85        }
86    }
87}
88
89#[derive(Debug, Clone)]
90pub struct EncoderBlock {
91    res1: ResidualUnit,
92    res2: ResidualUnit,
93    res3: ResidualUnit,
94    snake1: Snake1d,
95    conv1: Conv1d,
96}
97
98impl EncoderBlock {
99    pub fn new(dim: usize, stride: usize, vb: VarBuilder) -> Result<Self> {
100        let vb = vb.pp("block");
101        let res1 = ResidualUnit::new(dim / 2, 1, vb.pp(0))?;
102        let res2 = ResidualUnit::new(dim / 2, 3, vb.pp(1))?;
103        let res3 = ResidualUnit::new(dim / 2, 9, vb.pp(2))?;
104        let snake1 = Snake1d::new(dim / 2, vb.pp(3))?;
105        let cfg1 = Conv1dConfig {
106            stride,
107            padding: (stride + 1) / 2,
108            ..Default::default()
109        };
110        let conv1 = encodec::conv1d_weight_norm(dim / 2, dim, 2 * stride, cfg1, vb.pp(4))?;
111        Ok(Self {
112            res1,
113            res2,
114            res3,
115            snake1,
116            conv1,
117        })
118    }
119}
120
121impl candle::Module for EncoderBlock {
122    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
123        xs.apply(&self.res1)?
124            .apply(&self.res2)?
125            .apply(&self.res3)?
126            .apply(&self.snake1)?
127            .apply(&self.conv1)
128    }
129}
130
131#[derive(Debug, Clone)]
132pub struct Encoder {
133    conv1: Conv1d,
134    blocks: Vec<EncoderBlock>,
135    snake1: Snake1d,
136    conv2: Conv1d,
137}
138
139impl candle::Module for Encoder {
140    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
141        let mut xs = xs.apply(&self.conv1)?;
142        for block in self.blocks.iter() {
143            xs = xs.apply(block)?
144        }
145        xs.apply(&self.snake1)?.apply(&self.conv2)
146    }
147}
148
149impl Encoder {
150    pub fn new(
151        mut d_model: usize,
152        strides: &[usize],
153        d_latent: usize,
154        vb: VarBuilder,
155    ) -> Result<Self> {
156        let vb = vb.pp("block");
157        let cfg1 = Conv1dConfig {
158            padding: 3,
159            ..Default::default()
160        };
161        let conv1 = encodec::conv1d_weight_norm(1, d_model, 7, cfg1, vb.pp(0))?;
162        let mut blocks = Vec::with_capacity(strides.len());
163        for (block_idx, stride) in strides.iter().enumerate() {
164            d_model *= 2;
165            let block = EncoderBlock::new(d_model, *stride, vb.pp(block_idx + 1))?;
166            blocks.push(block)
167        }
168        let snake1 = Snake1d::new(d_model, vb.pp(strides.len() + 1))?;
169        let cfg2 = Conv1dConfig {
170            padding: 1,
171            ..Default::default()
172        };
173        let conv2 =
174            encodec::conv1d_weight_norm(d_model, d_latent, 3, cfg2, vb.pp(strides.len() + 2))?;
175        Ok(Self {
176            conv1,
177            blocks,
178            snake1,
179            conv2,
180        })
181    }
182}
183
184#[derive(Debug, Clone)]
185pub struct DecoderBlock {
186    snake1: Snake1d,
187    conv_tr1: ConvTranspose1d,
188    res1: ResidualUnit,
189    res2: ResidualUnit,
190    res3: ResidualUnit,
191}
192
193impl DecoderBlock {
194    pub fn new(in_dim: usize, out_dim: usize, stride: usize, vb: VarBuilder) -> Result<Self> {
195        let vb = vb.pp("block");
196        let snake1 = Snake1d::new(in_dim, vb.pp(0))?;
197        let cfg = ConvTranspose1dConfig {
198            stride,
199            padding: (stride + 1) / 2,
200            ..Default::default()
201        };
202        let conv_tr1 = encodec::conv_transpose1d_weight_norm(
203            in_dim,
204            out_dim,
205            2 * stride,
206            true,
207            cfg,
208            vb.pp(1),
209        )?;
210        let res1 = ResidualUnit::new(out_dim, 1, vb.pp(2))?;
211        let res2 = ResidualUnit::new(out_dim, 3, vb.pp(3))?;
212        let res3 = ResidualUnit::new(out_dim, 9, vb.pp(4))?;
213        Ok(Self {
214            snake1,
215            conv_tr1,
216            res1,
217            res2,
218            res3,
219        })
220    }
221}
222
223impl candle_nn::Module for DecoderBlock {
224    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
225        xs.apply(&self.snake1)?
226            .apply(&self.conv_tr1)?
227            .apply(&self.res1)?
228            .apply(&self.res2)?
229            .apply(&self.res3)
230    }
231}
232
233#[derive(Debug, Clone)]
234pub struct Decoder {
235    conv1: Conv1d,
236    blocks: Vec<DecoderBlock>,
237    snake1: Snake1d,
238    conv2: Conv1d,
239}
240
241impl Decoder {
242    pub fn new(
243        in_c: usize,
244        mut channels: usize,
245        rates: &[usize],
246        d_out: usize,
247        vb: VarBuilder,
248    ) -> Result<Self> {
249        let vb = vb.pp("model");
250        let cfg1 = Conv1dConfig {
251            padding: 3,
252            ..Default::default()
253        };
254        let conv1 = encodec::conv1d_weight_norm(in_c, channels, 7, cfg1, vb.pp(0))?;
255        let mut blocks = Vec::with_capacity(rates.len());
256        for (idx, stride) in rates.iter().enumerate() {
257            let block = DecoderBlock::new(channels, channels / 2, *stride, vb.pp(idx + 1))?;
258            channels /= 2;
259            blocks.push(block)
260        }
261        let snake1 = Snake1d::new(channels, vb.pp(rates.len() + 1))?;
262        let conv2 = encodec::conv1d_weight_norm(channels, d_out, 7, cfg1, vb.pp(rates.len() + 2))?;
263        Ok(Self {
264            conv1,
265            blocks,
266            snake1,
267            conv2,
268        })
269    }
270}
271
272impl candle::Module for Decoder {
273    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
274        let mut xs = xs.apply(&self.conv1)?;
275        for block in self.blocks.iter() {
276            xs = xs.apply(block)?
277        }
278        xs.apply(&self.snake1)?.apply(&self.conv2)
279    }
280}
281
282#[allow(unused)]
283#[derive(Clone, Debug)]
284pub struct VectorQuantizer {
285    in_proj: Conv1d,
286    out_proj: Conv1d,
287    codebook: candle_nn::Embedding,
288}
289
290impl VectorQuantizer {
291    pub fn new(in_dim: usize, cb_size: usize, cb_dim: usize, vb: VarBuilder) -> Result<Self> {
292        let in_proj =
293            encodec::conv1d_weight_norm(in_dim, cb_dim, 1, Default::default(), vb.pp("in_proj"))?;
294        let out_proj =
295            encodec::conv1d_weight_norm(cb_dim, in_dim, 1, Default::default(), vb.pp("out_proj"))?;
296        let codebook = candle_nn::embedding(cb_size, cb_dim, vb.pp("codebook"))?;
297        Ok(Self {
298            in_proj,
299            out_proj,
300            codebook,
301        })
302    }
303
304    pub fn embed_code(&self, embed_id: &Tensor) -> Result<Tensor> {
305        embed_id.apply(&self.codebook)
306    }
307
308    pub fn decode_code(&self, embed_id: &Tensor) -> Result<Tensor> {
309        self.embed_code(embed_id)?.transpose(1, 2)
310    }
311}
312
313#[derive(Clone, Debug)]
314pub struct ResidualVectorQuantizer {
315    quantizers: Vec<VectorQuantizer>,
316}
317
318impl ResidualVectorQuantizer {
319    pub fn new(
320        input_dim: usize,
321        n_codebooks: usize,
322        cb_size: usize,
323        cb_dim: usize,
324        vb: VarBuilder,
325    ) -> Result<Self> {
326        let vb = &vb.pp("quantizers");
327        let quantizers = (0..n_codebooks)
328            .map(|i| VectorQuantizer::new(input_dim, cb_size, cb_dim, vb.pp(i)))
329            .collect::<Result<Vec<_>>>()?;
330        Ok(Self { quantizers })
331    }
332
333    pub fn from_codes(&self, codes: &Tensor) -> Result<Tensor> {
334        let mut sum = None;
335        for (idx, quantizer) in self.quantizers.iter().enumerate() {
336            let z_p_i = quantizer.decode_code(&codes.i((.., idx))?)?;
337            let z_q_i = z_p_i.apply(&quantizer.out_proj)?;
338            let s = match sum {
339                None => z_q_i,
340                Some(s) => (s + z_q_i)?,
341            };
342            sum = Some(s)
343        }
344        match sum {
345            Some(s) => Ok(s),
346            None => candle::bail!("empty codebooks"),
347        }
348    }
349}
350
351#[derive(Debug, Clone)]
352pub struct Model {
353    pub encoder: Encoder,
354    pub quantizer: ResidualVectorQuantizer,
355    pub decoder: Decoder,
356}
357
358impl Model {
359    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
360        let vb = vb.pp("model");
361        let encoder = Encoder::new(64, &[2, 4, 8, 8], cfg.latent_dim, vb.pp("encoder"))?;
362        let quantizer = ResidualVectorQuantizer::new(
363            cfg.latent_dim,
364            cfg.num_codebooks,
365            cfg.codebook_size,
366            8,
367            vb.pp("quantizer"),
368        )?;
369        let decoder = Decoder::new(cfg.latent_dim, 1536, &[8, 8, 4, 2], 1, vb.pp("decoder"))?;
370        Ok(Self {
371            encoder,
372            decoder,
373            quantizer,
374        })
375    }
376
377    pub fn decode_codes(&self, audio_codes: &Tensor) -> Result<Tensor> {
378        let audio_values = self.quantizer.from_codes(audio_codes)?;
379        audio_values.apply(&self.decoder)
380    }
381}