candle_transformers/models/mmdit/
blocks.rs

1use candle::{Module, Result, Tensor, D};
2use candle_nn as nn;
3
4use super::projections::{AttnProjections, Mlp, Qkv, QkvOnlyAttnProjections};
5
6pub struct ModulateIntermediates {
7    gate_msa: Tensor,
8    shift_mlp: Tensor,
9    scale_mlp: Tensor,
10    gate_mlp: Tensor,
11}
12
13pub struct DiTBlock {
14    norm1: LayerNormNoAffine,
15    attn: AttnProjections,
16    norm2: LayerNormNoAffine,
17    mlp: Mlp,
18    ada_ln_modulation: nn::Sequential,
19}
20
21pub struct LayerNormNoAffine {
22    eps: f64,
23}
24
25impl LayerNormNoAffine {
26    pub fn new(eps: f64) -> Self {
27        Self { eps }
28    }
29}
30
31impl Module for LayerNormNoAffine {
32    fn forward(&self, x: &Tensor) -> Result<Tensor> {
33        nn::LayerNorm::new_no_bias(Tensor::ones_like(x)?, self.eps).forward(x)
34    }
35}
36
37impl DiTBlock {
38    pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
39        let norm1 = LayerNormNoAffine::new(1e-6);
40        let attn = AttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?;
41        let norm2 = LayerNormNoAffine::new(1e-6);
42        let mlp_ratio = 4;
43        let mlp = Mlp::new(hidden_size, hidden_size * mlp_ratio, vb.pp("mlp"))?;
44        let n_mods = 6;
45        let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear(
46            hidden_size,
47            n_mods * hidden_size,
48            vb.pp("adaLN_modulation.1"),
49        )?);
50
51        Ok(Self {
52            norm1,
53            attn,
54            norm2,
55            mlp,
56            ada_ln_modulation,
57        })
58    }
59
60    pub fn pre_attention(&self, x: &Tensor, c: &Tensor) -> Result<(Qkv, ModulateIntermediates)> {
61        let modulation = self.ada_ln_modulation.forward(c)?;
62        let chunks = modulation.chunk(6, D::Minus1)?;
63        let (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = (
64            chunks[0].clone(),
65            chunks[1].clone(),
66            chunks[2].clone(),
67            chunks[3].clone(),
68            chunks[4].clone(),
69            chunks[5].clone(),
70        );
71
72        let norm_x = self.norm1.forward(x)?;
73        let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?;
74        let qkv = self.attn.pre_attention(&modulated_x)?;
75
76        Ok((
77            qkv,
78            ModulateIntermediates {
79                gate_msa,
80                shift_mlp,
81                scale_mlp,
82                gate_mlp,
83            },
84        ))
85    }
86
87    pub fn post_attention(
88        &self,
89        attn: &Tensor,
90        x: &Tensor,
91        mod_interm: &ModulateIntermediates,
92    ) -> Result<Tensor> {
93        let attn_out = self.attn.post_attention(attn)?;
94        let x = x.add(&attn_out.broadcast_mul(&mod_interm.gate_msa.unsqueeze(1)?)?)?;
95
96        let norm_x = self.norm2.forward(&x)?;
97        let modulated_x = modulate(&norm_x, &mod_interm.shift_mlp, &mod_interm.scale_mlp)?;
98        let mlp_out = self.mlp.forward(&modulated_x)?;
99        let x = x.add(&mlp_out.broadcast_mul(&mod_interm.gate_mlp.unsqueeze(1)?)?)?;
100
101        Ok(x)
102    }
103}
104
105pub struct SelfAttnModulateIntermediates {
106    gate_msa: Tensor,
107    shift_mlp: Tensor,
108    scale_mlp: Tensor,
109    gate_mlp: Tensor,
110    gate_msa2: Tensor,
111}
112
113pub struct SelfAttnDiTBlock {
114    norm1: LayerNormNoAffine,
115    attn: AttnProjections,
116    attn2: AttnProjections,
117    norm2: LayerNormNoAffine,
118    mlp: Mlp,
119    ada_ln_modulation: nn::Sequential,
120}
121
122impl SelfAttnDiTBlock {
123    pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
124        let norm1 = LayerNormNoAffine::new(1e-6);
125        let attn = AttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?;
126        let attn2 = AttnProjections::new(hidden_size, num_heads, vb.pp("attn2"))?;
127        let norm2 = LayerNormNoAffine::new(1e-6);
128        let mlp_ratio = 4;
129        let mlp = Mlp::new(hidden_size, hidden_size * mlp_ratio, vb.pp("mlp"))?;
130        let n_mods = 9;
131        let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear(
132            hidden_size,
133            n_mods * hidden_size,
134            vb.pp("adaLN_modulation.1"),
135        )?);
136
137        Ok(Self {
138            norm1,
139            attn,
140            attn2,
141            norm2,
142            mlp,
143            ada_ln_modulation,
144        })
145    }
146
147    pub fn pre_attention(
148        &self,
149        x: &Tensor,
150        c: &Tensor,
151    ) -> Result<(Qkv, Qkv, SelfAttnModulateIntermediates)> {
152        let modulation = self.ada_ln_modulation.forward(c)?;
153        let chunks = modulation.chunk(9, D::Minus1)?;
154        let (
155            shift_msa,
156            scale_msa,
157            gate_msa,
158            shift_mlp,
159            scale_mlp,
160            gate_mlp,
161            shift_msa2,
162            scale_msa2,
163            gate_msa2,
164        ) = (
165            chunks[0].clone(),
166            chunks[1].clone(),
167            chunks[2].clone(),
168            chunks[3].clone(),
169            chunks[4].clone(),
170            chunks[5].clone(),
171            chunks[6].clone(),
172            chunks[7].clone(),
173            chunks[8].clone(),
174        );
175
176        let norm_x = self.norm1.forward(x)?;
177        let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?;
178        let qkv = self.attn.pre_attention(&modulated_x)?;
179
180        let modulated_x2 = modulate(&norm_x, &shift_msa2, &scale_msa2)?;
181        let qkv2 = self.attn2.pre_attention(&modulated_x2)?;
182
183        Ok((
184            qkv,
185            qkv2,
186            SelfAttnModulateIntermediates {
187                gate_msa,
188                shift_mlp,
189                scale_mlp,
190                gate_mlp,
191                gate_msa2,
192            },
193        ))
194    }
195
196    pub fn post_attention(
197        &self,
198        attn: &Tensor,
199        attn2: &Tensor,
200        x: &Tensor,
201        mod_interm: &SelfAttnModulateIntermediates,
202    ) -> Result<Tensor> {
203        let attn_out = self.attn.post_attention(attn)?;
204        let x = x.add(&attn_out.broadcast_mul(&mod_interm.gate_msa.unsqueeze(1)?)?)?;
205        let attn_out2 = self.attn2.post_attention(attn2)?;
206        let x = x.add(&attn_out2.broadcast_mul(&mod_interm.gate_msa2.unsqueeze(1)?)?)?;
207
208        let norm_x = self.norm2.forward(&x)?;
209        let modulated_x = modulate(&norm_x, &mod_interm.shift_mlp, &mod_interm.scale_mlp)?;
210        let mlp_out = self.mlp.forward(&modulated_x)?;
211        let x = x.add(&mlp_out.broadcast_mul(&mod_interm.gate_mlp.unsqueeze(1)?)?)?;
212        Ok(x)
213    }
214}
215
216pub struct QkvOnlyDiTBlock {
217    norm1: LayerNormNoAffine,
218    attn: QkvOnlyAttnProjections,
219    ada_ln_modulation: nn::Sequential,
220}
221
222impl QkvOnlyDiTBlock {
223    pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
224        let norm1 = LayerNormNoAffine::new(1e-6);
225        let attn = QkvOnlyAttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?;
226        let n_mods = 2;
227        let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear(
228            hidden_size,
229            n_mods * hidden_size,
230            vb.pp("adaLN_modulation.1"),
231        )?);
232
233        Ok(Self {
234            norm1,
235            attn,
236            ada_ln_modulation,
237        })
238    }
239
240    pub fn pre_attention(&self, x: &Tensor, c: &Tensor) -> Result<Qkv> {
241        let modulation = self.ada_ln_modulation.forward(c)?;
242        let chunks = modulation.chunk(2, D::Minus1)?;
243        let (shift_msa, scale_msa) = (chunks[0].clone(), chunks[1].clone());
244
245        let norm_x = self.norm1.forward(x)?;
246        let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?;
247        self.attn.pre_attention(&modulated_x)
248    }
249}
250
251pub struct FinalLayer {
252    norm_final: LayerNormNoAffine,
253    linear: nn::Linear,
254    ada_ln_modulation: nn::Sequential,
255}
256
257impl FinalLayer {
258    pub fn new(
259        hidden_size: usize,
260        patch_size: usize,
261        out_channels: usize,
262        vb: nn::VarBuilder,
263    ) -> Result<Self> {
264        let norm_final = LayerNormNoAffine::new(1e-6);
265        let linear = nn::linear(
266            hidden_size,
267            patch_size * patch_size * out_channels,
268            vb.pp("linear"),
269        )?;
270        let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear(
271            hidden_size,
272            2 * hidden_size,
273            vb.pp("adaLN_modulation.1"),
274        )?);
275
276        Ok(Self {
277            norm_final,
278            linear,
279            ada_ln_modulation,
280        })
281    }
282
283    pub fn forward(&self, x: &Tensor, c: &Tensor) -> Result<Tensor> {
284        let modulation = self.ada_ln_modulation.forward(c)?;
285        let chunks = modulation.chunk(2, D::Minus1)?;
286        let (shift, scale) = (chunks[0].clone(), chunks[1].clone());
287
288        let norm_x = self.norm_final.forward(x)?;
289        let modulated_x = modulate(&norm_x, &shift, &scale)?;
290        let output = self.linear.forward(&modulated_x)?;
291
292        Ok(output)
293    }
294}
295
296fn modulate(x: &Tensor, shift: &Tensor, scale: &Tensor) -> Result<Tensor> {
297    let shift = shift.unsqueeze(1)?;
298    let scale = scale.unsqueeze(1)?;
299    let scale_plus_one = scale.add(&Tensor::ones_like(&scale)?)?;
300    shift.broadcast_add(&x.broadcast_mul(&scale_plus_one)?)
301}
302
303pub trait JointBlock {
304    fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)>;
305}
306
307pub struct MMDiTJointBlock {
308    x_block: DiTBlock,
309    context_block: DiTBlock,
310    num_heads: usize,
311    use_flash_attn: bool,
312}
313
314impl MMDiTJointBlock {
315    pub fn new(
316        hidden_size: usize,
317        num_heads: usize,
318        use_flash_attn: bool,
319        vb: nn::VarBuilder,
320    ) -> Result<Self> {
321        let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
322        let context_block = DiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
323
324        Ok(Self {
325            x_block,
326            context_block,
327            num_heads,
328            use_flash_attn,
329        })
330    }
331}
332
333impl JointBlock for MMDiTJointBlock {
334    fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
335        let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?;
336        let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
337        let (context_attn, x_attn) =
338            joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?;
339        let context_out =
340            self.context_block
341                .post_attention(&context_attn, context, &context_interm)?;
342        let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?;
343        Ok((context_out, x_out))
344    }
345}
346
347pub struct MMDiTXJointBlock {
348    x_block: SelfAttnDiTBlock,
349    context_block: DiTBlock,
350    num_heads: usize,
351    use_flash_attn: bool,
352}
353
354impl MMDiTXJointBlock {
355    pub fn new(
356        hidden_size: usize,
357        num_heads: usize,
358        use_flash_attn: bool,
359        vb: nn::VarBuilder,
360    ) -> Result<Self> {
361        let x_block = SelfAttnDiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
362        let context_block = DiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
363
364        Ok(Self {
365            x_block,
366            context_block,
367            num_heads,
368            use_flash_attn,
369        })
370    }
371}
372
373impl JointBlock for MMDiTXJointBlock {
374    fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
375        let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?;
376        let (x_qkv, x_qkv2, x_interm) = self.x_block.pre_attention(x, c)?;
377        let (context_attn, x_attn) =
378            joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?;
379        let x_attn2 = attn(&x_qkv2, self.num_heads, self.use_flash_attn)?;
380        let context_out =
381            self.context_block
382                .post_attention(&context_attn, context, &context_interm)?;
383        let x_out = self
384            .x_block
385            .post_attention(&x_attn, &x_attn2, x, &x_interm)?;
386        Ok((context_out, x_out))
387    }
388}
389
390pub struct ContextQkvOnlyJointBlock {
391    x_block: DiTBlock,
392    context_block: QkvOnlyDiTBlock,
393    num_heads: usize,
394    use_flash_attn: bool,
395}
396
397impl ContextQkvOnlyJointBlock {
398    pub fn new(
399        hidden_size: usize,
400        num_heads: usize,
401        use_flash_attn: bool,
402        vb: nn::VarBuilder,
403    ) -> Result<Self> {
404        let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
405        let context_block = QkvOnlyDiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
406        Ok(Self {
407            x_block,
408            context_block,
409            num_heads,
410            use_flash_attn,
411        })
412    }
413
414    pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<Tensor> {
415        let context_qkv = self.context_block.pre_attention(context, c)?;
416        let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
417
418        let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?;
419
420        let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?;
421        Ok(x_out)
422    }
423}
424
425// A QKV-attention that is compatible with the interface of candle_flash_attn::flash_attn
426// Flash attention regards q, k, v dimensions as (batch_size, seqlen, nheads, headdim)
427fn flash_compatible_attention(
428    q: &Tensor,
429    k: &Tensor,
430    v: &Tensor,
431    softmax_scale: f32,
432) -> Result<Tensor> {
433    let q_dims_for_matmul = q.transpose(1, 2)?.dims().to_vec();
434    let rank = q_dims_for_matmul.len();
435    let q = q.transpose(1, 2)?.flatten_to(rank - 3)?;
436    let k = k.transpose(1, 2)?.flatten_to(rank - 3)?;
437    let v = v.transpose(1, 2)?.flatten_to(rank - 3)?;
438    let attn_weights = (q.matmul(&k.t()?)? * softmax_scale as f64)?;
439    let attn_scores = candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(&v)?;
440    attn_scores.reshape(q_dims_for_matmul)?.transpose(1, 2)
441}
442
443#[cfg(feature = "flash-attn")]
444fn flash_attn(
445    q: &Tensor,
446    k: &Tensor,
447    v: &Tensor,
448    softmax_scale: f32,
449    causal: bool,
450) -> Result<Tensor> {
451    candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
452}
453
454#[cfg(not(feature = "flash-attn"))]
455fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
456    unimplemented!("compile with '--features flash-attn'")
457}
458
459fn joint_attn(
460    context_qkv: &Qkv,
461    x_qkv: &Qkv,
462    num_heads: usize,
463    use_flash_attn: bool,
464) -> Result<(Tensor, Tensor)> {
465    let qkv = Qkv {
466        q: Tensor::cat(&[&context_qkv.q, &x_qkv.q], 1)?,
467        k: Tensor::cat(&[&context_qkv.k, &x_qkv.k], 1)?,
468        v: Tensor::cat(&[&context_qkv.v, &x_qkv.v], 1)?,
469    };
470
471    let seqlen = qkv.q.dim(1)?;
472    let attn = attn(&qkv, num_heads, use_flash_attn)?;
473    let context_qkv_seqlen = context_qkv.q.dim(1)?;
474    let context_attn = attn.narrow(1, 0, context_qkv_seqlen)?;
475    let x_attn = attn.narrow(1, context_qkv_seqlen, seqlen - context_qkv_seqlen)?;
476
477    Ok((context_attn, x_attn))
478}
479
480fn attn(qkv: &Qkv, num_heads: usize, use_flash_attn: bool) -> Result<Tensor> {
481    let batch_size = qkv.q.dim(0)?;
482    let seqlen = qkv.q.dim(1)?;
483    let qkv = Qkv {
484        q: qkv.q.reshape((batch_size, seqlen, num_heads, ()))?,
485        k: qkv.k.reshape((batch_size, seqlen, num_heads, ()))?,
486        v: qkv.v.clone(),
487    };
488
489    let headdim = qkv.q.dim(D::Minus1)?;
490    let softmax_scale = 1.0 / (headdim as f64).sqrt();
491    let attn = if use_flash_attn {
492        flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?
493    } else {
494        flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?
495    };
496    attn.reshape((batch_size, seqlen, ()))
497}