1use super::attention::{
4 AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
5};
6use super::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
7use crate::models::with_tracing::{conv2d, Conv2d};
8use candle::{Module, Result, Tensor, D};
9use candle_nn as nn;
10
11#[derive(Debug)]
12struct Downsample2D {
13 conv: Option<Conv2d>,
14 padding: usize,
15 span: tracing::Span,
16}
17
18impl Downsample2D {
19 fn new(
20 vs: nn::VarBuilder,
21 in_channels: usize,
22 use_conv: bool,
23 out_channels: usize,
24 padding: usize,
25 ) -> Result<Self> {
26 let conv = if use_conv {
27 let config = nn::Conv2dConfig {
28 stride: 2,
29 padding,
30 ..Default::default()
31 };
32 let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
33 Some(conv)
34 } else {
35 None
36 };
37 let span = tracing::span!(tracing::Level::TRACE, "downsample2d");
38 Ok(Self {
39 conv,
40 padding,
41 span,
42 })
43 }
44}
45
46impl Module for Downsample2D {
47 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
48 let _enter = self.span.enter();
49 match &self.conv {
50 None => xs.avg_pool2d(2),
51 Some(conv) => {
52 if self.padding == 0 {
53 let xs = xs
54 .pad_with_zeros(D::Minus1, 0, 1)?
55 .pad_with_zeros(D::Minus2, 0, 1)?;
56 conv.forward(&xs)
57 } else {
58 conv.forward(xs)
59 }
60 }
61 }
62 }
63}
64
65#[derive(Debug)]
67struct Upsample2D {
68 conv: Conv2d,
69 span: tracing::Span,
70}
71
72impl Upsample2D {
73 fn new(vs: nn::VarBuilder, in_channels: usize, out_channels: usize) -> Result<Self> {
74 let config = nn::Conv2dConfig {
75 padding: 1,
76 ..Default::default()
77 };
78 let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
79 let span = tracing::span!(tracing::Level::TRACE, "upsample2d");
80 Ok(Self { conv, span })
81 }
82}
83
84impl Upsample2D {
85 fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> {
86 let _enter = self.span.enter();
87 let xs = match size {
88 None => {
89 let (_bsize, _channels, h, w) = xs.dims4()?;
90 xs.upsample_nearest2d(2 * h, 2 * w)?
91 }
92 Some((h, w)) => xs.upsample_nearest2d(h, w)?,
93 };
94 self.conv.forward(&xs)
95 }
96}
97
98#[derive(Debug, Clone, Copy)]
99pub struct DownEncoderBlock2DConfig {
100 pub num_layers: usize,
101 pub resnet_eps: f64,
102 pub resnet_groups: usize,
103 pub output_scale_factor: f64,
104 pub add_downsample: bool,
105 pub downsample_padding: usize,
106}
107
108impl Default for DownEncoderBlock2DConfig {
109 fn default() -> Self {
110 Self {
111 num_layers: 1,
112 resnet_eps: 1e-6,
113 resnet_groups: 32,
114 output_scale_factor: 1.,
115 add_downsample: true,
116 downsample_padding: 1,
117 }
118 }
119}
120
121#[derive(Debug)]
122pub struct DownEncoderBlock2D {
123 resnets: Vec<ResnetBlock2D>,
124 downsampler: Option<Downsample2D>,
125 span: tracing::Span,
126 pub config: DownEncoderBlock2DConfig,
127}
128
129impl DownEncoderBlock2D {
130 pub fn new(
131 vs: nn::VarBuilder,
132 in_channels: usize,
133 out_channels: usize,
134 config: DownEncoderBlock2DConfig,
135 ) -> Result<Self> {
136 let resnets: Vec<_> = {
137 let vs = vs.pp("resnets");
138 let conv_cfg = ResnetBlock2DConfig {
139 eps: config.resnet_eps,
140 out_channels: Some(out_channels),
141 groups: config.resnet_groups,
142 output_scale_factor: config.output_scale_factor,
143 temb_channels: None,
144 ..Default::default()
145 };
146 (0..(config.num_layers))
147 .map(|i| {
148 let in_channels = if i == 0 { in_channels } else { out_channels };
149 ResnetBlock2D::new(vs.pp(i.to_string()), in_channels, conv_cfg)
150 })
151 .collect::<Result<Vec<_>>>()?
152 };
153 let downsampler = if config.add_downsample {
154 let downsample = Downsample2D::new(
155 vs.pp("downsamplers").pp("0"),
156 out_channels,
157 true,
158 out_channels,
159 config.downsample_padding,
160 )?;
161 Some(downsample)
162 } else {
163 None
164 };
165 let span = tracing::span!(tracing::Level::TRACE, "down-enc2d");
166 Ok(Self {
167 resnets,
168 downsampler,
169 span,
170 config,
171 })
172 }
173}
174
175impl Module for DownEncoderBlock2D {
176 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
177 let _enter = self.span.enter();
178 let mut xs = xs.clone();
179 for resnet in self.resnets.iter() {
180 xs = resnet.forward(&xs, None)?
181 }
182 match &self.downsampler {
183 Some(downsampler) => downsampler.forward(&xs),
184 None => Ok(xs),
185 }
186 }
187}
188
189#[derive(Debug, Clone, Copy)]
190pub struct UpDecoderBlock2DConfig {
191 pub num_layers: usize,
192 pub resnet_eps: f64,
193 pub resnet_groups: usize,
194 pub output_scale_factor: f64,
195 pub add_upsample: bool,
196}
197
198impl Default for UpDecoderBlock2DConfig {
199 fn default() -> Self {
200 Self {
201 num_layers: 1,
202 resnet_eps: 1e-6,
203 resnet_groups: 32,
204 output_scale_factor: 1.,
205 add_upsample: true,
206 }
207 }
208}
209
210#[derive(Debug)]
211pub struct UpDecoderBlock2D {
212 resnets: Vec<ResnetBlock2D>,
213 upsampler: Option<Upsample2D>,
214 span: tracing::Span,
215 pub config: UpDecoderBlock2DConfig,
216}
217
218impl UpDecoderBlock2D {
219 pub fn new(
220 vs: nn::VarBuilder,
221 in_channels: usize,
222 out_channels: usize,
223 config: UpDecoderBlock2DConfig,
224 ) -> Result<Self> {
225 let resnets: Vec<_> = {
226 let vs = vs.pp("resnets");
227 let conv_cfg = ResnetBlock2DConfig {
228 out_channels: Some(out_channels),
229 eps: config.resnet_eps,
230 groups: config.resnet_groups,
231 output_scale_factor: config.output_scale_factor,
232 temb_channels: None,
233 ..Default::default()
234 };
235 (0..(config.num_layers))
236 .map(|i| {
237 let in_channels = if i == 0 { in_channels } else { out_channels };
238 ResnetBlock2D::new(vs.pp(i.to_string()), in_channels, conv_cfg)
239 })
240 .collect::<Result<Vec<_>>>()?
241 };
242 let upsampler = if config.add_upsample {
243 let upsample =
244 Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
245 Some(upsample)
246 } else {
247 None
248 };
249 let span = tracing::span!(tracing::Level::TRACE, "up-dec2d");
250 Ok(Self {
251 resnets,
252 upsampler,
253 span,
254 config,
255 })
256 }
257}
258
259impl Module for UpDecoderBlock2D {
260 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
261 let _enter = self.span.enter();
262 let mut xs = xs.clone();
263 for resnet in self.resnets.iter() {
264 xs = resnet.forward(&xs, None)?
265 }
266 match &self.upsampler {
267 Some(upsampler) => upsampler.forward(&xs, None),
268 None => Ok(xs),
269 }
270 }
271}
272
273#[derive(Debug, Clone, Copy)]
274pub struct UNetMidBlock2DConfig {
275 pub num_layers: usize,
276 pub resnet_eps: f64,
277 pub resnet_groups: Option<usize>,
278 pub attn_num_head_channels: Option<usize>,
279 pub output_scale_factor: f64,
281}
282
283impl Default for UNetMidBlock2DConfig {
284 fn default() -> Self {
285 Self {
286 num_layers: 1,
287 resnet_eps: 1e-6,
288 resnet_groups: Some(32),
289 attn_num_head_channels: Some(1),
290 output_scale_factor: 1.,
291 }
292 }
293}
294
295#[derive(Debug)]
296pub struct UNetMidBlock2D {
297 resnet: ResnetBlock2D,
298 attn_resnets: Vec<(AttentionBlock, ResnetBlock2D)>,
299 span: tracing::Span,
300 pub config: UNetMidBlock2DConfig,
301}
302
303impl UNetMidBlock2D {
304 pub fn new(
305 vs: nn::VarBuilder,
306 in_channels: usize,
307 temb_channels: Option<usize>,
308 config: UNetMidBlock2DConfig,
309 ) -> Result<Self> {
310 let vs_resnets = vs.pp("resnets");
311 let vs_attns = vs.pp("attentions");
312 let resnet_groups = config
313 .resnet_groups
314 .unwrap_or_else(|| usize::min(in_channels / 4, 32));
315 let resnet_cfg = ResnetBlock2DConfig {
316 eps: config.resnet_eps,
317 groups: resnet_groups,
318 output_scale_factor: config.output_scale_factor,
319 temb_channels,
320 ..Default::default()
321 };
322 let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
323 let attn_cfg = AttentionBlockConfig {
324 num_head_channels: config.attn_num_head_channels,
325 num_groups: resnet_groups,
326 rescale_output_factor: config.output_scale_factor,
327 eps: config.resnet_eps,
328 };
329 let mut attn_resnets = vec![];
330 for index in 0..config.num_layers {
331 let attn = AttentionBlock::new(vs_attns.pp(index.to_string()), in_channels, attn_cfg)?;
332 let resnet = ResnetBlock2D::new(
333 vs_resnets.pp((index + 1).to_string()),
334 in_channels,
335 resnet_cfg,
336 )?;
337 attn_resnets.push((attn, resnet))
338 }
339 let span = tracing::span!(tracing::Level::TRACE, "mid2d");
340 Ok(Self {
341 resnet,
342 attn_resnets,
343 span,
344 config,
345 })
346 }
347
348 pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
349 let _enter = self.span.enter();
350 let mut xs = self.resnet.forward(xs, temb)?;
351 for (attn, resnet) in self.attn_resnets.iter() {
352 xs = resnet.forward(&attn.forward(&xs)?, temb)?
353 }
354 Ok(xs)
355 }
356}
357
358#[derive(Debug, Clone, Copy)]
359pub struct UNetMidBlock2DCrossAttnConfig {
360 pub num_layers: usize,
361 pub resnet_eps: f64,
362 pub resnet_groups: Option<usize>,
363 pub attn_num_head_channels: usize,
364 pub output_scale_factor: f64,
366 pub cross_attn_dim: usize,
367 pub sliced_attention_size: Option<usize>,
368 pub use_linear_projection: bool,
369 pub transformer_layers_per_block: usize,
370}
371
372impl Default for UNetMidBlock2DCrossAttnConfig {
373 fn default() -> Self {
374 Self {
375 num_layers: 1,
376 resnet_eps: 1e-6,
377 resnet_groups: Some(32),
378 attn_num_head_channels: 1,
379 output_scale_factor: 1.,
380 cross_attn_dim: 1280,
381 sliced_attention_size: None, use_linear_projection: false,
383 transformer_layers_per_block: 1,
384 }
385 }
386}
387
388#[derive(Debug)]
389pub struct UNetMidBlock2DCrossAttn {
390 resnet: ResnetBlock2D,
391 attn_resnets: Vec<(SpatialTransformer, ResnetBlock2D)>,
392 span: tracing::Span,
393 pub config: UNetMidBlock2DCrossAttnConfig,
394}
395
396impl UNetMidBlock2DCrossAttn {
397 pub fn new(
398 vs: nn::VarBuilder,
399 in_channels: usize,
400 temb_channels: Option<usize>,
401 use_flash_attn: bool,
402 config: UNetMidBlock2DCrossAttnConfig,
403 ) -> Result<Self> {
404 let vs_resnets = vs.pp("resnets");
405 let vs_attns = vs.pp("attentions");
406 let resnet_groups = config
407 .resnet_groups
408 .unwrap_or_else(|| usize::min(in_channels / 4, 32));
409 let resnet_cfg = ResnetBlock2DConfig {
410 eps: config.resnet_eps,
411 groups: resnet_groups,
412 output_scale_factor: config.output_scale_factor,
413 temb_channels,
414 ..Default::default()
415 };
416 let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
417 let n_heads = config.attn_num_head_channels;
418 let attn_cfg = SpatialTransformerConfig {
419 depth: config.transformer_layers_per_block,
420 num_groups: resnet_groups,
421 context_dim: Some(config.cross_attn_dim),
422 sliced_attention_size: config.sliced_attention_size,
423 use_linear_projection: config.use_linear_projection,
424 };
425 let mut attn_resnets = vec![];
426 for index in 0..config.num_layers {
427 let attn = SpatialTransformer::new(
428 vs_attns.pp(index.to_string()),
429 in_channels,
430 n_heads,
431 in_channels / n_heads,
432 use_flash_attn,
433 attn_cfg,
434 )?;
435 let resnet = ResnetBlock2D::new(
436 vs_resnets.pp((index + 1).to_string()),
437 in_channels,
438 resnet_cfg,
439 )?;
440 attn_resnets.push((attn, resnet))
441 }
442 let span = tracing::span!(tracing::Level::TRACE, "xa-mid2d");
443 Ok(Self {
444 resnet,
445 attn_resnets,
446 span,
447 config,
448 })
449 }
450
451 pub fn forward(
452 &self,
453 xs: &Tensor,
454 temb: Option<&Tensor>,
455 encoder_hidden_states: Option<&Tensor>,
456 ) -> Result<Tensor> {
457 let _enter = self.span.enter();
458 let mut xs = self.resnet.forward(xs, temb)?;
459 for (attn, resnet) in self.attn_resnets.iter() {
460 xs = resnet.forward(&attn.forward(&xs, encoder_hidden_states)?, temb)?
461 }
462 Ok(xs)
463 }
464}
465
466#[derive(Debug, Clone, Copy)]
467pub struct DownBlock2DConfig {
468 pub num_layers: usize,
469 pub resnet_eps: f64,
470 pub resnet_groups: usize,
473 pub output_scale_factor: f64,
474 pub add_downsample: bool,
475 pub downsample_padding: usize,
476}
477
478impl Default for DownBlock2DConfig {
479 fn default() -> Self {
480 Self {
481 num_layers: 1,
482 resnet_eps: 1e-6,
483 resnet_groups: 32,
484 output_scale_factor: 1.,
485 add_downsample: true,
486 downsample_padding: 1,
487 }
488 }
489}
490
491#[derive(Debug)]
492pub struct DownBlock2D {
493 resnets: Vec<ResnetBlock2D>,
494 downsampler: Option<Downsample2D>,
495 span: tracing::Span,
496 pub config: DownBlock2DConfig,
497}
498
499impl DownBlock2D {
500 pub fn new(
501 vs: nn::VarBuilder,
502 in_channels: usize,
503 out_channels: usize,
504 temb_channels: Option<usize>,
505 config: DownBlock2DConfig,
506 ) -> Result<Self> {
507 let vs_resnets = vs.pp("resnets");
508 let resnet_cfg = ResnetBlock2DConfig {
509 out_channels: Some(out_channels),
510 eps: config.resnet_eps,
511 output_scale_factor: config.output_scale_factor,
512 temb_channels,
513 ..Default::default()
514 };
515 let resnets = (0..config.num_layers)
516 .map(|i| {
517 let in_channels = if i == 0 { in_channels } else { out_channels };
518 ResnetBlock2D::new(vs_resnets.pp(i.to_string()), in_channels, resnet_cfg)
519 })
520 .collect::<Result<Vec<_>>>()?;
521 let downsampler = if config.add_downsample {
522 let downsampler = Downsample2D::new(
523 vs.pp("downsamplers").pp("0"),
524 out_channels,
525 true,
526 out_channels,
527 config.downsample_padding,
528 )?;
529 Some(downsampler)
530 } else {
531 None
532 };
533 let span = tracing::span!(tracing::Level::TRACE, "down2d");
534 Ok(Self {
535 resnets,
536 downsampler,
537 span,
538 config,
539 })
540 }
541
542 pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<(Tensor, Vec<Tensor>)> {
543 let _enter = self.span.enter();
544 let mut xs = xs.clone();
545 let mut output_states = vec![];
546 for resnet in self.resnets.iter() {
547 xs = resnet.forward(&xs, temb)?;
548 output_states.push(xs.clone());
549 }
550 let xs = match &self.downsampler {
551 Some(downsampler) => {
552 let xs = downsampler.forward(&xs)?;
553 output_states.push(xs.clone());
554 xs
555 }
556 None => xs,
557 };
558 Ok((xs, output_states))
559 }
560}
561
562#[derive(Debug, Clone, Copy)]
563pub struct CrossAttnDownBlock2DConfig {
564 pub downblock: DownBlock2DConfig,
565 pub attn_num_head_channels: usize,
566 pub cross_attention_dim: usize,
567 pub sliced_attention_size: Option<usize>,
569 pub use_linear_projection: bool,
570 pub transformer_layers_per_block: usize,
571}
572
573impl Default for CrossAttnDownBlock2DConfig {
574 fn default() -> Self {
575 Self {
576 downblock: Default::default(),
577 attn_num_head_channels: 1,
578 cross_attention_dim: 1280,
579 sliced_attention_size: None,
580 use_linear_projection: false,
581 transformer_layers_per_block: 1,
582 }
583 }
584}
585
586#[derive(Debug)]
587pub struct CrossAttnDownBlock2D {
588 downblock: DownBlock2D,
589 attentions: Vec<SpatialTransformer>,
590 span: tracing::Span,
591 pub config: CrossAttnDownBlock2DConfig,
592}
593
594impl CrossAttnDownBlock2D {
595 pub fn new(
596 vs: nn::VarBuilder,
597 in_channels: usize,
598 out_channels: usize,
599 temb_channels: Option<usize>,
600 use_flash_attn: bool,
601 config: CrossAttnDownBlock2DConfig,
602 ) -> Result<Self> {
603 let downblock = DownBlock2D::new(
604 vs.clone(),
605 in_channels,
606 out_channels,
607 temb_channels,
608 config.downblock,
609 )?;
610 let n_heads = config.attn_num_head_channels;
611 let cfg = SpatialTransformerConfig {
612 depth: config.transformer_layers_per_block,
613 context_dim: Some(config.cross_attention_dim),
614 num_groups: config.downblock.resnet_groups,
615 sliced_attention_size: config.sliced_attention_size,
616 use_linear_projection: config.use_linear_projection,
617 };
618 let vs_attn = vs.pp("attentions");
619 let attentions = (0..config.downblock.num_layers)
620 .map(|i| {
621 SpatialTransformer::new(
622 vs_attn.pp(i.to_string()),
623 out_channels,
624 n_heads,
625 out_channels / n_heads,
626 use_flash_attn,
627 cfg,
628 )
629 })
630 .collect::<Result<Vec<_>>>()?;
631 let span = tracing::span!(tracing::Level::TRACE, "xa-down2d");
632 Ok(Self {
633 downblock,
634 attentions,
635 span,
636 config,
637 })
638 }
639
640 pub fn forward(
641 &self,
642 xs: &Tensor,
643 temb: Option<&Tensor>,
644 encoder_hidden_states: Option<&Tensor>,
645 ) -> Result<(Tensor, Vec<Tensor>)> {
646 let _enter = self.span.enter();
647 let mut output_states = vec![];
648 let mut xs = xs.clone();
649 for (resnet, attn) in self.downblock.resnets.iter().zip(self.attentions.iter()) {
650 xs = resnet.forward(&xs, temb)?;
651 xs = attn.forward(&xs, encoder_hidden_states)?;
652 output_states.push(xs.clone());
653 }
654 let xs = match &self.downblock.downsampler {
655 Some(downsampler) => {
656 let xs = downsampler.forward(&xs)?;
657 output_states.push(xs.clone());
658 xs
659 }
660 None => xs,
661 };
662 Ok((xs, output_states))
663 }
664}
665
666#[derive(Debug, Clone, Copy)]
667pub struct UpBlock2DConfig {
668 pub num_layers: usize,
669 pub resnet_eps: f64,
670 pub resnet_groups: usize,
673 pub output_scale_factor: f64,
674 pub add_upsample: bool,
675}
676
677impl Default for UpBlock2DConfig {
678 fn default() -> Self {
679 Self {
680 num_layers: 1,
681 resnet_eps: 1e-6,
682 resnet_groups: 32,
683 output_scale_factor: 1.,
684 add_upsample: true,
685 }
686 }
687}
688
689#[derive(Debug)]
690pub struct UpBlock2D {
691 pub resnets: Vec<ResnetBlock2D>,
692 upsampler: Option<Upsample2D>,
693 span: tracing::Span,
694 pub config: UpBlock2DConfig,
695}
696
697impl UpBlock2D {
698 pub fn new(
699 vs: nn::VarBuilder,
700 in_channels: usize,
701 prev_output_channels: usize,
702 out_channels: usize,
703 temb_channels: Option<usize>,
704 config: UpBlock2DConfig,
705 ) -> Result<Self> {
706 let vs_resnets = vs.pp("resnets");
707 let resnet_cfg = ResnetBlock2DConfig {
708 out_channels: Some(out_channels),
709 temb_channels,
710 eps: config.resnet_eps,
711 output_scale_factor: config.output_scale_factor,
712 ..Default::default()
713 };
714 let resnets = (0..config.num_layers)
715 .map(|i| {
716 let res_skip_channels = if i == config.num_layers - 1 {
717 in_channels
718 } else {
719 out_channels
720 };
721 let resnet_in_channels = if i == 0 {
722 prev_output_channels
723 } else {
724 out_channels
725 };
726 let in_channels = resnet_in_channels + res_skip_channels;
727 ResnetBlock2D::new(vs_resnets.pp(i.to_string()), in_channels, resnet_cfg)
728 })
729 .collect::<Result<Vec<_>>>()?;
730 let upsampler = if config.add_upsample {
731 let upsampler =
732 Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
733 Some(upsampler)
734 } else {
735 None
736 };
737 let span = tracing::span!(tracing::Level::TRACE, "up2d");
738 Ok(Self {
739 resnets,
740 upsampler,
741 span,
742 config,
743 })
744 }
745
746 pub fn forward(
747 &self,
748 xs: &Tensor,
749 res_xs: &[Tensor],
750 temb: Option<&Tensor>,
751 upsample_size: Option<(usize, usize)>,
752 ) -> Result<Tensor> {
753 let _enter = self.span.enter();
754 let mut xs = xs.clone();
755 for (index, resnet) in self.resnets.iter().enumerate() {
756 xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
757 xs = xs.contiguous()?;
758 xs = resnet.forward(&xs, temb)?;
759 }
760 match &self.upsampler {
761 Some(upsampler) => upsampler.forward(&xs, upsample_size),
762 None => Ok(xs),
763 }
764 }
765}
766
767#[derive(Debug, Clone, Copy)]
768pub struct CrossAttnUpBlock2DConfig {
769 pub upblock: UpBlock2DConfig,
770 pub attn_num_head_channels: usize,
771 pub cross_attention_dim: usize,
772 pub sliced_attention_size: Option<usize>,
774 pub use_linear_projection: bool,
775 pub transformer_layers_per_block: usize,
776}
777
778impl Default for CrossAttnUpBlock2DConfig {
779 fn default() -> Self {
780 Self {
781 upblock: Default::default(),
782 attn_num_head_channels: 1,
783 cross_attention_dim: 1280,
784 sliced_attention_size: None,
785 use_linear_projection: false,
786 transformer_layers_per_block: 1,
787 }
788 }
789}
790
791#[derive(Debug)]
792pub struct CrossAttnUpBlock2D {
793 pub upblock: UpBlock2D,
794 pub attentions: Vec<SpatialTransformer>,
795 span: tracing::Span,
796 pub config: CrossAttnUpBlock2DConfig,
797}
798
799impl CrossAttnUpBlock2D {
800 pub fn new(
801 vs: nn::VarBuilder,
802 in_channels: usize,
803 prev_output_channels: usize,
804 out_channels: usize,
805 temb_channels: Option<usize>,
806 use_flash_attn: bool,
807 config: CrossAttnUpBlock2DConfig,
808 ) -> Result<Self> {
809 let upblock = UpBlock2D::new(
810 vs.clone(),
811 in_channels,
812 prev_output_channels,
813 out_channels,
814 temb_channels,
815 config.upblock,
816 )?;
817 let n_heads = config.attn_num_head_channels;
818 let cfg = SpatialTransformerConfig {
819 depth: config.transformer_layers_per_block,
820 context_dim: Some(config.cross_attention_dim),
821 num_groups: config.upblock.resnet_groups,
822 sliced_attention_size: config.sliced_attention_size,
823 use_linear_projection: config.use_linear_projection,
824 };
825 let vs_attn = vs.pp("attentions");
826 let attentions = (0..config.upblock.num_layers)
827 .map(|i| {
828 SpatialTransformer::new(
829 vs_attn.pp(i.to_string()),
830 out_channels,
831 n_heads,
832 out_channels / n_heads,
833 use_flash_attn,
834 cfg,
835 )
836 })
837 .collect::<Result<Vec<_>>>()?;
838 let span = tracing::span!(tracing::Level::TRACE, "xa-up2d");
839 Ok(Self {
840 upblock,
841 attentions,
842 span,
843 config,
844 })
845 }
846
847 pub fn forward(
848 &self,
849 xs: &Tensor,
850 res_xs: &[Tensor],
851 temb: Option<&Tensor>,
852 upsample_size: Option<(usize, usize)>,
853 encoder_hidden_states: Option<&Tensor>,
854 ) -> Result<Tensor> {
855 let _enter = self.span.enter();
856 let mut xs = xs.clone();
857 for (index, resnet) in self.upblock.resnets.iter().enumerate() {
858 xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
859 xs = xs.contiguous()?;
860 xs = resnet.forward(&xs, temb)?;
861 xs = self.attentions[index].forward(&xs, encoder_hidden_states)?;
862 }
863 match &self.upblock.upsampler {
864 Some(upsampler) => upsampler.forward(&xs, upsample_size),
865 None => Ok(xs),
866 }
867 }
868}