1use candle::{Result, Tensor, D};
2use candle_nn::{conv2d, group_norm, Conv2d, GroupNorm, VarBuilder};
3
4#[derive(Debug, Clone)]
6pub struct Config {
7 pub resolution: usize,
8 pub in_channels: usize,
9 pub ch: usize,
10 pub out_ch: usize,
11 pub ch_mult: Vec<usize>,
12 pub num_res_blocks: usize,
13 pub z_channels: usize,
14 pub scale_factor: f64,
15 pub shift_factor: f64,
16}
17
18impl Config {
19 pub fn dev() -> Self {
21 Self {
22 resolution: 256,
23 in_channels: 3,
24 ch: 128,
25 out_ch: 3,
26 ch_mult: vec![1, 2, 4, 4],
27 num_res_blocks: 2,
28 z_channels: 16,
29 scale_factor: 0.3611,
30 shift_factor: 0.1159,
31 }
32 }
33
34 pub fn schnell() -> Self {
36 Self {
37 resolution: 256,
38 in_channels: 3,
39 ch: 128,
40 out_ch: 3,
41 ch_mult: vec![1, 2, 4, 4],
42 num_res_blocks: 2,
43 z_channels: 16,
44 scale_factor: 0.3611,
45 shift_factor: 0.1159,
46 }
47 }
48}
49
50fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
51 let dim = q.dim(D::Minus1)?;
52 let scale_factor = 1.0 / (dim as f64).sqrt();
53 let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?;
54 candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(v)
55}
56
57#[derive(Debug, Clone)]
58struct AttnBlock {
59 q: Conv2d,
60 k: Conv2d,
61 v: Conv2d,
62 proj_out: Conv2d,
63 norm: GroupNorm,
64}
65
66impl AttnBlock {
67 fn new(in_c: usize, vb: VarBuilder) -> Result<Self> {
68 let q = conv2d(in_c, in_c, 1, Default::default(), vb.pp("q"))?;
69 let k = conv2d(in_c, in_c, 1, Default::default(), vb.pp("k"))?;
70 let v = conv2d(in_c, in_c, 1, Default::default(), vb.pp("v"))?;
71 let proj_out = conv2d(in_c, in_c, 1, Default::default(), vb.pp("proj_out"))?;
72 let norm = group_norm(32, in_c, 1e-6, vb.pp("norm"))?;
73 Ok(Self {
74 q,
75 k,
76 v,
77 proj_out,
78 norm,
79 })
80 }
81}
82
83impl candle::Module for AttnBlock {
84 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
85 let init_xs = xs;
86 let xs = xs.apply(&self.norm)?;
87 let q = xs.apply(&self.q)?;
88 let k = xs.apply(&self.k)?;
89 let v = xs.apply(&self.v)?;
90 let (b, c, h, w) = q.dims4()?;
91 let q = q.flatten_from(2)?.t()?.unsqueeze(1)?;
92 let k = k.flatten_from(2)?.t()?.unsqueeze(1)?;
93 let v = v.flatten_from(2)?.t()?.unsqueeze(1)?;
94 let xs = scaled_dot_product_attention(&q, &k, &v)?;
95 let xs = xs.squeeze(1)?.t()?.reshape((b, c, h, w))?;
96 xs.apply(&self.proj_out)? + init_xs
97 }
98}
99
100#[derive(Debug, Clone)]
101struct ResnetBlock {
102 norm1: GroupNorm,
103 conv1: Conv2d,
104 norm2: GroupNorm,
105 conv2: Conv2d,
106 nin_shortcut: Option<Conv2d>,
107}
108
109impl ResnetBlock {
110 fn new(in_c: usize, out_c: usize, vb: VarBuilder) -> Result<Self> {
111 let conv_cfg = candle_nn::Conv2dConfig {
112 padding: 1,
113 ..Default::default()
114 };
115 let norm1 = group_norm(32, in_c, 1e-6, vb.pp("norm1"))?;
116 let conv1 = conv2d(in_c, out_c, 3, conv_cfg, vb.pp("conv1"))?;
117 let norm2 = group_norm(32, out_c, 1e-6, vb.pp("norm2"))?;
118 let conv2 = conv2d(out_c, out_c, 3, conv_cfg, vb.pp("conv2"))?;
119 let nin_shortcut = if in_c == out_c {
120 None
121 } else {
122 Some(conv2d(
123 in_c,
124 out_c,
125 1,
126 Default::default(),
127 vb.pp("nin_shortcut"),
128 )?)
129 };
130 Ok(Self {
131 norm1,
132 conv1,
133 norm2,
134 conv2,
135 nin_shortcut,
136 })
137 }
138}
139
140impl candle::Module for ResnetBlock {
141 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
142 let h = xs
143 .apply(&self.norm1)?
144 .apply(&candle_nn::Activation::Swish)?
145 .apply(&self.conv1)?
146 .apply(&self.norm2)?
147 .apply(&candle_nn::Activation::Swish)?
148 .apply(&self.conv2)?;
149 match self.nin_shortcut.as_ref() {
150 None => xs + h,
151 Some(c) => xs.apply(c)? + h,
152 }
153 }
154}
155
156#[derive(Debug, Clone)]
157struct Downsample {
158 conv: Conv2d,
159}
160
161impl Downsample {
162 fn new(in_c: usize, vb: VarBuilder) -> Result<Self> {
163 let conv_cfg = candle_nn::Conv2dConfig {
164 stride: 2,
165 ..Default::default()
166 };
167 let conv = conv2d(in_c, in_c, 3, conv_cfg, vb.pp("conv"))?;
168 Ok(Self { conv })
169 }
170}
171
172impl candle::Module for Downsample {
173 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
174 let xs = xs.pad_with_zeros(D::Minus1, 0, 1)?;
175 let xs = xs.pad_with_zeros(D::Minus2, 0, 1)?;
176 xs.apply(&self.conv)
177 }
178}
179
180#[derive(Debug, Clone)]
181struct Upsample {
182 conv: Conv2d,
183}
184
185impl Upsample {
186 fn new(in_c: usize, vb: VarBuilder) -> Result<Self> {
187 let conv_cfg = candle_nn::Conv2dConfig {
188 padding: 1,
189 ..Default::default()
190 };
191 let conv = conv2d(in_c, in_c, 3, conv_cfg, vb.pp("conv"))?;
192 Ok(Self { conv })
193 }
194}
195
196impl candle::Module for Upsample {
197 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
198 let (_, _, h, w) = xs.dims4()?;
199 xs.upsample_nearest2d(h * 2, w * 2)?.apply(&self.conv)
200 }
201}
202
203#[derive(Debug, Clone)]
204struct DownBlock {
205 block: Vec<ResnetBlock>,
206 downsample: Option<Downsample>,
207}
208
209#[derive(Debug, Clone)]
210pub struct Encoder {
211 conv_in: Conv2d,
212 mid_block_1: ResnetBlock,
213 mid_attn_1: AttnBlock,
214 mid_block_2: ResnetBlock,
215 norm_out: GroupNorm,
216 conv_out: Conv2d,
217 down: Vec<DownBlock>,
218}
219
220impl Encoder {
221 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
222 let conv_cfg = candle_nn::Conv2dConfig {
223 padding: 1,
224 ..Default::default()
225 };
226 let mut block_in = cfg.ch;
227 let conv_in = conv2d(cfg.in_channels, block_in, 3, conv_cfg, vb.pp("conv_in"))?;
228
229 let mut down = Vec::with_capacity(cfg.ch_mult.len());
230 let vb_d = vb.pp("down");
231 for (i_level, ch_mult) in cfg.ch_mult.iter().enumerate() {
232 let mut block = Vec::with_capacity(cfg.num_res_blocks);
233 let vb_d = vb_d.pp(i_level);
234 let vb_b = vb_d.pp("block");
235 let in_ch_mult = if i_level == 0 {
236 1
237 } else {
238 cfg.ch_mult[i_level - 1]
239 };
240 block_in = cfg.ch * in_ch_mult;
241 let block_out = cfg.ch * ch_mult;
242 for i_block in 0..cfg.num_res_blocks {
243 let b = ResnetBlock::new(block_in, block_out, vb_b.pp(i_block))?;
244 block.push(b);
245 block_in = block_out;
246 }
247 let downsample = if i_level != cfg.ch_mult.len() - 1 {
248 Some(Downsample::new(block_in, vb_d.pp("downsample"))?)
249 } else {
250 None
251 };
252 let block = DownBlock { block, downsample };
253 down.push(block)
254 }
255
256 let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_1"))?;
257 let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid.attn_1"))?;
258 let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_2"))?;
259 let conv_out = conv2d(block_in, 2 * cfg.z_channels, 3, conv_cfg, vb.pp("conv_out"))?;
260 let norm_out = group_norm(32, block_in, 1e-6, vb.pp("norm_out"))?;
261 Ok(Self {
262 conv_in,
263 mid_block_1,
264 mid_attn_1,
265 mid_block_2,
266 norm_out,
267 conv_out,
268 down,
269 })
270 }
271}
272
273impl candle_nn::Module for Encoder {
274 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
275 let mut h = xs.apply(&self.conv_in)?;
276 for block in self.down.iter() {
277 for b in block.block.iter() {
278 h = h.apply(b)?
279 }
280 if let Some(ds) = block.downsample.as_ref() {
281 h = h.apply(ds)?
282 }
283 }
284 h.apply(&self.mid_block_1)?
285 .apply(&self.mid_attn_1)?
286 .apply(&self.mid_block_2)?
287 .apply(&self.norm_out)?
288 .apply(&candle_nn::Activation::Swish)?
289 .apply(&self.conv_out)
290 }
291}
292
293#[derive(Debug, Clone)]
294struct UpBlock {
295 block: Vec<ResnetBlock>,
296 upsample: Option<Upsample>,
297}
298
299#[derive(Debug, Clone)]
300pub struct Decoder {
301 conv_in: Conv2d,
302 mid_block_1: ResnetBlock,
303 mid_attn_1: AttnBlock,
304 mid_block_2: ResnetBlock,
305 norm_out: GroupNorm,
306 conv_out: Conv2d,
307 up: Vec<UpBlock>,
308}
309
310impl Decoder {
311 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
312 let conv_cfg = candle_nn::Conv2dConfig {
313 padding: 1,
314 ..Default::default()
315 };
316 let mut block_in = cfg.ch * cfg.ch_mult.last().unwrap_or(&1);
317 let conv_in = conv2d(cfg.z_channels, block_in, 3, conv_cfg, vb.pp("conv_in"))?;
318 let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_1"))?;
319 let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid.attn_1"))?;
320 let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_2"))?;
321
322 let mut up = Vec::with_capacity(cfg.ch_mult.len());
323 let vb_u = vb.pp("up");
324 for (i_level, ch_mult) in cfg.ch_mult.iter().enumerate().rev() {
325 let block_out = cfg.ch * ch_mult;
326 let vb_u = vb_u.pp(i_level);
327 let vb_b = vb_u.pp("block");
328 let mut block = Vec::with_capacity(cfg.num_res_blocks + 1);
329 for i_block in 0..=cfg.num_res_blocks {
330 let b = ResnetBlock::new(block_in, block_out, vb_b.pp(i_block))?;
331 block.push(b);
332 block_in = block_out;
333 }
334 let upsample = if i_level != 0 {
335 Some(Upsample::new(block_in, vb_u.pp("upsample"))?)
336 } else {
337 None
338 };
339 let block = UpBlock { block, upsample };
340 up.push(block)
341 }
342 up.reverse();
343
344 let norm_out = group_norm(32, block_in, 1e-6, vb.pp("norm_out"))?;
345 let conv_out = conv2d(block_in, cfg.out_ch, 3, conv_cfg, vb.pp("conv_out"))?;
346 Ok(Self {
347 conv_in,
348 mid_block_1,
349 mid_attn_1,
350 mid_block_2,
351 norm_out,
352 conv_out,
353 up,
354 })
355 }
356}
357
358impl candle_nn::Module for Decoder {
359 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
360 let h = xs.apply(&self.conv_in)?;
361 let mut h = h
362 .apply(&self.mid_block_1)?
363 .apply(&self.mid_attn_1)?
364 .apply(&self.mid_block_2)?;
365 for block in self.up.iter().rev() {
366 for b in block.block.iter() {
367 h = h.apply(b)?
368 }
369 if let Some(us) = block.upsample.as_ref() {
370 h = h.apply(us)?
371 }
372 }
373 h.apply(&self.norm_out)?
374 .apply(&candle_nn::Activation::Swish)?
375 .apply(&self.conv_out)
376 }
377}
378
379#[derive(Debug, Clone)]
380pub struct DiagonalGaussian {
381 sample: bool,
382 chunk_dim: usize,
383}
384
385impl DiagonalGaussian {
386 pub fn new(sample: bool, chunk_dim: usize) -> Result<Self> {
387 Ok(Self { sample, chunk_dim })
388 }
389}
390
391impl candle_nn::Module for DiagonalGaussian {
392 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
393 let chunks = xs.chunk(2, self.chunk_dim)?;
394 if self.sample {
395 let std = (&chunks[1] * 0.5)?.exp()?;
396 &chunks[0] + (std * chunks[0].randn_like(0., 1.))?
397 } else {
398 Ok(chunks[0].clone())
399 }
400 }
401}
402
403#[derive(Debug, Clone)]
404pub struct AutoEncoder {
405 encoder: Encoder,
406 decoder: Decoder,
407 reg: DiagonalGaussian,
408 shift_factor: f64,
409 scale_factor: f64,
410}
411
412impl AutoEncoder {
413 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
414 let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
415 let decoder = Decoder::new(cfg, vb.pp("decoder"))?;
416 let reg = DiagonalGaussian::new(true, 1)?;
417 Ok(Self {
418 encoder,
419 decoder,
420 reg,
421 scale_factor: cfg.scale_factor,
422 shift_factor: cfg.shift_factor,
423 })
424 }
425
426 pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
427 let z = xs.apply(&self.encoder)?.apply(&self.reg)?;
428 (z - self.shift_factor)? * self.scale_factor
429 }
430 pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
431 let xs = ((xs / self.scale_factor)? + self.shift_factor)?;
432 xs.apply(&self.decoder)
433 }
434}
435
436impl candle::Module for AutoEncoder {
437 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
438 self.decode(&self.encode(xs)?)
439 }
440}