1use candle::{Module, Result, StreamTensor, StreamingModule, Tensor, D};
6use candle_nn::{Conv1d, VarBuilder};
7
8#[allow(clippy::enum_variant_names)]
9#[derive(Debug, Copy, Clone, PartialEq, Eq)]
10pub enum Norm {
11 WeightNorm,
12 SpectralNorm,
13 TimeGroupNorm,
14}
15
16#[derive(Debug, Copy, Clone, PartialEq, Eq)]
17pub enum PadMode {
18 Constant,
19 Reflect,
20 Replicate,
21}
22
23fn conv1d_weight_norm(
27 in_c: usize,
28 out_c: usize,
29 kernel_size: usize,
30 bias: bool,
31 config: candle_nn::Conv1dConfig,
32 vb: VarBuilder,
33) -> Result<Conv1d> {
34 let weight = if vb.contains_tensor("weight") {
35 vb.get((out_c, in_c, kernel_size), "weight")?
36 } else {
37 let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
38 let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
39 let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
40 weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?
41 };
42 let bias = if bias {
43 Some(vb.get(out_c, "bias")?)
44 } else {
45 None
46 };
47 Ok(Conv1d::new(weight, bias, config))
48}
49
50#[derive(Debug, Clone)]
51pub struct NormConv1d {
52 conv: Conv1d,
53 norm: Option<candle_nn::GroupNorm>,
54 span: tracing::Span,
55}
56
57impl NormConv1d {
58 #[allow(clippy::too_many_arguments)]
59 pub fn new(
60 in_c: usize,
61 out_c: usize,
62 k_size: usize,
63 causal: bool,
64 norm: Option<Norm>,
65 bias: bool,
66 cfg: candle_nn::Conv1dConfig,
67 vb: VarBuilder,
68 ) -> Result<Self> {
69 let conv = match norm {
70 None | Some(Norm::TimeGroupNorm) => {
71 if bias {
72 candle_nn::conv1d(in_c, out_c, k_size, cfg, vb.pp("conv"))?
73 } else {
74 candle_nn::conv1d_no_bias(in_c, out_c, k_size, cfg, vb.pp("conv"))?
75 }
76 }
77 Some(Norm::WeightNorm) => {
78 conv1d_weight_norm(in_c, out_c, k_size, bias, cfg, vb.pp("conv"))?
79 }
80 Some(Norm::SpectralNorm) => candle::bail!("SpectralNorm is not supported yet."),
81 };
82 let norm = match norm {
83 None | Some(Norm::WeightNorm) | Some(Norm::SpectralNorm) => None,
84 Some(Norm::TimeGroupNorm) => {
85 if causal {
86 candle::bail!("GroupNorm doesn't support causal evaluation.")
87 }
88 let norm = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
89 Some(norm)
90 }
91 };
92 Ok(Self {
93 conv,
94 norm,
95 span: tracing::span!(tracing::Level::TRACE, "norm-conv1d"),
96 })
97 }
98}
99
100impl Module for NormConv1d {
101 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
102 let _enter = self.span.enter();
103 let xs = xs.apply(&self.conv)?;
104 match self.norm.as_ref() {
105 None => Ok(xs),
106 Some(norm) => xs.apply(norm),
107 }
108 }
109}
110
111#[derive(Debug, Clone)]
112pub struct NormConvTranspose1d {
113 ws: Tensor,
114 bs: Option<Tensor>,
115 k_size: usize,
116 stride: usize,
117 groups: usize,
118 norm: Option<candle_nn::GroupNorm>,
119 span: tracing::Span,
120}
121
122impl NormConvTranspose1d {
123 #[allow(clippy::too_many_arguments)]
124 pub fn new(
125 in_c: usize,
126 out_c: usize,
127 k_size: usize,
128 causal: bool,
129 norm: Option<Norm>,
130 bias: bool,
131 stride: usize,
132 groups: usize,
133 vb: VarBuilder,
134 ) -> Result<Self> {
135 let vb = vb.pp("conv");
136 let bs = if bias {
137 Some(vb.get(out_c, "bias")?)
138 } else {
139 None
140 };
141 let ws = match norm {
142 None | Some(Norm::TimeGroupNorm) => vb.get((in_c, out_c / groups, k_size), "weight")?,
143 Some(Norm::WeightNorm) => {
144 if vb.contains_tensor("weight") {
145 vb.get((in_c, out_c, k_size), "weight")?
146 } else {
147 let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
148 let weight_v = vb.get((in_c, out_c, k_size), "weight_v")?;
149 let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
150 weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?
151 }
152 }
153 Some(Norm::SpectralNorm) => candle::bail!("SpectralNorm is not supported yet."),
154 };
155 let (ws, groups) = if groups == out_c && in_c == out_c {
156 let eye = Tensor::eye(out_c, ws.dtype(), ws.device())?;
157 let ws = ws
158 .repeat((1, out_c, 1))?
159 .mul(&eye.unsqueeze(2)?.repeat((1, 1, k_size))?)?;
160 (ws, 1)
161 } else {
162 (ws, groups)
163 };
164 let norm = match norm {
165 None | Some(Norm::WeightNorm) | Some(Norm::SpectralNorm) => None,
166 Some(Norm::TimeGroupNorm) => {
167 if causal {
168 candle::bail!("GroupNorm doesn't support causal evaluation.")
169 }
170 let norm = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
171 Some(norm)
172 }
173 };
174 Ok(Self {
175 ws,
176 bs,
177 k_size,
178 stride,
179 groups,
180 norm,
181 span: tracing::span!(tracing::Level::TRACE, "norm-conv-tr1d"),
182 })
183 }
184}
185
186impl Module for NormConvTranspose1d {
187 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
188 let _enter = self.span.enter();
189 let xs = Tensor::conv_transpose1d(xs, &self.ws, 0, 0, self.stride, 1, self.groups)?;
195 let xs = match &self.bs {
196 None => xs,
197 Some(bias) => {
198 let b = bias.dims1()?;
199 let bias = bias.reshape((1, b, 1))?;
200 xs.broadcast_add(&bias)?
201 }
202 };
203 match self.norm.as_ref() {
204 None => Ok(xs),
205 Some(norm) => xs.apply(norm),
206 }
207 }
208}
209
210fn get_extra_padding_for_conv1d(
211 xs: &Tensor,
212 k_size: usize,
213 stride: usize,
214 padding_total: usize,
215) -> Result<usize> {
216 let len = xs.dim(D::Minus1)?;
217 let n_frames = (len + padding_total).saturating_sub(k_size) as f64 / stride as f64 + 1.0;
218 let ideal_len =
219 ((n_frames.ceil() as usize - 1) * stride + k_size).saturating_sub(padding_total);
220 Ok(ideal_len.saturating_sub(len))
221}
222
223fn pad1d(xs: &Tensor, pad_l: usize, pad_r: usize, mode: PadMode) -> Result<Tensor> {
224 match mode {
225 PadMode::Constant => xs.pad_with_zeros(D::Minus1, pad_l, pad_r),
226 PadMode::Reflect => candle::bail!("pad-mode 'reflect' is not supported"),
227 PadMode::Replicate => xs.pad_with_same(D::Minus1, pad_l, pad_r),
228 }
229}
230
231fn unpad1d(xs: &Tensor, unpad_l: usize, unpad_r: usize) -> Result<Tensor> {
232 let len = xs.dim(D::Minus1)?;
233 if len < unpad_l + unpad_r {
234 candle::bail!("unpad1d: tensor len {len} is too low, {unpad_l} + {unpad_r}")
235 }
236 xs.narrow(D::Minus1, unpad_l, len - (unpad_l + unpad_r))
237}
238
239#[derive(Debug, Clone)]
240pub struct StreamableConv1d {
241 conv: NormConv1d,
242 causal: bool,
243 pad_mode: PadMode,
244 state_prev_xs: StreamTensor,
245 left_pad_applied: bool,
246 kernel_size: usize,
247 span: tracing::Span,
248}
249
250impl StreamableConv1d {
251 #[allow(clippy::too_many_arguments)]
252 pub fn new(
253 in_c: usize,
254 out_c: usize,
255 k_size: usize,
256 stride: usize,
257 dilation: usize,
258 groups: usize,
259 bias: bool,
260 causal: bool,
261 norm: Option<Norm>,
262 pad_mode: PadMode,
263 vb: VarBuilder,
264 ) -> Result<Self> {
265 let cfg = candle_nn::Conv1dConfig {
266 padding: 0,
267 stride,
268 dilation,
269 groups,
270 };
271 let conv = NormConv1d::new(in_c, out_c, k_size, causal, norm, bias, cfg, vb)?;
272 if k_size < stride {
273 candle::bail!("kernel-size {k_size} is smaller than stride {stride}")
274 }
275 Ok(Self {
276 conv,
277 causal,
278 pad_mode,
279 state_prev_xs: StreamTensor::empty(),
280 left_pad_applied: false,
281 kernel_size: k_size,
282 span: tracing::span!(tracing::Level::TRACE, "streamable-conv1d"),
283 })
284 }
285}
286
287impl Module for StreamableConv1d {
288 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
289 let _enter = self.span.enter();
290 let (_b, _t, _c) = xs.dims3()?;
291 let k_size = self.conv.conv.weight().dim(D::Minus1)?;
292 let conv_cfg = self.conv.conv.config();
293 let k_size = (k_size - 1) * conv_cfg.dilation + 1;
295 let padding_total = k_size - conv_cfg.stride;
296 let extra_padding =
297 get_extra_padding_for_conv1d(xs, k_size, conv_cfg.stride, padding_total)?;
298 let xs = if self.causal {
299 pad1d(xs, padding_total, extra_padding, self.pad_mode)?
300 } else {
301 let padding_right = padding_total / 2;
302 let padding_left = padding_total - padding_right;
303 pad1d(
304 xs,
305 padding_left,
306 padding_right + extra_padding,
307 self.pad_mode,
308 )?
309 };
310 xs.apply(&self.conv)
311 }
312}
313
314impl StreamingModule for StreamableConv1d {
315 fn reset_state(&mut self) {
316 self.state_prev_xs.reset();
317 self.left_pad_applied = false;
318 }
319
320 fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
321 let _enter = self.span.enter();
322 let xs = match xs.as_option() {
323 None => return Ok(().into()),
324 Some(xs) => xs.clone(),
325 };
326 let xs = if self.left_pad_applied {
327 xs
328 } else {
329 self.left_pad_applied = true;
330 let k_size = self.conv.conv.weight().dim(D::Minus1)?;
331 let conv_cfg = self.conv.conv.config();
332 let k_size = (k_size - 1) * conv_cfg.dilation + 1;
333 let padding_total = k_size - conv_cfg.stride;
334 pad1d(&xs, padding_total, 0, self.pad_mode)?
335 };
336 let cfg = self.conv.conv.config();
337 let stride = cfg.stride;
338 let dilation = cfg.dilation;
339 let kernel = (self.kernel_size - 1) * dilation + 1;
340 let xs = StreamTensor::cat2(&self.state_prev_xs, &xs.into(), D::Minus1)?;
341 let seq_len = xs.seq_len(D::Minus1)?;
342 let num_frames = (seq_len + stride).saturating_sub(kernel) / stride;
343 if num_frames > 0 {
344 let offset = num_frames * stride;
345 self.state_prev_xs = xs.narrow(D::Minus1, offset, seq_len - offset)?;
346 let in_l = (num_frames - 1) * stride + kernel;
347 let xs = xs.narrow(D::Minus1, 0, in_l)?;
348 xs.apply(&self.conv.conv)
351 } else {
352 self.state_prev_xs = xs;
353 Ok(StreamTensor::empty())
354 }
355 }
356}
357
358#[derive(Debug, Clone)]
359pub struct StreamableConvTranspose1d {
360 convtr: NormConvTranspose1d,
361 causal: bool,
362 state_prev_ys: StreamTensor,
363 kernel_size: usize,
364 span: tracing::Span,
365}
366
367impl StreamableConvTranspose1d {
368 #[allow(clippy::too_many_arguments)]
369 pub fn new(
370 in_c: usize,
371 out_c: usize,
372 k_size: usize,
373 stride: usize,
374 groups: usize,
375 bias: bool,
376 causal: bool,
377 norm: Option<Norm>,
378 vb: VarBuilder,
379 ) -> Result<Self> {
380 let convtr =
381 NormConvTranspose1d::new(in_c, out_c, k_size, causal, norm, bias, stride, groups, vb)?;
382 Ok(Self {
383 convtr,
384 causal,
385 kernel_size: k_size,
386 state_prev_ys: StreamTensor::empty(),
387 span: tracing::span!(tracing::Level::TRACE, "streamable-conv-tr1d"),
388 })
389 }
390}
391
392impl Module for StreamableConvTranspose1d {
393 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
394 let _enter = self.span.enter();
395 let k_size = self.convtr.k_size;
396 let stride = self.convtr.stride;
397 let padding_total = k_size.saturating_sub(stride);
398 let xs = xs.apply(&self.convtr)?;
399 if self.causal {
400 unpad1d(&xs, 0, padding_total)
402 } else {
403 let padding_right = padding_total / 2;
404 let padding_left = padding_total - padding_right;
405 unpad1d(&xs, padding_left, padding_right)
406 }
407 }
408}
409
410impl StreamingModule for StreamableConvTranspose1d {
411 fn reset_state(&mut self) {
412 self.state_prev_ys.reset()
413 }
414
415 fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
416 let _enter = self.span.enter();
417 let xs = match xs.as_option() {
418 Some(xs) => xs,
419 None => return Ok(StreamTensor::empty()),
420 };
421 let stride = self.convtr.stride;
422 let ys = self.convtr.forward(xs)?;
425 let ot = ys.dim(D::Minus1)?;
426 let ys = match self.state_prev_ys.as_option() {
427 None => ys,
428 Some(prev_ys) => {
429 let pt = prev_ys.dim(D::Minus1)?;
430 let prev_ys = match &self.convtr.bs {
432 None => prev_ys.clone(),
433 Some(bias) => {
434 let bias = bias.reshape((1, (), 1))?;
435 prev_ys.broadcast_sub(&bias)?
436 }
437 };
438 let ys1 = (ys.narrow(D::Minus1, 0, pt)? + prev_ys)?;
439 let ys2 = ys.narrow(D::Minus1, pt, ot - pt)?;
440 Tensor::cat(&[ys1, ys2], D::Minus1)?
441 }
442 };
443 let invalid_steps = self.kernel_size - stride;
444 let (ys, prev_ys) = StreamTensor::from(ys).split(D::Minus1, ot - invalid_steps)?;
445 self.state_prev_ys = prev_ys;
446 Ok(ys)
447 }
448}
449
450#[derive(Debug, Clone)]
451pub struct ConvDownsample1d {
452 conv: StreamableConv1d,
453}
454
455impl ConvDownsample1d {
456 pub fn new(
457 stride: usize,
458 dim: usize,
459 causal: bool,
460 learnt: bool,
461 vb: VarBuilder,
462 ) -> Result<Self> {
463 if !learnt {
464 candle::bail!("only learnt=true is supported")
465 }
466 let conv = StreamableConv1d::new(
467 dim,
468 dim,
469 2 * stride,
470 stride,
471 1,
472 1, false,
474 causal,
475 None,
476 PadMode::Replicate,
477 vb,
478 )?;
479 Ok(Self { conv })
480 }
481}
482
483impl Module for ConvDownsample1d {
484 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
485 xs.apply(&self.conv)
486 }
487}
488
489impl StreamingModule for ConvDownsample1d {
490 fn reset_state(&mut self) {
491 self.conv.reset_state()
492 }
493
494 fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
495 self.conv.step(xs)
496 }
497}
498
499#[derive(Debug, Clone)]
500pub struct ConvTrUpsample1d {
501 convtr: StreamableConvTranspose1d,
502}
503
504impl ConvTrUpsample1d {
505 pub fn new(
506 stride: usize,
507 dim: usize,
508 causal: bool,
509 learnt: bool,
510 vb: VarBuilder,
511 ) -> Result<Self> {
512 if !learnt {
513 candle::bail!("only learnt=true is supported")
514 }
515 let convtr = StreamableConvTranspose1d::new(
516 dim,
517 dim,
518 2 * stride,
519 stride,
520 dim,
521 false,
522 causal,
523 None,
524 vb,
525 )?;
526 Ok(Self { convtr })
527 }
528}
529
530impl Module for ConvTrUpsample1d {
531 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
532 xs.apply(&self.convtr)
533 }
534}
535
536impl StreamingModule for ConvTrUpsample1d {
537 fn reset_state(&mut self) {
538 self.convtr.reset_state()
539 }
540
541 fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
542 self.convtr.step(xs)
543 }
544}
545
546#[cfg(test)]
547mod tests {
548 use super::*;
549 use candle::IndexOp;
550
551 fn run_conv1d(
552 k_size: usize,
553 stride: usize,
554 dilation: usize,
555 step_size: usize,
556 len: usize,
557 bias: bool,
558 ) -> Result<()> {
559 let dev = &candle::Device::Cpu;
561 let vm = candle_nn::VarMap::new();
562 let vb = VarBuilder::from_varmap(&vm, candle::DType::F32, dev);
563 let conv1d = StreamableConv1d::new(
564 2,
565 3,
566 k_size,
567 stride,
568 dilation,
569 1,
570 bias,
571 true,
572 None,
573 PadMode::Constant,
574 vb,
575 )?;
576 let xs = Tensor::randn(0f32, 1., (1, 2, step_size * len), dev)?;
577 let ys = conv1d.forward(&xs)?;
578 let mut conv1d = conv1d;
579 let mut ys_steps = vec![];
580 for idx in 0..len {
581 let xs = xs.i((.., .., step_size * idx..step_size * (idx + 1)))?;
582 let ys = conv1d.step(&xs.into())?;
583 if let Some(ys) = ys.as_option() {
584 ys_steps.push(ys.clone())
585 }
586 }
587 let ys_steps = Tensor::cat(&ys_steps, D::Minus1)?;
588 let diff = (&ys - &ys_steps)?
589 .abs()?
590 .flatten_all()?
591 .max(0)?
592 .to_vec0::<f32>()?;
593 if diff > 1e-5 {
594 println!("{xs}");
595 println!("{ys}");
596 println!("{ys_steps}");
597 candle::bail!("larger diff than expected {diff}")
598 }
599 Ok(())
600 }
601
602 fn run_conv_tr1d(
603 k_size: usize,
604 stride: usize,
605 step_size: usize,
606 len: usize,
607 bias: bool,
608 ) -> Result<()> {
609 let dev = &candle::Device::Cpu;
611 let vm = candle_nn::VarMap::new();
612 let vb = VarBuilder::from_varmap(&vm, candle::DType::F32, dev);
613 let conv1d = StreamableConvTranspose1d::new(
614 2, 3, k_size,
615 stride, 1, bias,
616 true, None, vb,
617 )?;
618 let xs = Tensor::randn(0f32, 1., (1, 2, step_size * len), dev)?;
619 let ys = conv1d.forward(&xs)?;
620 let mut conv1d = conv1d;
621 let mut ys_steps = vec![];
622 for idx in 0..len {
623 let xs = xs.i((.., .., step_size * idx..step_size * (idx + 1)))?;
624 let ys = conv1d.step(&xs.into())?;
625 if let Some(ys) = ys.as_option() {
626 ys_steps.push(ys.clone())
627 }
628 }
629 let ys_steps = Tensor::cat(&ys_steps, D::Minus1)?;
630 let diff = (&ys - &ys_steps)?
631 .abs()?
632 .flatten_all()?
633 .max(0)?
634 .to_vec0::<f32>()?;
635 if diff > 1e-5 {
636 println!("{xs}");
637 println!("{ys}");
638 println!("{ys_steps}");
639 candle::bail!("larger diff than expected {diff}")
640 }
641 Ok(())
642 }
643
644 #[test]
645 fn conv1d() -> Result<()> {
646 for step_size in [1, 2, 3] {
647 for bias in [false, true] {
648 run_conv1d(1, 1, 1, step_size, 5, bias)?;
649 run_conv1d(2, 1, 1, step_size, 5, bias)?;
650 run_conv1d(2, 2, 1, step_size, 6, bias)?;
651 run_conv1d(3, 2, 1, step_size, 8, bias)?;
652 run_conv1d(3, 2, 2, step_size, 8, bias)?;
653 }
654 }
655 Ok(())
656 }
657
658 #[test]
659 fn conv_tr1d() -> Result<()> {
660 for step_size in [1, 2, 3] {
661 for bias in [false, true] {
662 run_conv_tr1d(1, 1, step_size, 5, bias)?;
663 run_conv_tr1d(2, 1, step_size, 5, bias)?;
664 run_conv_tr1d(3, 1, step_size, 5, bias)?;
665 run_conv_tr1d(3, 2, step_size, 5, bias)?;
666 }
667 }
668 Ok(())
669 }
670}