candle_transformers/models/wuerstchen/
ddpm.rs

1use candle::{Result, Tensor};
2
3#[derive(Debug, Clone)]
4pub struct DDPMWSchedulerConfig {
5    scaler: f64,
6    s: f64,
7}
8
9impl Default for DDPMWSchedulerConfig {
10    fn default() -> Self {
11        Self {
12            scaler: 1f64,
13            s: 0.008f64,
14        }
15    }
16}
17
18pub struct DDPMWScheduler {
19    init_alpha_cumprod: f64,
20    init_noise_sigma: f64,
21    timesteps: Vec<f64>,
22    pub config: DDPMWSchedulerConfig,
23}
24
25impl DDPMWScheduler {
26    pub fn new(inference_steps: usize, config: DDPMWSchedulerConfig) -> Result<Self> {
27        let init_alpha_cumprod = (config.s / (1. + config.s) * std::f64::consts::PI)
28            .cos()
29            .powi(2);
30        let timesteps = (0..=inference_steps)
31            .map(|i| 1. - i as f64 / inference_steps as f64)
32            .collect::<Vec<_>>();
33        Ok(Self {
34            init_alpha_cumprod,
35            init_noise_sigma: 1.0,
36            timesteps,
37            config,
38        })
39    }
40
41    pub fn timesteps(&self) -> &[f64] {
42        &self.timesteps
43    }
44
45    fn alpha_cumprod(&self, t: f64) -> f64 {
46        let scaler = self.config.scaler;
47        let s = self.config.s;
48        let t = if scaler > 1. {
49            1. - (1. - t).powf(scaler)
50        } else if scaler < 1. {
51            t.powf(scaler)
52        } else {
53            t
54        };
55        let alpha_cumprod = ((t + s) / (1. + s) * std::f64::consts::PI * 0.5)
56            .cos()
57            .powi(2)
58            / self.init_alpha_cumprod;
59        alpha_cumprod.clamp(0.0001, 0.9999)
60    }
61
62    fn previous_timestep(&self, ts: f64) -> f64 {
63        let index = self
64            .timesteps
65            .iter()
66            .enumerate()
67            .map(|(idx, v)| (idx, (v - ts).abs()))
68            .min_by(|x, y| x.1.total_cmp(&y.1))
69            .unwrap()
70            .0;
71        self.timesteps[index + 1]
72    }
73
74    ///  Ensures interchangeability with schedulers that need to scale the denoising model input
75    /// depending on the current timestep.
76    pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
77        sample
78    }
79
80    pub fn step(&self, model_output: &Tensor, ts: f64, sample: &Tensor) -> Result<Tensor> {
81        let prev_t = self.previous_timestep(ts);
82
83        let alpha_cumprod = self.alpha_cumprod(ts);
84        let alpha_cumprod_prev = self.alpha_cumprod(prev_t);
85        let alpha = alpha_cumprod / alpha_cumprod_prev;
86
87        let mu = (sample - model_output * ((1. - alpha) / (1. - alpha_cumprod).sqrt()))?;
88        let mu = (mu * (1. / alpha).sqrt())?;
89
90        let std_noise = mu.randn_like(0., 1.)?;
91        let std =
92            std_noise * ((1. - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt();
93        if prev_t == 0. {
94            Ok(mu)
95        } else {
96            mu + std
97        }
98    }
99
100    pub fn init_noise_sigma(&self) -> f64 {
101        self.init_noise_sigma
102    }
103}