candle_transformers/models/stable_diffusion/
ddpm.rs

1use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType};
2use candle::{Result, Tensor};
3
4#[derive(Debug, Clone, PartialEq, Eq)]
5pub enum DDPMVarianceType {
6    FixedSmall,
7    FixedSmallLog,
8    FixedLarge,
9    FixedLargeLog,
10    Learned,
11}
12
13impl Default for DDPMVarianceType {
14    fn default() -> Self {
15        Self::FixedSmall
16    }
17}
18
19#[derive(Debug, Clone)]
20pub struct DDPMSchedulerConfig {
21    /// The value of beta at the beginning of training.
22    pub beta_start: f64,
23    /// The value of beta at the end of training.
24    pub beta_end: f64,
25    /// How beta evolved during training.
26    pub beta_schedule: BetaSchedule,
27    /// Option to predicted sample between -1 and 1 for numerical stability.
28    pub clip_sample: bool,
29    /// Option to clip the variance used when adding noise to the denoised sample.
30    pub variance_type: DDPMVarianceType,
31    /// prediction type of the scheduler function
32    pub prediction_type: PredictionType,
33    /// number of diffusion steps used to train the model.
34    pub train_timesteps: usize,
35}
36
37impl Default for DDPMSchedulerConfig {
38    fn default() -> Self {
39        Self {
40            beta_start: 0.00085,
41            beta_end: 0.012,
42            beta_schedule: BetaSchedule::ScaledLinear,
43            clip_sample: false,
44            variance_type: DDPMVarianceType::FixedSmall,
45            prediction_type: PredictionType::Epsilon,
46            train_timesteps: 1000,
47        }
48    }
49}
50
51pub struct DDPMScheduler {
52    alphas_cumprod: Vec<f64>,
53    init_noise_sigma: f64,
54    timesteps: Vec<usize>,
55    step_ratio: usize,
56    pub config: DDPMSchedulerConfig,
57}
58
59impl DDPMScheduler {
60    pub fn new(inference_steps: usize, config: DDPMSchedulerConfig) -> Result<Self> {
61        let betas = match config.beta_schedule {
62            BetaSchedule::ScaledLinear => super::utils::linspace(
63                config.beta_start.sqrt(),
64                config.beta_end.sqrt(),
65                config.train_timesteps,
66            )?
67            .sqr()?,
68            BetaSchedule::Linear => {
69                super::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
70            }
71            BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?,
72        };
73
74        let betas = betas.to_vec1::<f64>()?;
75        let mut alphas_cumprod = Vec::with_capacity(betas.len());
76        for &beta in betas.iter() {
77            let alpha = 1.0 - beta;
78            alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64))
79        }
80
81        // min(train_timesteps, inference_steps)
82        // https://github.com/huggingface/diffusers/blob/8331da46837be40f96fbd24de6a6fb2da28acd11/src/diffusers/schedulers/scheduling_ddpm.py#L187
83        let inference_steps = inference_steps.min(config.train_timesteps);
84        // arange the number of the scheduler's timesteps
85        let step_ratio = config.train_timesteps / inference_steps;
86        let timesteps: Vec<usize> = (0..inference_steps).map(|s| s * step_ratio).rev().collect();
87
88        Ok(Self {
89            alphas_cumprod,
90            init_noise_sigma: 1.0,
91            timesteps,
92            step_ratio,
93            config,
94        })
95    }
96
97    fn get_variance(&self, timestep: usize) -> f64 {
98        let prev_t = timestep as isize - self.step_ratio as isize;
99        let alpha_prod_t = self.alphas_cumprod[timestep];
100        let alpha_prod_t_prev = if prev_t >= 0 {
101            self.alphas_cumprod[prev_t as usize]
102        } else {
103            1.0
104        };
105        let current_beta_t = 1. - alpha_prod_t / alpha_prod_t_prev;
106
107        // For t > 0, compute predicted variance βt (see formula (6) and (7) from [the pdf](https://arxiv.org/pdf/2006.11239.pdf))
108        // and sample from it to get previous sample
109        // x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
110        let variance = (1. - alpha_prod_t_prev) / (1. - alpha_prod_t) * current_beta_t;
111
112        // retrieve variance
113        match self.config.variance_type {
114            DDPMVarianceType::FixedSmall => variance.max(1e-20),
115            // for rl-diffuser https://arxiv.org/abs/2205.09991
116            DDPMVarianceType::FixedSmallLog => {
117                let variance = variance.max(1e-20).ln();
118                (variance * 0.5).exp()
119            }
120            DDPMVarianceType::FixedLarge => current_beta_t,
121            DDPMVarianceType::FixedLargeLog => current_beta_t.ln(),
122            DDPMVarianceType::Learned => variance,
123        }
124    }
125
126    pub fn timesteps(&self) -> &[usize] {
127        self.timesteps.as_slice()
128    }
129
130    ///  Ensures interchangeability with schedulers that need to scale the denoising model input
131    /// depending on the current timestep.
132    pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
133        sample
134    }
135
136    pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
137        let prev_t = timestep as isize - self.step_ratio as isize;
138
139        // https://github.com/huggingface/diffusers/blob/df2b548e893ccb8a888467c2508756680df22821/src/diffusers/schedulers/scheduling_ddpm.py#L272
140        // 1. compute alphas, betas
141        let alpha_prod_t = self.alphas_cumprod[timestep];
142        let alpha_prod_t_prev = if prev_t >= 0 {
143            self.alphas_cumprod[prev_t as usize]
144        } else {
145            1.0
146        };
147        let beta_prod_t = 1. - alpha_prod_t;
148        let beta_prod_t_prev = 1. - alpha_prod_t_prev;
149        let current_alpha_t = alpha_prod_t / alpha_prod_t_prev;
150        let current_beta_t = 1. - current_alpha_t;
151
152        // 2. compute predicted original sample from predicted noise also called "predicted x_0" of formula (15)
153        let mut pred_original_sample = match self.config.prediction_type {
154            PredictionType::Epsilon => {
155                ((sample - model_output * beta_prod_t.sqrt())? / alpha_prod_t.sqrt())?
156            }
157            PredictionType::Sample => model_output.clone(),
158            PredictionType::VPrediction => {
159                ((sample * alpha_prod_t.sqrt())? - model_output * beta_prod_t.sqrt())?
160            }
161        };
162
163        // 3. clip predicted x_0
164        if self.config.clip_sample {
165            pred_original_sample = pred_original_sample.clamp(-1f32, 1f32)?;
166        }
167
168        // 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
169        // See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
170        let pred_original_sample_coeff = (alpha_prod_t_prev.sqrt() * current_beta_t) / beta_prod_t;
171        let current_sample_coeff = current_alpha_t.sqrt() * beta_prod_t_prev / beta_prod_t;
172
173        // 5. Compute predicted previous sample µ_t
174        // See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
175        let pred_prev_sample = ((&pred_original_sample * pred_original_sample_coeff)?
176            + sample * current_sample_coeff)?;
177
178        // https://github.com/huggingface/diffusers/blob/df2b548e893ccb8a888467c2508756680df22821/src/diffusers/schedulers/scheduling_ddpm.py#L305
179        // 6. Add noise
180        let mut variance = model_output.zeros_like()?;
181        if timestep > 0 {
182            let variance_noise = model_output.randn_like(0., 1.)?;
183            if self.config.variance_type == DDPMVarianceType::FixedSmallLog {
184                variance = (variance_noise * self.get_variance(timestep))?;
185            } else {
186                variance = (variance_noise * self.get_variance(timestep).sqrt())?;
187            }
188        }
189        &pred_prev_sample + variance
190    }
191
192    pub fn add_noise(
193        &self,
194        original_samples: &Tensor,
195        noise: Tensor,
196        timestep: usize,
197    ) -> Result<Tensor> {
198        (original_samples * self.alphas_cumprod[timestep].sqrt())?
199            + noise * (1. - self.alphas_cumprod[timestep]).sqrt()
200    }
201
202    pub fn init_noise_sigma(&self) -> f64 {
203        self.init_noise_sigma
204    }
205}