candle_transformers/models/stable_diffusion/
resnet.rs

1//! ResNet Building Blocks
2//!
3//! Some Residual Network blocks used in UNet models.
4//!
5//! Denoising Diffusion Implicit Models, K. He and al, 2015.
6//! - [Paper](https://arxiv.org/abs/1512.03385)
7//!
8use crate::models::with_tracing::{conv2d, Conv2d};
9use candle::{Result, Tensor, D};
10use candle_nn as nn;
11use candle_nn::Module;
12
13/// Configuration for a ResNet block.
14#[derive(Debug, Clone, Copy)]
15pub struct ResnetBlock2DConfig {
16    /// The number of output channels, defaults to the number of input channels.
17    pub out_channels: Option<usize>,
18    pub temb_channels: Option<usize>,
19    /// The number of groups to use in group normalization.
20    pub groups: usize,
21    pub groups_out: Option<usize>,
22    /// The epsilon to be used in the group normalization operations.
23    pub eps: f64,
24    /// Whether to use a 2D convolution in the skip connection. When using None,
25    /// such a convolution is used if the number of input channels is different from
26    /// the number of output channels.
27    pub use_in_shortcut: Option<bool>,
28    // non_linearity: silu
29    /// The final output is scaled by dividing by this value.
30    pub output_scale_factor: f64,
31}
32
33impl Default for ResnetBlock2DConfig {
34    fn default() -> Self {
35        Self {
36            out_channels: None,
37            temb_channels: Some(512),
38            groups: 32,
39            groups_out: None,
40            eps: 1e-6,
41            use_in_shortcut: None,
42            output_scale_factor: 1.,
43        }
44    }
45}
46
47#[derive(Debug)]
48pub struct ResnetBlock2D {
49    norm1: nn::GroupNorm,
50    conv1: Conv2d,
51    norm2: nn::GroupNorm,
52    conv2: Conv2d,
53    time_emb_proj: Option<nn::Linear>,
54    conv_shortcut: Option<Conv2d>,
55    span: tracing::Span,
56    config: ResnetBlock2DConfig,
57}
58
59impl ResnetBlock2D {
60    pub fn new(
61        vs: nn::VarBuilder,
62        in_channels: usize,
63        config: ResnetBlock2DConfig,
64    ) -> Result<Self> {
65        let out_channels = config.out_channels.unwrap_or(in_channels);
66        let conv_cfg = nn::Conv2dConfig {
67            stride: 1,
68            padding: 1,
69            groups: 1,
70            dilation: 1,
71        };
72        let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
73        let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
74        let groups_out = config.groups_out.unwrap_or(config.groups);
75        let norm2 = nn::group_norm(groups_out, out_channels, config.eps, vs.pp("norm2"))?;
76        let conv2 = conv2d(out_channels, out_channels, 3, conv_cfg, vs.pp("conv2"))?;
77        let use_in_shortcut = config
78            .use_in_shortcut
79            .unwrap_or(in_channels != out_channels);
80        let conv_shortcut = if use_in_shortcut {
81            let conv_cfg = nn::Conv2dConfig {
82                stride: 1,
83                padding: 0,
84                groups: 1,
85                dilation: 1,
86            };
87            Some(conv2d(
88                in_channels,
89                out_channels,
90                1,
91                conv_cfg,
92                vs.pp("conv_shortcut"),
93            )?)
94        } else {
95            None
96        };
97        let time_emb_proj = match config.temb_channels {
98            None => None,
99            Some(temb_channels) => Some(nn::linear(
100                temb_channels,
101                out_channels,
102                vs.pp("time_emb_proj"),
103            )?),
104        };
105        let span = tracing::span!(tracing::Level::TRACE, "resnet2d");
106        Ok(Self {
107            norm1,
108            conv1,
109            norm2,
110            conv2,
111            time_emb_proj,
112            span,
113            config,
114            conv_shortcut,
115        })
116    }
117
118    pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
119        let _enter = self.span.enter();
120        let shortcut_xs = match &self.conv_shortcut {
121            Some(conv_shortcut) => conv_shortcut.forward(xs)?,
122            None => xs.clone(),
123        };
124        let xs = self.norm1.forward(xs)?;
125        let xs = self.conv1.forward(&nn::ops::silu(&xs)?)?;
126        let xs = match (temb, &self.time_emb_proj) {
127            (Some(temb), Some(time_emb_proj)) => time_emb_proj
128                .forward(&nn::ops::silu(temb)?)?
129                .unsqueeze(D::Minus1)?
130                .unsqueeze(D::Minus1)?
131                .broadcast_add(&xs)?,
132            _ => xs,
133        };
134        let xs = self
135            .conv2
136            .forward(&nn::ops::silu(&self.norm2.forward(&xs)?)?)?;
137        (shortcut_xs + xs)? / self.config.output_scale_factor
138    }
139}