1use candle::{Module, Result, Tensor, D};
2use candle_nn as nn;
3
4use super::projections::{AttnProjections, Mlp, Qkv, QkvOnlyAttnProjections};
5
6pub struct ModulateIntermediates {
7 gate_msa: Tensor,
8 shift_mlp: Tensor,
9 scale_mlp: Tensor,
10 gate_mlp: Tensor,
11}
12
13pub struct DiTBlock {
14 norm1: LayerNormNoAffine,
15 attn: AttnProjections,
16 norm2: LayerNormNoAffine,
17 mlp: Mlp,
18 ada_ln_modulation: nn::Sequential,
19}
20
21pub struct LayerNormNoAffine {
22 eps: f64,
23}
24
25impl LayerNormNoAffine {
26 pub fn new(eps: f64) -> Self {
27 Self { eps }
28 }
29}
30
31impl Module for LayerNormNoAffine {
32 fn forward(&self, x: &Tensor) -> Result<Tensor> {
33 nn::LayerNorm::new_no_bias(Tensor::ones_like(x)?, self.eps).forward(x)
34 }
35}
36
37impl DiTBlock {
38 pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
39 let norm1 = LayerNormNoAffine::new(1e-6);
40 let attn = AttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?;
41 let norm2 = LayerNormNoAffine::new(1e-6);
42 let mlp_ratio = 4;
43 let mlp = Mlp::new(hidden_size, hidden_size * mlp_ratio, vb.pp("mlp"))?;
44 let n_mods = 6;
45 let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear(
46 hidden_size,
47 n_mods * hidden_size,
48 vb.pp("adaLN_modulation.1"),
49 )?);
50
51 Ok(Self {
52 norm1,
53 attn,
54 norm2,
55 mlp,
56 ada_ln_modulation,
57 })
58 }
59
60 pub fn pre_attention(&self, x: &Tensor, c: &Tensor) -> Result<(Qkv, ModulateIntermediates)> {
61 let modulation = self.ada_ln_modulation.forward(c)?;
62 let chunks = modulation.chunk(6, D::Minus1)?;
63 let (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = (
64 chunks[0].clone(),
65 chunks[1].clone(),
66 chunks[2].clone(),
67 chunks[3].clone(),
68 chunks[4].clone(),
69 chunks[5].clone(),
70 );
71
72 let norm_x = self.norm1.forward(x)?;
73 let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?;
74 let qkv = self.attn.pre_attention(&modulated_x)?;
75
76 Ok((
77 qkv,
78 ModulateIntermediates {
79 gate_msa,
80 shift_mlp,
81 scale_mlp,
82 gate_mlp,
83 },
84 ))
85 }
86
87 pub fn post_attention(
88 &self,
89 attn: &Tensor,
90 x: &Tensor,
91 mod_interm: &ModulateIntermediates,
92 ) -> Result<Tensor> {
93 let attn_out = self.attn.post_attention(attn)?;
94 let x = x.add(&attn_out.broadcast_mul(&mod_interm.gate_msa.unsqueeze(1)?)?)?;
95
96 let norm_x = self.norm2.forward(&x)?;
97 let modulated_x = modulate(&norm_x, &mod_interm.shift_mlp, &mod_interm.scale_mlp)?;
98 let mlp_out = self.mlp.forward(&modulated_x)?;
99 let x = x.add(&mlp_out.broadcast_mul(&mod_interm.gate_mlp.unsqueeze(1)?)?)?;
100
101 Ok(x)
102 }
103}
104
105pub struct SelfAttnModulateIntermediates {
106 gate_msa: Tensor,
107 shift_mlp: Tensor,
108 scale_mlp: Tensor,
109 gate_mlp: Tensor,
110 gate_msa2: Tensor,
111}
112
113pub struct SelfAttnDiTBlock {
114 norm1: LayerNormNoAffine,
115 attn: AttnProjections,
116 attn2: AttnProjections,
117 norm2: LayerNormNoAffine,
118 mlp: Mlp,
119 ada_ln_modulation: nn::Sequential,
120}
121
122impl SelfAttnDiTBlock {
123 pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
124 let norm1 = LayerNormNoAffine::new(1e-6);
125 let attn = AttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?;
126 let attn2 = AttnProjections::new(hidden_size, num_heads, vb.pp("attn2"))?;
127 let norm2 = LayerNormNoAffine::new(1e-6);
128 let mlp_ratio = 4;
129 let mlp = Mlp::new(hidden_size, hidden_size * mlp_ratio, vb.pp("mlp"))?;
130 let n_mods = 9;
131 let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear(
132 hidden_size,
133 n_mods * hidden_size,
134 vb.pp("adaLN_modulation.1"),
135 )?);
136
137 Ok(Self {
138 norm1,
139 attn,
140 attn2,
141 norm2,
142 mlp,
143 ada_ln_modulation,
144 })
145 }
146
147 pub fn pre_attention(
148 &self,
149 x: &Tensor,
150 c: &Tensor,
151 ) -> Result<(Qkv, Qkv, SelfAttnModulateIntermediates)> {
152 let modulation = self.ada_ln_modulation.forward(c)?;
153 let chunks = modulation.chunk(9, D::Minus1)?;
154 let (
155 shift_msa,
156 scale_msa,
157 gate_msa,
158 shift_mlp,
159 scale_mlp,
160 gate_mlp,
161 shift_msa2,
162 scale_msa2,
163 gate_msa2,
164 ) = (
165 chunks[0].clone(),
166 chunks[1].clone(),
167 chunks[2].clone(),
168 chunks[3].clone(),
169 chunks[4].clone(),
170 chunks[5].clone(),
171 chunks[6].clone(),
172 chunks[7].clone(),
173 chunks[8].clone(),
174 );
175
176 let norm_x = self.norm1.forward(x)?;
177 let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?;
178 let qkv = self.attn.pre_attention(&modulated_x)?;
179
180 let modulated_x2 = modulate(&norm_x, &shift_msa2, &scale_msa2)?;
181 let qkv2 = self.attn2.pre_attention(&modulated_x2)?;
182
183 Ok((
184 qkv,
185 qkv2,
186 SelfAttnModulateIntermediates {
187 gate_msa,
188 shift_mlp,
189 scale_mlp,
190 gate_mlp,
191 gate_msa2,
192 },
193 ))
194 }
195
196 pub fn post_attention(
197 &self,
198 attn: &Tensor,
199 attn2: &Tensor,
200 x: &Tensor,
201 mod_interm: &SelfAttnModulateIntermediates,
202 ) -> Result<Tensor> {
203 let attn_out = self.attn.post_attention(attn)?;
204 let x = x.add(&attn_out.broadcast_mul(&mod_interm.gate_msa.unsqueeze(1)?)?)?;
205 let attn_out2 = self.attn2.post_attention(attn2)?;
206 let x = x.add(&attn_out2.broadcast_mul(&mod_interm.gate_msa2.unsqueeze(1)?)?)?;
207
208 let norm_x = self.norm2.forward(&x)?;
209 let modulated_x = modulate(&norm_x, &mod_interm.shift_mlp, &mod_interm.scale_mlp)?;
210 let mlp_out = self.mlp.forward(&modulated_x)?;
211 let x = x.add(&mlp_out.broadcast_mul(&mod_interm.gate_mlp.unsqueeze(1)?)?)?;
212 Ok(x)
213 }
214}
215
216pub struct QkvOnlyDiTBlock {
217 norm1: LayerNormNoAffine,
218 attn: QkvOnlyAttnProjections,
219 ada_ln_modulation: nn::Sequential,
220}
221
222impl QkvOnlyDiTBlock {
223 pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
224 let norm1 = LayerNormNoAffine::new(1e-6);
225 let attn = QkvOnlyAttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?;
226 let n_mods = 2;
227 let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear(
228 hidden_size,
229 n_mods * hidden_size,
230 vb.pp("adaLN_modulation.1"),
231 )?);
232
233 Ok(Self {
234 norm1,
235 attn,
236 ada_ln_modulation,
237 })
238 }
239
240 pub fn pre_attention(&self, x: &Tensor, c: &Tensor) -> Result<Qkv> {
241 let modulation = self.ada_ln_modulation.forward(c)?;
242 let chunks = modulation.chunk(2, D::Minus1)?;
243 let (shift_msa, scale_msa) = (chunks[0].clone(), chunks[1].clone());
244
245 let norm_x = self.norm1.forward(x)?;
246 let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?;
247 self.attn.pre_attention(&modulated_x)
248 }
249}
250
251pub struct FinalLayer {
252 norm_final: LayerNormNoAffine,
253 linear: nn::Linear,
254 ada_ln_modulation: nn::Sequential,
255}
256
257impl FinalLayer {
258 pub fn new(
259 hidden_size: usize,
260 patch_size: usize,
261 out_channels: usize,
262 vb: nn::VarBuilder,
263 ) -> Result<Self> {
264 let norm_final = LayerNormNoAffine::new(1e-6);
265 let linear = nn::linear(
266 hidden_size,
267 patch_size * patch_size * out_channels,
268 vb.pp("linear"),
269 )?;
270 let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear(
271 hidden_size,
272 2 * hidden_size,
273 vb.pp("adaLN_modulation.1"),
274 )?);
275
276 Ok(Self {
277 norm_final,
278 linear,
279 ada_ln_modulation,
280 })
281 }
282
283 pub fn forward(&self, x: &Tensor, c: &Tensor) -> Result<Tensor> {
284 let modulation = self.ada_ln_modulation.forward(c)?;
285 let chunks = modulation.chunk(2, D::Minus1)?;
286 let (shift, scale) = (chunks[0].clone(), chunks[1].clone());
287
288 let norm_x = self.norm_final.forward(x)?;
289 let modulated_x = modulate(&norm_x, &shift, &scale)?;
290 let output = self.linear.forward(&modulated_x)?;
291
292 Ok(output)
293 }
294}
295
296fn modulate(x: &Tensor, shift: &Tensor, scale: &Tensor) -> Result<Tensor> {
297 let shift = shift.unsqueeze(1)?;
298 let scale = scale.unsqueeze(1)?;
299 let scale_plus_one = scale.add(&Tensor::ones_like(&scale)?)?;
300 shift.broadcast_add(&x.broadcast_mul(&scale_plus_one)?)
301}
302
303pub trait JointBlock {
304 fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)>;
305}
306
307pub struct MMDiTJointBlock {
308 x_block: DiTBlock,
309 context_block: DiTBlock,
310 num_heads: usize,
311 use_flash_attn: bool,
312}
313
314impl MMDiTJointBlock {
315 pub fn new(
316 hidden_size: usize,
317 num_heads: usize,
318 use_flash_attn: bool,
319 vb: nn::VarBuilder,
320 ) -> Result<Self> {
321 let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
322 let context_block = DiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
323
324 Ok(Self {
325 x_block,
326 context_block,
327 num_heads,
328 use_flash_attn,
329 })
330 }
331}
332
333impl JointBlock for MMDiTJointBlock {
334 fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
335 let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?;
336 let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
337 let (context_attn, x_attn) =
338 joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?;
339 let context_out =
340 self.context_block
341 .post_attention(&context_attn, context, &context_interm)?;
342 let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?;
343 Ok((context_out, x_out))
344 }
345}
346
347pub struct MMDiTXJointBlock {
348 x_block: SelfAttnDiTBlock,
349 context_block: DiTBlock,
350 num_heads: usize,
351 use_flash_attn: bool,
352}
353
354impl MMDiTXJointBlock {
355 pub fn new(
356 hidden_size: usize,
357 num_heads: usize,
358 use_flash_attn: bool,
359 vb: nn::VarBuilder,
360 ) -> Result<Self> {
361 let x_block = SelfAttnDiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
362 let context_block = DiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
363
364 Ok(Self {
365 x_block,
366 context_block,
367 num_heads,
368 use_flash_attn,
369 })
370 }
371}
372
373impl JointBlock for MMDiTXJointBlock {
374 fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
375 let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?;
376 let (x_qkv, x_qkv2, x_interm) = self.x_block.pre_attention(x, c)?;
377 let (context_attn, x_attn) =
378 joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?;
379 let x_attn2 = attn(&x_qkv2, self.num_heads, self.use_flash_attn)?;
380 let context_out =
381 self.context_block
382 .post_attention(&context_attn, context, &context_interm)?;
383 let x_out = self
384 .x_block
385 .post_attention(&x_attn, &x_attn2, x, &x_interm)?;
386 Ok((context_out, x_out))
387 }
388}
389
390pub struct ContextQkvOnlyJointBlock {
391 x_block: DiTBlock,
392 context_block: QkvOnlyDiTBlock,
393 num_heads: usize,
394 use_flash_attn: bool,
395}
396
397impl ContextQkvOnlyJointBlock {
398 pub fn new(
399 hidden_size: usize,
400 num_heads: usize,
401 use_flash_attn: bool,
402 vb: nn::VarBuilder,
403 ) -> Result<Self> {
404 let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
405 let context_block = QkvOnlyDiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
406 Ok(Self {
407 x_block,
408 context_block,
409 num_heads,
410 use_flash_attn,
411 })
412 }
413
414 pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<Tensor> {
415 let context_qkv = self.context_block.pre_attention(context, c)?;
416 let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
417
418 let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?;
419
420 let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?;
421 Ok(x_out)
422 }
423}
424
425fn flash_compatible_attention(
428 q: &Tensor,
429 k: &Tensor,
430 v: &Tensor,
431 softmax_scale: f32,
432) -> Result<Tensor> {
433 let q_dims_for_matmul = q.transpose(1, 2)?.dims().to_vec();
434 let rank = q_dims_for_matmul.len();
435 let q = q.transpose(1, 2)?.flatten_to(rank - 3)?;
436 let k = k.transpose(1, 2)?.flatten_to(rank - 3)?;
437 let v = v.transpose(1, 2)?.flatten_to(rank - 3)?;
438 let attn_weights = (q.matmul(&k.t()?)? * softmax_scale as f64)?;
439 let attn_scores = candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(&v)?;
440 attn_scores.reshape(q_dims_for_matmul)?.transpose(1, 2)
441}
442
443#[cfg(feature = "flash-attn")]
444fn flash_attn(
445 q: &Tensor,
446 k: &Tensor,
447 v: &Tensor,
448 softmax_scale: f32,
449 causal: bool,
450) -> Result<Tensor> {
451 candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
452}
453
454#[cfg(not(feature = "flash-attn"))]
455fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
456 unimplemented!("compile with '--features flash-attn'")
457}
458
459fn joint_attn(
460 context_qkv: &Qkv,
461 x_qkv: &Qkv,
462 num_heads: usize,
463 use_flash_attn: bool,
464) -> Result<(Tensor, Tensor)> {
465 let qkv = Qkv {
466 q: Tensor::cat(&[&context_qkv.q, &x_qkv.q], 1)?,
467 k: Tensor::cat(&[&context_qkv.k, &x_qkv.k], 1)?,
468 v: Tensor::cat(&[&context_qkv.v, &x_qkv.v], 1)?,
469 };
470
471 let seqlen = qkv.q.dim(1)?;
472 let attn = attn(&qkv, num_heads, use_flash_attn)?;
473 let context_qkv_seqlen = context_qkv.q.dim(1)?;
474 let context_attn = attn.narrow(1, 0, context_qkv_seqlen)?;
475 let x_attn = attn.narrow(1, context_qkv_seqlen, seqlen - context_qkv_seqlen)?;
476
477 Ok((context_attn, x_attn))
478}
479
480fn attn(qkv: &Qkv, num_heads: usize, use_flash_attn: bool) -> Result<Tensor> {
481 let batch_size = qkv.q.dim(0)?;
482 let seqlen = qkv.q.dim(1)?;
483 let qkv = Qkv {
484 q: qkv.q.reshape((batch_size, seqlen, num_heads, ()))?,
485 k: qkv.k.reshape((batch_size, seqlen, num_heads, ()))?,
486 v: qkv.v.clone(),
487 };
488
489 let headdim = qkv.q.dim(D::Minus1)?;
490 let softmax_scale = 1.0 / (headdim as f64).sqrt();
491 let attn = if use_flash_attn {
492 flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?
493 } else {
494 flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?
495 };
496 attn.reshape((batch_size, seqlen, ()))
497}