1use candle::{DType, IndexOp, Layout, Module, Result, Shape, Tensor, D};
8use candle_nn::{conv1d, Conv1d, ConvTranspose1d, VarBuilder};
9
10#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Deserialize)]
14pub enum NormType {
15 WeightNorm,
16 TimeGroupNorm,
17 None,
18}
19
20#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Deserialize)]
21pub enum PadMode {
22 Constant,
23 Reflect,
24 Replicate,
25}
26
27#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
28pub struct Config {
29 pub target_bandwidths: Vec<f64>,
30 pub sampling_rate: usize,
31 pub audio_channels: usize,
32 pub normalize: bool,
33 pub chunk_length_s: Option<usize>,
34 pub overlap: Option<usize>,
35 pub hidden_size: usize,
36 pub num_filters: usize,
37 pub num_residual_layers: usize,
38 pub upsampling_ratios: Vec<usize>,
39 pub norm_type: NormType,
40 pub kernel_size: usize,
41 pub last_kernel_size: usize,
42 pub residual_kernel_size: usize,
43 pub dilation_growth_rate: usize,
44 pub use_causal_conv: bool,
45 pub pad_mode: PadMode,
46 pub compress: usize,
47 pub num_lstm_layers: usize,
48 pub trim_right_ratio: f64,
49 pub codebook_size: usize,
50 pub codebook_dim: Option<usize>,
51 pub use_conv_shortcut: bool,
52}
53
54impl Default for Config {
55 fn default() -> Self {
56 Self {
57 target_bandwidths: vec![1.5, 3.0, 6.0, 12.0, 24.0],
58 sampling_rate: 24_000,
59 audio_channels: 1,
60 normalize: false,
61 chunk_length_s: None,
62 overlap: None,
63 hidden_size: 128,
64 num_filters: 32,
65 num_residual_layers: 1,
66 upsampling_ratios: vec![8, 5, 4, 2],
67 norm_type: NormType::WeightNorm,
68 kernel_size: 7,
69 last_kernel_size: 7,
70 residual_kernel_size: 3,
71 dilation_growth_rate: 2,
72 use_causal_conv: true,
73 pad_mode: PadMode::Replicate,
75 compress: 2,
76 num_lstm_layers: 2,
77 trim_right_ratio: 1.0,
78 codebook_size: 1024,
79 codebook_dim: None,
80 use_conv_shortcut: true,
81 }
82 }
83}
84
85impl Config {
86 fn codebook_dim(&self) -> usize {
87 self.codebook_dim.unwrap_or(self.hidden_size)
88 }
89
90 fn frame_rate(&self) -> usize {
91 let hop_length: usize = self.upsampling_ratios.iter().product();
92 self.sampling_rate.div_ceil(hop_length)
93 }
94
95 fn num_quantizers(&self) -> usize {
96 let num = 1000f64
97 * self
98 .target_bandwidths
99 .last()
100 .expect("empty target_bandwidths");
101 (num as usize) / (self.frame_rate() * 10)
102 }
103}
104
105fn get_extra_padding_for_conv1d(
106 xs: &Tensor,
107 k_size: usize,
108 stride: usize,
109 padding_total: usize,
110) -> Result<usize> {
111 let len = xs.dim(D::Minus1)?;
112 let n_frames = (len + padding_total).saturating_sub(k_size) as f64 / stride as f64 + 1.0;
113 let ideal_len =
114 ((n_frames.ceil() as usize - 1) * stride + k_size).saturating_sub(padding_total);
115 Ok(ideal_len.saturating_sub(len))
116}
117
118fn pad1d(xs: &Tensor, pad_l: usize, pad_r: usize, mode: PadMode) -> Result<Tensor> {
119 match mode {
120 PadMode::Constant => xs.pad_with_zeros(D::Minus1, pad_l, pad_r),
121 PadMode::Reflect => candle::bail!("pad-mode 'reflect' is not supported"),
122 PadMode::Replicate => xs.pad_with_same(D::Minus1, pad_l, pad_r),
123 }
124}
125
126pub fn conv1d_weight_norm(
130 in_c: usize,
131 out_c: usize,
132 kernel_size: usize,
133 config: candle_nn::Conv1dConfig,
134 vb: VarBuilder,
135) -> Result<Conv1d> {
136 let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
137 let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
138 let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
139 let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
140 let bias = vb.get(out_c, "bias")?;
141 Ok(Conv1d::new(weight, Some(bias), config))
142}
143
144pub fn conv_transpose1d_weight_norm(
145 in_c: usize,
146 out_c: usize,
147 kernel_size: usize,
148 bias: bool,
149 config: candle_nn::ConvTranspose1dConfig,
150 vb: VarBuilder,
151) -> Result<ConvTranspose1d> {
152 let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
153 let weight_v = vb.get((in_c, out_c, kernel_size), "weight_v")?;
154 let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
155 let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
156 let bias = if bias {
157 Some(vb.get(out_c, "bias")?)
158 } else {
159 None
160 };
161 Ok(ConvTranspose1d::new(weight, bias, config))
162}
163
164struct CodebookEncode;
165
166impl candle::CustomOp2 for CodebookEncode {
167 fn name(&self) -> &'static str {
168 "cb"
169 }
170
171 fn cpu_fwd(
172 &self,
173 lhs_storage: &candle::CpuStorage,
174 lhs_layout: &Layout,
175 rhs_storage: &candle::CpuStorage,
176 rhs_layout: &Layout,
177 ) -> Result<(candle::CpuStorage, Shape)> {
178 use rayon::prelude::*;
179
180 let (lhs_dim1, lhs_dim2) = lhs_layout.shape().dims2()?;
181 let (rhs_dim1, rhs_dim2) = rhs_layout.shape().dims2()?;
182 if lhs_dim2 != rhs_dim2 {
183 candle::bail!("CodebookEncode, mismatch on last dim, {lhs_layout:?} {rhs_layout:?}");
184 }
185 if lhs_dim2 == 0 {
186 candle::bail!("CodebookEncode, empty last dim {lhs_layout:?}")
187 }
188 let lhs = match lhs_layout.contiguous_offsets() {
189 None => candle::bail!("CodebookEncode, lhs has to be contiguous, got {lhs_layout:?}"),
190 Some((o1, o2)) => {
191 let slice = lhs_storage.as_slice::<f32>()?;
192 &slice[o1..o2]
193 }
194 };
195 let rhs = match rhs_layout.contiguous_offsets() {
196 None => candle::bail!("CodebookEncode, rhs has to be contiguous, got {rhs_layout:?}"),
197 Some((o1, o2)) => {
198 let slice = rhs_storage.as_slice::<f32>()?;
199 &slice[o1..o2]
200 }
201 };
202 let dst = (0..lhs_dim1)
203 .into_par_iter()
204 .map(|idx1| {
205 let mut where_min = 0;
206 let mut min_dist = f32::INFINITY;
207 let lhs = &lhs[idx1 * lhs_dim2..(idx1 + 1) * lhs_dim2];
208 for idx2 in 0..rhs_dim1 {
209 let rhs = &rhs[idx2 * rhs_dim2..(idx2 + 1) * rhs_dim2];
210 let mut dist = 0f32;
211 for (a, b) in lhs.iter().zip(rhs.iter()) {
212 dist += (a - b) * (a - b)
213 }
214 if dist < min_dist {
215 min_dist = dist;
216 where_min = idx2;
217 }
218 }
219 where_min as u32
220 })
221 .collect();
222 let storage = candle::WithDType::to_cpu_storage_owned(dst);
223 Ok((storage, (lhs_dim1,).into()))
224 }
225}
226
227#[allow(unused)]
229#[derive(Clone, Debug)]
230pub struct EuclideanCodebook {
231 inited: Tensor,
232 cluster_size: Tensor,
233 embed: candle_nn::Embedding,
234 embed_avg: Tensor,
235 c2: Tensor,
236}
237
238impl EuclideanCodebook {
239 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
240 let inited = vb.get(1, "inited")?;
241 let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?;
242 let e_shape = (cfg.codebook_size, cfg.codebook_dim());
243 let embed = vb.get(e_shape, "embed")?;
244 let c2 = ((&embed * &embed)?.sum(D::Minus1)? / 2.0)?;
245 let embed_avg = vb.get(e_shape, "embed_avg")?;
246 Ok(Self {
247 inited,
248 cluster_size,
249 embed: candle_nn::Embedding::new(embed, cfg.codebook_dim()),
250 embed_avg,
251 c2,
252 })
253 }
254
255 pub fn encode_slow(&self, xs: &Tensor) -> Result<Tensor> {
256 let mut target_shape = xs.dims().to_vec();
257 target_shape.pop();
258 let xs = xs.flatten_to(D::Minus2)?;
259 let _ = xs.dims2()?;
260 let dot_prod = xs.matmul(&self.embed.embeddings().t()?)?;
261 let codes = self.c2.broadcast_sub(&dot_prod)?.argmin(D::Minus1)?;
262 codes.reshape(target_shape)
263 }
264
265 pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
266 let mut target_shape = xs.dims().to_vec();
267 target_shape.pop();
268 let xs = xs.flatten_to(D::Minus2)?;
269 let _ = xs.dims2()?;
270 let codes = Tensor::apply_op2(&xs, self.embed.embeddings(), CodebookEncode)?;
271 codes.reshape(target_shape)
272 }
273
274 pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
275 let quantize = self.embed.forward(embed_ind)?;
276 Ok(quantize)
277 }
278}
279
280#[derive(Clone, Debug)]
281pub struct VectorQuantization {
282 codebook: EuclideanCodebook,
283}
284
285impl VectorQuantization {
286 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
287 let codebook = EuclideanCodebook::new(cfg, vb.pp("codebook"))?;
288 Ok(Self { codebook })
289 }
290
291 pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
292 let xs = xs.transpose(1, 2)?;
293 self.codebook.encode_slow(&xs)
294 }
295
296 pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
297 let quantize = self.codebook.decode(embed_ind)?;
298 let quantize = quantize.transpose(1, 2)?;
299 Ok(quantize)
300 }
301}
302
303#[derive(Clone, Debug)]
304pub struct ResidualVectorQuantizer {
305 layers: Vec<VectorQuantization>,
306 dtype: DType,
307}
308
309impl ResidualVectorQuantizer {
310 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
311 let vb = &vb.pp("layers");
312 let layers = (0..cfg.num_quantizers())
313 .map(|i| VectorQuantization::new(cfg, vb.pp(i)))
314 .collect::<Result<Vec<_>>>()?;
315 Ok(Self {
316 layers,
317 dtype: vb.dtype(),
318 })
319 }
320
321 pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
322 let mut codes = Vec::with_capacity(self.layers.len());
323 let mut residual = xs.clone();
324 for layer in self.layers.iter() {
325 let indices = layer.encode(&residual)?;
326 let quantized = layer.decode(&indices)?;
327 residual = (residual - quantized)?;
328 codes.push(indices)
329 }
330 Tensor::stack(&codes, 0)
331 }
332
333 pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
334 let mut quantized_out = Tensor::zeros((), self.dtype, codes.device())?;
335 let ncodes = codes.dim(0)?;
336 if ncodes > self.layers.len() {
337 candle::bail!(
338 "codes shape {:?} does not match the number of quantization layers {}",
339 codes.shape(),
340 self.layers.len()
341 )
342 }
343 for (i, layer) in self.layers.iter().take(ncodes).enumerate() {
344 let quantized = layer.decode(&codes.i(i)?)?;
345 quantized_out = quantized.broadcast_add(&quantized_out)?;
346 }
347 Ok(quantized_out)
348 }
349}
350
351#[derive(Clone, Debug)]
353pub struct EncodecLSTM {
354 layers: Vec<candle_nn::LSTM>,
355}
356
357impl EncodecLSTM {
358 pub fn new(dim: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
359 let vb = &vb.pp("lstm");
360 let mut layers = vec![];
361 for layer_idx in 0..cfg.num_lstm_layers {
362 let config = candle_nn::LSTMConfig {
363 layer_idx,
364 ..Default::default()
365 };
366 let lstm = candle_nn::lstm(dim, dim, config, vb.clone())?;
367 layers.push(lstm)
368 }
369 Ok(Self { layers })
370 }
371}
372
373impl Module for EncodecLSTM {
374 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
375 use candle_nn::RNN;
376 let xs = xs.t()?;
378 let residual = &xs;
379 let mut xs = xs.clone();
380 for layer in self.layers.iter() {
381 let states = layer.seq(&xs)?;
382 xs = layer.states_to_tensor(&states)?;
383 }
384 let xs = (xs + residual)?.t()?;
385 Ok(xs)
386 }
387}
388
389#[derive(Clone, Debug)]
390pub struct EncodecConvTranspose1d {
391 conv: ConvTranspose1d,
392}
393
394impl EncodecConvTranspose1d {
395 fn new(
396 in_c: usize,
397 out_c: usize,
398 k: usize,
399 stride: usize,
400 _cfg: &Config,
401 vb: VarBuilder,
402 ) -> Result<Self> {
403 let cfg = candle_nn::ConvTranspose1dConfig {
404 stride,
405 ..Default::default()
406 };
407 let conv = conv_transpose1d_weight_norm(in_c, out_c, k, true, cfg, vb.pp("conv"))?;
408 Ok(Self { conv })
409 }
410}
411
412impl Module for EncodecConvTranspose1d {
413 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
414 xs.apply(&self.conv)
415 }
416}
417
418#[derive(Clone, Debug)]
419pub struct EncodecConv1d {
420 causal: bool,
421 conv: Conv1d,
422 norm: Option<candle_nn::GroupNorm>,
423 pad_mode: PadMode,
424}
425
426impl EncodecConv1d {
427 pub fn new(
428 in_c: usize,
429 out_c: usize,
430 kernel_size: usize,
431 stride: usize,
432 dilation: usize,
433 cfg: &Config,
434 vb: VarBuilder,
435 ) -> Result<Self> {
436 let conv = match cfg.norm_type {
437 NormType::WeightNorm => conv1d_weight_norm(
438 in_c,
439 out_c,
440 kernel_size,
441 candle_nn::Conv1dConfig {
442 stride,
443 dilation,
444 ..Default::default()
445 },
446 vb.pp("conv"),
447 )?,
448 NormType::None | NormType::TimeGroupNorm => conv1d(
449 in_c,
450 out_c,
451 kernel_size,
452 candle_nn::Conv1dConfig {
453 padding: 0,
454 stride,
455 groups: 1,
456 dilation: 1,
457 },
458 vb.pp("conv"),
459 )?,
460 };
461 let norm = match cfg.norm_type {
462 NormType::None | NormType::WeightNorm => None,
463 NormType::TimeGroupNorm => {
464 let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
465 Some(gn)
466 }
467 };
468 Ok(Self {
469 causal: cfg.use_causal_conv,
470 conv,
471 norm,
472 pad_mode: cfg.pad_mode,
473 })
474 }
475}
476
477impl Module for EncodecConv1d {
478 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
479 let (_b, _t, _c) = xs.dims3()?;
480 let k_size = self.conv.weight().dim(D::Minus1)?;
481 let conv_cfg = self.conv.config();
482 let k_size = (k_size - 1) * conv_cfg.dilation + 1;
484 let padding_total = k_size - conv_cfg.stride;
485 let extra_padding =
486 get_extra_padding_for_conv1d(xs, k_size, conv_cfg.stride, padding_total)?;
487 let xs = if self.causal {
488 pad1d(xs, padding_total, extra_padding, self.pad_mode)?
489 } else {
490 let padding_right = padding_total / 2;
491 let padding_left = padding_total - padding_right;
492 pad1d(
493 xs,
494 padding_left,
495 padding_right + extra_padding,
496 self.pad_mode,
497 )?
498 };
499 let xs = self.conv.forward(&xs)?;
500 match &self.norm {
501 None => Ok(xs),
502 Some(norm) => xs.apply(norm),
503 }
504 }
505}
506
507#[derive(Clone, Debug)]
508pub struct EncodecResnetBlock {
509 block_conv1: EncodecConv1d,
510 block_conv2: EncodecConv1d,
511 shortcut: Option<EncodecConv1d>,
512}
513
514impl EncodecResnetBlock {
515 pub fn new(
516 dim: usize,
517 (dilation1, dilation2): (usize, usize),
518 cfg: &Config,
519 vb: VarBuilder,
520 ) -> Result<Self> {
521 let h = dim / cfg.compress;
522 let mut layer = Layer::new(vb.pp("block"));
523 layer.inc();
525 let block_conv1 = EncodecConv1d::new(
526 dim,
527 h,
528 cfg.residual_kernel_size,
529 1,
530 dilation1,
531 cfg,
532 layer.next(),
533 )?;
534 layer.inc();
535 let block_conv2 = EncodecConv1d::new(h, dim, 1, 1, dilation2, cfg, layer.next())?;
536 let shortcut = if cfg.use_conv_shortcut {
537 let conv = EncodecConv1d::new(dim, dim, 1, 1, 1, cfg, vb.pp("shortcut"))?;
538 Some(conv)
539 } else {
540 None
541 };
542 Ok(Self {
543 block_conv1,
544 block_conv2,
545 shortcut,
546 })
547 }
548}
549
550impl Module for EncodecResnetBlock {
551 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
552 let residual = xs.clone();
553 let xs = xs.elu(1.)?;
554 let xs = self.block_conv1.forward(&xs)?;
555 let xs = xs.elu(1.)?;
556 let xs = self.block_conv2.forward(&xs)?;
557 let xs = match &self.shortcut {
558 None => (xs + residual)?,
559 Some(shortcut) => xs.add(&shortcut.forward(&residual)?)?,
560 };
561 Ok(xs)
562 }
563}
564
565struct Layer<'a> {
566 vb: VarBuilder<'a>,
567 cnt: usize,
568}
569
570impl<'a> Layer<'a> {
571 fn new(vb: VarBuilder<'a>) -> Self {
572 Self { vb, cnt: 0 }
573 }
574
575 fn inc(&mut self) {
576 self.cnt += 1;
577 }
578
579 fn next(&mut self) -> VarBuilder {
580 let vb = self.vb.pp(self.cnt.to_string());
581 self.cnt += 1;
582 vb
583 }
584}
585
586#[derive(Clone, Debug)]
587pub struct Encoder {
588 init_conv: EncodecConv1d,
589 sampling_layers: Vec<(Vec<EncodecResnetBlock>, EncodecConv1d)>,
590 final_lstm: EncodecLSTM,
591 final_conv: EncodecConv1d,
592}
593
594impl Encoder {
595 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
596 let mut layer = Layer::new(vb.pp("layers"));
597 let init_conv = EncodecConv1d::new(
598 cfg.audio_channels,
599 cfg.num_filters,
600 cfg.kernel_size,
601 1,
602 1,
603 cfg,
604 layer.next(),
605 )?;
606 let mut sampling_layers = vec![];
607 let mut scaling = 1;
608 for &ratio in cfg.upsampling_ratios.iter().rev() {
609 let current_scale = scaling * cfg.num_filters;
610 let mut resnets = vec![];
611 for j in 0..(cfg.num_residual_layers as u32) {
612 let resnet = EncodecResnetBlock::new(
613 current_scale,
614 (cfg.dilation_growth_rate.pow(j), 1),
615 cfg,
616 layer.next(),
617 )?;
618 resnets.push(resnet)
619 }
620 layer.inc(); let conv1d = EncodecConv1d::new(
622 current_scale,
623 current_scale * 2,
624 ratio * 2,
625 ratio,
626 1,
627 cfg,
628 layer.next(),
629 )?;
630 sampling_layers.push((resnets, conv1d));
631 scaling *= 2;
632 }
633 let final_lstm = EncodecLSTM::new(cfg.num_filters * scaling, cfg, layer.next())?;
634 layer.inc(); let final_conv = EncodecConv1d::new(
636 cfg.num_filters * scaling,
637 cfg.hidden_size,
638 cfg.last_kernel_size,
639 1,
640 1,
641 cfg,
642 layer.next(),
643 )?;
644 Ok(Self {
645 init_conv,
646 sampling_layers,
647 final_conv,
648 final_lstm,
649 })
650 }
651}
652
653impl Module for Encoder {
654 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
655 let mut xs = xs.apply(&self.init_conv)?;
656 for (resnets, conv) in self.sampling_layers.iter() {
657 for resnet in resnets.iter() {
658 xs = xs.apply(resnet)?;
659 }
660 xs = xs.elu(1.0)?.apply(conv)?;
661 }
662 xs.apply(&self.final_lstm)?
663 .elu(1.0)?
664 .apply(&self.final_conv)
665 }
666}
667
668#[derive(Clone, Debug)]
669pub struct Decoder {
670 init_conv: EncodecConv1d,
671 init_lstm: EncodecLSTM,
672 sampling_layers: Vec<(EncodecConvTranspose1d, Vec<EncodecResnetBlock>)>,
673 final_conv: EncodecConv1d,
674}
675
676impl Decoder {
677 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
678 let mut layer = Layer::new(vb.pp("layers"));
679 let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32);
680 let init_conv = EncodecConv1d::new(
681 cfg.hidden_size,
682 cfg.num_filters * scaling,
683 cfg.last_kernel_size,
684 1,
685 1,
686 cfg,
687 layer.next(),
688 )?;
689 let init_lstm = EncodecLSTM::new(cfg.num_filters * scaling, cfg, layer.next())?;
690 let mut sampling_layers = vec![];
691 for &ratio in cfg.upsampling_ratios.iter() {
692 let current_scale = scaling * cfg.num_filters;
693 layer.inc(); let conv1d = EncodecConvTranspose1d::new(
695 current_scale,
696 current_scale / 2,
697 ratio * 2,
698 ratio,
699 cfg,
700 layer.next(),
701 )?;
702 let mut resnets = vec![];
703 for j in 0..(cfg.num_residual_layers as u32) {
704 let resnet = EncodecResnetBlock::new(
705 current_scale / 2,
706 (cfg.dilation_growth_rate.pow(j), 1),
707 cfg,
708 layer.next(),
709 )?;
710 resnets.push(resnet)
711 }
712 sampling_layers.push((conv1d, resnets));
713 scaling /= 2;
714 }
715 layer.inc(); let final_conv = EncodecConv1d::new(
717 cfg.num_filters,
718 cfg.audio_channels,
719 cfg.last_kernel_size,
720 1,
721 1,
722 cfg,
723 layer.next(),
724 )?;
725 Ok(Self {
726 init_conv,
727 init_lstm,
728 sampling_layers,
729 final_conv,
730 })
731 }
732}
733
734impl Module for Decoder {
735 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
736 let mut xs = xs.apply(&self.init_conv)?.apply(&self.init_lstm)?;
737 for (conv, resnets) in self.sampling_layers.iter() {
738 xs = xs.elu(1.)?.apply(conv)?;
739 for resnet in resnets.iter() {
740 xs = xs.apply(resnet)?
741 }
742 }
743 xs.elu(1.)?.apply(&self.final_conv)
744 }
745}
746
747#[derive(Debug)]
748pub struct Model {
749 encoder: Encoder,
750 decoder: Decoder,
751 quantizer: ResidualVectorQuantizer,
752}
753
754impl Model {
755 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
756 let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
757 let decoder = Decoder::new(cfg, vb.pp("decoder"))?;
758 let quantizer = ResidualVectorQuantizer::new(cfg, vb.pp("quantizer"))?;
759 Ok(Self {
760 encoder,
761 decoder,
762 quantizer,
763 })
764 }
765
766 pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
767 let xs = self.encoder.forward(xs)?;
768 let codes = self.quantizer.encode(&xs)?;
769 codes.transpose(0, 1)
770 }
771
772 pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
773 let (_b_sz, _codebooks, _seqlen) = codes.dims3()?;
774 let codes = codes.transpose(0, 1)?;
775 let embeddings = self.quantizer.decode(&codes)?;
776 let outputs = self.decoder.forward(&embeddings)?;
777 Ok(outputs)
778 }
779}