candle_transformers/models/stable_diffusion/
resnet.rs1use crate::models::with_tracing::{conv2d, Conv2d};
9use candle::{Result, Tensor, D};
10use candle_nn as nn;
11use candle_nn::Module;
12
13#[derive(Debug, Clone, Copy)]
15pub struct ResnetBlock2DConfig {
16 pub out_channels: Option<usize>,
18 pub temb_channels: Option<usize>,
19 pub groups: usize,
21 pub groups_out: Option<usize>,
22 pub eps: f64,
24 pub use_in_shortcut: Option<bool>,
28 pub output_scale_factor: f64,
31}
32
33impl Default for ResnetBlock2DConfig {
34 fn default() -> Self {
35 Self {
36 out_channels: None,
37 temb_channels: Some(512),
38 groups: 32,
39 groups_out: None,
40 eps: 1e-6,
41 use_in_shortcut: None,
42 output_scale_factor: 1.,
43 }
44 }
45}
46
47#[derive(Debug)]
48pub struct ResnetBlock2D {
49 norm1: nn::GroupNorm,
50 conv1: Conv2d,
51 norm2: nn::GroupNorm,
52 conv2: Conv2d,
53 time_emb_proj: Option<nn::Linear>,
54 conv_shortcut: Option<Conv2d>,
55 span: tracing::Span,
56 config: ResnetBlock2DConfig,
57}
58
59impl ResnetBlock2D {
60 pub fn new(
61 vs: nn::VarBuilder,
62 in_channels: usize,
63 config: ResnetBlock2DConfig,
64 ) -> Result<Self> {
65 let out_channels = config.out_channels.unwrap_or(in_channels);
66 let conv_cfg = nn::Conv2dConfig {
67 stride: 1,
68 padding: 1,
69 groups: 1,
70 dilation: 1,
71 };
72 let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
73 let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
74 let groups_out = config.groups_out.unwrap_or(config.groups);
75 let norm2 = nn::group_norm(groups_out, out_channels, config.eps, vs.pp("norm2"))?;
76 let conv2 = conv2d(out_channels, out_channels, 3, conv_cfg, vs.pp("conv2"))?;
77 let use_in_shortcut = config
78 .use_in_shortcut
79 .unwrap_or(in_channels != out_channels);
80 let conv_shortcut = if use_in_shortcut {
81 let conv_cfg = nn::Conv2dConfig {
82 stride: 1,
83 padding: 0,
84 groups: 1,
85 dilation: 1,
86 };
87 Some(conv2d(
88 in_channels,
89 out_channels,
90 1,
91 conv_cfg,
92 vs.pp("conv_shortcut"),
93 )?)
94 } else {
95 None
96 };
97 let time_emb_proj = match config.temb_channels {
98 None => None,
99 Some(temb_channels) => Some(nn::linear(
100 temb_channels,
101 out_channels,
102 vs.pp("time_emb_proj"),
103 )?),
104 };
105 let span = tracing::span!(tracing::Level::TRACE, "resnet2d");
106 Ok(Self {
107 norm1,
108 conv1,
109 norm2,
110 conv2,
111 time_emb_proj,
112 span,
113 config,
114 conv_shortcut,
115 })
116 }
117
118 pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
119 let _enter = self.span.enter();
120 let shortcut_xs = match &self.conv_shortcut {
121 Some(conv_shortcut) => conv_shortcut.forward(xs)?,
122 None => xs.clone(),
123 };
124 let xs = self.norm1.forward(xs)?;
125 let xs = self.conv1.forward(&nn::ops::silu(&xs)?)?;
126 let xs = match (temb, &self.time_emb_proj) {
127 (Some(temb), Some(time_emb_proj)) => time_emb_proj
128 .forward(&nn::ops::silu(temb)?)?
129 .unsqueeze(D::Minus1)?
130 .unsqueeze(D::Minus1)?
131 .broadcast_add(&xs)?,
132 _ => xs,
133 };
134 let xs = self
135 .conv2
136 .forward(&nn::ops::silu(&self.norm2.forward(&xs)?)?)?;
137 (shortcut_xs + xs)? / self.config.output_scale_factor
138 }
139}