candle_transformers/models/stable_diffusion/
unet_2d.rs

1//! 2D UNet Denoising Models
2//!
3//! The 2D Unet models take as input a noisy sample and the current diffusion
4//! timestep and return a denoised version of the input.
5use super::embeddings::{TimestepEmbedding, Timesteps};
6use super::unet_2d_blocks::*;
7use crate::models::with_tracing::{conv2d, Conv2d};
8use candle::{Result, Tensor};
9use candle_nn as nn;
10use candle_nn::Module;
11
12#[derive(Debug, Clone, Copy)]
13pub struct BlockConfig {
14    pub out_channels: usize,
15    /// When `None` no cross-attn is used, when `Some(d)` then cross-attn is used and `d` is the
16    /// number of transformer blocks to be used.
17    pub use_cross_attn: Option<usize>,
18    pub attention_head_dim: usize,
19}
20
21#[derive(Debug, Clone)]
22pub struct UNet2DConditionModelConfig {
23    pub center_input_sample: bool,
24    pub flip_sin_to_cos: bool,
25    pub freq_shift: f64,
26    pub blocks: Vec<BlockConfig>,
27    pub layers_per_block: usize,
28    pub downsample_padding: usize,
29    pub mid_block_scale_factor: f64,
30    pub norm_num_groups: usize,
31    pub norm_eps: f64,
32    pub cross_attention_dim: usize,
33    pub sliced_attention_size: Option<usize>,
34    pub use_linear_projection: bool,
35}
36
37impl Default for UNet2DConditionModelConfig {
38    fn default() -> Self {
39        Self {
40            center_input_sample: false,
41            flip_sin_to_cos: true,
42            freq_shift: 0.,
43            blocks: vec![
44                BlockConfig {
45                    out_channels: 320,
46                    use_cross_attn: Some(1),
47                    attention_head_dim: 8,
48                },
49                BlockConfig {
50                    out_channels: 640,
51                    use_cross_attn: Some(1),
52                    attention_head_dim: 8,
53                },
54                BlockConfig {
55                    out_channels: 1280,
56                    use_cross_attn: Some(1),
57                    attention_head_dim: 8,
58                },
59                BlockConfig {
60                    out_channels: 1280,
61                    use_cross_attn: None,
62                    attention_head_dim: 8,
63                },
64            ],
65            layers_per_block: 2,
66            downsample_padding: 1,
67            mid_block_scale_factor: 1.,
68            norm_num_groups: 32,
69            norm_eps: 1e-5,
70            cross_attention_dim: 1280,
71            sliced_attention_size: None,
72            use_linear_projection: false,
73        }
74    }
75}
76
77#[derive(Debug)]
78pub(crate) enum UNetDownBlock {
79    Basic(DownBlock2D),
80    CrossAttn(CrossAttnDownBlock2D),
81}
82
83#[derive(Debug)]
84enum UNetUpBlock {
85    Basic(UpBlock2D),
86    CrossAttn(CrossAttnUpBlock2D),
87}
88
89#[derive(Debug)]
90pub struct UNet2DConditionModel {
91    conv_in: Conv2d,
92    time_proj: Timesteps,
93    time_embedding: TimestepEmbedding,
94    down_blocks: Vec<UNetDownBlock>,
95    mid_block: UNetMidBlock2DCrossAttn,
96    up_blocks: Vec<UNetUpBlock>,
97    conv_norm_out: nn::GroupNorm,
98    conv_out: Conv2d,
99    span: tracing::Span,
100    config: UNet2DConditionModelConfig,
101}
102
103impl UNet2DConditionModel {
104    pub fn new(
105        vs: nn::VarBuilder,
106        in_channels: usize,
107        out_channels: usize,
108        use_flash_attn: bool,
109        config: UNet2DConditionModelConfig,
110    ) -> Result<Self> {
111        let n_blocks = config.blocks.len();
112        let b_channels = config.blocks[0].out_channels;
113        let bl_channels = config.blocks.last().unwrap().out_channels;
114        let bl_attention_head_dim = config.blocks.last().unwrap().attention_head_dim;
115        let time_embed_dim = b_channels * 4;
116        let conv_cfg = nn::Conv2dConfig {
117            padding: 1,
118            ..Default::default()
119        };
120        let conv_in = conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?;
121
122        let time_proj = Timesteps::new(b_channels, config.flip_sin_to_cos, config.freq_shift);
123        let time_embedding =
124            TimestepEmbedding::new(vs.pp("time_embedding"), b_channels, time_embed_dim)?;
125
126        let vs_db = vs.pp("down_blocks");
127        let down_blocks = (0..n_blocks)
128            .map(|i| {
129                let BlockConfig {
130                    out_channels,
131                    use_cross_attn,
132                    attention_head_dim,
133                } = config.blocks[i];
134
135                // Enable automatic attention slicing if the config sliced_attention_size is set to 0.
136                let sliced_attention_size = match config.sliced_attention_size {
137                    Some(0) => Some(attention_head_dim / 2),
138                    _ => config.sliced_attention_size,
139                };
140
141                let in_channels = if i > 0 {
142                    config.blocks[i - 1].out_channels
143                } else {
144                    b_channels
145                };
146                let db_cfg = DownBlock2DConfig {
147                    num_layers: config.layers_per_block,
148                    resnet_eps: config.norm_eps,
149                    resnet_groups: config.norm_num_groups,
150                    add_downsample: i < n_blocks - 1,
151                    downsample_padding: config.downsample_padding,
152                    ..Default::default()
153                };
154                if let Some(transformer_layers_per_block) = use_cross_attn {
155                    let config = CrossAttnDownBlock2DConfig {
156                        downblock: db_cfg,
157                        attn_num_head_channels: attention_head_dim,
158                        cross_attention_dim: config.cross_attention_dim,
159                        sliced_attention_size,
160                        use_linear_projection: config.use_linear_projection,
161                        transformer_layers_per_block,
162                    };
163                    let block = CrossAttnDownBlock2D::new(
164                        vs_db.pp(i.to_string()),
165                        in_channels,
166                        out_channels,
167                        Some(time_embed_dim),
168                        use_flash_attn,
169                        config,
170                    )?;
171                    Ok(UNetDownBlock::CrossAttn(block))
172                } else {
173                    let block = DownBlock2D::new(
174                        vs_db.pp(i.to_string()),
175                        in_channels,
176                        out_channels,
177                        Some(time_embed_dim),
178                        db_cfg,
179                    )?;
180                    Ok(UNetDownBlock::Basic(block))
181                }
182            })
183            .collect::<Result<Vec<_>>>()?;
184
185        // https://github.com/huggingface/diffusers/blob/a76f2ad538e73b34d5fe7be08c8eb8ab38c7e90c/src/diffusers/models/unet_2d_condition.py#L462
186        let mid_transformer_layers_per_block = match config.blocks.last() {
187            None => 1,
188            Some(block) => block.use_cross_attn.unwrap_or(1),
189        };
190        let mid_cfg = UNetMidBlock2DCrossAttnConfig {
191            resnet_eps: config.norm_eps,
192            output_scale_factor: config.mid_block_scale_factor,
193            cross_attn_dim: config.cross_attention_dim,
194            attn_num_head_channels: bl_attention_head_dim,
195            resnet_groups: Some(config.norm_num_groups),
196            use_linear_projection: config.use_linear_projection,
197            transformer_layers_per_block: mid_transformer_layers_per_block,
198            ..Default::default()
199        };
200
201        let mid_block = UNetMidBlock2DCrossAttn::new(
202            vs.pp("mid_block"),
203            bl_channels,
204            Some(time_embed_dim),
205            use_flash_attn,
206            mid_cfg,
207        )?;
208
209        let vs_ub = vs.pp("up_blocks");
210        let up_blocks = (0..n_blocks)
211            .map(|i| {
212                let BlockConfig {
213                    out_channels,
214                    use_cross_attn,
215                    attention_head_dim,
216                } = config.blocks[n_blocks - 1 - i];
217
218                // Enable automatic attention slicing if the config sliced_attention_size is set to 0.
219                let sliced_attention_size = match config.sliced_attention_size {
220                    Some(0) => Some(attention_head_dim / 2),
221                    _ => config.sliced_attention_size,
222                };
223
224                let prev_out_channels = if i > 0 {
225                    config.blocks[n_blocks - i].out_channels
226                } else {
227                    bl_channels
228                };
229                let in_channels = {
230                    let index = if i == n_blocks - 1 {
231                        0
232                    } else {
233                        n_blocks - i - 2
234                    };
235                    config.blocks[index].out_channels
236                };
237                let ub_cfg = UpBlock2DConfig {
238                    num_layers: config.layers_per_block + 1,
239                    resnet_eps: config.norm_eps,
240                    resnet_groups: config.norm_num_groups,
241                    add_upsample: i < n_blocks - 1,
242                    ..Default::default()
243                };
244                if let Some(transformer_layers_per_block) = use_cross_attn {
245                    let config = CrossAttnUpBlock2DConfig {
246                        upblock: ub_cfg,
247                        attn_num_head_channels: attention_head_dim,
248                        cross_attention_dim: config.cross_attention_dim,
249                        sliced_attention_size,
250                        use_linear_projection: config.use_linear_projection,
251                        transformer_layers_per_block,
252                    };
253                    let block = CrossAttnUpBlock2D::new(
254                        vs_ub.pp(i.to_string()),
255                        in_channels,
256                        prev_out_channels,
257                        out_channels,
258                        Some(time_embed_dim),
259                        use_flash_attn,
260                        config,
261                    )?;
262                    Ok(UNetUpBlock::CrossAttn(block))
263                } else {
264                    let block = UpBlock2D::new(
265                        vs_ub.pp(i.to_string()),
266                        in_channels,
267                        prev_out_channels,
268                        out_channels,
269                        Some(time_embed_dim),
270                        ub_cfg,
271                    )?;
272                    Ok(UNetUpBlock::Basic(block))
273                }
274            })
275            .collect::<Result<Vec<_>>>()?;
276
277        let conv_norm_out = nn::group_norm(
278            config.norm_num_groups,
279            b_channels,
280            config.norm_eps,
281            vs.pp("conv_norm_out"),
282        )?;
283        let conv_out = conv2d(b_channels, out_channels, 3, conv_cfg, vs.pp("conv_out"))?;
284        let span = tracing::span!(tracing::Level::TRACE, "unet2d");
285        Ok(Self {
286            conv_in,
287            time_proj,
288            time_embedding,
289            down_blocks,
290            mid_block,
291            up_blocks,
292            conv_norm_out,
293            conv_out,
294            span,
295            config,
296        })
297    }
298
299    pub fn forward(
300        &self,
301        xs: &Tensor,
302        timestep: f64,
303        encoder_hidden_states: &Tensor,
304    ) -> Result<Tensor> {
305        let _enter = self.span.enter();
306        self.forward_with_additional_residuals(xs, timestep, encoder_hidden_states, None, None)
307    }
308
309    pub fn forward_with_additional_residuals(
310        &self,
311        xs: &Tensor,
312        timestep: f64,
313        encoder_hidden_states: &Tensor,
314        down_block_additional_residuals: Option<&[Tensor]>,
315        mid_block_additional_residual: Option<&Tensor>,
316    ) -> Result<Tensor> {
317        let (bsize, _channels, height, width) = xs.dims4()?;
318        let device = xs.device();
319        let n_blocks = self.config.blocks.len();
320        let num_upsamplers = n_blocks - 1;
321        let default_overall_up_factor = 2usize.pow(num_upsamplers as u32);
322        let forward_upsample_size =
323            height % default_overall_up_factor != 0 || width % default_overall_up_factor != 0;
324        // 0. center input if necessary
325        let xs = if self.config.center_input_sample {
326            ((xs * 2.0)? - 1.0)?
327        } else {
328            xs.clone()
329        };
330        // 1. time
331        let emb = (Tensor::ones(bsize, xs.dtype(), device)? * timestep)?;
332        let emb = self.time_proj.forward(&emb)?;
333        let emb = self.time_embedding.forward(&emb)?;
334        // 2. pre-process
335        let xs = self.conv_in.forward(&xs)?;
336        // 3. down
337        let mut down_block_res_xs = vec![xs.clone()];
338        let mut xs = xs;
339        for down_block in self.down_blocks.iter() {
340            let (_xs, res_xs) = match down_block {
341                UNetDownBlock::Basic(b) => b.forward(&xs, Some(&emb))?,
342                UNetDownBlock::CrossAttn(b) => {
343                    b.forward(&xs, Some(&emb), Some(encoder_hidden_states))?
344                }
345            };
346            down_block_res_xs.extend(res_xs);
347            xs = _xs;
348        }
349
350        let new_down_block_res_xs =
351            if let Some(down_block_additional_residuals) = down_block_additional_residuals {
352                let mut v = vec![];
353                // A previous version of this code had a bug because of the addition being made
354                // in place via += hence modifying the input of the mid block.
355                for (i, residuals) in down_block_additional_residuals.iter().enumerate() {
356                    v.push((&down_block_res_xs[i] + residuals)?)
357                }
358                v
359            } else {
360                down_block_res_xs
361            };
362        let mut down_block_res_xs = new_down_block_res_xs;
363
364        // 4. mid
365        let xs = self
366            .mid_block
367            .forward(&xs, Some(&emb), Some(encoder_hidden_states))?;
368        let xs = match mid_block_additional_residual {
369            None => xs,
370            Some(m) => (m + xs)?,
371        };
372        // 5. up
373        let mut xs = xs;
374        let mut upsample_size = None;
375        for (i, up_block) in self.up_blocks.iter().enumerate() {
376            let n_resnets = match up_block {
377                UNetUpBlock::Basic(b) => b.resnets.len(),
378                UNetUpBlock::CrossAttn(b) => b.upblock.resnets.len(),
379            };
380            let res_xs = down_block_res_xs.split_off(down_block_res_xs.len() - n_resnets);
381            if i < n_blocks - 1 && forward_upsample_size {
382                let (_, _, h, w) = down_block_res_xs.last().unwrap().dims4()?;
383                upsample_size = Some((h, w))
384            }
385            xs = match up_block {
386                UNetUpBlock::Basic(b) => b.forward(&xs, &res_xs, Some(&emb), upsample_size)?,
387                UNetUpBlock::CrossAttn(b) => b.forward(
388                    &xs,
389                    &res_xs,
390                    Some(&emb),
391                    upsample_size,
392                    Some(encoder_hidden_states),
393                )?,
394            };
395        }
396        // 6. post-process
397        let xs = self.conv_norm_out.forward(&xs)?;
398        let xs = nn::ops::silu(&xs)?;
399        self.conv_out.forward(&xs)
400    }
401}