candle_transformers/models/mmdit/
model.rs

1// Implement the MMDiT model originally introduced for Stable Diffusion 3 (https://arxiv.org/abs/2403.03206),
2// as well as the MMDiT-X variant introduced for Stable Diffusion 3.5-medium (https://huggingface.co/stabilityai/stable-diffusion-3.5-medium)
3// This follows the implementation of the MMDiT model in the ComfyUI repository.
4// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L1
5// with MMDiT-X support following the Stability-AI/sd3.5 repository.
6// https://github.com/Stability-AI/sd3.5/blob/4e484e05308d83fb77ae6f680028e6c313f9da54/mmditx.py#L1
7use candle::{Module, Result, Tensor, D};
8use candle_nn as nn;
9
10use super::blocks::{
11    ContextQkvOnlyJointBlock, FinalLayer, JointBlock, MMDiTJointBlock, MMDiTXJointBlock,
12};
13use super::embedding::{
14    PatchEmbedder, PositionEmbedder, TimestepEmbedder, Unpatchifier, VectorEmbedder,
15};
16
17#[derive(Debug, Clone)]
18pub struct Config {
19    pub patch_size: usize,
20    pub in_channels: usize,
21    pub out_channels: usize,
22    pub depth: usize,
23    pub head_size: usize,
24    pub adm_in_channels: usize,
25    pub pos_embed_max_size: usize,
26    pub context_embed_size: usize,
27    pub frequency_embedding_size: usize,
28}
29
30impl Config {
31    pub fn sd3_medium() -> Self {
32        Self {
33            patch_size: 2,
34            in_channels: 16,
35            out_channels: 16,
36            depth: 24,
37            head_size: 64,
38            adm_in_channels: 2048,
39            pos_embed_max_size: 192,
40            context_embed_size: 4096,
41            frequency_embedding_size: 256,
42        }
43    }
44
45    pub fn sd3_5_medium() -> Self {
46        Self {
47            patch_size: 2,
48            in_channels: 16,
49            out_channels: 16,
50            depth: 24,
51            head_size: 64,
52            adm_in_channels: 2048,
53            pos_embed_max_size: 384,
54            context_embed_size: 4096,
55            frequency_embedding_size: 256,
56        }
57    }
58
59    pub fn sd3_5_large() -> Self {
60        Self {
61            patch_size: 2,
62            in_channels: 16,
63            out_channels: 16,
64            depth: 38,
65            head_size: 64,
66            adm_in_channels: 2048,
67            pos_embed_max_size: 192,
68            context_embed_size: 4096,
69            frequency_embedding_size: 256,
70        }
71    }
72}
73
74pub struct MMDiT {
75    core: MMDiTCore,
76    patch_embedder: PatchEmbedder,
77    pos_embedder: PositionEmbedder,
78    timestep_embedder: TimestepEmbedder,
79    vector_embedder: VectorEmbedder,
80    context_embedder: nn::Linear,
81    unpatchifier: Unpatchifier,
82}
83
84impl MMDiT {
85    pub fn new(cfg: &Config, use_flash_attn: bool, vb: nn::VarBuilder) -> Result<Self> {
86        let hidden_size = cfg.head_size * cfg.depth;
87        let core = MMDiTCore::new(
88            cfg.depth,
89            hidden_size,
90            cfg.depth,
91            cfg.patch_size,
92            cfg.out_channels,
93            use_flash_attn,
94            vb.clone(),
95        )?;
96        let patch_embedder = PatchEmbedder::new(
97            cfg.patch_size,
98            cfg.in_channels,
99            hidden_size,
100            vb.pp("x_embedder"),
101        )?;
102        let pos_embedder = PositionEmbedder::new(
103            hidden_size,
104            cfg.patch_size,
105            cfg.pos_embed_max_size,
106            vb.clone(),
107        )?;
108        let timestep_embedder = TimestepEmbedder::new(
109            hidden_size,
110            cfg.frequency_embedding_size,
111            vb.pp("t_embedder"),
112        )?;
113        let vector_embedder =
114            VectorEmbedder::new(cfg.adm_in_channels, hidden_size, vb.pp("y_embedder"))?;
115        let context_embedder = nn::linear(
116            cfg.context_embed_size,
117            hidden_size,
118            vb.pp("context_embedder"),
119        )?;
120        let unpatchifier = Unpatchifier::new(cfg.patch_size, cfg.out_channels)?;
121
122        Ok(Self {
123            core,
124            patch_embedder,
125            pos_embedder,
126            timestep_embedder,
127            vector_embedder,
128            context_embedder,
129            unpatchifier,
130        })
131    }
132
133    pub fn forward(
134        &self,
135        x: &Tensor,
136        t: &Tensor,
137        y: &Tensor,
138        context: &Tensor,
139        skip_layers: Option<&[usize]>,
140    ) -> Result<Tensor> {
141        // Following the convention of the ComfyUI implementation.
142        // https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L919
143        //
144        // Forward pass of DiT.
145        // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
146        // t: (N,) tensor of diffusion timesteps
147        // y: (N,) tensor of class labels
148        let h = x.dim(D::Minus2)?;
149        let w = x.dim(D::Minus1)?;
150        let cropped_pos_embed = self.pos_embedder.get_cropped_pos_embed(h, w)?;
151        let x = self
152            .patch_embedder
153            .forward(x)?
154            .broadcast_add(&cropped_pos_embed)?;
155        let c = self.timestep_embedder.forward(t)?;
156        let y = self.vector_embedder.forward(y)?;
157        let c = (c + y)?;
158        let context = self.context_embedder.forward(context)?;
159
160        let x = self.core.forward(&context, &x, &c, skip_layers)?;
161        let x = self.unpatchifier.unpatchify(&x, h, w)?;
162        x.narrow(2, 0, h)?.narrow(3, 0, w)
163    }
164}
165
166pub struct MMDiTCore {
167    joint_blocks: Vec<Box<dyn JointBlock>>,
168    context_qkv_only_joint_block: ContextQkvOnlyJointBlock,
169    final_layer: FinalLayer,
170}
171
172impl MMDiTCore {
173    pub fn new(
174        depth: usize,
175        hidden_size: usize,
176        num_heads: usize,
177        patch_size: usize,
178        out_channels: usize,
179        use_flash_attn: bool,
180        vb: nn::VarBuilder,
181    ) -> Result<Self> {
182        let mut joint_blocks = Vec::with_capacity(depth - 1);
183        for i in 0..depth - 1 {
184            let joint_block_vb_pp = format!("joint_blocks.{}", i);
185            let joint_block: Box<dyn JointBlock> =
186                if vb.contains_tensor(&format!("{}.x_block.attn2.qkv.weight", joint_block_vb_pp)) {
187                    Box::new(MMDiTXJointBlock::new(
188                        hidden_size,
189                        num_heads,
190                        use_flash_attn,
191                        vb.pp(&joint_block_vb_pp),
192                    )?)
193                } else {
194                    Box::new(MMDiTJointBlock::new(
195                        hidden_size,
196                        num_heads,
197                        use_flash_attn,
198                        vb.pp(&joint_block_vb_pp),
199                    )?)
200                };
201            joint_blocks.push(joint_block);
202        }
203
204        Ok(Self {
205            joint_blocks,
206            context_qkv_only_joint_block: ContextQkvOnlyJointBlock::new(
207                hidden_size,
208                num_heads,
209                use_flash_attn,
210                vb.pp(format!("joint_blocks.{}", depth - 1)),
211            )?,
212            final_layer: FinalLayer::new(
213                hidden_size,
214                patch_size,
215                out_channels,
216                vb.pp("final_layer"),
217            )?,
218        })
219    }
220
221    pub fn forward(
222        &self,
223        context: &Tensor,
224        x: &Tensor,
225        c: &Tensor,
226        skip_layers: Option<&[usize]>,
227    ) -> Result<Tensor> {
228        let (mut context, mut x) = (context.clone(), x.clone());
229        for (i, joint_block) in self.joint_blocks.iter().enumerate() {
230            if let Some(skip_layers) = &skip_layers {
231                if skip_layers.contains(&i) {
232                    continue;
233                }
234            }
235            (context, x) = joint_block.forward(&context, &x, c)?;
236        }
237        let x = self.context_qkv_only_joint_block.forward(&context, &x, c)?;
238        self.final_layer.forward(&x, c)
239    }
240}