candle_transformers/models/stable_diffusion/
unet_2d_blocks.rs

1//! 2D UNet Building Blocks
2//!
3use super::attention::{
4    AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
5};
6use super::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
7use crate::models::with_tracing::{conv2d, Conv2d};
8use candle::{Module, Result, Tensor, D};
9use candle_nn as nn;
10
11#[derive(Debug)]
12struct Downsample2D {
13    conv: Option<Conv2d>,
14    padding: usize,
15    span: tracing::Span,
16}
17
18impl Downsample2D {
19    fn new(
20        vs: nn::VarBuilder,
21        in_channels: usize,
22        use_conv: bool,
23        out_channels: usize,
24        padding: usize,
25    ) -> Result<Self> {
26        let conv = if use_conv {
27            let config = nn::Conv2dConfig {
28                stride: 2,
29                padding,
30                ..Default::default()
31            };
32            let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
33            Some(conv)
34        } else {
35            None
36        };
37        let span = tracing::span!(tracing::Level::TRACE, "downsample2d");
38        Ok(Self {
39            conv,
40            padding,
41            span,
42        })
43    }
44}
45
46impl Module for Downsample2D {
47    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
48        let _enter = self.span.enter();
49        match &self.conv {
50            None => xs.avg_pool2d(2),
51            Some(conv) => {
52                if self.padding == 0 {
53                    let xs = xs
54                        .pad_with_zeros(D::Minus1, 0, 1)?
55                        .pad_with_zeros(D::Minus2, 0, 1)?;
56                    conv.forward(&xs)
57                } else {
58                    conv.forward(xs)
59                }
60            }
61        }
62    }
63}
64
65// This does not support the conv-transpose mode.
66#[derive(Debug)]
67struct Upsample2D {
68    conv: Conv2d,
69    span: tracing::Span,
70}
71
72impl Upsample2D {
73    fn new(vs: nn::VarBuilder, in_channels: usize, out_channels: usize) -> Result<Self> {
74        let config = nn::Conv2dConfig {
75            padding: 1,
76            ..Default::default()
77        };
78        let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
79        let span = tracing::span!(tracing::Level::TRACE, "upsample2d");
80        Ok(Self { conv, span })
81    }
82}
83
84impl Upsample2D {
85    fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> {
86        let _enter = self.span.enter();
87        let xs = match size {
88            None => {
89                let (_bsize, _channels, h, w) = xs.dims4()?;
90                xs.upsample_nearest2d(2 * h, 2 * w)?
91            }
92            Some((h, w)) => xs.upsample_nearest2d(h, w)?,
93        };
94        self.conv.forward(&xs)
95    }
96}
97
98#[derive(Debug, Clone, Copy)]
99pub struct DownEncoderBlock2DConfig {
100    pub num_layers: usize,
101    pub resnet_eps: f64,
102    pub resnet_groups: usize,
103    pub output_scale_factor: f64,
104    pub add_downsample: bool,
105    pub downsample_padding: usize,
106}
107
108impl Default for DownEncoderBlock2DConfig {
109    fn default() -> Self {
110        Self {
111            num_layers: 1,
112            resnet_eps: 1e-6,
113            resnet_groups: 32,
114            output_scale_factor: 1.,
115            add_downsample: true,
116            downsample_padding: 1,
117        }
118    }
119}
120
121#[derive(Debug)]
122pub struct DownEncoderBlock2D {
123    resnets: Vec<ResnetBlock2D>,
124    downsampler: Option<Downsample2D>,
125    span: tracing::Span,
126    pub config: DownEncoderBlock2DConfig,
127}
128
129impl DownEncoderBlock2D {
130    pub fn new(
131        vs: nn::VarBuilder,
132        in_channels: usize,
133        out_channels: usize,
134        config: DownEncoderBlock2DConfig,
135    ) -> Result<Self> {
136        let resnets: Vec<_> = {
137            let vs = vs.pp("resnets");
138            let conv_cfg = ResnetBlock2DConfig {
139                eps: config.resnet_eps,
140                out_channels: Some(out_channels),
141                groups: config.resnet_groups,
142                output_scale_factor: config.output_scale_factor,
143                temb_channels: None,
144                ..Default::default()
145            };
146            (0..(config.num_layers))
147                .map(|i| {
148                    let in_channels = if i == 0 { in_channels } else { out_channels };
149                    ResnetBlock2D::new(vs.pp(i.to_string()), in_channels, conv_cfg)
150                })
151                .collect::<Result<Vec<_>>>()?
152        };
153        let downsampler = if config.add_downsample {
154            let downsample = Downsample2D::new(
155                vs.pp("downsamplers").pp("0"),
156                out_channels,
157                true,
158                out_channels,
159                config.downsample_padding,
160            )?;
161            Some(downsample)
162        } else {
163            None
164        };
165        let span = tracing::span!(tracing::Level::TRACE, "down-enc2d");
166        Ok(Self {
167            resnets,
168            downsampler,
169            span,
170            config,
171        })
172    }
173}
174
175impl Module for DownEncoderBlock2D {
176    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
177        let _enter = self.span.enter();
178        let mut xs = xs.clone();
179        for resnet in self.resnets.iter() {
180            xs = resnet.forward(&xs, None)?
181        }
182        match &self.downsampler {
183            Some(downsampler) => downsampler.forward(&xs),
184            None => Ok(xs),
185        }
186    }
187}
188
189#[derive(Debug, Clone, Copy)]
190pub struct UpDecoderBlock2DConfig {
191    pub num_layers: usize,
192    pub resnet_eps: f64,
193    pub resnet_groups: usize,
194    pub output_scale_factor: f64,
195    pub add_upsample: bool,
196}
197
198impl Default for UpDecoderBlock2DConfig {
199    fn default() -> Self {
200        Self {
201            num_layers: 1,
202            resnet_eps: 1e-6,
203            resnet_groups: 32,
204            output_scale_factor: 1.,
205            add_upsample: true,
206        }
207    }
208}
209
210#[derive(Debug)]
211pub struct UpDecoderBlock2D {
212    resnets: Vec<ResnetBlock2D>,
213    upsampler: Option<Upsample2D>,
214    span: tracing::Span,
215    pub config: UpDecoderBlock2DConfig,
216}
217
218impl UpDecoderBlock2D {
219    pub fn new(
220        vs: nn::VarBuilder,
221        in_channels: usize,
222        out_channels: usize,
223        config: UpDecoderBlock2DConfig,
224    ) -> Result<Self> {
225        let resnets: Vec<_> = {
226            let vs = vs.pp("resnets");
227            let conv_cfg = ResnetBlock2DConfig {
228                out_channels: Some(out_channels),
229                eps: config.resnet_eps,
230                groups: config.resnet_groups,
231                output_scale_factor: config.output_scale_factor,
232                temb_channels: None,
233                ..Default::default()
234            };
235            (0..(config.num_layers))
236                .map(|i| {
237                    let in_channels = if i == 0 { in_channels } else { out_channels };
238                    ResnetBlock2D::new(vs.pp(i.to_string()), in_channels, conv_cfg)
239                })
240                .collect::<Result<Vec<_>>>()?
241        };
242        let upsampler = if config.add_upsample {
243            let upsample =
244                Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
245            Some(upsample)
246        } else {
247            None
248        };
249        let span = tracing::span!(tracing::Level::TRACE, "up-dec2d");
250        Ok(Self {
251            resnets,
252            upsampler,
253            span,
254            config,
255        })
256    }
257}
258
259impl Module for UpDecoderBlock2D {
260    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
261        let _enter = self.span.enter();
262        let mut xs = xs.clone();
263        for resnet in self.resnets.iter() {
264            xs = resnet.forward(&xs, None)?
265        }
266        match &self.upsampler {
267            Some(upsampler) => upsampler.forward(&xs, None),
268            None => Ok(xs),
269        }
270    }
271}
272
273#[derive(Debug, Clone, Copy)]
274pub struct UNetMidBlock2DConfig {
275    pub num_layers: usize,
276    pub resnet_eps: f64,
277    pub resnet_groups: Option<usize>,
278    pub attn_num_head_channels: Option<usize>,
279    // attention_type "default"
280    pub output_scale_factor: f64,
281}
282
283impl Default for UNetMidBlock2DConfig {
284    fn default() -> Self {
285        Self {
286            num_layers: 1,
287            resnet_eps: 1e-6,
288            resnet_groups: Some(32),
289            attn_num_head_channels: Some(1),
290            output_scale_factor: 1.,
291        }
292    }
293}
294
295#[derive(Debug)]
296pub struct UNetMidBlock2D {
297    resnet: ResnetBlock2D,
298    attn_resnets: Vec<(AttentionBlock, ResnetBlock2D)>,
299    span: tracing::Span,
300    pub config: UNetMidBlock2DConfig,
301}
302
303impl UNetMidBlock2D {
304    pub fn new(
305        vs: nn::VarBuilder,
306        in_channels: usize,
307        temb_channels: Option<usize>,
308        config: UNetMidBlock2DConfig,
309    ) -> Result<Self> {
310        let vs_resnets = vs.pp("resnets");
311        let vs_attns = vs.pp("attentions");
312        let resnet_groups = config
313            .resnet_groups
314            .unwrap_or_else(|| usize::min(in_channels / 4, 32));
315        let resnet_cfg = ResnetBlock2DConfig {
316            eps: config.resnet_eps,
317            groups: resnet_groups,
318            output_scale_factor: config.output_scale_factor,
319            temb_channels,
320            ..Default::default()
321        };
322        let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
323        let attn_cfg = AttentionBlockConfig {
324            num_head_channels: config.attn_num_head_channels,
325            num_groups: resnet_groups,
326            rescale_output_factor: config.output_scale_factor,
327            eps: config.resnet_eps,
328        };
329        let mut attn_resnets = vec![];
330        for index in 0..config.num_layers {
331            let attn = AttentionBlock::new(vs_attns.pp(index.to_string()), in_channels, attn_cfg)?;
332            let resnet = ResnetBlock2D::new(
333                vs_resnets.pp((index + 1).to_string()),
334                in_channels,
335                resnet_cfg,
336            )?;
337            attn_resnets.push((attn, resnet))
338        }
339        let span = tracing::span!(tracing::Level::TRACE, "mid2d");
340        Ok(Self {
341            resnet,
342            attn_resnets,
343            span,
344            config,
345        })
346    }
347
348    pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
349        let _enter = self.span.enter();
350        let mut xs = self.resnet.forward(xs, temb)?;
351        for (attn, resnet) in self.attn_resnets.iter() {
352            xs = resnet.forward(&attn.forward(&xs)?, temb)?
353        }
354        Ok(xs)
355    }
356}
357
358#[derive(Debug, Clone, Copy)]
359pub struct UNetMidBlock2DCrossAttnConfig {
360    pub num_layers: usize,
361    pub resnet_eps: f64,
362    pub resnet_groups: Option<usize>,
363    pub attn_num_head_channels: usize,
364    // attention_type "default"
365    pub output_scale_factor: f64,
366    pub cross_attn_dim: usize,
367    pub sliced_attention_size: Option<usize>,
368    pub use_linear_projection: bool,
369    pub transformer_layers_per_block: usize,
370}
371
372impl Default for UNetMidBlock2DCrossAttnConfig {
373    fn default() -> Self {
374        Self {
375            num_layers: 1,
376            resnet_eps: 1e-6,
377            resnet_groups: Some(32),
378            attn_num_head_channels: 1,
379            output_scale_factor: 1.,
380            cross_attn_dim: 1280,
381            sliced_attention_size: None, // Sliced attention disabled
382            use_linear_projection: false,
383            transformer_layers_per_block: 1,
384        }
385    }
386}
387
388#[derive(Debug)]
389pub struct UNetMidBlock2DCrossAttn {
390    resnet: ResnetBlock2D,
391    attn_resnets: Vec<(SpatialTransformer, ResnetBlock2D)>,
392    span: tracing::Span,
393    pub config: UNetMidBlock2DCrossAttnConfig,
394}
395
396impl UNetMidBlock2DCrossAttn {
397    pub fn new(
398        vs: nn::VarBuilder,
399        in_channels: usize,
400        temb_channels: Option<usize>,
401        use_flash_attn: bool,
402        config: UNetMidBlock2DCrossAttnConfig,
403    ) -> Result<Self> {
404        let vs_resnets = vs.pp("resnets");
405        let vs_attns = vs.pp("attentions");
406        let resnet_groups = config
407            .resnet_groups
408            .unwrap_or_else(|| usize::min(in_channels / 4, 32));
409        let resnet_cfg = ResnetBlock2DConfig {
410            eps: config.resnet_eps,
411            groups: resnet_groups,
412            output_scale_factor: config.output_scale_factor,
413            temb_channels,
414            ..Default::default()
415        };
416        let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
417        let n_heads = config.attn_num_head_channels;
418        let attn_cfg = SpatialTransformerConfig {
419            depth: config.transformer_layers_per_block,
420            num_groups: resnet_groups,
421            context_dim: Some(config.cross_attn_dim),
422            sliced_attention_size: config.sliced_attention_size,
423            use_linear_projection: config.use_linear_projection,
424        };
425        let mut attn_resnets = vec![];
426        for index in 0..config.num_layers {
427            let attn = SpatialTransformer::new(
428                vs_attns.pp(index.to_string()),
429                in_channels,
430                n_heads,
431                in_channels / n_heads,
432                use_flash_attn,
433                attn_cfg,
434            )?;
435            let resnet = ResnetBlock2D::new(
436                vs_resnets.pp((index + 1).to_string()),
437                in_channels,
438                resnet_cfg,
439            )?;
440            attn_resnets.push((attn, resnet))
441        }
442        let span = tracing::span!(tracing::Level::TRACE, "xa-mid2d");
443        Ok(Self {
444            resnet,
445            attn_resnets,
446            span,
447            config,
448        })
449    }
450
451    pub fn forward(
452        &self,
453        xs: &Tensor,
454        temb: Option<&Tensor>,
455        encoder_hidden_states: Option<&Tensor>,
456    ) -> Result<Tensor> {
457        let _enter = self.span.enter();
458        let mut xs = self.resnet.forward(xs, temb)?;
459        for (attn, resnet) in self.attn_resnets.iter() {
460            xs = resnet.forward(&attn.forward(&xs, encoder_hidden_states)?, temb)?
461        }
462        Ok(xs)
463    }
464}
465
466#[derive(Debug, Clone, Copy)]
467pub struct DownBlock2DConfig {
468    pub num_layers: usize,
469    pub resnet_eps: f64,
470    // resnet_time_scale_shift: "default"
471    // resnet_act_fn: "swish"
472    pub resnet_groups: usize,
473    pub output_scale_factor: f64,
474    pub add_downsample: bool,
475    pub downsample_padding: usize,
476}
477
478impl Default for DownBlock2DConfig {
479    fn default() -> Self {
480        Self {
481            num_layers: 1,
482            resnet_eps: 1e-6,
483            resnet_groups: 32,
484            output_scale_factor: 1.,
485            add_downsample: true,
486            downsample_padding: 1,
487        }
488    }
489}
490
491#[derive(Debug)]
492pub struct DownBlock2D {
493    resnets: Vec<ResnetBlock2D>,
494    downsampler: Option<Downsample2D>,
495    span: tracing::Span,
496    pub config: DownBlock2DConfig,
497}
498
499impl DownBlock2D {
500    pub fn new(
501        vs: nn::VarBuilder,
502        in_channels: usize,
503        out_channels: usize,
504        temb_channels: Option<usize>,
505        config: DownBlock2DConfig,
506    ) -> Result<Self> {
507        let vs_resnets = vs.pp("resnets");
508        let resnet_cfg = ResnetBlock2DConfig {
509            out_channels: Some(out_channels),
510            eps: config.resnet_eps,
511            output_scale_factor: config.output_scale_factor,
512            temb_channels,
513            ..Default::default()
514        };
515        let resnets = (0..config.num_layers)
516            .map(|i| {
517                let in_channels = if i == 0 { in_channels } else { out_channels };
518                ResnetBlock2D::new(vs_resnets.pp(i.to_string()), in_channels, resnet_cfg)
519            })
520            .collect::<Result<Vec<_>>>()?;
521        let downsampler = if config.add_downsample {
522            let downsampler = Downsample2D::new(
523                vs.pp("downsamplers").pp("0"),
524                out_channels,
525                true,
526                out_channels,
527                config.downsample_padding,
528            )?;
529            Some(downsampler)
530        } else {
531            None
532        };
533        let span = tracing::span!(tracing::Level::TRACE, "down2d");
534        Ok(Self {
535            resnets,
536            downsampler,
537            span,
538            config,
539        })
540    }
541
542    pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<(Tensor, Vec<Tensor>)> {
543        let _enter = self.span.enter();
544        let mut xs = xs.clone();
545        let mut output_states = vec![];
546        for resnet in self.resnets.iter() {
547            xs = resnet.forward(&xs, temb)?;
548            output_states.push(xs.clone());
549        }
550        let xs = match &self.downsampler {
551            Some(downsampler) => {
552                let xs = downsampler.forward(&xs)?;
553                output_states.push(xs.clone());
554                xs
555            }
556            None => xs,
557        };
558        Ok((xs, output_states))
559    }
560}
561
562#[derive(Debug, Clone, Copy)]
563pub struct CrossAttnDownBlock2DConfig {
564    pub downblock: DownBlock2DConfig,
565    pub attn_num_head_channels: usize,
566    pub cross_attention_dim: usize,
567    // attention_type: "default"
568    pub sliced_attention_size: Option<usize>,
569    pub use_linear_projection: bool,
570    pub transformer_layers_per_block: usize,
571}
572
573impl Default for CrossAttnDownBlock2DConfig {
574    fn default() -> Self {
575        Self {
576            downblock: Default::default(),
577            attn_num_head_channels: 1,
578            cross_attention_dim: 1280,
579            sliced_attention_size: None,
580            use_linear_projection: false,
581            transformer_layers_per_block: 1,
582        }
583    }
584}
585
586#[derive(Debug)]
587pub struct CrossAttnDownBlock2D {
588    downblock: DownBlock2D,
589    attentions: Vec<SpatialTransformer>,
590    span: tracing::Span,
591    pub config: CrossAttnDownBlock2DConfig,
592}
593
594impl CrossAttnDownBlock2D {
595    pub fn new(
596        vs: nn::VarBuilder,
597        in_channels: usize,
598        out_channels: usize,
599        temb_channels: Option<usize>,
600        use_flash_attn: bool,
601        config: CrossAttnDownBlock2DConfig,
602    ) -> Result<Self> {
603        let downblock = DownBlock2D::new(
604            vs.clone(),
605            in_channels,
606            out_channels,
607            temb_channels,
608            config.downblock,
609        )?;
610        let n_heads = config.attn_num_head_channels;
611        let cfg = SpatialTransformerConfig {
612            depth: config.transformer_layers_per_block,
613            context_dim: Some(config.cross_attention_dim),
614            num_groups: config.downblock.resnet_groups,
615            sliced_attention_size: config.sliced_attention_size,
616            use_linear_projection: config.use_linear_projection,
617        };
618        let vs_attn = vs.pp("attentions");
619        let attentions = (0..config.downblock.num_layers)
620            .map(|i| {
621                SpatialTransformer::new(
622                    vs_attn.pp(i.to_string()),
623                    out_channels,
624                    n_heads,
625                    out_channels / n_heads,
626                    use_flash_attn,
627                    cfg,
628                )
629            })
630            .collect::<Result<Vec<_>>>()?;
631        let span = tracing::span!(tracing::Level::TRACE, "xa-down2d");
632        Ok(Self {
633            downblock,
634            attentions,
635            span,
636            config,
637        })
638    }
639
640    pub fn forward(
641        &self,
642        xs: &Tensor,
643        temb: Option<&Tensor>,
644        encoder_hidden_states: Option<&Tensor>,
645    ) -> Result<(Tensor, Vec<Tensor>)> {
646        let _enter = self.span.enter();
647        let mut output_states = vec![];
648        let mut xs = xs.clone();
649        for (resnet, attn) in self.downblock.resnets.iter().zip(self.attentions.iter()) {
650            xs = resnet.forward(&xs, temb)?;
651            xs = attn.forward(&xs, encoder_hidden_states)?;
652            output_states.push(xs.clone());
653        }
654        let xs = match &self.downblock.downsampler {
655            Some(downsampler) => {
656                let xs = downsampler.forward(&xs)?;
657                output_states.push(xs.clone());
658                xs
659            }
660            None => xs,
661        };
662        Ok((xs, output_states))
663    }
664}
665
666#[derive(Debug, Clone, Copy)]
667pub struct UpBlock2DConfig {
668    pub num_layers: usize,
669    pub resnet_eps: f64,
670    // resnet_time_scale_shift: "default"
671    // resnet_act_fn: "swish"
672    pub resnet_groups: usize,
673    pub output_scale_factor: f64,
674    pub add_upsample: bool,
675}
676
677impl Default for UpBlock2DConfig {
678    fn default() -> Self {
679        Self {
680            num_layers: 1,
681            resnet_eps: 1e-6,
682            resnet_groups: 32,
683            output_scale_factor: 1.,
684            add_upsample: true,
685        }
686    }
687}
688
689#[derive(Debug)]
690pub struct UpBlock2D {
691    pub resnets: Vec<ResnetBlock2D>,
692    upsampler: Option<Upsample2D>,
693    span: tracing::Span,
694    pub config: UpBlock2DConfig,
695}
696
697impl UpBlock2D {
698    pub fn new(
699        vs: nn::VarBuilder,
700        in_channels: usize,
701        prev_output_channels: usize,
702        out_channels: usize,
703        temb_channels: Option<usize>,
704        config: UpBlock2DConfig,
705    ) -> Result<Self> {
706        let vs_resnets = vs.pp("resnets");
707        let resnet_cfg = ResnetBlock2DConfig {
708            out_channels: Some(out_channels),
709            temb_channels,
710            eps: config.resnet_eps,
711            output_scale_factor: config.output_scale_factor,
712            ..Default::default()
713        };
714        let resnets = (0..config.num_layers)
715            .map(|i| {
716                let res_skip_channels = if i == config.num_layers - 1 {
717                    in_channels
718                } else {
719                    out_channels
720                };
721                let resnet_in_channels = if i == 0 {
722                    prev_output_channels
723                } else {
724                    out_channels
725                };
726                let in_channels = resnet_in_channels + res_skip_channels;
727                ResnetBlock2D::new(vs_resnets.pp(i.to_string()), in_channels, resnet_cfg)
728            })
729            .collect::<Result<Vec<_>>>()?;
730        let upsampler = if config.add_upsample {
731            let upsampler =
732                Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
733            Some(upsampler)
734        } else {
735            None
736        };
737        let span = tracing::span!(tracing::Level::TRACE, "up2d");
738        Ok(Self {
739            resnets,
740            upsampler,
741            span,
742            config,
743        })
744    }
745
746    pub fn forward(
747        &self,
748        xs: &Tensor,
749        res_xs: &[Tensor],
750        temb: Option<&Tensor>,
751        upsample_size: Option<(usize, usize)>,
752    ) -> Result<Tensor> {
753        let _enter = self.span.enter();
754        let mut xs = xs.clone();
755        for (index, resnet) in self.resnets.iter().enumerate() {
756            xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
757            xs = xs.contiguous()?;
758            xs = resnet.forward(&xs, temb)?;
759        }
760        match &self.upsampler {
761            Some(upsampler) => upsampler.forward(&xs, upsample_size),
762            None => Ok(xs),
763        }
764    }
765}
766
767#[derive(Debug, Clone, Copy)]
768pub struct CrossAttnUpBlock2DConfig {
769    pub upblock: UpBlock2DConfig,
770    pub attn_num_head_channels: usize,
771    pub cross_attention_dim: usize,
772    // attention_type: "default"
773    pub sliced_attention_size: Option<usize>,
774    pub use_linear_projection: bool,
775    pub transformer_layers_per_block: usize,
776}
777
778impl Default for CrossAttnUpBlock2DConfig {
779    fn default() -> Self {
780        Self {
781            upblock: Default::default(),
782            attn_num_head_channels: 1,
783            cross_attention_dim: 1280,
784            sliced_attention_size: None,
785            use_linear_projection: false,
786            transformer_layers_per_block: 1,
787        }
788    }
789}
790
791#[derive(Debug)]
792pub struct CrossAttnUpBlock2D {
793    pub upblock: UpBlock2D,
794    pub attentions: Vec<SpatialTransformer>,
795    span: tracing::Span,
796    pub config: CrossAttnUpBlock2DConfig,
797}
798
799impl CrossAttnUpBlock2D {
800    pub fn new(
801        vs: nn::VarBuilder,
802        in_channels: usize,
803        prev_output_channels: usize,
804        out_channels: usize,
805        temb_channels: Option<usize>,
806        use_flash_attn: bool,
807        config: CrossAttnUpBlock2DConfig,
808    ) -> Result<Self> {
809        let upblock = UpBlock2D::new(
810            vs.clone(),
811            in_channels,
812            prev_output_channels,
813            out_channels,
814            temb_channels,
815            config.upblock,
816        )?;
817        let n_heads = config.attn_num_head_channels;
818        let cfg = SpatialTransformerConfig {
819            depth: config.transformer_layers_per_block,
820            context_dim: Some(config.cross_attention_dim),
821            num_groups: config.upblock.resnet_groups,
822            sliced_attention_size: config.sliced_attention_size,
823            use_linear_projection: config.use_linear_projection,
824        };
825        let vs_attn = vs.pp("attentions");
826        let attentions = (0..config.upblock.num_layers)
827            .map(|i| {
828                SpatialTransformer::new(
829                    vs_attn.pp(i.to_string()),
830                    out_channels,
831                    n_heads,
832                    out_channels / n_heads,
833                    use_flash_attn,
834                    cfg,
835                )
836            })
837            .collect::<Result<Vec<_>>>()?;
838        let span = tracing::span!(tracing::Level::TRACE, "xa-up2d");
839        Ok(Self {
840            upblock,
841            attentions,
842            span,
843            config,
844        })
845    }
846
847    pub fn forward(
848        &self,
849        xs: &Tensor,
850        res_xs: &[Tensor],
851        temb: Option<&Tensor>,
852        upsample_size: Option<(usize, usize)>,
853        encoder_hidden_states: Option<&Tensor>,
854    ) -> Result<Tensor> {
855        let _enter = self.span.enter();
856        let mut xs = xs.clone();
857        for (index, resnet) in self.upblock.resnets.iter().enumerate() {
858            xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
859            xs = xs.contiguous()?;
860            xs = resnet.forward(&xs, temb)?;
861            xs = self.attentions[index].forward(&xs, encoder_hidden_states)?;
862        }
863        match &self.upblock.upsampler {
864            Some(upsampler) => upsampler.forward(&xs, upsample_size),
865            None => Ok(xs),
866        }
867    }
868}