1use crate::models::encodec;
8use candle::{IndexOp, Result, Tensor, D};
9use candle_nn::{Conv1d, Conv1dConfig, ConvTranspose1d, ConvTranspose1dConfig, VarBuilder};
10
11#[derive(serde::Deserialize, Debug, Clone)]
12pub struct Config {
13 pub num_codebooks: usize,
14 pub model_bitrate: u32,
15 pub codebook_size: usize,
16 pub latent_dim: usize,
17 pub frame_rate: u32,
18 pub sampling_rate: u32,
19}
20
21#[derive(Debug, Clone)]
22pub struct Snake1d {
23 alpha: Tensor,
24}
25
26impl Snake1d {
27 pub fn new(channels: usize, vb: VarBuilder) -> Result<Self> {
28 let alpha = vb.get((1, channels, 1), "alpha")?;
29 Ok(Self { alpha })
30 }
31}
32
33impl candle::Module for Snake1d {
34 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
35 let xs_shape = xs.shape();
36 let xs = xs.flatten_from(2)?;
37 let sin = self.alpha.broadcast_mul(&xs)?.sin()?;
38 let sin = (&sin * &sin)?;
39 (xs + (&self.alpha + 1e-9)?.recip()?.broadcast_mul(&sin)?)?.reshape(xs_shape)
40 }
41}
42
43#[derive(Debug, Clone)]
44pub struct ResidualUnit {
45 snake1: Snake1d,
46 conv1: Conv1d,
47 snake2: Snake1d,
48 conv2: Conv1d,
49}
50
51impl ResidualUnit {
52 pub fn new(dim: usize, dilation: usize, vb: VarBuilder) -> Result<Self> {
53 let pad = ((7 - 1) * dilation) / 2;
54 let vb = vb.pp("block");
55 let snake1 = Snake1d::new(dim, vb.pp(0))?;
56 let cfg1 = Conv1dConfig {
57 dilation,
58 padding: pad,
59 ..Default::default()
60 };
61 let conv1 = encodec::conv1d_weight_norm(dim, dim, 7, cfg1, vb.pp(1))?;
62 let snake2 = Snake1d::new(dim, vb.pp(2))?;
63 let conv2 = encodec::conv1d_weight_norm(dim, dim, 1, Default::default(), vb.pp(3))?;
64 Ok(Self {
65 snake1,
66 conv1,
67 snake2,
68 conv2,
69 })
70 }
71}
72
73impl candle::Module for ResidualUnit {
74 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
75 let ys = xs
76 .apply(&self.snake1)?
77 .apply(&self.conv1)?
78 .apply(&self.snake2)?
79 .apply(&self.conv2)?;
80 let pad = (xs.dim(D::Minus1)? - ys.dim(D::Minus1)?) / 2;
81 if pad > 0 {
82 &ys + xs.narrow(D::Minus1, pad, ys.dim(D::Minus1)?)
83 } else {
84 ys + xs
85 }
86 }
87}
88
89#[derive(Debug, Clone)]
90pub struct EncoderBlock {
91 res1: ResidualUnit,
92 res2: ResidualUnit,
93 res3: ResidualUnit,
94 snake1: Snake1d,
95 conv1: Conv1d,
96}
97
98impl EncoderBlock {
99 pub fn new(dim: usize, stride: usize, vb: VarBuilder) -> Result<Self> {
100 let vb = vb.pp("block");
101 let res1 = ResidualUnit::new(dim / 2, 1, vb.pp(0))?;
102 let res2 = ResidualUnit::new(dim / 2, 3, vb.pp(1))?;
103 let res3 = ResidualUnit::new(dim / 2, 9, vb.pp(2))?;
104 let snake1 = Snake1d::new(dim / 2, vb.pp(3))?;
105 let cfg1 = Conv1dConfig {
106 stride,
107 padding: (stride + 1) / 2,
108 ..Default::default()
109 };
110 let conv1 = encodec::conv1d_weight_norm(dim / 2, dim, 2 * stride, cfg1, vb.pp(4))?;
111 Ok(Self {
112 res1,
113 res2,
114 res3,
115 snake1,
116 conv1,
117 })
118 }
119}
120
121impl candle::Module for EncoderBlock {
122 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
123 xs.apply(&self.res1)?
124 .apply(&self.res2)?
125 .apply(&self.res3)?
126 .apply(&self.snake1)?
127 .apply(&self.conv1)
128 }
129}
130
131#[derive(Debug, Clone)]
132pub struct Encoder {
133 conv1: Conv1d,
134 blocks: Vec<EncoderBlock>,
135 snake1: Snake1d,
136 conv2: Conv1d,
137}
138
139impl candle::Module for Encoder {
140 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
141 let mut xs = xs.apply(&self.conv1)?;
142 for block in self.blocks.iter() {
143 xs = xs.apply(block)?
144 }
145 xs.apply(&self.snake1)?.apply(&self.conv2)
146 }
147}
148
149impl Encoder {
150 pub fn new(
151 mut d_model: usize,
152 strides: &[usize],
153 d_latent: usize,
154 vb: VarBuilder,
155 ) -> Result<Self> {
156 let vb = vb.pp("block");
157 let cfg1 = Conv1dConfig {
158 padding: 3,
159 ..Default::default()
160 };
161 let conv1 = encodec::conv1d_weight_norm(1, d_model, 7, cfg1, vb.pp(0))?;
162 let mut blocks = Vec::with_capacity(strides.len());
163 for (block_idx, stride) in strides.iter().enumerate() {
164 d_model *= 2;
165 let block = EncoderBlock::new(d_model, *stride, vb.pp(block_idx + 1))?;
166 blocks.push(block)
167 }
168 let snake1 = Snake1d::new(d_model, vb.pp(strides.len() + 1))?;
169 let cfg2 = Conv1dConfig {
170 padding: 1,
171 ..Default::default()
172 };
173 let conv2 =
174 encodec::conv1d_weight_norm(d_model, d_latent, 3, cfg2, vb.pp(strides.len() + 2))?;
175 Ok(Self {
176 conv1,
177 blocks,
178 snake1,
179 conv2,
180 })
181 }
182}
183
184#[derive(Debug, Clone)]
185pub struct DecoderBlock {
186 snake1: Snake1d,
187 conv_tr1: ConvTranspose1d,
188 res1: ResidualUnit,
189 res2: ResidualUnit,
190 res3: ResidualUnit,
191}
192
193impl DecoderBlock {
194 pub fn new(in_dim: usize, out_dim: usize, stride: usize, vb: VarBuilder) -> Result<Self> {
195 let vb = vb.pp("block");
196 let snake1 = Snake1d::new(in_dim, vb.pp(0))?;
197 let cfg = ConvTranspose1dConfig {
198 stride,
199 padding: (stride + 1) / 2,
200 ..Default::default()
201 };
202 let conv_tr1 = encodec::conv_transpose1d_weight_norm(
203 in_dim,
204 out_dim,
205 2 * stride,
206 true,
207 cfg,
208 vb.pp(1),
209 )?;
210 let res1 = ResidualUnit::new(out_dim, 1, vb.pp(2))?;
211 let res2 = ResidualUnit::new(out_dim, 3, vb.pp(3))?;
212 let res3 = ResidualUnit::new(out_dim, 9, vb.pp(4))?;
213 Ok(Self {
214 snake1,
215 conv_tr1,
216 res1,
217 res2,
218 res3,
219 })
220 }
221}
222
223impl candle_nn::Module for DecoderBlock {
224 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
225 xs.apply(&self.snake1)?
226 .apply(&self.conv_tr1)?
227 .apply(&self.res1)?
228 .apply(&self.res2)?
229 .apply(&self.res3)
230 }
231}
232
233#[derive(Debug, Clone)]
234pub struct Decoder {
235 conv1: Conv1d,
236 blocks: Vec<DecoderBlock>,
237 snake1: Snake1d,
238 conv2: Conv1d,
239}
240
241impl Decoder {
242 pub fn new(
243 in_c: usize,
244 mut channels: usize,
245 rates: &[usize],
246 d_out: usize,
247 vb: VarBuilder,
248 ) -> Result<Self> {
249 let vb = vb.pp("model");
250 let cfg1 = Conv1dConfig {
251 padding: 3,
252 ..Default::default()
253 };
254 let conv1 = encodec::conv1d_weight_norm(in_c, channels, 7, cfg1, vb.pp(0))?;
255 let mut blocks = Vec::with_capacity(rates.len());
256 for (idx, stride) in rates.iter().enumerate() {
257 let block = DecoderBlock::new(channels, channels / 2, *stride, vb.pp(idx + 1))?;
258 channels /= 2;
259 blocks.push(block)
260 }
261 let snake1 = Snake1d::new(channels, vb.pp(rates.len() + 1))?;
262 let conv2 = encodec::conv1d_weight_norm(channels, d_out, 7, cfg1, vb.pp(rates.len() + 2))?;
263 Ok(Self {
264 conv1,
265 blocks,
266 snake1,
267 conv2,
268 })
269 }
270}
271
272impl candle::Module for Decoder {
273 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
274 let mut xs = xs.apply(&self.conv1)?;
275 for block in self.blocks.iter() {
276 xs = xs.apply(block)?
277 }
278 xs.apply(&self.snake1)?.apply(&self.conv2)
279 }
280}
281
282#[allow(unused)]
283#[derive(Clone, Debug)]
284pub struct VectorQuantizer {
285 in_proj: Conv1d,
286 out_proj: Conv1d,
287 codebook: candle_nn::Embedding,
288}
289
290impl VectorQuantizer {
291 pub fn new(in_dim: usize, cb_size: usize, cb_dim: usize, vb: VarBuilder) -> Result<Self> {
292 let in_proj =
293 encodec::conv1d_weight_norm(in_dim, cb_dim, 1, Default::default(), vb.pp("in_proj"))?;
294 let out_proj =
295 encodec::conv1d_weight_norm(cb_dim, in_dim, 1, Default::default(), vb.pp("out_proj"))?;
296 let codebook = candle_nn::embedding(cb_size, cb_dim, vb.pp("codebook"))?;
297 Ok(Self {
298 in_proj,
299 out_proj,
300 codebook,
301 })
302 }
303
304 pub fn embed_code(&self, embed_id: &Tensor) -> Result<Tensor> {
305 embed_id.apply(&self.codebook)
306 }
307
308 pub fn decode_code(&self, embed_id: &Tensor) -> Result<Tensor> {
309 self.embed_code(embed_id)?.transpose(1, 2)
310 }
311}
312
313#[derive(Clone, Debug)]
314pub struct ResidualVectorQuantizer {
315 quantizers: Vec<VectorQuantizer>,
316}
317
318impl ResidualVectorQuantizer {
319 pub fn new(
320 input_dim: usize,
321 n_codebooks: usize,
322 cb_size: usize,
323 cb_dim: usize,
324 vb: VarBuilder,
325 ) -> Result<Self> {
326 let vb = &vb.pp("quantizers");
327 let quantizers = (0..n_codebooks)
328 .map(|i| VectorQuantizer::new(input_dim, cb_size, cb_dim, vb.pp(i)))
329 .collect::<Result<Vec<_>>>()?;
330 Ok(Self { quantizers })
331 }
332
333 pub fn from_codes(&self, codes: &Tensor) -> Result<Tensor> {
334 let mut sum = None;
335 for (idx, quantizer) in self.quantizers.iter().enumerate() {
336 let z_p_i = quantizer.decode_code(&codes.i((.., idx))?)?;
337 let z_q_i = z_p_i.apply(&quantizer.out_proj)?;
338 let s = match sum {
339 None => z_q_i,
340 Some(s) => (s + z_q_i)?,
341 };
342 sum = Some(s)
343 }
344 match sum {
345 Some(s) => Ok(s),
346 None => candle::bail!("empty codebooks"),
347 }
348 }
349}
350
351#[derive(Debug, Clone)]
352pub struct Model {
353 pub encoder: Encoder,
354 pub quantizer: ResidualVectorQuantizer,
355 pub decoder: Decoder,
356}
357
358impl Model {
359 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
360 let vb = vb.pp("model");
361 let encoder = Encoder::new(64, &[2, 4, 8, 8], cfg.latent_dim, vb.pp("encoder"))?;
362 let quantizer = ResidualVectorQuantizer::new(
363 cfg.latent_dim,
364 cfg.num_codebooks,
365 cfg.codebook_size,
366 8,
367 vb.pp("quantizer"),
368 )?;
369 let decoder = Decoder::new(cfg.latent_dim, 1536, &[8, 8, 4, 2], 1, vb.pp("decoder"))?;
370 Ok(Self {
371 encoder,
372 decoder,
373 quantizer,
374 })
375 }
376
377 pub fn decode_codes(&self, audio_codes: &Tensor) -> Result<Tensor> {
378 let audio_values = self.quantizer.from_codes(audio_codes)?;
379 audio_values.apply(&self.decoder)
380 }
381}