candle_transformers/models/mimi/
encodec.rs

1// Copyright (c) Kyutai, all rights reserved.
2// This source code is licensed under the license found in the
3// LICENSE file in the root directory of this source tree.
4
5use super::{conv, quantization, seanet, transformer};
6use candle::{DType, Device, Module, Result, StreamTensor, StreamingModule, Tensor};
7use candle_nn::VarBuilder;
8
9#[derive(Debug, Copy, Clone, PartialEq, Eq)]
10pub enum ResampleMethod {
11    Conv,
12    Interpolate,
13}
14
15#[derive(Debug, Clone)]
16pub struct Config {
17    pub channels: usize,
18    pub sample_rate: f64,
19    pub frame_rate: f64,
20    pub renormalize: bool,
21    pub resample_method: ResampleMethod,
22    pub seanet: seanet::Config,
23    pub transformer: transformer::Config,
24    pub quantizer_n_q: usize,
25    pub quantizer_bins: usize,
26    pub quantizer_dim: usize,
27}
28
29impl Config {
30    // /lustre/scwpod02/client/kyutai/alex/mimi_exp/xps/b7d2bd5a/.hydra/config.yaml
31    pub fn v0_1(num_codebooks: Option<usize>) -> Self {
32        let seanet_cfg = seanet::Config {
33            dimension: 512,
34            channels: 1,
35            causal: true,
36            n_filters: 64,
37            n_residual_layers: 1,
38            activation: candle_nn::Activation::Elu(1.),
39            compress: 2,
40            dilation_base: 2,
41            disable_norm_outer_blocks: 0,
42            final_activation: None,
43            kernel_size: 7,
44            residual_kernel_size: 3,
45            last_kernel_size: 3,
46            lstm: 0,
47            norm: conv::Norm::WeightNorm,
48            pad_mode: conv::PadMode::Constant,
49            ratios: vec![8, 6, 5, 4],
50            true_skip: true,
51        };
52        let transformer_cfg = transformer::Config {
53            d_model: seanet_cfg.dimension,
54            num_heads: 8,
55            num_layers: 8,
56            causal: true,
57            norm_first: true,
58            bias_ff: false,
59            bias_attn: false,
60            layer_scale: Some(0.01),
61            context: 250,
62            conv_kernel_size: 5,
63            use_conv_bias: true,
64            use_conv_block: false,
65            cross_attention: false,
66            max_period: 10000,
67            gating: None,
68            norm: super::NormType::LayerNorm,
69            positional_embedding: transformer::PositionalEmbedding::Rope,
70
71            dim_feedforward: 2048,
72            kv_repeat: 1,
73            conv_layout: true, // see builders.py
74            max_seq_len: 8192, // the transformer works at 25hz so this is ~5 mins.
75        };
76        Config {
77            channels: 1,
78            sample_rate: 24_000.,
79            frame_rate: 12.5,
80            renormalize: true,
81            resample_method: ResampleMethod::Conv,
82            seanet: seanet_cfg,
83            transformer: transformer_cfg,
84            quantizer_n_q: num_codebooks.unwrap_or(16),
85            quantizer_bins: 2048,
86            quantizer_dim: 256,
87        }
88    }
89}
90
91#[derive(Debug, Clone)]
92pub struct Encodec {
93    encoder: seanet::SeaNetEncoder,
94    decoder: seanet::SeaNetDecoder,
95    encoder_transformer: transformer::ProjectedTransformer,
96    decoder_transformer: transformer::ProjectedTransformer,
97    downsample: conv::ConvDownsample1d,
98    upsample: conv::ConvTrUpsample1d,
99    quantizer: quantization::SplitResidualVectorQuantizer,
100    config: Config,
101}
102
103impl Encodec {
104    pub fn new(cfg: Config, vb: VarBuilder) -> Result<Self> {
105        let dim = cfg.seanet.dimension;
106        let encoder = seanet::SeaNetEncoder::new(&cfg.seanet, vb.pp("encoder"))?;
107        let decoder = seanet::SeaNetDecoder::new(&cfg.seanet, vb.pp("decoder"))?;
108        let encoder_transformer = transformer::ProjectedTransformer::new(
109            dim,
110            &[dim],
111            &cfg.transformer,
112            vb.pp("encoder_transformer"),
113        )?;
114        let decoder_transformer = transformer::ProjectedTransformer::new(
115            dim,
116            &[dim],
117            &cfg.transformer,
118            vb.pp("decoder_transformer"),
119        )?;
120        let quantizer = quantization::SplitResidualVectorQuantizer::new(
121            /* dim */ cfg.quantizer_dim,
122            /* input_dim */ Some(dim),
123            /* output_dim */ Some(dim),
124            /* n_q */ cfg.quantizer_n_q,
125            /* bins */ cfg.quantizer_bins,
126            vb.pp("quantizer"),
127        )?;
128        let encoder_frame_rate =
129            cfg.sample_rate / cfg.seanet.ratios.iter().product::<usize>() as f64;
130
131        let downsample_stride = (encoder_frame_rate / cfg.frame_rate) as usize;
132        // `upsample` and `downsample` only apply if frame_rate is different from encoder_frame_rate.
133        let downsample = conv::ConvDownsample1d::new(
134            /* stride */ downsample_stride,
135            /* dim */ dim,
136            /* causal */ true,
137            /* learnt */ true,
138            vb.pp("downsample"),
139        )?;
140        let upsample = conv::ConvTrUpsample1d::new(
141            /* stride */ downsample_stride,
142            /* dim */ dim,
143            /* causal */ true,
144            /* learnt */ true,
145            vb.pp("upsample"),
146        )?;
147
148        Ok(Self {
149            encoder,
150            decoder,
151            encoder_transformer,
152            decoder_transformer,
153            quantizer,
154            downsample,
155            upsample,
156            config: cfg,
157        })
158    }
159
160    pub fn config(&self) -> &Config {
161        &self.config
162    }
163
164    pub fn encode_pre_quantize(&mut self, xs: &Tensor) -> Result<Tensor> {
165        let xs = self.encoder.forward(xs)?;
166        self.encoder_transformer.reset_state();
167        let xs = self.encoder_transformer.forward(&xs)?;
168        let xs = &xs[0];
169        xs.apply(&self.downsample)
170    }
171
172    pub fn encode(&mut self, xs: &Tensor) -> Result<Tensor> {
173        let xs = self.encoder.forward(xs)?;
174        self.encoder_transformer.reset_state();
175        let xs = self.encoder_transformer.forward(&xs)?;
176        let xs = &xs[0];
177        let xs = xs.apply(&self.downsample)?;
178        let codes = self.quantizer.encode(&xs)?;
179        Ok(codes)
180    }
181
182    pub fn encode_step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
183        let xs = self.encoder.step(xs)?;
184        let xs = self.encoder_transformer.step(&xs)?;
185        let xs = self.downsample.step(&xs)?;
186        match xs.as_option() {
187            None => Ok(().into()),
188            Some(xs) => {
189                let codes = self.quantizer.encode(xs)?;
190                Ok(codes.into())
191            }
192        }
193    }
194
195    pub fn decode(&mut self, codes: &Tensor) -> Result<Tensor> {
196        let emb = self.quantizer.decode(codes)?;
197        let emb = emb.apply(&self.upsample)?;
198        self.decoder_transformer.reset_state();
199        let outs = self.decoder_transformer.forward(&emb)?;
200        let out = &outs[0];
201        self.decoder.forward(out)
202    }
203
204    pub fn decode_step(&mut self, codes: &StreamTensor) -> Result<StreamTensor> {
205        let emb = match codes.as_option() {
206            Some(codes) => StreamTensor::from_tensor(self.quantizer.decode(codes)?),
207            None => StreamTensor::empty(),
208        };
209        let emb = self.upsample.step(&emb)?;
210        let out = self.decoder_transformer.step(&emb)?;
211        self.decoder.step(&out)
212    }
213
214    pub fn reset_state(&mut self) {
215        self.encoder.reset_state();
216        self.encoder_transformer.reset_state();
217        self.decoder.reset_state();
218        self.decoder_transformer.reset_state();
219        self.upsample.reset_state();
220    }
221}
222
223pub fn load(model_file: &str, num_codebooks: Option<usize>, dev: &Device) -> Result<Encodec> {
224    let vb =
225        unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, dev)? };
226    let cfg = Config::v0_1(num_codebooks);
227    let encodec = Encodec::new(cfg, vb)?;
228    Ok(encodec)
229}