candle_transformers/models/flux/
model.rs

1use candle::{DType, IndexOp, Result, Tensor, D};
2use candle_nn::{LayerNorm, Linear, RmsNorm, VarBuilder};
3
4// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/model.py#L12
5#[derive(Debug, Clone)]
6pub struct Config {
7    pub in_channels: usize,
8    pub vec_in_dim: usize,
9    pub context_in_dim: usize,
10    pub hidden_size: usize,
11    pub mlp_ratio: f64,
12    pub num_heads: usize,
13    pub depth: usize,
14    pub depth_single_blocks: usize,
15    pub axes_dim: Vec<usize>,
16    pub theta: usize,
17    pub qkv_bias: bool,
18    pub guidance_embed: bool,
19}
20
21impl Config {
22    // https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L32
23    pub fn dev() -> Self {
24        Self {
25            in_channels: 64,
26            vec_in_dim: 768,
27            context_in_dim: 4096,
28            hidden_size: 3072,
29            mlp_ratio: 4.0,
30            num_heads: 24,
31            depth: 19,
32            depth_single_blocks: 38,
33            axes_dim: vec![16, 56, 56],
34            theta: 10_000,
35            qkv_bias: true,
36            guidance_embed: true,
37        }
38    }
39
40    // https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L64
41    pub fn schnell() -> Self {
42        Self {
43            in_channels: 64,
44            vec_in_dim: 768,
45            context_in_dim: 4096,
46            hidden_size: 3072,
47            mlp_ratio: 4.0,
48            num_heads: 24,
49            depth: 19,
50            depth_single_blocks: 38,
51            axes_dim: vec![16, 56, 56],
52            theta: 10_000,
53            qkv_bias: true,
54            guidance_embed: false,
55        }
56    }
57}
58
59fn layer_norm(dim: usize, vb: VarBuilder) -> Result<LayerNorm> {
60    let ws = Tensor::ones(dim, vb.dtype(), vb.device())?;
61    Ok(LayerNorm::new_no_bias(ws, 1e-6))
62}
63
64fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
65    let dim = q.dim(D::Minus1)?;
66    let scale_factor = 1.0 / (dim as f64).sqrt();
67    let mut batch_dims = q.dims().to_vec();
68    batch_dims.pop();
69    batch_dims.pop();
70    let q = q.flatten_to(batch_dims.len() - 1)?;
71    let k = k.flatten_to(batch_dims.len() - 1)?;
72    let v = v.flatten_to(batch_dims.len() - 1)?;
73    let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?;
74    let attn_scores = candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(&v)?;
75    batch_dims.push(attn_scores.dim(D::Minus2)?);
76    batch_dims.push(attn_scores.dim(D::Minus1)?);
77    attn_scores.reshape(batch_dims)
78}
79
80fn rope(pos: &Tensor, dim: usize, theta: usize) -> Result<Tensor> {
81    if dim % 2 == 1 {
82        candle::bail!("dim {dim} is odd")
83    }
84    let dev = pos.device();
85    let theta = theta as f64;
86    let inv_freq: Vec<_> = (0..dim)
87        .step_by(2)
88        .map(|i| 1f32 / theta.powf(i as f64 / dim as f64) as f32)
89        .collect();
90    let inv_freq_len = inv_freq.len();
91    let inv_freq = Tensor::from_vec(inv_freq, (1, 1, inv_freq_len), dev)?;
92    let inv_freq = inv_freq.to_dtype(pos.dtype())?;
93    let freqs = pos.unsqueeze(2)?.broadcast_mul(&inv_freq)?;
94    let cos = freqs.cos()?;
95    let sin = freqs.sin()?;
96    let out = Tensor::stack(&[&cos, &sin.neg()?, &sin, &cos], 3)?;
97    let (b, n, d, _ij) = out.dims4()?;
98    out.reshape((b, n, d, 2, 2))
99}
100
101fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result<Tensor> {
102    let dims = x.dims();
103    let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
104    let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
105    let x0 = x.narrow(D::Minus1, 0, 1)?;
106    let x1 = x.narrow(D::Minus1, 1, 1)?;
107    let fr0 = freq_cis.get_on_dim(D::Minus1, 0)?;
108    let fr1 = freq_cis.get_on_dim(D::Minus1, 1)?;
109    (fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec())
110}
111
112pub(crate) fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
113    let q = apply_rope(q, pe)?.contiguous()?;
114    let k = apply_rope(k, pe)?.contiguous()?;
115    let x = scaled_dot_product_attention(&q, &k, v)?;
116    x.transpose(1, 2)?.flatten_from(2)
117}
118
119pub(crate) fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
120    const TIME_FACTOR: f64 = 1000.;
121    const MAX_PERIOD: f64 = 10000.;
122    if dim % 2 == 1 {
123        candle::bail!("{dim} is odd")
124    }
125    let dev = t.device();
126    let half = dim / 2;
127    let t = (t * TIME_FACTOR)?;
128    let arange = Tensor::arange(0, half as u32, dev)?.to_dtype(candle::DType::F32)?;
129    let freqs = (arange * (-MAX_PERIOD.ln() / half as f64))?.exp()?;
130    let args = t
131        .unsqueeze(1)?
132        .to_dtype(candle::DType::F32)?
133        .broadcast_mul(&freqs.unsqueeze(0)?)?;
134    let emb = Tensor::cat(&[args.cos()?, args.sin()?], D::Minus1)?.to_dtype(dtype)?;
135    Ok(emb)
136}
137
138#[derive(Debug, Clone)]
139pub struct EmbedNd {
140    #[allow(unused)]
141    dim: usize,
142    theta: usize,
143    axes_dim: Vec<usize>,
144}
145
146impl EmbedNd {
147    pub fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
148        Self {
149            dim,
150            theta,
151            axes_dim,
152        }
153    }
154}
155
156impl candle::Module for EmbedNd {
157    fn forward(&self, ids: &Tensor) -> Result<Tensor> {
158        let n_axes = ids.dim(D::Minus1)?;
159        let mut emb = Vec::with_capacity(n_axes);
160        for idx in 0..n_axes {
161            let r = rope(
162                &ids.get_on_dim(D::Minus1, idx)?,
163                self.axes_dim[idx],
164                self.theta,
165            )?;
166            emb.push(r)
167        }
168        let emb = Tensor::cat(&emb, 2)?;
169        emb.unsqueeze(1)
170    }
171}
172
173#[derive(Debug, Clone)]
174pub struct MlpEmbedder {
175    in_layer: Linear,
176    out_layer: Linear,
177}
178
179impl MlpEmbedder {
180    fn new(in_sz: usize, h_sz: usize, vb: VarBuilder) -> Result<Self> {
181        let in_layer = candle_nn::linear(in_sz, h_sz, vb.pp("in_layer"))?;
182        let out_layer = candle_nn::linear(h_sz, h_sz, vb.pp("out_layer"))?;
183        Ok(Self {
184            in_layer,
185            out_layer,
186        })
187    }
188}
189
190impl candle::Module for MlpEmbedder {
191    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
192        xs.apply(&self.in_layer)?.silu()?.apply(&self.out_layer)
193    }
194}
195
196#[derive(Debug, Clone)]
197pub struct QkNorm {
198    query_norm: RmsNorm,
199    key_norm: RmsNorm,
200}
201
202impl QkNorm {
203    fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
204        let query_norm = vb.get(dim, "query_norm.scale")?;
205        let query_norm = RmsNorm::new(query_norm, 1e-6);
206        let key_norm = vb.get(dim, "key_norm.scale")?;
207        let key_norm = RmsNorm::new(key_norm, 1e-6);
208        Ok(Self {
209            query_norm,
210            key_norm,
211        })
212    }
213}
214
215struct ModulationOut {
216    shift: Tensor,
217    scale: Tensor,
218    gate: Tensor,
219}
220
221impl ModulationOut {
222    fn scale_shift(&self, xs: &Tensor) -> Result<Tensor> {
223        xs.broadcast_mul(&(&self.scale + 1.)?)?
224            .broadcast_add(&self.shift)
225    }
226
227    fn gate(&self, xs: &Tensor) -> Result<Tensor> {
228        self.gate.broadcast_mul(xs)
229    }
230}
231
232#[derive(Debug, Clone)]
233struct Modulation1 {
234    lin: Linear,
235}
236
237impl Modulation1 {
238    fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
239        let lin = candle_nn::linear(dim, 3 * dim, vb.pp("lin"))?;
240        Ok(Self { lin })
241    }
242
243    fn forward(&self, vec_: &Tensor) -> Result<ModulationOut> {
244        let ys = vec_
245            .silu()?
246            .apply(&self.lin)?
247            .unsqueeze(1)?
248            .chunk(3, D::Minus1)?;
249        if ys.len() != 3 {
250            candle::bail!("unexpected len from chunk {ys:?}")
251        }
252        Ok(ModulationOut {
253            shift: ys[0].clone(),
254            scale: ys[1].clone(),
255            gate: ys[2].clone(),
256        })
257    }
258}
259
260#[derive(Debug, Clone)]
261struct Modulation2 {
262    lin: Linear,
263}
264
265impl Modulation2 {
266    fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
267        let lin = candle_nn::linear(dim, 6 * dim, vb.pp("lin"))?;
268        Ok(Self { lin })
269    }
270
271    fn forward(&self, vec_: &Tensor) -> Result<(ModulationOut, ModulationOut)> {
272        let ys = vec_
273            .silu()?
274            .apply(&self.lin)?
275            .unsqueeze(1)?
276            .chunk(6, D::Minus1)?;
277        if ys.len() != 6 {
278            candle::bail!("unexpected len from chunk {ys:?}")
279        }
280        let mod1 = ModulationOut {
281            shift: ys[0].clone(),
282            scale: ys[1].clone(),
283            gate: ys[2].clone(),
284        };
285        let mod2 = ModulationOut {
286            shift: ys[3].clone(),
287            scale: ys[4].clone(),
288            gate: ys[5].clone(),
289        };
290        Ok((mod1, mod2))
291    }
292}
293
294#[derive(Debug, Clone)]
295pub struct SelfAttention {
296    qkv: Linear,
297    norm: QkNorm,
298    proj: Linear,
299    num_heads: usize,
300}
301
302impl SelfAttention {
303    fn new(dim: usize, num_heads: usize, qkv_bias: bool, vb: VarBuilder) -> Result<Self> {
304        let head_dim = dim / num_heads;
305        let qkv = candle_nn::linear_b(dim, dim * 3, qkv_bias, vb.pp("qkv"))?;
306        let norm = QkNorm::new(head_dim, vb.pp("norm"))?;
307        let proj = candle_nn::linear(dim, dim, vb.pp("proj"))?;
308        Ok(Self {
309            qkv,
310            norm,
311            proj,
312            num_heads,
313        })
314    }
315
316    fn qkv(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
317        let qkv = xs.apply(&self.qkv)?;
318        let (b, l, _khd) = qkv.dims3()?;
319        let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
320        let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
321        let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
322        let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
323        let q = q.apply(&self.norm.query_norm)?;
324        let k = k.apply(&self.norm.key_norm)?;
325        Ok((q, k, v))
326    }
327
328    #[allow(unused)]
329    fn forward(&self, xs: &Tensor, pe: &Tensor) -> Result<Tensor> {
330        let (q, k, v) = self.qkv(xs)?;
331        attention(&q, &k, &v, pe)?.apply(&self.proj)
332    }
333}
334
335#[derive(Debug, Clone)]
336struct Mlp {
337    lin1: Linear,
338    lin2: Linear,
339}
340
341impl Mlp {
342    fn new(in_sz: usize, mlp_sz: usize, vb: VarBuilder) -> Result<Self> {
343        let lin1 = candle_nn::linear(in_sz, mlp_sz, vb.pp("0"))?;
344        let lin2 = candle_nn::linear(mlp_sz, in_sz, vb.pp("2"))?;
345        Ok(Self { lin1, lin2 })
346    }
347}
348
349impl candle::Module for Mlp {
350    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
351        xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2)
352    }
353}
354
355#[derive(Debug, Clone)]
356pub struct DoubleStreamBlock {
357    img_mod: Modulation2,
358    img_norm1: LayerNorm,
359    img_attn: SelfAttention,
360    img_norm2: LayerNorm,
361    img_mlp: Mlp,
362    txt_mod: Modulation2,
363    txt_norm1: LayerNorm,
364    txt_attn: SelfAttention,
365    txt_norm2: LayerNorm,
366    txt_mlp: Mlp,
367}
368
369impl DoubleStreamBlock {
370    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
371        let h_sz = cfg.hidden_size;
372        let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize;
373        let img_mod = Modulation2::new(h_sz, vb.pp("img_mod"))?;
374        let img_norm1 = layer_norm(h_sz, vb.pp("img_norm1"))?;
375        let img_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("img_attn"))?;
376        let img_norm2 = layer_norm(h_sz, vb.pp("img_norm2"))?;
377        let img_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("img_mlp"))?;
378        let txt_mod = Modulation2::new(h_sz, vb.pp("txt_mod"))?;
379        let txt_norm1 = layer_norm(h_sz, vb.pp("txt_norm1"))?;
380        let txt_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("txt_attn"))?;
381        let txt_norm2 = layer_norm(h_sz, vb.pp("txt_norm2"))?;
382        let txt_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("txt_mlp"))?;
383        Ok(Self {
384            img_mod,
385            img_norm1,
386            img_attn,
387            img_norm2,
388            img_mlp,
389            txt_mod,
390            txt_norm1,
391            txt_attn,
392            txt_norm2,
393            txt_mlp,
394        })
395    }
396
397    fn forward(
398        &self,
399        img: &Tensor,
400        txt: &Tensor,
401        vec_: &Tensor,
402        pe: &Tensor,
403    ) -> Result<(Tensor, Tensor)> {
404        let (img_mod1, img_mod2) = self.img_mod.forward(vec_)?; // shift, scale, gate
405        let (txt_mod1, txt_mod2) = self.txt_mod.forward(vec_)?; // shift, scale, gate
406        let img_modulated = img.apply(&self.img_norm1)?;
407        let img_modulated = img_mod1.scale_shift(&img_modulated)?;
408        let (img_q, img_k, img_v) = self.img_attn.qkv(&img_modulated)?;
409
410        let txt_modulated = txt.apply(&self.txt_norm1)?;
411        let txt_modulated = txt_mod1.scale_shift(&txt_modulated)?;
412        let (txt_q, txt_k, txt_v) = self.txt_attn.qkv(&txt_modulated)?;
413
414        let q = Tensor::cat(&[txt_q, img_q], 2)?;
415        let k = Tensor::cat(&[txt_k, img_k], 2)?;
416        let v = Tensor::cat(&[txt_v, img_v], 2)?;
417
418        let attn = attention(&q, &k, &v, pe)?;
419        let txt_attn = attn.narrow(1, 0, txt.dim(1)?)?;
420        let img_attn = attn.narrow(1, txt.dim(1)?, attn.dim(1)? - txt.dim(1)?)?;
421
422        let img = (img + img_mod1.gate(&img_attn.apply(&self.img_attn.proj)?))?;
423        let img = (&img
424            + img_mod2.gate(
425                &img_mod2
426                    .scale_shift(&img.apply(&self.img_norm2)?)?
427                    .apply(&self.img_mlp)?,
428            )?)?;
429
430        let txt = (txt + txt_mod1.gate(&txt_attn.apply(&self.txt_attn.proj)?))?;
431        let txt = (&txt
432            + txt_mod2.gate(
433                &txt_mod2
434                    .scale_shift(&txt.apply(&self.txt_norm2)?)?
435                    .apply(&self.txt_mlp)?,
436            )?)?;
437
438        Ok((img, txt))
439    }
440}
441
442#[derive(Debug, Clone)]
443pub struct SingleStreamBlock {
444    linear1: Linear,
445    linear2: Linear,
446    norm: QkNorm,
447    pre_norm: LayerNorm,
448    modulation: Modulation1,
449    h_sz: usize,
450    mlp_sz: usize,
451    num_heads: usize,
452}
453
454impl SingleStreamBlock {
455    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
456        let h_sz = cfg.hidden_size;
457        let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize;
458        let head_dim = h_sz / cfg.num_heads;
459        let linear1 = candle_nn::linear(h_sz, h_sz * 3 + mlp_sz, vb.pp("linear1"))?;
460        let linear2 = candle_nn::linear(h_sz + mlp_sz, h_sz, vb.pp("linear2"))?;
461        let norm = QkNorm::new(head_dim, vb.pp("norm"))?;
462        let pre_norm = layer_norm(h_sz, vb.pp("pre_norm"))?;
463        let modulation = Modulation1::new(h_sz, vb.pp("modulation"))?;
464        Ok(Self {
465            linear1,
466            linear2,
467            norm,
468            pre_norm,
469            modulation,
470            h_sz,
471            mlp_sz,
472            num_heads: cfg.num_heads,
473        })
474    }
475
476    fn forward(&self, xs: &Tensor, vec_: &Tensor, pe: &Tensor) -> Result<Tensor> {
477        let mod_ = self.modulation.forward(vec_)?;
478        let x_mod = mod_.scale_shift(&xs.apply(&self.pre_norm)?)?;
479        let x_mod = x_mod.apply(&self.linear1)?;
480        let qkv = x_mod.narrow(D::Minus1, 0, 3 * self.h_sz)?;
481        let (b, l, _khd) = qkv.dims3()?;
482        let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
483        let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
484        let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
485        let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
486        let mlp = x_mod.narrow(D::Minus1, 3 * self.h_sz, self.mlp_sz)?;
487        let q = q.apply(&self.norm.query_norm)?;
488        let k = k.apply(&self.norm.key_norm)?;
489        let attn = attention(&q, &k, &v, pe)?;
490        let output = Tensor::cat(&[attn, mlp.gelu()?], 2)?.apply(&self.linear2)?;
491        xs + mod_.gate(&output)
492    }
493}
494
495#[derive(Debug, Clone)]
496pub struct LastLayer {
497    norm_final: LayerNorm,
498    linear: Linear,
499    ada_ln_modulation: Linear,
500}
501
502impl LastLayer {
503    fn new(h_sz: usize, p_sz: usize, out_c: usize, vb: VarBuilder) -> Result<Self> {
504        let norm_final = layer_norm(h_sz, vb.pp("norm_final"))?;
505        let linear = candle_nn::linear(h_sz, p_sz * p_sz * out_c, vb.pp("linear"))?;
506        let ada_ln_modulation = candle_nn::linear(h_sz, 2 * h_sz, vb.pp("adaLN_modulation.1"))?;
507        Ok(Self {
508            norm_final,
509            linear,
510            ada_ln_modulation,
511        })
512    }
513
514    fn forward(&self, xs: &Tensor, vec: &Tensor) -> Result<Tensor> {
515        let chunks = vec.silu()?.apply(&self.ada_ln_modulation)?.chunk(2, 1)?;
516        let (shift, scale) = (&chunks[0], &chunks[1]);
517        let xs = xs
518            .apply(&self.norm_final)?
519            .broadcast_mul(&(scale.unsqueeze(1)? + 1.0)?)?
520            .broadcast_add(&shift.unsqueeze(1)?)?;
521        xs.apply(&self.linear)
522    }
523}
524
525#[derive(Debug, Clone)]
526pub struct Flux {
527    img_in: Linear,
528    txt_in: Linear,
529    time_in: MlpEmbedder,
530    vector_in: MlpEmbedder,
531    guidance_in: Option<MlpEmbedder>,
532    pe_embedder: EmbedNd,
533    double_blocks: Vec<DoubleStreamBlock>,
534    single_blocks: Vec<SingleStreamBlock>,
535    final_layer: LastLayer,
536}
537
538impl Flux {
539    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
540        let img_in = candle_nn::linear(cfg.in_channels, cfg.hidden_size, vb.pp("img_in"))?;
541        let txt_in = candle_nn::linear(cfg.context_in_dim, cfg.hidden_size, vb.pp("txt_in"))?;
542        let mut double_blocks = Vec::with_capacity(cfg.depth);
543        let vb_d = vb.pp("double_blocks");
544        for idx in 0..cfg.depth {
545            let db = DoubleStreamBlock::new(cfg, vb_d.pp(idx))?;
546            double_blocks.push(db)
547        }
548        let mut single_blocks = Vec::with_capacity(cfg.depth_single_blocks);
549        let vb_s = vb.pp("single_blocks");
550        for idx in 0..cfg.depth_single_blocks {
551            let sb = SingleStreamBlock::new(cfg, vb_s.pp(idx))?;
552            single_blocks.push(sb)
553        }
554        let time_in = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("time_in"))?;
555        let vector_in = MlpEmbedder::new(cfg.vec_in_dim, cfg.hidden_size, vb.pp("vector_in"))?;
556        let guidance_in = if cfg.guidance_embed {
557            let mlp = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("guidance_in"))?;
558            Some(mlp)
559        } else {
560            None
561        };
562        let final_layer =
563            LastLayer::new(cfg.hidden_size, 1, cfg.in_channels, vb.pp("final_layer"))?;
564        let pe_dim = cfg.hidden_size / cfg.num_heads;
565        let pe_embedder = EmbedNd::new(pe_dim, cfg.theta, cfg.axes_dim.to_vec());
566        Ok(Self {
567            img_in,
568            txt_in,
569            time_in,
570            vector_in,
571            guidance_in,
572            pe_embedder,
573            double_blocks,
574            single_blocks,
575            final_layer,
576        })
577    }
578}
579
580impl super::WithForward for Flux {
581    #[allow(clippy::too_many_arguments)]
582    fn forward(
583        &self,
584        img: &Tensor,
585        img_ids: &Tensor,
586        txt: &Tensor,
587        txt_ids: &Tensor,
588        timesteps: &Tensor,
589        y: &Tensor,
590        guidance: Option<&Tensor>,
591    ) -> Result<Tensor> {
592        if txt.rank() != 3 {
593            candle::bail!("unexpected shape for txt {:?}", txt.shape())
594        }
595        if img.rank() != 3 {
596            candle::bail!("unexpected shape for img {:?}", img.shape())
597        }
598        let dtype = img.dtype();
599        let pe = {
600            let ids = Tensor::cat(&[txt_ids, img_ids], 1)?;
601            ids.apply(&self.pe_embedder)?
602        };
603        let mut txt = txt.apply(&self.txt_in)?;
604        let mut img = img.apply(&self.img_in)?;
605        let vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?;
606        let vec_ = match (self.guidance_in.as_ref(), guidance) {
607            (Some(g_in), Some(guidance)) => {
608                (vec_ + timestep_embedding(guidance, 256, dtype)?.apply(g_in))?
609            }
610            _ => vec_,
611        };
612        let vec_ = (vec_ + y.apply(&self.vector_in))?;
613
614        // Double blocks
615        for block in self.double_blocks.iter() {
616            (img, txt) = block.forward(&img, &txt, &vec_, &pe)?
617        }
618        // Single blocks
619        let mut img = Tensor::cat(&[&txt, &img], 1)?;
620        for block in self.single_blocks.iter() {
621            img = block.forward(&img, &vec_, &pe)?;
622        }
623        let img = img.i((.., txt.dim(1)?..))?;
624        self.final_layer.forward(&img, &vec_)
625    }
626}