candle_transformers/models/flux/
autoencoder.rs

1use candle::{Result, Tensor, D};
2use candle_nn::{conv2d, group_norm, Conv2d, GroupNorm, VarBuilder};
3
4// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/modules/autoencoder.py#L9
5#[derive(Debug, Clone)]
6pub struct Config {
7    pub resolution: usize,
8    pub in_channels: usize,
9    pub ch: usize,
10    pub out_ch: usize,
11    pub ch_mult: Vec<usize>,
12    pub num_res_blocks: usize,
13    pub z_channels: usize,
14    pub scale_factor: f64,
15    pub shift_factor: f64,
16}
17
18impl Config {
19    // https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L47
20    pub fn dev() -> Self {
21        Self {
22            resolution: 256,
23            in_channels: 3,
24            ch: 128,
25            out_ch: 3,
26            ch_mult: vec![1, 2, 4, 4],
27            num_res_blocks: 2,
28            z_channels: 16,
29            scale_factor: 0.3611,
30            shift_factor: 0.1159,
31        }
32    }
33
34    // https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L79
35    pub fn schnell() -> Self {
36        Self {
37            resolution: 256,
38            in_channels: 3,
39            ch: 128,
40            out_ch: 3,
41            ch_mult: vec![1, 2, 4, 4],
42            num_res_blocks: 2,
43            z_channels: 16,
44            scale_factor: 0.3611,
45            shift_factor: 0.1159,
46        }
47    }
48}
49
50fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
51    let dim = q.dim(D::Minus1)?;
52    let scale_factor = 1.0 / (dim as f64).sqrt();
53    let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?;
54    candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(v)
55}
56
57#[derive(Debug, Clone)]
58struct AttnBlock {
59    q: Conv2d,
60    k: Conv2d,
61    v: Conv2d,
62    proj_out: Conv2d,
63    norm: GroupNorm,
64}
65
66impl AttnBlock {
67    fn new(in_c: usize, vb: VarBuilder) -> Result<Self> {
68        let q = conv2d(in_c, in_c, 1, Default::default(), vb.pp("q"))?;
69        let k = conv2d(in_c, in_c, 1, Default::default(), vb.pp("k"))?;
70        let v = conv2d(in_c, in_c, 1, Default::default(), vb.pp("v"))?;
71        let proj_out = conv2d(in_c, in_c, 1, Default::default(), vb.pp("proj_out"))?;
72        let norm = group_norm(32, in_c, 1e-6, vb.pp("norm"))?;
73        Ok(Self {
74            q,
75            k,
76            v,
77            proj_out,
78            norm,
79        })
80    }
81}
82
83impl candle::Module for AttnBlock {
84    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
85        let init_xs = xs;
86        let xs = xs.apply(&self.norm)?;
87        let q = xs.apply(&self.q)?;
88        let k = xs.apply(&self.k)?;
89        let v = xs.apply(&self.v)?;
90        let (b, c, h, w) = q.dims4()?;
91        let q = q.flatten_from(2)?.t()?.unsqueeze(1)?;
92        let k = k.flatten_from(2)?.t()?.unsqueeze(1)?;
93        let v = v.flatten_from(2)?.t()?.unsqueeze(1)?;
94        let xs = scaled_dot_product_attention(&q, &k, &v)?;
95        let xs = xs.squeeze(1)?.t()?.reshape((b, c, h, w))?;
96        xs.apply(&self.proj_out)? + init_xs
97    }
98}
99
100#[derive(Debug, Clone)]
101struct ResnetBlock {
102    norm1: GroupNorm,
103    conv1: Conv2d,
104    norm2: GroupNorm,
105    conv2: Conv2d,
106    nin_shortcut: Option<Conv2d>,
107}
108
109impl ResnetBlock {
110    fn new(in_c: usize, out_c: usize, vb: VarBuilder) -> Result<Self> {
111        let conv_cfg = candle_nn::Conv2dConfig {
112            padding: 1,
113            ..Default::default()
114        };
115        let norm1 = group_norm(32, in_c, 1e-6, vb.pp("norm1"))?;
116        let conv1 = conv2d(in_c, out_c, 3, conv_cfg, vb.pp("conv1"))?;
117        let norm2 = group_norm(32, out_c, 1e-6, vb.pp("norm2"))?;
118        let conv2 = conv2d(out_c, out_c, 3, conv_cfg, vb.pp("conv2"))?;
119        let nin_shortcut = if in_c == out_c {
120            None
121        } else {
122            Some(conv2d(
123                in_c,
124                out_c,
125                1,
126                Default::default(),
127                vb.pp("nin_shortcut"),
128            )?)
129        };
130        Ok(Self {
131            norm1,
132            conv1,
133            norm2,
134            conv2,
135            nin_shortcut,
136        })
137    }
138}
139
140impl candle::Module for ResnetBlock {
141    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
142        let h = xs
143            .apply(&self.norm1)?
144            .apply(&candle_nn::Activation::Swish)?
145            .apply(&self.conv1)?
146            .apply(&self.norm2)?
147            .apply(&candle_nn::Activation::Swish)?
148            .apply(&self.conv2)?;
149        match self.nin_shortcut.as_ref() {
150            None => xs + h,
151            Some(c) => xs.apply(c)? + h,
152        }
153    }
154}
155
156#[derive(Debug, Clone)]
157struct Downsample {
158    conv: Conv2d,
159}
160
161impl Downsample {
162    fn new(in_c: usize, vb: VarBuilder) -> Result<Self> {
163        let conv_cfg = candle_nn::Conv2dConfig {
164            stride: 2,
165            ..Default::default()
166        };
167        let conv = conv2d(in_c, in_c, 3, conv_cfg, vb.pp("conv"))?;
168        Ok(Self { conv })
169    }
170}
171
172impl candle::Module for Downsample {
173    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
174        let xs = xs.pad_with_zeros(D::Minus1, 0, 1)?;
175        let xs = xs.pad_with_zeros(D::Minus2, 0, 1)?;
176        xs.apply(&self.conv)
177    }
178}
179
180#[derive(Debug, Clone)]
181struct Upsample {
182    conv: Conv2d,
183}
184
185impl Upsample {
186    fn new(in_c: usize, vb: VarBuilder) -> Result<Self> {
187        let conv_cfg = candle_nn::Conv2dConfig {
188            padding: 1,
189            ..Default::default()
190        };
191        let conv = conv2d(in_c, in_c, 3, conv_cfg, vb.pp("conv"))?;
192        Ok(Self { conv })
193    }
194}
195
196impl candle::Module for Upsample {
197    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
198        let (_, _, h, w) = xs.dims4()?;
199        xs.upsample_nearest2d(h * 2, w * 2)?.apply(&self.conv)
200    }
201}
202
203#[derive(Debug, Clone)]
204struct DownBlock {
205    block: Vec<ResnetBlock>,
206    downsample: Option<Downsample>,
207}
208
209#[derive(Debug, Clone)]
210pub struct Encoder {
211    conv_in: Conv2d,
212    mid_block_1: ResnetBlock,
213    mid_attn_1: AttnBlock,
214    mid_block_2: ResnetBlock,
215    norm_out: GroupNorm,
216    conv_out: Conv2d,
217    down: Vec<DownBlock>,
218}
219
220impl Encoder {
221    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
222        let conv_cfg = candle_nn::Conv2dConfig {
223            padding: 1,
224            ..Default::default()
225        };
226        let mut block_in = cfg.ch;
227        let conv_in = conv2d(cfg.in_channels, block_in, 3, conv_cfg, vb.pp("conv_in"))?;
228
229        let mut down = Vec::with_capacity(cfg.ch_mult.len());
230        let vb_d = vb.pp("down");
231        for (i_level, ch_mult) in cfg.ch_mult.iter().enumerate() {
232            let mut block = Vec::with_capacity(cfg.num_res_blocks);
233            let vb_d = vb_d.pp(i_level);
234            let vb_b = vb_d.pp("block");
235            let in_ch_mult = if i_level == 0 {
236                1
237            } else {
238                cfg.ch_mult[i_level - 1]
239            };
240            block_in = cfg.ch * in_ch_mult;
241            let block_out = cfg.ch * ch_mult;
242            for i_block in 0..cfg.num_res_blocks {
243                let b = ResnetBlock::new(block_in, block_out, vb_b.pp(i_block))?;
244                block.push(b);
245                block_in = block_out;
246            }
247            let downsample = if i_level != cfg.ch_mult.len() - 1 {
248                Some(Downsample::new(block_in, vb_d.pp("downsample"))?)
249            } else {
250                None
251            };
252            let block = DownBlock { block, downsample };
253            down.push(block)
254        }
255
256        let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_1"))?;
257        let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid.attn_1"))?;
258        let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_2"))?;
259        let conv_out = conv2d(block_in, 2 * cfg.z_channels, 3, conv_cfg, vb.pp("conv_out"))?;
260        let norm_out = group_norm(32, block_in, 1e-6, vb.pp("norm_out"))?;
261        Ok(Self {
262            conv_in,
263            mid_block_1,
264            mid_attn_1,
265            mid_block_2,
266            norm_out,
267            conv_out,
268            down,
269        })
270    }
271}
272
273impl candle_nn::Module for Encoder {
274    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
275        let mut h = xs.apply(&self.conv_in)?;
276        for block in self.down.iter() {
277            for b in block.block.iter() {
278                h = h.apply(b)?
279            }
280            if let Some(ds) = block.downsample.as_ref() {
281                h = h.apply(ds)?
282            }
283        }
284        h.apply(&self.mid_block_1)?
285            .apply(&self.mid_attn_1)?
286            .apply(&self.mid_block_2)?
287            .apply(&self.norm_out)?
288            .apply(&candle_nn::Activation::Swish)?
289            .apply(&self.conv_out)
290    }
291}
292
293#[derive(Debug, Clone)]
294struct UpBlock {
295    block: Vec<ResnetBlock>,
296    upsample: Option<Upsample>,
297}
298
299#[derive(Debug, Clone)]
300pub struct Decoder {
301    conv_in: Conv2d,
302    mid_block_1: ResnetBlock,
303    mid_attn_1: AttnBlock,
304    mid_block_2: ResnetBlock,
305    norm_out: GroupNorm,
306    conv_out: Conv2d,
307    up: Vec<UpBlock>,
308}
309
310impl Decoder {
311    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
312        let conv_cfg = candle_nn::Conv2dConfig {
313            padding: 1,
314            ..Default::default()
315        };
316        let mut block_in = cfg.ch * cfg.ch_mult.last().unwrap_or(&1);
317        let conv_in = conv2d(cfg.z_channels, block_in, 3, conv_cfg, vb.pp("conv_in"))?;
318        let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_1"))?;
319        let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid.attn_1"))?;
320        let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_2"))?;
321
322        let mut up = Vec::with_capacity(cfg.ch_mult.len());
323        let vb_u = vb.pp("up");
324        for (i_level, ch_mult) in cfg.ch_mult.iter().enumerate().rev() {
325            let block_out = cfg.ch * ch_mult;
326            let vb_u = vb_u.pp(i_level);
327            let vb_b = vb_u.pp("block");
328            let mut block = Vec::with_capacity(cfg.num_res_blocks + 1);
329            for i_block in 0..=cfg.num_res_blocks {
330                let b = ResnetBlock::new(block_in, block_out, vb_b.pp(i_block))?;
331                block.push(b);
332                block_in = block_out;
333            }
334            let upsample = if i_level != 0 {
335                Some(Upsample::new(block_in, vb_u.pp("upsample"))?)
336            } else {
337                None
338            };
339            let block = UpBlock { block, upsample };
340            up.push(block)
341        }
342        up.reverse();
343
344        let norm_out = group_norm(32, block_in, 1e-6, vb.pp("norm_out"))?;
345        let conv_out = conv2d(block_in, cfg.out_ch, 3, conv_cfg, vb.pp("conv_out"))?;
346        Ok(Self {
347            conv_in,
348            mid_block_1,
349            mid_attn_1,
350            mid_block_2,
351            norm_out,
352            conv_out,
353            up,
354        })
355    }
356}
357
358impl candle_nn::Module for Decoder {
359    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
360        let h = xs.apply(&self.conv_in)?;
361        let mut h = h
362            .apply(&self.mid_block_1)?
363            .apply(&self.mid_attn_1)?
364            .apply(&self.mid_block_2)?;
365        for block in self.up.iter().rev() {
366            for b in block.block.iter() {
367                h = h.apply(b)?
368            }
369            if let Some(us) = block.upsample.as_ref() {
370                h = h.apply(us)?
371            }
372        }
373        h.apply(&self.norm_out)?
374            .apply(&candle_nn::Activation::Swish)?
375            .apply(&self.conv_out)
376    }
377}
378
379#[derive(Debug, Clone)]
380pub struct DiagonalGaussian {
381    sample: bool,
382    chunk_dim: usize,
383}
384
385impl DiagonalGaussian {
386    pub fn new(sample: bool, chunk_dim: usize) -> Result<Self> {
387        Ok(Self { sample, chunk_dim })
388    }
389}
390
391impl candle_nn::Module for DiagonalGaussian {
392    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
393        let chunks = xs.chunk(2, self.chunk_dim)?;
394        if self.sample {
395            let std = (&chunks[1] * 0.5)?.exp()?;
396            &chunks[0] + (std * chunks[0].randn_like(0., 1.))?
397        } else {
398            Ok(chunks[0].clone())
399        }
400    }
401}
402
403#[derive(Debug, Clone)]
404pub struct AutoEncoder {
405    encoder: Encoder,
406    decoder: Decoder,
407    reg: DiagonalGaussian,
408    shift_factor: f64,
409    scale_factor: f64,
410}
411
412impl AutoEncoder {
413    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
414        let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
415        let decoder = Decoder::new(cfg, vb.pp("decoder"))?;
416        let reg = DiagonalGaussian::new(true, 1)?;
417        Ok(Self {
418            encoder,
419            decoder,
420            reg,
421            scale_factor: cfg.scale_factor,
422            shift_factor: cfg.shift_factor,
423        })
424    }
425
426    pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
427        let z = xs.apply(&self.encoder)?.apply(&self.reg)?;
428        (z - self.shift_factor)? * self.scale_factor
429    }
430    pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
431        let xs = ((xs / self.scale_factor)? + self.shift_factor)?;
432        xs.apply(&self.decoder)
433    }
434}
435
436impl candle::Module for AutoEncoder {
437    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
438        self.decode(&self.encode(xs)?)
439    }
440}