candle_transformers/models/mimi/
seanet.rs

1// Copyright (c) Kyutai, all rights reserved.
2// This source code is licensed under the license found in the
3// LICENSE file in the root directory of this source tree.
4
5use candle::{streaming, Module, Result, StreamTensor, StreamingModule, Tensor};
6use candle_nn::VarBuilder;
7
8use super::conv::{StreamableConv1d, StreamableConvTranspose1d};
9
10#[derive(Debug, Clone)]
11pub struct Config {
12    pub dimension: usize,
13    pub channels: usize,
14    pub causal: bool,
15    pub n_filters: usize,
16    pub n_residual_layers: usize,
17    pub ratios: Vec<usize>,
18    pub activation: candle_nn::Activation,
19    pub norm: super::conv::Norm,
20    pub kernel_size: usize,
21    pub residual_kernel_size: usize,
22    pub last_kernel_size: usize,
23    pub dilation_base: usize,
24    pub pad_mode: super::conv::PadMode,
25    pub true_skip: bool,
26    pub compress: usize,
27    pub lstm: usize,
28    pub disable_norm_outer_blocks: usize,
29    pub final_activation: Option<candle_nn::Activation>,
30}
31
32#[derive(Debug, Clone)]
33pub struct SeaNetResnetBlock {
34    block: Vec<StreamableConv1d>,
35    shortcut: Option<StreamableConv1d>,
36    activation: candle_nn::Activation,
37    skip_op: candle::StreamingBinOp,
38    span: tracing::Span,
39}
40
41impl SeaNetResnetBlock {
42    #[allow(clippy::too_many_arguments)]
43    pub fn new(
44        dim: usize,
45        k_sizes_and_dilations: &[(usize, usize)],
46        activation: candle_nn::Activation,
47        norm: Option<super::conv::Norm>,
48        causal: bool,
49        pad_mode: super::conv::PadMode,
50        compress: usize,
51        true_skip: bool,
52        vb: VarBuilder,
53    ) -> Result<Self> {
54        let mut block = Vec::with_capacity(k_sizes_and_dilations.len());
55        let hidden = dim / compress;
56        let vb_b = vb.pp("block");
57        for (i, (k_size, dilation)) in k_sizes_and_dilations.iter().enumerate() {
58            let in_c = if i == 0 { dim } else { hidden };
59            let out_c = if i == k_sizes_and_dilations.len() - 1 {
60                dim
61            } else {
62                hidden
63            };
64            let c = StreamableConv1d::new(
65                in_c,
66                out_c,
67                /* k_size */ *k_size,
68                /* stride */ 1,
69                /* dilation */ *dilation,
70                /* groups */ 1,
71                /* bias */ true,
72                /* causal */ causal,
73                /* norm */ norm,
74                /* pad_mode */ pad_mode,
75                vb_b.pp(2 * i + 1),
76            )?;
77            block.push(c)
78        }
79        let shortcut = if true_skip {
80            None
81        } else {
82            let c = StreamableConv1d::new(
83                dim,
84                dim,
85                /* k_size */ 1,
86                /* stride */ 1,
87                /* dilation */ 1,
88                /* groups */ 1,
89                /* bias */ true,
90                /* causal */ causal,
91                /* norm */ norm,
92                /* pad_mode */ pad_mode,
93                vb.pp("shortcut"),
94            )?;
95            Some(c)
96        };
97        Ok(Self {
98            block,
99            shortcut,
100            activation,
101            skip_op: streaming::StreamingBinOp::new(streaming::BinOp::Add, candle::D::Minus1),
102            span: tracing::span!(tracing::Level::TRACE, "sea-resnet"),
103        })
104    }
105}
106
107impl Module for SeaNetResnetBlock {
108    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
109        let _enter = self.span.enter();
110        let mut ys = xs.clone();
111        for block in self.block.iter() {
112            ys = ys.apply(&self.activation)?.apply(block)?;
113        }
114        match self.shortcut.as_ref() {
115            None => ys + xs,
116            Some(shortcut) => ys + xs.apply(shortcut),
117        }
118    }
119}
120
121impl StreamingModule for SeaNetResnetBlock {
122    fn reset_state(&mut self) {
123        for block in self.block.iter_mut() {
124            block.reset_state()
125        }
126        if let Some(shortcut) = self.shortcut.as_mut() {
127            shortcut.reset_state()
128        }
129    }
130
131    fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
132        let _enter = self.span.enter();
133        let mut ys = xs.clone();
134        for block in self.block.iter_mut() {
135            ys = block.step(&ys.apply(&self.activation)?)?;
136        }
137        match self.shortcut.as_ref() {
138            None => self.skip_op.step(&ys, xs),
139            Some(shortcut) => self.skip_op.step(&ys, &xs.apply(shortcut)?),
140        }
141    }
142}
143
144#[derive(Debug, Clone)]
145struct EncoderLayer {
146    residuals: Vec<SeaNetResnetBlock>,
147    downsample: StreamableConv1d,
148}
149
150#[derive(Debug, Clone)]
151pub struct SeaNetEncoder {
152    init_conv1d: StreamableConv1d,
153    activation: candle_nn::Activation,
154    layers: Vec<EncoderLayer>,
155    final_conv1d: StreamableConv1d,
156    span: tracing::Span,
157}
158
159impl SeaNetEncoder {
160    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
161        if cfg.lstm > 0 {
162            candle::bail!("seanet lstm is not supported")
163        }
164        let n_blocks = 2 + cfg.ratios.len();
165        let mut mult = 1usize;
166        let init_norm = if cfg.disable_norm_outer_blocks >= 1 {
167            None
168        } else {
169            Some(cfg.norm)
170        };
171        let mut layer_idx = 0;
172        let vb = vb.pp("layers");
173        let init_conv1d = StreamableConv1d::new(
174            cfg.channels,
175            mult * cfg.n_filters,
176            cfg.kernel_size,
177            /* stride */ 1,
178            /* dilation */ 1,
179            /* groups */ 1,
180            /* bias */ true,
181            /* causal */ cfg.causal,
182            /* norm */ init_norm,
183            /* pad_mode */ cfg.pad_mode,
184            vb.pp(layer_idx),
185        )?;
186        layer_idx += 1;
187        let mut layers = Vec::with_capacity(cfg.ratios.len());
188
189        for (i, &ratio) in cfg.ratios.iter().rev().enumerate() {
190            let norm = if cfg.disable_norm_outer_blocks >= i + 2 {
191                None
192            } else {
193                Some(cfg.norm)
194            };
195            let mut residuals = Vec::with_capacity(cfg.n_residual_layers);
196            for j in 0..cfg.n_residual_layers {
197                let resnet_block = SeaNetResnetBlock::new(
198                    mult * cfg.n_filters,
199                    &[
200                        (cfg.residual_kernel_size, cfg.dilation_base.pow(j as u32)),
201                        (1, 1),
202                    ],
203                    cfg.activation,
204                    norm,
205                    cfg.causal,
206                    cfg.pad_mode,
207                    cfg.compress,
208                    cfg.true_skip,
209                    vb.pp(layer_idx),
210                )?;
211                residuals.push(resnet_block);
212                layer_idx += 1;
213            }
214            let downsample = StreamableConv1d::new(
215                mult * cfg.n_filters,
216                mult * cfg.n_filters * 2,
217                /* k_size */ ratio * 2,
218                /* stride */ ratio,
219                /* dilation */ 1,
220                /* groups */ 1,
221                /* bias */ true,
222                /* causal */ true,
223                /* norm */ norm,
224                /* pad_mode */ cfg.pad_mode,
225                vb.pp(layer_idx + 1),
226            )?;
227            layer_idx += 2;
228            let layer = EncoderLayer {
229                downsample,
230                residuals,
231            };
232            layers.push(layer);
233            mult *= 2
234        }
235
236        let final_norm = if cfg.disable_norm_outer_blocks >= n_blocks {
237            None
238        } else {
239            Some(cfg.norm)
240        };
241        let final_conv1d = StreamableConv1d::new(
242            mult * cfg.n_filters,
243            cfg.dimension,
244            cfg.last_kernel_size,
245            /* stride */ 1,
246            /* dilation */ 1,
247            /* groups */ 1,
248            /* bias */ true,
249            /* causal */ cfg.causal,
250            /* norm */ final_norm,
251            /* pad_mode */ cfg.pad_mode,
252            vb.pp(layer_idx + 1),
253        )?;
254        Ok(Self {
255            init_conv1d,
256            activation: cfg.activation,
257            layers,
258            final_conv1d,
259            span: tracing::span!(tracing::Level::TRACE, "sea-encoder"),
260        })
261    }
262}
263
264impl Module for SeaNetEncoder {
265    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
266        let _enter = self.span.enter();
267        let mut xs = xs.apply(&self.init_conv1d)?;
268        for layer in self.layers.iter() {
269            for residual in layer.residuals.iter() {
270                xs = xs.apply(residual)?
271            }
272            xs = xs.apply(&self.activation)?.apply(&layer.downsample)?;
273        }
274        xs.apply(&self.activation)?.apply(&self.final_conv1d)
275    }
276}
277
278impl StreamingModule for SeaNetEncoder {
279    fn reset_state(&mut self) {
280        self.init_conv1d.reset_state();
281        self.layers.iter_mut().for_each(|v| {
282            v.residuals.iter_mut().for_each(|v| v.reset_state());
283            v.downsample.reset_state()
284        });
285        self.final_conv1d.reset_state();
286    }
287
288    fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
289        let _enter = self.span.enter();
290        let mut xs = self.init_conv1d.step(xs)?;
291        for layer in self.layers.iter_mut() {
292            for residual in layer.residuals.iter_mut() {
293                xs = residual.step(&xs)?;
294            }
295            xs = layer.downsample.step(&xs.apply(&self.activation)?)?;
296        }
297        self.final_conv1d.step(&xs.apply(&self.activation)?)
298    }
299}
300
301#[derive(Debug, Clone)]
302struct DecoderLayer {
303    upsample: StreamableConvTranspose1d,
304    residuals: Vec<SeaNetResnetBlock>,
305}
306
307#[derive(Debug, Clone)]
308pub struct SeaNetDecoder {
309    init_conv1d: StreamableConv1d,
310    activation: candle_nn::Activation,
311    layers: Vec<DecoderLayer>,
312    final_conv1d: StreamableConv1d,
313    final_activation: Option<candle_nn::Activation>,
314    span: tracing::Span,
315}
316
317impl SeaNetDecoder {
318    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
319        if cfg.lstm > 0 {
320            candle::bail!("seanet lstm is not supported")
321        }
322        let n_blocks = 2 + cfg.ratios.len();
323        let mut mult = 1 << cfg.ratios.len();
324        let init_norm = if cfg.disable_norm_outer_blocks == n_blocks {
325            None
326        } else {
327            Some(cfg.norm)
328        };
329        let mut layer_idx = 0;
330        let vb = vb.pp("layers");
331        let init_conv1d = StreamableConv1d::new(
332            cfg.dimension,
333            mult * cfg.n_filters,
334            cfg.kernel_size,
335            /* stride */ 1,
336            /* dilation */ 1,
337            /* groups */ 1,
338            /* bias */ true,
339            /* causal */ cfg.causal,
340            /* norm */ init_norm,
341            /* pad_mode */ cfg.pad_mode,
342            vb.pp(layer_idx),
343        )?;
344        layer_idx += 1;
345        let mut layers = Vec::with_capacity(cfg.ratios.len());
346        for (i, &ratio) in cfg.ratios.iter().enumerate() {
347            let norm = if cfg.disable_norm_outer_blocks + i + 1 >= n_blocks {
348                None
349            } else {
350                Some(cfg.norm)
351            };
352            let upsample = StreamableConvTranspose1d::new(
353                mult * cfg.n_filters,
354                mult * cfg.n_filters / 2,
355                /* k_size */ ratio * 2,
356                /* stride */ ratio,
357                /* groups */ 1,
358                /* bias */ true,
359                /* causal */ true,
360                /* norm */ norm,
361                vb.pp(layer_idx + 1),
362            )?;
363            layer_idx += 2;
364
365            let mut residuals = Vec::with_capacity(cfg.n_residual_layers);
366            for j in 0..cfg.n_residual_layers {
367                let resnet_block = SeaNetResnetBlock::new(
368                    mult * cfg.n_filters / 2,
369                    &[
370                        (cfg.residual_kernel_size, cfg.dilation_base.pow(j as u32)),
371                        (1, 1),
372                    ],
373                    cfg.activation,
374                    norm,
375                    cfg.causal,
376                    cfg.pad_mode,
377                    cfg.compress,
378                    cfg.true_skip,
379                    vb.pp(layer_idx),
380                )?;
381                residuals.push(resnet_block);
382                layer_idx += 1;
383            }
384            let layer = DecoderLayer {
385                upsample,
386                residuals,
387            };
388            layers.push(layer);
389            mult /= 2
390        }
391        let final_norm = if cfg.disable_norm_outer_blocks >= 1 {
392            None
393        } else {
394            Some(cfg.norm)
395        };
396        let final_conv1d = StreamableConv1d::new(
397            cfg.n_filters,
398            cfg.channels,
399            cfg.last_kernel_size,
400            /* stride */ 1,
401            /* dilation */ 1,
402            /* groups */ 1,
403            /* bias */ true,
404            /* causal */ cfg.causal,
405            /* norm */ final_norm,
406            /* pad_mode */ cfg.pad_mode,
407            vb.pp(layer_idx + 1),
408        )?;
409        Ok(Self {
410            init_conv1d,
411            activation: cfg.activation,
412            layers,
413            final_conv1d,
414            final_activation: cfg.final_activation,
415            span: tracing::span!(tracing::Level::TRACE, "sea-decoder"),
416        })
417    }
418}
419
420impl Module for SeaNetDecoder {
421    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
422        let _enter = self.span.enter();
423        let mut xs = xs.apply(&self.init_conv1d)?;
424        for layer in self.layers.iter() {
425            xs = xs.apply(&self.activation)?.apply(&layer.upsample)?;
426            for residual in layer.residuals.iter() {
427                xs = xs.apply(residual)?
428            }
429        }
430        let xs = xs.apply(&self.activation)?.apply(&self.final_conv1d)?;
431        let xs = match self.final_activation.as_ref() {
432            None => xs,
433            Some(act) => xs.apply(act)?,
434        };
435        Ok(xs)
436    }
437}
438
439impl StreamingModule for SeaNetDecoder {
440    fn reset_state(&mut self) {
441        self.init_conv1d.reset_state();
442        self.layers.iter_mut().for_each(|v| {
443            v.residuals.iter_mut().for_each(|v| v.reset_state());
444            v.upsample.reset_state()
445        });
446        self.final_conv1d.reset_state();
447    }
448
449    fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
450        let _enter = self.span.enter();
451        let mut xs = self.init_conv1d.step(xs)?;
452        for layer in self.layers.iter_mut() {
453            xs = layer.upsample.step(&xs.apply(&self.activation)?)?;
454            for residual in layer.residuals.iter_mut() {
455                xs = residual.step(&xs)?;
456            }
457        }
458        let xs = self.final_conv1d.step(&xs.apply(&self.activation)?)?;
459        let xs = match self.final_activation.as_ref() {
460            None => xs,
461            Some(act) => xs.apply(act)?,
462        };
463        Ok(xs)
464    }
465}