1use candle::{streaming, Module, Result, StreamTensor, StreamingModule, Tensor};
6use candle_nn::VarBuilder;
7
8use super::conv::{StreamableConv1d, StreamableConvTranspose1d};
9
10#[derive(Debug, Clone)]
11pub struct Config {
12 pub dimension: usize,
13 pub channels: usize,
14 pub causal: bool,
15 pub n_filters: usize,
16 pub n_residual_layers: usize,
17 pub ratios: Vec<usize>,
18 pub activation: candle_nn::Activation,
19 pub norm: super::conv::Norm,
20 pub kernel_size: usize,
21 pub residual_kernel_size: usize,
22 pub last_kernel_size: usize,
23 pub dilation_base: usize,
24 pub pad_mode: super::conv::PadMode,
25 pub true_skip: bool,
26 pub compress: usize,
27 pub lstm: usize,
28 pub disable_norm_outer_blocks: usize,
29 pub final_activation: Option<candle_nn::Activation>,
30}
31
32#[derive(Debug, Clone)]
33pub struct SeaNetResnetBlock {
34 block: Vec<StreamableConv1d>,
35 shortcut: Option<StreamableConv1d>,
36 activation: candle_nn::Activation,
37 skip_op: candle::StreamingBinOp,
38 span: tracing::Span,
39}
40
41impl SeaNetResnetBlock {
42 #[allow(clippy::too_many_arguments)]
43 pub fn new(
44 dim: usize,
45 k_sizes_and_dilations: &[(usize, usize)],
46 activation: candle_nn::Activation,
47 norm: Option<super::conv::Norm>,
48 causal: bool,
49 pad_mode: super::conv::PadMode,
50 compress: usize,
51 true_skip: bool,
52 vb: VarBuilder,
53 ) -> Result<Self> {
54 let mut block = Vec::with_capacity(k_sizes_and_dilations.len());
55 let hidden = dim / compress;
56 let vb_b = vb.pp("block");
57 for (i, (k_size, dilation)) in k_sizes_and_dilations.iter().enumerate() {
58 let in_c = if i == 0 { dim } else { hidden };
59 let out_c = if i == k_sizes_and_dilations.len() - 1 {
60 dim
61 } else {
62 hidden
63 };
64 let c = StreamableConv1d::new(
65 in_c,
66 out_c,
67 *k_size,
68 1,
69 *dilation,
70 1,
71 true,
72 causal,
73 norm,
74 pad_mode,
75 vb_b.pp(2 * i + 1),
76 )?;
77 block.push(c)
78 }
79 let shortcut = if true_skip {
80 None
81 } else {
82 let c = StreamableConv1d::new(
83 dim,
84 dim,
85 1,
86 1,
87 1,
88 1,
89 true,
90 causal,
91 norm,
92 pad_mode,
93 vb.pp("shortcut"),
94 )?;
95 Some(c)
96 };
97 Ok(Self {
98 block,
99 shortcut,
100 activation,
101 skip_op: streaming::StreamingBinOp::new(streaming::BinOp::Add, candle::D::Minus1),
102 span: tracing::span!(tracing::Level::TRACE, "sea-resnet"),
103 })
104 }
105}
106
107impl Module for SeaNetResnetBlock {
108 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
109 let _enter = self.span.enter();
110 let mut ys = xs.clone();
111 for block in self.block.iter() {
112 ys = ys.apply(&self.activation)?.apply(block)?;
113 }
114 match self.shortcut.as_ref() {
115 None => ys + xs,
116 Some(shortcut) => ys + xs.apply(shortcut),
117 }
118 }
119}
120
121impl StreamingModule for SeaNetResnetBlock {
122 fn reset_state(&mut self) {
123 for block in self.block.iter_mut() {
124 block.reset_state()
125 }
126 if let Some(shortcut) = self.shortcut.as_mut() {
127 shortcut.reset_state()
128 }
129 }
130
131 fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
132 let _enter = self.span.enter();
133 let mut ys = xs.clone();
134 for block in self.block.iter_mut() {
135 ys = block.step(&ys.apply(&self.activation)?)?;
136 }
137 match self.shortcut.as_ref() {
138 None => self.skip_op.step(&ys, xs),
139 Some(shortcut) => self.skip_op.step(&ys, &xs.apply(shortcut)?),
140 }
141 }
142}
143
144#[derive(Debug, Clone)]
145struct EncoderLayer {
146 residuals: Vec<SeaNetResnetBlock>,
147 downsample: StreamableConv1d,
148}
149
150#[derive(Debug, Clone)]
151pub struct SeaNetEncoder {
152 init_conv1d: StreamableConv1d,
153 activation: candle_nn::Activation,
154 layers: Vec<EncoderLayer>,
155 final_conv1d: StreamableConv1d,
156 span: tracing::Span,
157}
158
159impl SeaNetEncoder {
160 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
161 if cfg.lstm > 0 {
162 candle::bail!("seanet lstm is not supported")
163 }
164 let n_blocks = 2 + cfg.ratios.len();
165 let mut mult = 1usize;
166 let init_norm = if cfg.disable_norm_outer_blocks >= 1 {
167 None
168 } else {
169 Some(cfg.norm)
170 };
171 let mut layer_idx = 0;
172 let vb = vb.pp("layers");
173 let init_conv1d = StreamableConv1d::new(
174 cfg.channels,
175 mult * cfg.n_filters,
176 cfg.kernel_size,
177 1,
178 1,
179 1,
180 true,
181 cfg.causal,
182 init_norm,
183 cfg.pad_mode,
184 vb.pp(layer_idx),
185 )?;
186 layer_idx += 1;
187 let mut layers = Vec::with_capacity(cfg.ratios.len());
188
189 for (i, &ratio) in cfg.ratios.iter().rev().enumerate() {
190 let norm = if cfg.disable_norm_outer_blocks >= i + 2 {
191 None
192 } else {
193 Some(cfg.norm)
194 };
195 let mut residuals = Vec::with_capacity(cfg.n_residual_layers);
196 for j in 0..cfg.n_residual_layers {
197 let resnet_block = SeaNetResnetBlock::new(
198 mult * cfg.n_filters,
199 &[
200 (cfg.residual_kernel_size, cfg.dilation_base.pow(j as u32)),
201 (1, 1),
202 ],
203 cfg.activation,
204 norm,
205 cfg.causal,
206 cfg.pad_mode,
207 cfg.compress,
208 cfg.true_skip,
209 vb.pp(layer_idx),
210 )?;
211 residuals.push(resnet_block);
212 layer_idx += 1;
213 }
214 let downsample = StreamableConv1d::new(
215 mult * cfg.n_filters,
216 mult * cfg.n_filters * 2,
217 ratio * 2,
218 ratio,
219 1,
220 1,
221 true,
222 true,
223 norm,
224 cfg.pad_mode,
225 vb.pp(layer_idx + 1),
226 )?;
227 layer_idx += 2;
228 let layer = EncoderLayer {
229 downsample,
230 residuals,
231 };
232 layers.push(layer);
233 mult *= 2
234 }
235
236 let final_norm = if cfg.disable_norm_outer_blocks >= n_blocks {
237 None
238 } else {
239 Some(cfg.norm)
240 };
241 let final_conv1d = StreamableConv1d::new(
242 mult * cfg.n_filters,
243 cfg.dimension,
244 cfg.last_kernel_size,
245 1,
246 1,
247 1,
248 true,
249 cfg.causal,
250 final_norm,
251 cfg.pad_mode,
252 vb.pp(layer_idx + 1),
253 )?;
254 Ok(Self {
255 init_conv1d,
256 activation: cfg.activation,
257 layers,
258 final_conv1d,
259 span: tracing::span!(tracing::Level::TRACE, "sea-encoder"),
260 })
261 }
262}
263
264impl Module for SeaNetEncoder {
265 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
266 let _enter = self.span.enter();
267 let mut xs = xs.apply(&self.init_conv1d)?;
268 for layer in self.layers.iter() {
269 for residual in layer.residuals.iter() {
270 xs = xs.apply(residual)?
271 }
272 xs = xs.apply(&self.activation)?.apply(&layer.downsample)?;
273 }
274 xs.apply(&self.activation)?.apply(&self.final_conv1d)
275 }
276}
277
278impl StreamingModule for SeaNetEncoder {
279 fn reset_state(&mut self) {
280 self.init_conv1d.reset_state();
281 self.layers.iter_mut().for_each(|v| {
282 v.residuals.iter_mut().for_each(|v| v.reset_state());
283 v.downsample.reset_state()
284 });
285 self.final_conv1d.reset_state();
286 }
287
288 fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
289 let _enter = self.span.enter();
290 let mut xs = self.init_conv1d.step(xs)?;
291 for layer in self.layers.iter_mut() {
292 for residual in layer.residuals.iter_mut() {
293 xs = residual.step(&xs)?;
294 }
295 xs = layer.downsample.step(&xs.apply(&self.activation)?)?;
296 }
297 self.final_conv1d.step(&xs.apply(&self.activation)?)
298 }
299}
300
301#[derive(Debug, Clone)]
302struct DecoderLayer {
303 upsample: StreamableConvTranspose1d,
304 residuals: Vec<SeaNetResnetBlock>,
305}
306
307#[derive(Debug, Clone)]
308pub struct SeaNetDecoder {
309 init_conv1d: StreamableConv1d,
310 activation: candle_nn::Activation,
311 layers: Vec<DecoderLayer>,
312 final_conv1d: StreamableConv1d,
313 final_activation: Option<candle_nn::Activation>,
314 span: tracing::Span,
315}
316
317impl SeaNetDecoder {
318 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
319 if cfg.lstm > 0 {
320 candle::bail!("seanet lstm is not supported")
321 }
322 let n_blocks = 2 + cfg.ratios.len();
323 let mut mult = 1 << cfg.ratios.len();
324 let init_norm = if cfg.disable_norm_outer_blocks == n_blocks {
325 None
326 } else {
327 Some(cfg.norm)
328 };
329 let mut layer_idx = 0;
330 let vb = vb.pp("layers");
331 let init_conv1d = StreamableConv1d::new(
332 cfg.dimension,
333 mult * cfg.n_filters,
334 cfg.kernel_size,
335 1,
336 1,
337 1,
338 true,
339 cfg.causal,
340 init_norm,
341 cfg.pad_mode,
342 vb.pp(layer_idx),
343 )?;
344 layer_idx += 1;
345 let mut layers = Vec::with_capacity(cfg.ratios.len());
346 for (i, &ratio) in cfg.ratios.iter().enumerate() {
347 let norm = if cfg.disable_norm_outer_blocks + i + 1 >= n_blocks {
348 None
349 } else {
350 Some(cfg.norm)
351 };
352 let upsample = StreamableConvTranspose1d::new(
353 mult * cfg.n_filters,
354 mult * cfg.n_filters / 2,
355 ratio * 2,
356 ratio,
357 1,
358 true,
359 true,
360 norm,
361 vb.pp(layer_idx + 1),
362 )?;
363 layer_idx += 2;
364
365 let mut residuals = Vec::with_capacity(cfg.n_residual_layers);
366 for j in 0..cfg.n_residual_layers {
367 let resnet_block = SeaNetResnetBlock::new(
368 mult * cfg.n_filters / 2,
369 &[
370 (cfg.residual_kernel_size, cfg.dilation_base.pow(j as u32)),
371 (1, 1),
372 ],
373 cfg.activation,
374 norm,
375 cfg.causal,
376 cfg.pad_mode,
377 cfg.compress,
378 cfg.true_skip,
379 vb.pp(layer_idx),
380 )?;
381 residuals.push(resnet_block);
382 layer_idx += 1;
383 }
384 let layer = DecoderLayer {
385 upsample,
386 residuals,
387 };
388 layers.push(layer);
389 mult /= 2
390 }
391 let final_norm = if cfg.disable_norm_outer_blocks >= 1 {
392 None
393 } else {
394 Some(cfg.norm)
395 };
396 let final_conv1d = StreamableConv1d::new(
397 cfg.n_filters,
398 cfg.channels,
399 cfg.last_kernel_size,
400 1,
401 1,
402 1,
403 true,
404 cfg.causal,
405 final_norm,
406 cfg.pad_mode,
407 vb.pp(layer_idx + 1),
408 )?;
409 Ok(Self {
410 init_conv1d,
411 activation: cfg.activation,
412 layers,
413 final_conv1d,
414 final_activation: cfg.final_activation,
415 span: tracing::span!(tracing::Level::TRACE, "sea-decoder"),
416 })
417 }
418}
419
420impl Module for SeaNetDecoder {
421 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
422 let _enter = self.span.enter();
423 let mut xs = xs.apply(&self.init_conv1d)?;
424 for layer in self.layers.iter() {
425 xs = xs.apply(&self.activation)?.apply(&layer.upsample)?;
426 for residual in layer.residuals.iter() {
427 xs = xs.apply(residual)?
428 }
429 }
430 let xs = xs.apply(&self.activation)?.apply(&self.final_conv1d)?;
431 let xs = match self.final_activation.as_ref() {
432 None => xs,
433 Some(act) => xs.apply(act)?,
434 };
435 Ok(xs)
436 }
437}
438
439impl StreamingModule for SeaNetDecoder {
440 fn reset_state(&mut self) {
441 self.init_conv1d.reset_state();
442 self.layers.iter_mut().for_each(|v| {
443 v.residuals.iter_mut().for_each(|v| v.reset_state());
444 v.upsample.reset_state()
445 });
446 self.final_conv1d.reset_state();
447 }
448
449 fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
450 let _enter = self.span.enter();
451 let mut xs = self.init_conv1d.step(xs)?;
452 for layer in self.layers.iter_mut() {
453 xs = layer.upsample.step(&xs.apply(&self.activation)?)?;
454 for residual in layer.residuals.iter_mut() {
455 xs = residual.step(&xs)?;
456 }
457 }
458 let xs = self.final_conv1d.step(&xs.apply(&self.activation)?)?;
459 let xs = match self.final_activation.as_ref() {
460 None => xs,
461 Some(act) => xs.apply(act)?,
462 };
463 Ok(xs)
464 }
465}