candle_transformers/models/mmdit/
model.rs1use candle::{Module, Result, Tensor, D};
8use candle_nn as nn;
9
10use super::blocks::{
11 ContextQkvOnlyJointBlock, FinalLayer, JointBlock, MMDiTJointBlock, MMDiTXJointBlock,
12};
13use super::embedding::{
14 PatchEmbedder, PositionEmbedder, TimestepEmbedder, Unpatchifier, VectorEmbedder,
15};
16
17#[derive(Debug, Clone)]
18pub struct Config {
19 pub patch_size: usize,
20 pub in_channels: usize,
21 pub out_channels: usize,
22 pub depth: usize,
23 pub head_size: usize,
24 pub adm_in_channels: usize,
25 pub pos_embed_max_size: usize,
26 pub context_embed_size: usize,
27 pub frequency_embedding_size: usize,
28}
29
30impl Config {
31 pub fn sd3_medium() -> Self {
32 Self {
33 patch_size: 2,
34 in_channels: 16,
35 out_channels: 16,
36 depth: 24,
37 head_size: 64,
38 adm_in_channels: 2048,
39 pos_embed_max_size: 192,
40 context_embed_size: 4096,
41 frequency_embedding_size: 256,
42 }
43 }
44
45 pub fn sd3_5_medium() -> Self {
46 Self {
47 patch_size: 2,
48 in_channels: 16,
49 out_channels: 16,
50 depth: 24,
51 head_size: 64,
52 adm_in_channels: 2048,
53 pos_embed_max_size: 384,
54 context_embed_size: 4096,
55 frequency_embedding_size: 256,
56 }
57 }
58
59 pub fn sd3_5_large() -> Self {
60 Self {
61 patch_size: 2,
62 in_channels: 16,
63 out_channels: 16,
64 depth: 38,
65 head_size: 64,
66 adm_in_channels: 2048,
67 pos_embed_max_size: 192,
68 context_embed_size: 4096,
69 frequency_embedding_size: 256,
70 }
71 }
72}
73
74pub struct MMDiT {
75 core: MMDiTCore,
76 patch_embedder: PatchEmbedder,
77 pos_embedder: PositionEmbedder,
78 timestep_embedder: TimestepEmbedder,
79 vector_embedder: VectorEmbedder,
80 context_embedder: nn::Linear,
81 unpatchifier: Unpatchifier,
82}
83
84impl MMDiT {
85 pub fn new(cfg: &Config, use_flash_attn: bool, vb: nn::VarBuilder) -> Result<Self> {
86 let hidden_size = cfg.head_size * cfg.depth;
87 let core = MMDiTCore::new(
88 cfg.depth,
89 hidden_size,
90 cfg.depth,
91 cfg.patch_size,
92 cfg.out_channels,
93 use_flash_attn,
94 vb.clone(),
95 )?;
96 let patch_embedder = PatchEmbedder::new(
97 cfg.patch_size,
98 cfg.in_channels,
99 hidden_size,
100 vb.pp("x_embedder"),
101 )?;
102 let pos_embedder = PositionEmbedder::new(
103 hidden_size,
104 cfg.patch_size,
105 cfg.pos_embed_max_size,
106 vb.clone(),
107 )?;
108 let timestep_embedder = TimestepEmbedder::new(
109 hidden_size,
110 cfg.frequency_embedding_size,
111 vb.pp("t_embedder"),
112 )?;
113 let vector_embedder =
114 VectorEmbedder::new(cfg.adm_in_channels, hidden_size, vb.pp("y_embedder"))?;
115 let context_embedder = nn::linear(
116 cfg.context_embed_size,
117 hidden_size,
118 vb.pp("context_embedder"),
119 )?;
120 let unpatchifier = Unpatchifier::new(cfg.patch_size, cfg.out_channels)?;
121
122 Ok(Self {
123 core,
124 patch_embedder,
125 pos_embedder,
126 timestep_embedder,
127 vector_embedder,
128 context_embedder,
129 unpatchifier,
130 })
131 }
132
133 pub fn forward(
134 &self,
135 x: &Tensor,
136 t: &Tensor,
137 y: &Tensor,
138 context: &Tensor,
139 skip_layers: Option<&[usize]>,
140 ) -> Result<Tensor> {
141 let h = x.dim(D::Minus2)?;
149 let w = x.dim(D::Minus1)?;
150 let cropped_pos_embed = self.pos_embedder.get_cropped_pos_embed(h, w)?;
151 let x = self
152 .patch_embedder
153 .forward(x)?
154 .broadcast_add(&cropped_pos_embed)?;
155 let c = self.timestep_embedder.forward(t)?;
156 let y = self.vector_embedder.forward(y)?;
157 let c = (c + y)?;
158 let context = self.context_embedder.forward(context)?;
159
160 let x = self.core.forward(&context, &x, &c, skip_layers)?;
161 let x = self.unpatchifier.unpatchify(&x, h, w)?;
162 x.narrow(2, 0, h)?.narrow(3, 0, w)
163 }
164}
165
166pub struct MMDiTCore {
167 joint_blocks: Vec<Box<dyn JointBlock>>,
168 context_qkv_only_joint_block: ContextQkvOnlyJointBlock,
169 final_layer: FinalLayer,
170}
171
172impl MMDiTCore {
173 pub fn new(
174 depth: usize,
175 hidden_size: usize,
176 num_heads: usize,
177 patch_size: usize,
178 out_channels: usize,
179 use_flash_attn: bool,
180 vb: nn::VarBuilder,
181 ) -> Result<Self> {
182 let mut joint_blocks = Vec::with_capacity(depth - 1);
183 for i in 0..depth - 1 {
184 let joint_block_vb_pp = format!("joint_blocks.{}", i);
185 let joint_block: Box<dyn JointBlock> =
186 if vb.contains_tensor(&format!("{}.x_block.attn2.qkv.weight", joint_block_vb_pp)) {
187 Box::new(MMDiTXJointBlock::new(
188 hidden_size,
189 num_heads,
190 use_flash_attn,
191 vb.pp(&joint_block_vb_pp),
192 )?)
193 } else {
194 Box::new(MMDiTJointBlock::new(
195 hidden_size,
196 num_heads,
197 use_flash_attn,
198 vb.pp(&joint_block_vb_pp),
199 )?)
200 };
201 joint_blocks.push(joint_block);
202 }
203
204 Ok(Self {
205 joint_blocks,
206 context_qkv_only_joint_block: ContextQkvOnlyJointBlock::new(
207 hidden_size,
208 num_heads,
209 use_flash_attn,
210 vb.pp(format!("joint_blocks.{}", depth - 1)),
211 )?,
212 final_layer: FinalLayer::new(
213 hidden_size,
214 patch_size,
215 out_channels,
216 vb.pp("final_layer"),
217 )?,
218 })
219 }
220
221 pub fn forward(
222 &self,
223 context: &Tensor,
224 x: &Tensor,
225 c: &Tensor,
226 skip_layers: Option<&[usize]>,
227 ) -> Result<Tensor> {
228 let (mut context, mut x) = (context.clone(), x.clone());
229 for (i, joint_block) in self.joint_blocks.iter().enumerate() {
230 if let Some(skip_layers) = &skip_layers {
231 if skip_layers.contains(&i) {
232 continue;
233 }
234 }
235 (context, x) = joint_block.forward(&context, &x, c)?;
236 }
237 let x = self.context_qkv_only_joint_block.forward(&context, &x, c)?;
238 self.final_layer.forward(&x, c)
239 }
240}