1use super::model::{attention, timestep_embedding, Config, EmbedNd};
2use crate::quantized_nn::{linear, linear_b, Linear};
3use crate::quantized_var_builder::VarBuilder;
4use candle::{DType, IndexOp, Result, Tensor, D};
5use candle_nn::{LayerNorm, RmsNorm};
6
7fn layer_norm(dim: usize, vb: VarBuilder) -> Result<LayerNorm> {
8 let ws = Tensor::ones(dim, DType::F32, vb.device())?;
9 Ok(LayerNorm::new_no_bias(ws, 1e-6))
10}
11
12#[derive(Debug, Clone)]
13pub struct MlpEmbedder {
14 in_layer: Linear,
15 out_layer: Linear,
16}
17
18impl MlpEmbedder {
19 fn new(in_sz: usize, h_sz: usize, vb: VarBuilder) -> Result<Self> {
20 let in_layer = linear(in_sz, h_sz, vb.pp("in_layer"))?;
21 let out_layer = linear(h_sz, h_sz, vb.pp("out_layer"))?;
22 Ok(Self {
23 in_layer,
24 out_layer,
25 })
26 }
27}
28
29impl candle::Module for MlpEmbedder {
30 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
31 xs.apply(&self.in_layer)?.silu()?.apply(&self.out_layer)
32 }
33}
34
35#[derive(Debug, Clone)]
36pub struct QkNorm {
37 query_norm: RmsNorm,
38 key_norm: RmsNorm,
39}
40
41impl QkNorm {
42 fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
43 let query_norm = vb.get(dim, "query_norm.scale")?.dequantize(vb.device())?;
44 let query_norm = RmsNorm::new(query_norm, 1e-6);
45 let key_norm = vb.get(dim, "key_norm.scale")?.dequantize(vb.device())?;
46 let key_norm = RmsNorm::new(key_norm, 1e-6);
47 Ok(Self {
48 query_norm,
49 key_norm,
50 })
51 }
52}
53
54struct ModulationOut {
55 shift: Tensor,
56 scale: Tensor,
57 gate: Tensor,
58}
59
60impl ModulationOut {
61 fn scale_shift(&self, xs: &Tensor) -> Result<Tensor> {
62 xs.broadcast_mul(&(&self.scale + 1.)?)?
63 .broadcast_add(&self.shift)
64 }
65
66 fn gate(&self, xs: &Tensor) -> Result<Tensor> {
67 self.gate.broadcast_mul(xs)
68 }
69}
70
71#[derive(Debug, Clone)]
72struct Modulation1 {
73 lin: Linear,
74}
75
76impl Modulation1 {
77 fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
78 let lin = linear(dim, 3 * dim, vb.pp("lin"))?;
79 Ok(Self { lin })
80 }
81
82 fn forward(&self, vec_: &Tensor) -> Result<ModulationOut> {
83 let ys = vec_
84 .silu()?
85 .apply(&self.lin)?
86 .unsqueeze(1)?
87 .chunk(3, D::Minus1)?;
88 if ys.len() != 3 {
89 candle::bail!("unexpected len from chunk {ys:?}")
90 }
91 Ok(ModulationOut {
92 shift: ys[0].clone(),
93 scale: ys[1].clone(),
94 gate: ys[2].clone(),
95 })
96 }
97}
98
99#[derive(Debug, Clone)]
100struct Modulation2 {
101 lin: Linear,
102}
103
104impl Modulation2 {
105 fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
106 let lin = linear(dim, 6 * dim, vb.pp("lin"))?;
107 Ok(Self { lin })
108 }
109
110 fn forward(&self, vec_: &Tensor) -> Result<(ModulationOut, ModulationOut)> {
111 let ys = vec_
112 .silu()?
113 .apply(&self.lin)?
114 .unsqueeze(1)?
115 .chunk(6, D::Minus1)?;
116 if ys.len() != 6 {
117 candle::bail!("unexpected len from chunk {ys:?}")
118 }
119 let mod1 = ModulationOut {
120 shift: ys[0].clone(),
121 scale: ys[1].clone(),
122 gate: ys[2].clone(),
123 };
124 let mod2 = ModulationOut {
125 shift: ys[3].clone(),
126 scale: ys[4].clone(),
127 gate: ys[5].clone(),
128 };
129 Ok((mod1, mod2))
130 }
131}
132
133#[derive(Debug, Clone)]
134pub struct SelfAttention {
135 qkv: Linear,
136 norm: QkNorm,
137 proj: Linear,
138 num_heads: usize,
139}
140
141impl SelfAttention {
142 fn new(dim: usize, num_heads: usize, qkv_bias: bool, vb: VarBuilder) -> Result<Self> {
143 let head_dim = dim / num_heads;
144 let qkv = linear_b(dim, dim * 3, qkv_bias, vb.pp("qkv"))?;
145 let norm = QkNorm::new(head_dim, vb.pp("norm"))?;
146 let proj = linear(dim, dim, vb.pp("proj"))?;
147 Ok(Self {
148 qkv,
149 norm,
150 proj,
151 num_heads,
152 })
153 }
154
155 fn qkv(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
156 let qkv = xs.apply(&self.qkv)?;
157 let (b, l, _khd) = qkv.dims3()?;
158 let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
159 let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
160 let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
161 let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
162 let q = q.apply(&self.norm.query_norm)?;
163 let k = k.apply(&self.norm.key_norm)?;
164 Ok((q, k, v))
165 }
166
167 #[allow(unused)]
168 fn forward(&self, xs: &Tensor, pe: &Tensor) -> Result<Tensor> {
169 let (q, k, v) = self.qkv(xs)?;
170 attention(&q, &k, &v, pe)?.apply(&self.proj)
171 }
172}
173
174#[derive(Debug, Clone)]
175struct Mlp {
176 lin1: Linear,
177 lin2: Linear,
178}
179
180impl Mlp {
181 fn new(in_sz: usize, mlp_sz: usize, vb: VarBuilder) -> Result<Self> {
182 let lin1 = linear(in_sz, mlp_sz, vb.pp("0"))?;
183 let lin2 = linear(mlp_sz, in_sz, vb.pp("2"))?;
184 Ok(Self { lin1, lin2 })
185 }
186}
187
188impl candle::Module for Mlp {
189 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
190 xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2)
191 }
192}
193
194#[derive(Debug, Clone)]
195pub struct DoubleStreamBlock {
196 img_mod: Modulation2,
197 img_norm1: LayerNorm,
198 img_attn: SelfAttention,
199 img_norm2: LayerNorm,
200 img_mlp: Mlp,
201 txt_mod: Modulation2,
202 txt_norm1: LayerNorm,
203 txt_attn: SelfAttention,
204 txt_norm2: LayerNorm,
205 txt_mlp: Mlp,
206}
207
208impl DoubleStreamBlock {
209 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
210 let h_sz = cfg.hidden_size;
211 let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize;
212 let img_mod = Modulation2::new(h_sz, vb.pp("img_mod"))?;
213 let img_norm1 = layer_norm(h_sz, vb.pp("img_norm1"))?;
214 let img_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("img_attn"))?;
215 let img_norm2 = layer_norm(h_sz, vb.pp("img_norm2"))?;
216 let img_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("img_mlp"))?;
217 let txt_mod = Modulation2::new(h_sz, vb.pp("txt_mod"))?;
218 let txt_norm1 = layer_norm(h_sz, vb.pp("txt_norm1"))?;
219 let txt_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("txt_attn"))?;
220 let txt_norm2 = layer_norm(h_sz, vb.pp("txt_norm2"))?;
221 let txt_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("txt_mlp"))?;
222 Ok(Self {
223 img_mod,
224 img_norm1,
225 img_attn,
226 img_norm2,
227 img_mlp,
228 txt_mod,
229 txt_norm1,
230 txt_attn,
231 txt_norm2,
232 txt_mlp,
233 })
234 }
235
236 fn forward(
237 &self,
238 img: &Tensor,
239 txt: &Tensor,
240 vec_: &Tensor,
241 pe: &Tensor,
242 ) -> Result<(Tensor, Tensor)> {
243 let (img_mod1, img_mod2) = self.img_mod.forward(vec_)?; let (txt_mod1, txt_mod2) = self.txt_mod.forward(vec_)?; let img_modulated = img.apply(&self.img_norm1)?;
246 let img_modulated = img_mod1.scale_shift(&img_modulated)?;
247 let (img_q, img_k, img_v) = self.img_attn.qkv(&img_modulated)?;
248
249 let txt_modulated = txt.apply(&self.txt_norm1)?;
250 let txt_modulated = txt_mod1.scale_shift(&txt_modulated)?;
251 let (txt_q, txt_k, txt_v) = self.txt_attn.qkv(&txt_modulated)?;
252
253 let q = Tensor::cat(&[txt_q, img_q], 2)?;
254 let k = Tensor::cat(&[txt_k, img_k], 2)?;
255 let v = Tensor::cat(&[txt_v, img_v], 2)?;
256
257 let attn = attention(&q, &k, &v, pe)?;
258 let txt_attn = attn.narrow(1, 0, txt.dim(1)?)?;
259 let img_attn = attn.narrow(1, txt.dim(1)?, attn.dim(1)? - txt.dim(1)?)?;
260
261 let img = (img + img_mod1.gate(&img_attn.apply(&self.img_attn.proj)?))?;
262 let img = (&img
263 + img_mod2.gate(
264 &img_mod2
265 .scale_shift(&img.apply(&self.img_norm2)?)?
266 .apply(&self.img_mlp)?,
267 )?)?;
268
269 let txt = (txt + txt_mod1.gate(&txt_attn.apply(&self.txt_attn.proj)?))?;
270 let txt = (&txt
271 + txt_mod2.gate(
272 &txt_mod2
273 .scale_shift(&txt.apply(&self.txt_norm2)?)?
274 .apply(&self.txt_mlp)?,
275 )?)?;
276
277 Ok((img, txt))
278 }
279}
280
281#[derive(Debug, Clone)]
282pub struct SingleStreamBlock {
283 linear1: Linear,
284 linear2: Linear,
285 norm: QkNorm,
286 pre_norm: LayerNorm,
287 modulation: Modulation1,
288 h_sz: usize,
289 mlp_sz: usize,
290 num_heads: usize,
291}
292
293impl SingleStreamBlock {
294 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
295 let h_sz = cfg.hidden_size;
296 let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize;
297 let head_dim = h_sz / cfg.num_heads;
298 let linear1 = linear(h_sz, h_sz * 3 + mlp_sz, vb.pp("linear1"))?;
299 let linear2 = linear(h_sz + mlp_sz, h_sz, vb.pp("linear2"))?;
300 let norm = QkNorm::new(head_dim, vb.pp("norm"))?;
301 let pre_norm = layer_norm(h_sz, vb.pp("pre_norm"))?;
302 let modulation = Modulation1::new(h_sz, vb.pp("modulation"))?;
303 Ok(Self {
304 linear1,
305 linear2,
306 norm,
307 pre_norm,
308 modulation,
309 h_sz,
310 mlp_sz,
311 num_heads: cfg.num_heads,
312 })
313 }
314
315 fn forward(&self, xs: &Tensor, vec_: &Tensor, pe: &Tensor) -> Result<Tensor> {
316 let mod_ = self.modulation.forward(vec_)?;
317 let x_mod = mod_.scale_shift(&xs.apply(&self.pre_norm)?)?;
318 let x_mod = x_mod.apply(&self.linear1)?;
319 let qkv = x_mod.narrow(D::Minus1, 0, 3 * self.h_sz)?;
320 let (b, l, _khd) = qkv.dims3()?;
321 let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
322 let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
323 let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
324 let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
325 let mlp = x_mod.narrow(D::Minus1, 3 * self.h_sz, self.mlp_sz)?;
326 let q = q.apply(&self.norm.query_norm)?;
327 let k = k.apply(&self.norm.key_norm)?;
328 let attn = attention(&q, &k, &v, pe)?;
329 let output = Tensor::cat(&[attn, mlp.gelu()?], 2)?.apply(&self.linear2)?;
330 xs + mod_.gate(&output)
331 }
332}
333
334#[derive(Debug, Clone)]
335pub struct LastLayer {
336 norm_final: LayerNorm,
337 linear: Linear,
338 ada_ln_modulation: Linear,
339}
340
341impl LastLayer {
342 fn new(h_sz: usize, p_sz: usize, out_c: usize, vb: VarBuilder) -> Result<Self> {
343 let norm_final = layer_norm(h_sz, vb.pp("norm_final"))?;
344 let linear_ = linear(h_sz, p_sz * p_sz * out_c, vb.pp("linear"))?;
345 let ada_ln_modulation = linear(h_sz, 2 * h_sz, vb.pp("adaLN_modulation.1"))?;
346 Ok(Self {
347 norm_final,
348 linear: linear_,
349 ada_ln_modulation,
350 })
351 }
352
353 fn forward(&self, xs: &Tensor, vec: &Tensor) -> Result<Tensor> {
354 let chunks = vec.silu()?.apply(&self.ada_ln_modulation)?.chunk(2, 1)?;
355 let (shift, scale) = (&chunks[0], &chunks[1]);
356 let xs = xs
357 .apply(&self.norm_final)?
358 .broadcast_mul(&(scale.unsqueeze(1)? + 1.0)?)?
359 .broadcast_add(&shift.unsqueeze(1)?)?;
360 xs.apply(&self.linear)
361 }
362}
363
364#[derive(Debug, Clone)]
365pub struct Flux {
366 img_in: Linear,
367 txt_in: Linear,
368 time_in: MlpEmbedder,
369 vector_in: MlpEmbedder,
370 guidance_in: Option<MlpEmbedder>,
371 pe_embedder: EmbedNd,
372 double_blocks: Vec<DoubleStreamBlock>,
373 single_blocks: Vec<SingleStreamBlock>,
374 final_layer: LastLayer,
375}
376
377impl Flux {
378 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
379 let img_in = linear(cfg.in_channels, cfg.hidden_size, vb.pp("img_in"))?;
380 let txt_in = linear(cfg.context_in_dim, cfg.hidden_size, vb.pp("txt_in"))?;
381 let mut double_blocks = Vec::with_capacity(cfg.depth);
382 let vb_d = vb.pp("double_blocks");
383 for idx in 0..cfg.depth {
384 let db = DoubleStreamBlock::new(cfg, vb_d.pp(idx))?;
385 double_blocks.push(db)
386 }
387 let mut single_blocks = Vec::with_capacity(cfg.depth_single_blocks);
388 let vb_s = vb.pp("single_blocks");
389 for idx in 0..cfg.depth_single_blocks {
390 let sb = SingleStreamBlock::new(cfg, vb_s.pp(idx))?;
391 single_blocks.push(sb)
392 }
393 let time_in = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("time_in"))?;
394 let vector_in = MlpEmbedder::new(cfg.vec_in_dim, cfg.hidden_size, vb.pp("vector_in"))?;
395 let guidance_in = if cfg.guidance_embed {
396 let mlp = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("guidance_in"))?;
397 Some(mlp)
398 } else {
399 None
400 };
401 let final_layer =
402 LastLayer::new(cfg.hidden_size, 1, cfg.in_channels, vb.pp("final_layer"))?;
403 let pe_dim = cfg.hidden_size / cfg.num_heads;
404 let pe_embedder = EmbedNd::new(pe_dim, cfg.theta, cfg.axes_dim.to_vec());
405 Ok(Self {
406 img_in,
407 txt_in,
408 time_in,
409 vector_in,
410 guidance_in,
411 pe_embedder,
412 double_blocks,
413 single_blocks,
414 final_layer,
415 })
416 }
417}
418
419impl super::WithForward for Flux {
420 #[allow(clippy::too_many_arguments)]
421 fn forward(
422 &self,
423 img: &Tensor,
424 img_ids: &Tensor,
425 txt: &Tensor,
426 txt_ids: &Tensor,
427 timesteps: &Tensor,
428 y: &Tensor,
429 guidance: Option<&Tensor>,
430 ) -> Result<Tensor> {
431 if txt.rank() != 3 {
432 candle::bail!("unexpected shape for txt {:?}", txt.shape())
433 }
434 if img.rank() != 3 {
435 candle::bail!("unexpected shape for img {:?}", img.shape())
436 }
437 let dtype = img.dtype();
438 let pe = {
439 let ids = Tensor::cat(&[txt_ids, img_ids], 1)?;
440 ids.apply(&self.pe_embedder)?
441 };
442 let mut txt = txt.apply(&self.txt_in)?;
443 let mut img = img.apply(&self.img_in)?;
444 let vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?;
445 let vec_ = match (self.guidance_in.as_ref(), guidance) {
446 (Some(g_in), Some(guidance)) => {
447 (vec_ + timestep_embedding(guidance, 256, dtype)?.apply(g_in))?
448 }
449 _ => vec_,
450 };
451 let vec_ = (vec_ + y.apply(&self.vector_in))?;
452
453 for block in self.double_blocks.iter() {
455 (img, txt) = block.forward(&img, &txt, &vec_, &pe)?
456 }
457 let mut img = Tensor::cat(&[&txt, &img], 1)?;
459 for block in self.single_blocks.iter() {
460 img = block.forward(&img, &vec_, &pe)?;
461 }
462 let img = img.i((.., txt.dim(1)?..))?;
463 self.final_layer.forward(&img, &vec_)
464 }
465}