candle_transformers/models/flux/
sampling.rs1use candle::{Device, Result, Tensor};
2
3pub fn get_noise(
4 num_samples: usize,
5 height: usize,
6 width: usize,
7 device: &Device,
8) -> Result<Tensor> {
9 let height = (height + 15) / 16 * 2;
10 let width = (width + 15) / 16 * 2;
11 Tensor::randn(0f32, 1., (num_samples, 16, height, width), device)
12}
13
14#[derive(Debug, Clone)]
15pub struct State {
16 pub img: Tensor,
17 pub img_ids: Tensor,
18 pub txt: Tensor,
19 pub txt_ids: Tensor,
20 pub vec: Tensor,
21}
22
23impl State {
24 pub fn new(t5_emb: &Tensor, clip_emb: &Tensor, img: &Tensor) -> Result<Self> {
25 let dtype = img.dtype();
26 let (bs, c, h, w) = img.dims4()?;
27 let dev = img.device();
28 let img = img.reshape((bs, c, h / 2, 2, w / 2, 2))?; let img = img.permute((0, 2, 4, 1, 3, 5))?; let img = img.reshape((bs, h / 2 * w / 2, c * 4))?;
31 let img_ids = Tensor::stack(
32 &[
33 Tensor::full(0u32, (h / 2, w / 2), dev)?,
34 Tensor::arange(0u32, h as u32 / 2, dev)?
35 .reshape(((), 1))?
36 .broadcast_as((h / 2, w / 2))?,
37 Tensor::arange(0u32, w as u32 / 2, dev)?
38 .reshape((1, ()))?
39 .broadcast_as((h / 2, w / 2))?,
40 ],
41 2,
42 )?
43 .to_dtype(dtype)?;
44 let img_ids = img_ids.reshape((1, h / 2 * w / 2, 3))?;
45 let img_ids = img_ids.repeat((bs, 1, 1))?;
46 let txt = t5_emb.repeat(bs)?;
47 let txt_ids = Tensor::zeros((bs, txt.dim(1)?, 3), dtype, dev)?;
48 let vec = clip_emb.repeat(bs)?;
49 Ok(Self {
50 img,
51 img_ids,
52 txt,
53 txt_ids,
54 vec,
55 })
56 }
57}
58
59fn time_shift(mu: f64, sigma: f64, t: f64) -> f64 {
60 let e = mu.exp();
61 e / (e + (1. / t - 1.).powf(sigma))
62}
63
64pub fn get_schedule(num_steps: usize, shift: Option<(usize, f64, f64)>) -> Vec<f64> {
66 let timesteps: Vec<f64> = (0..=num_steps)
67 .map(|v| v as f64 / num_steps as f64)
68 .rev()
69 .collect();
70 match shift {
71 None => timesteps,
72 Some((image_seq_len, y1, y2)) => {
73 let (x1, x2) = (256., 4096.);
74 let m = (y2 - y1) / (x2 - x1);
75 let b = y1 - m * x1;
76 let mu = m * image_seq_len as f64 + b;
77 timesteps
78 .into_iter()
79 .map(|v| time_shift(mu, 1., v))
80 .collect()
81 }
82 }
83}
84
85pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
86 let (b, _h_w, c_ph_pw) = xs.dims3()?;
87 let height = (height + 15) / 16;
88 let width = (width + 15) / 16;
89 xs.reshape((b, height, width, c_ph_pw / 4, 2, 2))? .permute((0, 3, 1, 4, 2, 5))? .reshape((b, c_ph_pw / 4, height * 2, width * 2))
92}
93
94#[allow(clippy::too_many_arguments)]
95pub fn denoise<M: super::WithForward>(
96 model: &M,
97 img: &Tensor,
98 img_ids: &Tensor,
99 txt: &Tensor,
100 txt_ids: &Tensor,
101 vec_: &Tensor,
102 timesteps: &[f64],
103 guidance: f64,
104) -> Result<Tensor> {
105 let b_sz = img.dim(0)?;
106 let dev = img.device();
107 let guidance = Tensor::full(guidance as f32, b_sz, dev)?;
108 let mut img = img.clone();
109 for window in timesteps.windows(2) {
110 let (t_curr, t_prev) = match window {
111 [a, b] => (a, b),
112 _ => continue,
113 };
114 let t_vec = Tensor::full(*t_curr as f32, b_sz, dev)?;
115 let pred = model.forward(&img, img_ids, txt, txt_ids, &t_vec, vec_, Some(&guidance))?;
116 img = (img + pred * (t_prev - t_curr))?
117 }
118 Ok(img)
119}