candle_transformers/models/wuerstchen/
ddpm.rs1use candle::{Result, Tensor};
2
3#[derive(Debug, Clone)]
4pub struct DDPMWSchedulerConfig {
5 scaler: f64,
6 s: f64,
7}
8
9impl Default for DDPMWSchedulerConfig {
10 fn default() -> Self {
11 Self {
12 scaler: 1f64,
13 s: 0.008f64,
14 }
15 }
16}
17
18pub struct DDPMWScheduler {
19 init_alpha_cumprod: f64,
20 init_noise_sigma: f64,
21 timesteps: Vec<f64>,
22 pub config: DDPMWSchedulerConfig,
23}
24
25impl DDPMWScheduler {
26 pub fn new(inference_steps: usize, config: DDPMWSchedulerConfig) -> Result<Self> {
27 let init_alpha_cumprod = (config.s / (1. + config.s) * std::f64::consts::PI)
28 .cos()
29 .powi(2);
30 let timesteps = (0..=inference_steps)
31 .map(|i| 1. - i as f64 / inference_steps as f64)
32 .collect::<Vec<_>>();
33 Ok(Self {
34 init_alpha_cumprod,
35 init_noise_sigma: 1.0,
36 timesteps,
37 config,
38 })
39 }
40
41 pub fn timesteps(&self) -> &[f64] {
42 &self.timesteps
43 }
44
45 fn alpha_cumprod(&self, t: f64) -> f64 {
46 let scaler = self.config.scaler;
47 let s = self.config.s;
48 let t = if scaler > 1. {
49 1. - (1. - t).powf(scaler)
50 } else if scaler < 1. {
51 t.powf(scaler)
52 } else {
53 t
54 };
55 let alpha_cumprod = ((t + s) / (1. + s) * std::f64::consts::PI * 0.5)
56 .cos()
57 .powi(2)
58 / self.init_alpha_cumprod;
59 alpha_cumprod.clamp(0.0001, 0.9999)
60 }
61
62 fn previous_timestep(&self, ts: f64) -> f64 {
63 let index = self
64 .timesteps
65 .iter()
66 .enumerate()
67 .map(|(idx, v)| (idx, (v - ts).abs()))
68 .min_by(|x, y| x.1.total_cmp(&y.1))
69 .unwrap()
70 .0;
71 self.timesteps[index + 1]
72 }
73
74 pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
77 sample
78 }
79
80 pub fn step(&self, model_output: &Tensor, ts: f64, sample: &Tensor) -> Result<Tensor> {
81 let prev_t = self.previous_timestep(ts);
82
83 let alpha_cumprod = self.alpha_cumprod(ts);
84 let alpha_cumprod_prev = self.alpha_cumprod(prev_t);
85 let alpha = alpha_cumprod / alpha_cumprod_prev;
86
87 let mu = (sample - model_output * ((1. - alpha) / (1. - alpha_cumprod).sqrt()))?;
88 let mu = (mu * (1. / alpha).sqrt())?;
89
90 let std_noise = mu.randn_like(0., 1.)?;
91 let std =
92 std_noise * ((1. - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt();
93 if prev_t == 0. {
94 Ok(mu)
95 } else {
96 mu + std
97 }
98 }
99
100 pub fn init_noise_sigma(&self) -> f64 {
101 self.init_noise_sigma
102 }
103}