candle_transformers/models/mimi/
encodec.rs1use super::{conv, quantization, seanet, transformer};
6use candle::{DType, Device, Module, Result, StreamTensor, StreamingModule, Tensor};
7use candle_nn::VarBuilder;
8
9#[derive(Debug, Copy, Clone, PartialEq, Eq)]
10pub enum ResampleMethod {
11 Conv,
12 Interpolate,
13}
14
15#[derive(Debug, Clone)]
16pub struct Config {
17 pub channels: usize,
18 pub sample_rate: f64,
19 pub frame_rate: f64,
20 pub renormalize: bool,
21 pub resample_method: ResampleMethod,
22 pub seanet: seanet::Config,
23 pub transformer: transformer::Config,
24 pub quantizer_n_q: usize,
25 pub quantizer_bins: usize,
26 pub quantizer_dim: usize,
27}
28
29impl Config {
30 pub fn v0_1(num_codebooks: Option<usize>) -> Self {
32 let seanet_cfg = seanet::Config {
33 dimension: 512,
34 channels: 1,
35 causal: true,
36 n_filters: 64,
37 n_residual_layers: 1,
38 activation: candle_nn::Activation::Elu(1.),
39 compress: 2,
40 dilation_base: 2,
41 disable_norm_outer_blocks: 0,
42 final_activation: None,
43 kernel_size: 7,
44 residual_kernel_size: 3,
45 last_kernel_size: 3,
46 lstm: 0,
47 norm: conv::Norm::WeightNorm,
48 pad_mode: conv::PadMode::Constant,
49 ratios: vec![8, 6, 5, 4],
50 true_skip: true,
51 };
52 let transformer_cfg = transformer::Config {
53 d_model: seanet_cfg.dimension,
54 num_heads: 8,
55 num_layers: 8,
56 causal: true,
57 norm_first: true,
58 bias_ff: false,
59 bias_attn: false,
60 layer_scale: Some(0.01),
61 context: 250,
62 conv_kernel_size: 5,
63 use_conv_bias: true,
64 use_conv_block: false,
65 cross_attention: false,
66 max_period: 10000,
67 gating: None,
68 norm: super::NormType::LayerNorm,
69 positional_embedding: transformer::PositionalEmbedding::Rope,
70
71 dim_feedforward: 2048,
72 kv_repeat: 1,
73 conv_layout: true, max_seq_len: 8192, };
76 Config {
77 channels: 1,
78 sample_rate: 24_000.,
79 frame_rate: 12.5,
80 renormalize: true,
81 resample_method: ResampleMethod::Conv,
82 seanet: seanet_cfg,
83 transformer: transformer_cfg,
84 quantizer_n_q: num_codebooks.unwrap_or(16),
85 quantizer_bins: 2048,
86 quantizer_dim: 256,
87 }
88 }
89}
90
91#[derive(Debug, Clone)]
92pub struct Encodec {
93 encoder: seanet::SeaNetEncoder,
94 decoder: seanet::SeaNetDecoder,
95 encoder_transformer: transformer::ProjectedTransformer,
96 decoder_transformer: transformer::ProjectedTransformer,
97 downsample: conv::ConvDownsample1d,
98 upsample: conv::ConvTrUpsample1d,
99 quantizer: quantization::SplitResidualVectorQuantizer,
100 config: Config,
101}
102
103impl Encodec {
104 pub fn new(cfg: Config, vb: VarBuilder) -> Result<Self> {
105 let dim = cfg.seanet.dimension;
106 let encoder = seanet::SeaNetEncoder::new(&cfg.seanet, vb.pp("encoder"))?;
107 let decoder = seanet::SeaNetDecoder::new(&cfg.seanet, vb.pp("decoder"))?;
108 let encoder_transformer = transformer::ProjectedTransformer::new(
109 dim,
110 &[dim],
111 &cfg.transformer,
112 vb.pp("encoder_transformer"),
113 )?;
114 let decoder_transformer = transformer::ProjectedTransformer::new(
115 dim,
116 &[dim],
117 &cfg.transformer,
118 vb.pp("decoder_transformer"),
119 )?;
120 let quantizer = quantization::SplitResidualVectorQuantizer::new(
121 cfg.quantizer_dim,
122 Some(dim),
123 Some(dim),
124 cfg.quantizer_n_q,
125 cfg.quantizer_bins,
126 vb.pp("quantizer"),
127 )?;
128 let encoder_frame_rate =
129 cfg.sample_rate / cfg.seanet.ratios.iter().product::<usize>() as f64;
130
131 let downsample_stride = (encoder_frame_rate / cfg.frame_rate) as usize;
132 let downsample = conv::ConvDownsample1d::new(
134 downsample_stride,
135 dim,
136 true,
137 true,
138 vb.pp("downsample"),
139 )?;
140 let upsample = conv::ConvTrUpsample1d::new(
141 downsample_stride,
142 dim,
143 true,
144 true,
145 vb.pp("upsample"),
146 )?;
147
148 Ok(Self {
149 encoder,
150 decoder,
151 encoder_transformer,
152 decoder_transformer,
153 quantizer,
154 downsample,
155 upsample,
156 config: cfg,
157 })
158 }
159
160 pub fn config(&self) -> &Config {
161 &self.config
162 }
163
164 pub fn encode_pre_quantize(&mut self, xs: &Tensor) -> Result<Tensor> {
165 let xs = self.encoder.forward(xs)?;
166 self.encoder_transformer.reset_state();
167 let xs = self.encoder_transformer.forward(&xs)?;
168 let xs = &xs[0];
169 xs.apply(&self.downsample)
170 }
171
172 pub fn encode(&mut self, xs: &Tensor) -> Result<Tensor> {
173 let xs = self.encoder.forward(xs)?;
174 self.encoder_transformer.reset_state();
175 let xs = self.encoder_transformer.forward(&xs)?;
176 let xs = &xs[0];
177 let xs = xs.apply(&self.downsample)?;
178 let codes = self.quantizer.encode(&xs)?;
179 Ok(codes)
180 }
181
182 pub fn encode_step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
183 let xs = self.encoder.step(xs)?;
184 let xs = self.encoder_transformer.step(&xs)?;
185 let xs = self.downsample.step(&xs)?;
186 match xs.as_option() {
187 None => Ok(().into()),
188 Some(xs) => {
189 let codes = self.quantizer.encode(xs)?;
190 Ok(codes.into())
191 }
192 }
193 }
194
195 pub fn decode(&mut self, codes: &Tensor) -> Result<Tensor> {
196 let emb = self.quantizer.decode(codes)?;
197 let emb = emb.apply(&self.upsample)?;
198 self.decoder_transformer.reset_state();
199 let outs = self.decoder_transformer.forward(&emb)?;
200 let out = &outs[0];
201 self.decoder.forward(out)
202 }
203
204 pub fn decode_step(&mut self, codes: &StreamTensor) -> Result<StreamTensor> {
205 let emb = match codes.as_option() {
206 Some(codes) => StreamTensor::from_tensor(self.quantizer.decode(codes)?),
207 None => StreamTensor::empty(),
208 };
209 let emb = self.upsample.step(&emb)?;
210 let out = self.decoder_transformer.step(&emb)?;
211 self.decoder.step(&out)
212 }
213
214 pub fn reset_state(&mut self) {
215 self.encoder.reset_state();
216 self.encoder_transformer.reset_state();
217 self.decoder.reset_state();
218 self.decoder_transformer.reset_state();
219 self.upsample.reset_state();
220 }
221}
222
223pub fn load(model_file: &str, num_codebooks: Option<usize>, dev: &Device) -> Result<Encodec> {
224 let vb =
225 unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, dev)? };
226 let cfg = Config::v0_1(num_codebooks);
227 let encodec = Encodec::new(cfg, vb)?;
228 Ok(encodec)
229}