1use super::embeddings::{TimestepEmbedding, Timesteps};
6use super::unet_2d_blocks::*;
7use crate::models::with_tracing::{conv2d, Conv2d};
8use candle::{Result, Tensor};
9use candle_nn as nn;
10use candle_nn::Module;
11
12#[derive(Debug, Clone, Copy)]
13pub struct BlockConfig {
14 pub out_channels: usize,
15 pub use_cross_attn: Option<usize>,
18 pub attention_head_dim: usize,
19}
20
21#[derive(Debug, Clone)]
22pub struct UNet2DConditionModelConfig {
23 pub center_input_sample: bool,
24 pub flip_sin_to_cos: bool,
25 pub freq_shift: f64,
26 pub blocks: Vec<BlockConfig>,
27 pub layers_per_block: usize,
28 pub downsample_padding: usize,
29 pub mid_block_scale_factor: f64,
30 pub norm_num_groups: usize,
31 pub norm_eps: f64,
32 pub cross_attention_dim: usize,
33 pub sliced_attention_size: Option<usize>,
34 pub use_linear_projection: bool,
35}
36
37impl Default for UNet2DConditionModelConfig {
38 fn default() -> Self {
39 Self {
40 center_input_sample: false,
41 flip_sin_to_cos: true,
42 freq_shift: 0.,
43 blocks: vec![
44 BlockConfig {
45 out_channels: 320,
46 use_cross_attn: Some(1),
47 attention_head_dim: 8,
48 },
49 BlockConfig {
50 out_channels: 640,
51 use_cross_attn: Some(1),
52 attention_head_dim: 8,
53 },
54 BlockConfig {
55 out_channels: 1280,
56 use_cross_attn: Some(1),
57 attention_head_dim: 8,
58 },
59 BlockConfig {
60 out_channels: 1280,
61 use_cross_attn: None,
62 attention_head_dim: 8,
63 },
64 ],
65 layers_per_block: 2,
66 downsample_padding: 1,
67 mid_block_scale_factor: 1.,
68 norm_num_groups: 32,
69 norm_eps: 1e-5,
70 cross_attention_dim: 1280,
71 sliced_attention_size: None,
72 use_linear_projection: false,
73 }
74 }
75}
76
77#[derive(Debug)]
78pub(crate) enum UNetDownBlock {
79 Basic(DownBlock2D),
80 CrossAttn(CrossAttnDownBlock2D),
81}
82
83#[derive(Debug)]
84enum UNetUpBlock {
85 Basic(UpBlock2D),
86 CrossAttn(CrossAttnUpBlock2D),
87}
88
89#[derive(Debug)]
90pub struct UNet2DConditionModel {
91 conv_in: Conv2d,
92 time_proj: Timesteps,
93 time_embedding: TimestepEmbedding,
94 down_blocks: Vec<UNetDownBlock>,
95 mid_block: UNetMidBlock2DCrossAttn,
96 up_blocks: Vec<UNetUpBlock>,
97 conv_norm_out: nn::GroupNorm,
98 conv_out: Conv2d,
99 span: tracing::Span,
100 config: UNet2DConditionModelConfig,
101}
102
103impl UNet2DConditionModel {
104 pub fn new(
105 vs: nn::VarBuilder,
106 in_channels: usize,
107 out_channels: usize,
108 use_flash_attn: bool,
109 config: UNet2DConditionModelConfig,
110 ) -> Result<Self> {
111 let n_blocks = config.blocks.len();
112 let b_channels = config.blocks[0].out_channels;
113 let bl_channels = config.blocks.last().unwrap().out_channels;
114 let bl_attention_head_dim = config.blocks.last().unwrap().attention_head_dim;
115 let time_embed_dim = b_channels * 4;
116 let conv_cfg = nn::Conv2dConfig {
117 padding: 1,
118 ..Default::default()
119 };
120 let conv_in = conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?;
121
122 let time_proj = Timesteps::new(b_channels, config.flip_sin_to_cos, config.freq_shift);
123 let time_embedding =
124 TimestepEmbedding::new(vs.pp("time_embedding"), b_channels, time_embed_dim)?;
125
126 let vs_db = vs.pp("down_blocks");
127 let down_blocks = (0..n_blocks)
128 .map(|i| {
129 let BlockConfig {
130 out_channels,
131 use_cross_attn,
132 attention_head_dim,
133 } = config.blocks[i];
134
135 let sliced_attention_size = match config.sliced_attention_size {
137 Some(0) => Some(attention_head_dim / 2),
138 _ => config.sliced_attention_size,
139 };
140
141 let in_channels = if i > 0 {
142 config.blocks[i - 1].out_channels
143 } else {
144 b_channels
145 };
146 let db_cfg = DownBlock2DConfig {
147 num_layers: config.layers_per_block,
148 resnet_eps: config.norm_eps,
149 resnet_groups: config.norm_num_groups,
150 add_downsample: i < n_blocks - 1,
151 downsample_padding: config.downsample_padding,
152 ..Default::default()
153 };
154 if let Some(transformer_layers_per_block) = use_cross_attn {
155 let config = CrossAttnDownBlock2DConfig {
156 downblock: db_cfg,
157 attn_num_head_channels: attention_head_dim,
158 cross_attention_dim: config.cross_attention_dim,
159 sliced_attention_size,
160 use_linear_projection: config.use_linear_projection,
161 transformer_layers_per_block,
162 };
163 let block = CrossAttnDownBlock2D::new(
164 vs_db.pp(i.to_string()),
165 in_channels,
166 out_channels,
167 Some(time_embed_dim),
168 use_flash_attn,
169 config,
170 )?;
171 Ok(UNetDownBlock::CrossAttn(block))
172 } else {
173 let block = DownBlock2D::new(
174 vs_db.pp(i.to_string()),
175 in_channels,
176 out_channels,
177 Some(time_embed_dim),
178 db_cfg,
179 )?;
180 Ok(UNetDownBlock::Basic(block))
181 }
182 })
183 .collect::<Result<Vec<_>>>()?;
184
185 let mid_transformer_layers_per_block = match config.blocks.last() {
187 None => 1,
188 Some(block) => block.use_cross_attn.unwrap_or(1),
189 };
190 let mid_cfg = UNetMidBlock2DCrossAttnConfig {
191 resnet_eps: config.norm_eps,
192 output_scale_factor: config.mid_block_scale_factor,
193 cross_attn_dim: config.cross_attention_dim,
194 attn_num_head_channels: bl_attention_head_dim,
195 resnet_groups: Some(config.norm_num_groups),
196 use_linear_projection: config.use_linear_projection,
197 transformer_layers_per_block: mid_transformer_layers_per_block,
198 ..Default::default()
199 };
200
201 let mid_block = UNetMidBlock2DCrossAttn::new(
202 vs.pp("mid_block"),
203 bl_channels,
204 Some(time_embed_dim),
205 use_flash_attn,
206 mid_cfg,
207 )?;
208
209 let vs_ub = vs.pp("up_blocks");
210 let up_blocks = (0..n_blocks)
211 .map(|i| {
212 let BlockConfig {
213 out_channels,
214 use_cross_attn,
215 attention_head_dim,
216 } = config.blocks[n_blocks - 1 - i];
217
218 let sliced_attention_size = match config.sliced_attention_size {
220 Some(0) => Some(attention_head_dim / 2),
221 _ => config.sliced_attention_size,
222 };
223
224 let prev_out_channels = if i > 0 {
225 config.blocks[n_blocks - i].out_channels
226 } else {
227 bl_channels
228 };
229 let in_channels = {
230 let index = if i == n_blocks - 1 {
231 0
232 } else {
233 n_blocks - i - 2
234 };
235 config.blocks[index].out_channels
236 };
237 let ub_cfg = UpBlock2DConfig {
238 num_layers: config.layers_per_block + 1,
239 resnet_eps: config.norm_eps,
240 resnet_groups: config.norm_num_groups,
241 add_upsample: i < n_blocks - 1,
242 ..Default::default()
243 };
244 if let Some(transformer_layers_per_block) = use_cross_attn {
245 let config = CrossAttnUpBlock2DConfig {
246 upblock: ub_cfg,
247 attn_num_head_channels: attention_head_dim,
248 cross_attention_dim: config.cross_attention_dim,
249 sliced_attention_size,
250 use_linear_projection: config.use_linear_projection,
251 transformer_layers_per_block,
252 };
253 let block = CrossAttnUpBlock2D::new(
254 vs_ub.pp(i.to_string()),
255 in_channels,
256 prev_out_channels,
257 out_channels,
258 Some(time_embed_dim),
259 use_flash_attn,
260 config,
261 )?;
262 Ok(UNetUpBlock::CrossAttn(block))
263 } else {
264 let block = UpBlock2D::new(
265 vs_ub.pp(i.to_string()),
266 in_channels,
267 prev_out_channels,
268 out_channels,
269 Some(time_embed_dim),
270 ub_cfg,
271 )?;
272 Ok(UNetUpBlock::Basic(block))
273 }
274 })
275 .collect::<Result<Vec<_>>>()?;
276
277 let conv_norm_out = nn::group_norm(
278 config.norm_num_groups,
279 b_channels,
280 config.norm_eps,
281 vs.pp("conv_norm_out"),
282 )?;
283 let conv_out = conv2d(b_channels, out_channels, 3, conv_cfg, vs.pp("conv_out"))?;
284 let span = tracing::span!(tracing::Level::TRACE, "unet2d");
285 Ok(Self {
286 conv_in,
287 time_proj,
288 time_embedding,
289 down_blocks,
290 mid_block,
291 up_blocks,
292 conv_norm_out,
293 conv_out,
294 span,
295 config,
296 })
297 }
298
299 pub fn forward(
300 &self,
301 xs: &Tensor,
302 timestep: f64,
303 encoder_hidden_states: &Tensor,
304 ) -> Result<Tensor> {
305 let _enter = self.span.enter();
306 self.forward_with_additional_residuals(xs, timestep, encoder_hidden_states, None, None)
307 }
308
309 pub fn forward_with_additional_residuals(
310 &self,
311 xs: &Tensor,
312 timestep: f64,
313 encoder_hidden_states: &Tensor,
314 down_block_additional_residuals: Option<&[Tensor]>,
315 mid_block_additional_residual: Option<&Tensor>,
316 ) -> Result<Tensor> {
317 let (bsize, _channels, height, width) = xs.dims4()?;
318 let device = xs.device();
319 let n_blocks = self.config.blocks.len();
320 let num_upsamplers = n_blocks - 1;
321 let default_overall_up_factor = 2usize.pow(num_upsamplers as u32);
322 let forward_upsample_size =
323 height % default_overall_up_factor != 0 || width % default_overall_up_factor != 0;
324 let xs = if self.config.center_input_sample {
326 ((xs * 2.0)? - 1.0)?
327 } else {
328 xs.clone()
329 };
330 let emb = (Tensor::ones(bsize, xs.dtype(), device)? * timestep)?;
332 let emb = self.time_proj.forward(&emb)?;
333 let emb = self.time_embedding.forward(&emb)?;
334 let xs = self.conv_in.forward(&xs)?;
336 let mut down_block_res_xs = vec![xs.clone()];
338 let mut xs = xs;
339 for down_block in self.down_blocks.iter() {
340 let (_xs, res_xs) = match down_block {
341 UNetDownBlock::Basic(b) => b.forward(&xs, Some(&emb))?,
342 UNetDownBlock::CrossAttn(b) => {
343 b.forward(&xs, Some(&emb), Some(encoder_hidden_states))?
344 }
345 };
346 down_block_res_xs.extend(res_xs);
347 xs = _xs;
348 }
349
350 let new_down_block_res_xs =
351 if let Some(down_block_additional_residuals) = down_block_additional_residuals {
352 let mut v = vec![];
353 for (i, residuals) in down_block_additional_residuals.iter().enumerate() {
356 v.push((&down_block_res_xs[i] + residuals)?)
357 }
358 v
359 } else {
360 down_block_res_xs
361 };
362 let mut down_block_res_xs = new_down_block_res_xs;
363
364 let xs = self
366 .mid_block
367 .forward(&xs, Some(&emb), Some(encoder_hidden_states))?;
368 let xs = match mid_block_additional_residual {
369 None => xs,
370 Some(m) => (m + xs)?,
371 };
372 let mut xs = xs;
374 let mut upsample_size = None;
375 for (i, up_block) in self.up_blocks.iter().enumerate() {
376 let n_resnets = match up_block {
377 UNetUpBlock::Basic(b) => b.resnets.len(),
378 UNetUpBlock::CrossAttn(b) => b.upblock.resnets.len(),
379 };
380 let res_xs = down_block_res_xs.split_off(down_block_res_xs.len() - n_resnets);
381 if i < n_blocks - 1 && forward_upsample_size {
382 let (_, _, h, w) = down_block_res_xs.last().unwrap().dims4()?;
383 upsample_size = Some((h, w))
384 }
385 xs = match up_block {
386 UNetUpBlock::Basic(b) => b.forward(&xs, &res_xs, Some(&emb), upsample_size)?,
387 UNetUpBlock::CrossAttn(b) => b.forward(
388 &xs,
389 &res_xs,
390 Some(&emb),
391 upsample_size,
392 Some(encoder_hidden_states),
393 )?,
394 };
395 }
396 let xs = self.conv_norm_out.forward(&xs)?;
398 let xs = nn::ops::silu(&xs)?;
399 self.conv_out.forward(&xs)
400 }
401}