1use candle::{DType, IndexOp, Result, Tensor, D};
2use candle_nn::{LayerNorm, Linear, RmsNorm, VarBuilder};
3
4#[derive(Debug, Clone)]
6pub struct Config {
7 pub in_channels: usize,
8 pub vec_in_dim: usize,
9 pub context_in_dim: usize,
10 pub hidden_size: usize,
11 pub mlp_ratio: f64,
12 pub num_heads: usize,
13 pub depth: usize,
14 pub depth_single_blocks: usize,
15 pub axes_dim: Vec<usize>,
16 pub theta: usize,
17 pub qkv_bias: bool,
18 pub guidance_embed: bool,
19}
20
21impl Config {
22 pub fn dev() -> Self {
24 Self {
25 in_channels: 64,
26 vec_in_dim: 768,
27 context_in_dim: 4096,
28 hidden_size: 3072,
29 mlp_ratio: 4.0,
30 num_heads: 24,
31 depth: 19,
32 depth_single_blocks: 38,
33 axes_dim: vec![16, 56, 56],
34 theta: 10_000,
35 qkv_bias: true,
36 guidance_embed: true,
37 }
38 }
39
40 pub fn schnell() -> Self {
42 Self {
43 in_channels: 64,
44 vec_in_dim: 768,
45 context_in_dim: 4096,
46 hidden_size: 3072,
47 mlp_ratio: 4.0,
48 num_heads: 24,
49 depth: 19,
50 depth_single_blocks: 38,
51 axes_dim: vec![16, 56, 56],
52 theta: 10_000,
53 qkv_bias: true,
54 guidance_embed: false,
55 }
56 }
57}
58
59fn layer_norm(dim: usize, vb: VarBuilder) -> Result<LayerNorm> {
60 let ws = Tensor::ones(dim, vb.dtype(), vb.device())?;
61 Ok(LayerNorm::new_no_bias(ws, 1e-6))
62}
63
64fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
65 let dim = q.dim(D::Minus1)?;
66 let scale_factor = 1.0 / (dim as f64).sqrt();
67 let mut batch_dims = q.dims().to_vec();
68 batch_dims.pop();
69 batch_dims.pop();
70 let q = q.flatten_to(batch_dims.len() - 1)?;
71 let k = k.flatten_to(batch_dims.len() - 1)?;
72 let v = v.flatten_to(batch_dims.len() - 1)?;
73 let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?;
74 let attn_scores = candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(&v)?;
75 batch_dims.push(attn_scores.dim(D::Minus2)?);
76 batch_dims.push(attn_scores.dim(D::Minus1)?);
77 attn_scores.reshape(batch_dims)
78}
79
80fn rope(pos: &Tensor, dim: usize, theta: usize) -> Result<Tensor> {
81 if dim % 2 == 1 {
82 candle::bail!("dim {dim} is odd")
83 }
84 let dev = pos.device();
85 let theta = theta as f64;
86 let inv_freq: Vec<_> = (0..dim)
87 .step_by(2)
88 .map(|i| 1f32 / theta.powf(i as f64 / dim as f64) as f32)
89 .collect();
90 let inv_freq_len = inv_freq.len();
91 let inv_freq = Tensor::from_vec(inv_freq, (1, 1, inv_freq_len), dev)?;
92 let inv_freq = inv_freq.to_dtype(pos.dtype())?;
93 let freqs = pos.unsqueeze(2)?.broadcast_mul(&inv_freq)?;
94 let cos = freqs.cos()?;
95 let sin = freqs.sin()?;
96 let out = Tensor::stack(&[&cos, &sin.neg()?, &sin, &cos], 3)?;
97 let (b, n, d, _ij) = out.dims4()?;
98 out.reshape((b, n, d, 2, 2))
99}
100
101fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result<Tensor> {
102 let dims = x.dims();
103 let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
104 let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
105 let x0 = x.narrow(D::Minus1, 0, 1)?;
106 let x1 = x.narrow(D::Minus1, 1, 1)?;
107 let fr0 = freq_cis.get_on_dim(D::Minus1, 0)?;
108 let fr1 = freq_cis.get_on_dim(D::Minus1, 1)?;
109 (fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec())
110}
111
112pub(crate) fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
113 let q = apply_rope(q, pe)?.contiguous()?;
114 let k = apply_rope(k, pe)?.contiguous()?;
115 let x = scaled_dot_product_attention(&q, &k, v)?;
116 x.transpose(1, 2)?.flatten_from(2)
117}
118
119pub(crate) fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
120 const TIME_FACTOR: f64 = 1000.;
121 const MAX_PERIOD: f64 = 10000.;
122 if dim % 2 == 1 {
123 candle::bail!("{dim} is odd")
124 }
125 let dev = t.device();
126 let half = dim / 2;
127 let t = (t * TIME_FACTOR)?;
128 let arange = Tensor::arange(0, half as u32, dev)?.to_dtype(candle::DType::F32)?;
129 let freqs = (arange * (-MAX_PERIOD.ln() / half as f64))?.exp()?;
130 let args = t
131 .unsqueeze(1)?
132 .to_dtype(candle::DType::F32)?
133 .broadcast_mul(&freqs.unsqueeze(0)?)?;
134 let emb = Tensor::cat(&[args.cos()?, args.sin()?], D::Minus1)?.to_dtype(dtype)?;
135 Ok(emb)
136}
137
138#[derive(Debug, Clone)]
139pub struct EmbedNd {
140 #[allow(unused)]
141 dim: usize,
142 theta: usize,
143 axes_dim: Vec<usize>,
144}
145
146impl EmbedNd {
147 pub fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
148 Self {
149 dim,
150 theta,
151 axes_dim,
152 }
153 }
154}
155
156impl candle::Module for EmbedNd {
157 fn forward(&self, ids: &Tensor) -> Result<Tensor> {
158 let n_axes = ids.dim(D::Minus1)?;
159 let mut emb = Vec::with_capacity(n_axes);
160 for idx in 0..n_axes {
161 let r = rope(
162 &ids.get_on_dim(D::Minus1, idx)?,
163 self.axes_dim[idx],
164 self.theta,
165 )?;
166 emb.push(r)
167 }
168 let emb = Tensor::cat(&emb, 2)?;
169 emb.unsqueeze(1)
170 }
171}
172
173#[derive(Debug, Clone)]
174pub struct MlpEmbedder {
175 in_layer: Linear,
176 out_layer: Linear,
177}
178
179impl MlpEmbedder {
180 fn new(in_sz: usize, h_sz: usize, vb: VarBuilder) -> Result<Self> {
181 let in_layer = candle_nn::linear(in_sz, h_sz, vb.pp("in_layer"))?;
182 let out_layer = candle_nn::linear(h_sz, h_sz, vb.pp("out_layer"))?;
183 Ok(Self {
184 in_layer,
185 out_layer,
186 })
187 }
188}
189
190impl candle::Module for MlpEmbedder {
191 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
192 xs.apply(&self.in_layer)?.silu()?.apply(&self.out_layer)
193 }
194}
195
196#[derive(Debug, Clone)]
197pub struct QkNorm {
198 query_norm: RmsNorm,
199 key_norm: RmsNorm,
200}
201
202impl QkNorm {
203 fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
204 let query_norm = vb.get(dim, "query_norm.scale")?;
205 let query_norm = RmsNorm::new(query_norm, 1e-6);
206 let key_norm = vb.get(dim, "key_norm.scale")?;
207 let key_norm = RmsNorm::new(key_norm, 1e-6);
208 Ok(Self {
209 query_norm,
210 key_norm,
211 })
212 }
213}
214
215struct ModulationOut {
216 shift: Tensor,
217 scale: Tensor,
218 gate: Tensor,
219}
220
221impl ModulationOut {
222 fn scale_shift(&self, xs: &Tensor) -> Result<Tensor> {
223 xs.broadcast_mul(&(&self.scale + 1.)?)?
224 .broadcast_add(&self.shift)
225 }
226
227 fn gate(&self, xs: &Tensor) -> Result<Tensor> {
228 self.gate.broadcast_mul(xs)
229 }
230}
231
232#[derive(Debug, Clone)]
233struct Modulation1 {
234 lin: Linear,
235}
236
237impl Modulation1 {
238 fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
239 let lin = candle_nn::linear(dim, 3 * dim, vb.pp("lin"))?;
240 Ok(Self { lin })
241 }
242
243 fn forward(&self, vec_: &Tensor) -> Result<ModulationOut> {
244 let ys = vec_
245 .silu()?
246 .apply(&self.lin)?
247 .unsqueeze(1)?
248 .chunk(3, D::Minus1)?;
249 if ys.len() != 3 {
250 candle::bail!("unexpected len from chunk {ys:?}")
251 }
252 Ok(ModulationOut {
253 shift: ys[0].clone(),
254 scale: ys[1].clone(),
255 gate: ys[2].clone(),
256 })
257 }
258}
259
260#[derive(Debug, Clone)]
261struct Modulation2 {
262 lin: Linear,
263}
264
265impl Modulation2 {
266 fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
267 let lin = candle_nn::linear(dim, 6 * dim, vb.pp("lin"))?;
268 Ok(Self { lin })
269 }
270
271 fn forward(&self, vec_: &Tensor) -> Result<(ModulationOut, ModulationOut)> {
272 let ys = vec_
273 .silu()?
274 .apply(&self.lin)?
275 .unsqueeze(1)?
276 .chunk(6, D::Minus1)?;
277 if ys.len() != 6 {
278 candle::bail!("unexpected len from chunk {ys:?}")
279 }
280 let mod1 = ModulationOut {
281 shift: ys[0].clone(),
282 scale: ys[1].clone(),
283 gate: ys[2].clone(),
284 };
285 let mod2 = ModulationOut {
286 shift: ys[3].clone(),
287 scale: ys[4].clone(),
288 gate: ys[5].clone(),
289 };
290 Ok((mod1, mod2))
291 }
292}
293
294#[derive(Debug, Clone)]
295pub struct SelfAttention {
296 qkv: Linear,
297 norm: QkNorm,
298 proj: Linear,
299 num_heads: usize,
300}
301
302impl SelfAttention {
303 fn new(dim: usize, num_heads: usize, qkv_bias: bool, vb: VarBuilder) -> Result<Self> {
304 let head_dim = dim / num_heads;
305 let qkv = candle_nn::linear_b(dim, dim * 3, qkv_bias, vb.pp("qkv"))?;
306 let norm = QkNorm::new(head_dim, vb.pp("norm"))?;
307 let proj = candle_nn::linear(dim, dim, vb.pp("proj"))?;
308 Ok(Self {
309 qkv,
310 norm,
311 proj,
312 num_heads,
313 })
314 }
315
316 fn qkv(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
317 let qkv = xs.apply(&self.qkv)?;
318 let (b, l, _khd) = qkv.dims3()?;
319 let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
320 let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
321 let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
322 let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
323 let q = q.apply(&self.norm.query_norm)?;
324 let k = k.apply(&self.norm.key_norm)?;
325 Ok((q, k, v))
326 }
327
328 #[allow(unused)]
329 fn forward(&self, xs: &Tensor, pe: &Tensor) -> Result<Tensor> {
330 let (q, k, v) = self.qkv(xs)?;
331 attention(&q, &k, &v, pe)?.apply(&self.proj)
332 }
333}
334
335#[derive(Debug, Clone)]
336struct Mlp {
337 lin1: Linear,
338 lin2: Linear,
339}
340
341impl Mlp {
342 fn new(in_sz: usize, mlp_sz: usize, vb: VarBuilder) -> Result<Self> {
343 let lin1 = candle_nn::linear(in_sz, mlp_sz, vb.pp("0"))?;
344 let lin2 = candle_nn::linear(mlp_sz, in_sz, vb.pp("2"))?;
345 Ok(Self { lin1, lin2 })
346 }
347}
348
349impl candle::Module for Mlp {
350 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
351 xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2)
352 }
353}
354
355#[derive(Debug, Clone)]
356pub struct DoubleStreamBlock {
357 img_mod: Modulation2,
358 img_norm1: LayerNorm,
359 img_attn: SelfAttention,
360 img_norm2: LayerNorm,
361 img_mlp: Mlp,
362 txt_mod: Modulation2,
363 txt_norm1: LayerNorm,
364 txt_attn: SelfAttention,
365 txt_norm2: LayerNorm,
366 txt_mlp: Mlp,
367}
368
369impl DoubleStreamBlock {
370 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
371 let h_sz = cfg.hidden_size;
372 let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize;
373 let img_mod = Modulation2::new(h_sz, vb.pp("img_mod"))?;
374 let img_norm1 = layer_norm(h_sz, vb.pp("img_norm1"))?;
375 let img_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("img_attn"))?;
376 let img_norm2 = layer_norm(h_sz, vb.pp("img_norm2"))?;
377 let img_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("img_mlp"))?;
378 let txt_mod = Modulation2::new(h_sz, vb.pp("txt_mod"))?;
379 let txt_norm1 = layer_norm(h_sz, vb.pp("txt_norm1"))?;
380 let txt_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("txt_attn"))?;
381 let txt_norm2 = layer_norm(h_sz, vb.pp("txt_norm2"))?;
382 let txt_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("txt_mlp"))?;
383 Ok(Self {
384 img_mod,
385 img_norm1,
386 img_attn,
387 img_norm2,
388 img_mlp,
389 txt_mod,
390 txt_norm1,
391 txt_attn,
392 txt_norm2,
393 txt_mlp,
394 })
395 }
396
397 fn forward(
398 &self,
399 img: &Tensor,
400 txt: &Tensor,
401 vec_: &Tensor,
402 pe: &Tensor,
403 ) -> Result<(Tensor, Tensor)> {
404 let (img_mod1, img_mod2) = self.img_mod.forward(vec_)?; let (txt_mod1, txt_mod2) = self.txt_mod.forward(vec_)?; let img_modulated = img.apply(&self.img_norm1)?;
407 let img_modulated = img_mod1.scale_shift(&img_modulated)?;
408 let (img_q, img_k, img_v) = self.img_attn.qkv(&img_modulated)?;
409
410 let txt_modulated = txt.apply(&self.txt_norm1)?;
411 let txt_modulated = txt_mod1.scale_shift(&txt_modulated)?;
412 let (txt_q, txt_k, txt_v) = self.txt_attn.qkv(&txt_modulated)?;
413
414 let q = Tensor::cat(&[txt_q, img_q], 2)?;
415 let k = Tensor::cat(&[txt_k, img_k], 2)?;
416 let v = Tensor::cat(&[txt_v, img_v], 2)?;
417
418 let attn = attention(&q, &k, &v, pe)?;
419 let txt_attn = attn.narrow(1, 0, txt.dim(1)?)?;
420 let img_attn = attn.narrow(1, txt.dim(1)?, attn.dim(1)? - txt.dim(1)?)?;
421
422 let img = (img + img_mod1.gate(&img_attn.apply(&self.img_attn.proj)?))?;
423 let img = (&img
424 + img_mod2.gate(
425 &img_mod2
426 .scale_shift(&img.apply(&self.img_norm2)?)?
427 .apply(&self.img_mlp)?,
428 )?)?;
429
430 let txt = (txt + txt_mod1.gate(&txt_attn.apply(&self.txt_attn.proj)?))?;
431 let txt = (&txt
432 + txt_mod2.gate(
433 &txt_mod2
434 .scale_shift(&txt.apply(&self.txt_norm2)?)?
435 .apply(&self.txt_mlp)?,
436 )?)?;
437
438 Ok((img, txt))
439 }
440}
441
442#[derive(Debug, Clone)]
443pub struct SingleStreamBlock {
444 linear1: Linear,
445 linear2: Linear,
446 norm: QkNorm,
447 pre_norm: LayerNorm,
448 modulation: Modulation1,
449 h_sz: usize,
450 mlp_sz: usize,
451 num_heads: usize,
452}
453
454impl SingleStreamBlock {
455 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
456 let h_sz = cfg.hidden_size;
457 let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize;
458 let head_dim = h_sz / cfg.num_heads;
459 let linear1 = candle_nn::linear(h_sz, h_sz * 3 + mlp_sz, vb.pp("linear1"))?;
460 let linear2 = candle_nn::linear(h_sz + mlp_sz, h_sz, vb.pp("linear2"))?;
461 let norm = QkNorm::new(head_dim, vb.pp("norm"))?;
462 let pre_norm = layer_norm(h_sz, vb.pp("pre_norm"))?;
463 let modulation = Modulation1::new(h_sz, vb.pp("modulation"))?;
464 Ok(Self {
465 linear1,
466 linear2,
467 norm,
468 pre_norm,
469 modulation,
470 h_sz,
471 mlp_sz,
472 num_heads: cfg.num_heads,
473 })
474 }
475
476 fn forward(&self, xs: &Tensor, vec_: &Tensor, pe: &Tensor) -> Result<Tensor> {
477 let mod_ = self.modulation.forward(vec_)?;
478 let x_mod = mod_.scale_shift(&xs.apply(&self.pre_norm)?)?;
479 let x_mod = x_mod.apply(&self.linear1)?;
480 let qkv = x_mod.narrow(D::Minus1, 0, 3 * self.h_sz)?;
481 let (b, l, _khd) = qkv.dims3()?;
482 let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
483 let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
484 let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
485 let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
486 let mlp = x_mod.narrow(D::Minus1, 3 * self.h_sz, self.mlp_sz)?;
487 let q = q.apply(&self.norm.query_norm)?;
488 let k = k.apply(&self.norm.key_norm)?;
489 let attn = attention(&q, &k, &v, pe)?;
490 let output = Tensor::cat(&[attn, mlp.gelu()?], 2)?.apply(&self.linear2)?;
491 xs + mod_.gate(&output)
492 }
493}
494
495#[derive(Debug, Clone)]
496pub struct LastLayer {
497 norm_final: LayerNorm,
498 linear: Linear,
499 ada_ln_modulation: Linear,
500}
501
502impl LastLayer {
503 fn new(h_sz: usize, p_sz: usize, out_c: usize, vb: VarBuilder) -> Result<Self> {
504 let norm_final = layer_norm(h_sz, vb.pp("norm_final"))?;
505 let linear = candle_nn::linear(h_sz, p_sz * p_sz * out_c, vb.pp("linear"))?;
506 let ada_ln_modulation = candle_nn::linear(h_sz, 2 * h_sz, vb.pp("adaLN_modulation.1"))?;
507 Ok(Self {
508 norm_final,
509 linear,
510 ada_ln_modulation,
511 })
512 }
513
514 fn forward(&self, xs: &Tensor, vec: &Tensor) -> Result<Tensor> {
515 let chunks = vec.silu()?.apply(&self.ada_ln_modulation)?.chunk(2, 1)?;
516 let (shift, scale) = (&chunks[0], &chunks[1]);
517 let xs = xs
518 .apply(&self.norm_final)?
519 .broadcast_mul(&(scale.unsqueeze(1)? + 1.0)?)?
520 .broadcast_add(&shift.unsqueeze(1)?)?;
521 xs.apply(&self.linear)
522 }
523}
524
525#[derive(Debug, Clone)]
526pub struct Flux {
527 img_in: Linear,
528 txt_in: Linear,
529 time_in: MlpEmbedder,
530 vector_in: MlpEmbedder,
531 guidance_in: Option<MlpEmbedder>,
532 pe_embedder: EmbedNd,
533 double_blocks: Vec<DoubleStreamBlock>,
534 single_blocks: Vec<SingleStreamBlock>,
535 final_layer: LastLayer,
536}
537
538impl Flux {
539 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
540 let img_in = candle_nn::linear(cfg.in_channels, cfg.hidden_size, vb.pp("img_in"))?;
541 let txt_in = candle_nn::linear(cfg.context_in_dim, cfg.hidden_size, vb.pp("txt_in"))?;
542 let mut double_blocks = Vec::with_capacity(cfg.depth);
543 let vb_d = vb.pp("double_blocks");
544 for idx in 0..cfg.depth {
545 let db = DoubleStreamBlock::new(cfg, vb_d.pp(idx))?;
546 double_blocks.push(db)
547 }
548 let mut single_blocks = Vec::with_capacity(cfg.depth_single_blocks);
549 let vb_s = vb.pp("single_blocks");
550 for idx in 0..cfg.depth_single_blocks {
551 let sb = SingleStreamBlock::new(cfg, vb_s.pp(idx))?;
552 single_blocks.push(sb)
553 }
554 let time_in = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("time_in"))?;
555 let vector_in = MlpEmbedder::new(cfg.vec_in_dim, cfg.hidden_size, vb.pp("vector_in"))?;
556 let guidance_in = if cfg.guidance_embed {
557 let mlp = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("guidance_in"))?;
558 Some(mlp)
559 } else {
560 None
561 };
562 let final_layer =
563 LastLayer::new(cfg.hidden_size, 1, cfg.in_channels, vb.pp("final_layer"))?;
564 let pe_dim = cfg.hidden_size / cfg.num_heads;
565 let pe_embedder = EmbedNd::new(pe_dim, cfg.theta, cfg.axes_dim.to_vec());
566 Ok(Self {
567 img_in,
568 txt_in,
569 time_in,
570 vector_in,
571 guidance_in,
572 pe_embedder,
573 double_blocks,
574 single_blocks,
575 final_layer,
576 })
577 }
578}
579
580impl super::WithForward for Flux {
581 #[allow(clippy::too_many_arguments)]
582 fn forward(
583 &self,
584 img: &Tensor,
585 img_ids: &Tensor,
586 txt: &Tensor,
587 txt_ids: &Tensor,
588 timesteps: &Tensor,
589 y: &Tensor,
590 guidance: Option<&Tensor>,
591 ) -> Result<Tensor> {
592 if txt.rank() != 3 {
593 candle::bail!("unexpected shape for txt {:?}", txt.shape())
594 }
595 if img.rank() != 3 {
596 candle::bail!("unexpected shape for img {:?}", img.shape())
597 }
598 let dtype = img.dtype();
599 let pe = {
600 let ids = Tensor::cat(&[txt_ids, img_ids], 1)?;
601 ids.apply(&self.pe_embedder)?
602 };
603 let mut txt = txt.apply(&self.txt_in)?;
604 let mut img = img.apply(&self.img_in)?;
605 let vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?;
606 let vec_ = match (self.guidance_in.as_ref(), guidance) {
607 (Some(g_in), Some(guidance)) => {
608 (vec_ + timestep_embedding(guidance, 256, dtype)?.apply(g_in))?
609 }
610 _ => vec_,
611 };
612 let vec_ = (vec_ + y.apply(&self.vector_in))?;
613
614 for block in self.double_blocks.iter() {
616 (img, txt) = block.forward(&img, &txt, &vec_, &pe)?
617 }
618 let mut img = Tensor::cat(&[&txt, &img], 1)?;
620 for block in self.single_blocks.iter() {
621 img = block.forward(&img, &vec_, &pe)?;
622 }
623 let img = img.i((.., txt.dim(1)?..))?;
624 self.final_layer.forward(&img, &vec_)
625 }
626}