candle_transformers/models/stable_diffusion/
ddim.rs

1//! # Denoising Diffusion Implicit Models
2//!
3//! The Denoising Diffusion Implicit Models (DDIM) is a simple scheduler
4//! similar to Denoising Diffusion Probabilistic Models (DDPM). The DDPM
5//! generative process is the reverse of a Markovian process, DDIM generalizes
6//! this to non-Markovian guidance.
7//!
8//! Denoising Diffusion Implicit Models, J. Song et al, 2020.
9//! https://arxiv.org/abs/2010.02502
10use super::schedulers::{
11    betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig, TimestepSpacing,
12};
13use candle::{Result, Tensor};
14
15/// The configuration for the DDIM scheduler.
16#[derive(Debug, Clone, Copy)]
17pub struct DDIMSchedulerConfig {
18    /// The value of beta at the beginning of training.
19    pub beta_start: f64,
20    /// The value of beta at the end of training.
21    pub beta_end: f64,
22    /// How beta evolved during training.
23    pub beta_schedule: BetaSchedule,
24    /// The amount of noise to be added at each step.
25    pub eta: f64,
26    /// Adjust the indexes of the inference schedule by this value.
27    pub steps_offset: usize,
28    /// prediction type of the scheduler function, one of `epsilon` (predicting
29    /// the noise of the diffusion process), `sample` (directly predicting the noisy sample`)
30    /// or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf)
31    pub prediction_type: PredictionType,
32    /// number of diffusion steps used to train the model
33    pub train_timesteps: usize,
34    /// time step spacing for the diffusion process
35    pub timestep_spacing: TimestepSpacing,
36}
37
38impl Default for DDIMSchedulerConfig {
39    fn default() -> Self {
40        Self {
41            beta_start: 0.00085f64,
42            beta_end: 0.012f64,
43            beta_schedule: BetaSchedule::ScaledLinear,
44            eta: 0.,
45            steps_offset: 1,
46            prediction_type: PredictionType::Epsilon,
47            train_timesteps: 1000,
48            timestep_spacing: TimestepSpacing::Leading,
49        }
50    }
51}
52
53impl SchedulerConfig for DDIMSchedulerConfig {
54    fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>> {
55        Ok(Box::new(DDIMScheduler::new(inference_steps, *self)?))
56    }
57}
58
59/// The DDIM scheduler.
60#[derive(Debug, Clone)]
61pub struct DDIMScheduler {
62    timesteps: Vec<usize>,
63    alphas_cumprod: Vec<f64>,
64    step_ratio: usize,
65    init_noise_sigma: f64,
66    pub config: DDIMSchedulerConfig,
67}
68
69// clip_sample: False, set_alpha_to_one: False
70impl DDIMScheduler {
71    /// Creates a new DDIM scheduler given the number of steps to be
72    /// used for inference as well as the number of steps that was used
73    /// during training.
74    fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> {
75        let step_ratio = config.train_timesteps / inference_steps;
76        let timesteps: Vec<usize> = match config.timestep_spacing {
77            TimestepSpacing::Leading => (0..(inference_steps))
78                .map(|s| s * step_ratio + config.steps_offset)
79                .rev()
80                .collect(),
81            TimestepSpacing::Trailing => std::iter::successors(Some(config.train_timesteps), |n| {
82                if *n > step_ratio {
83                    Some(n - step_ratio)
84                } else {
85                    None
86                }
87            })
88            .map(|n| n - 1)
89            .collect(),
90            TimestepSpacing::Linspace => {
91                super::utils::linspace(0.0, (config.train_timesteps - 1) as f64, inference_steps)?
92                    .to_vec1::<f64>()?
93                    .iter()
94                    .map(|&f| f as usize)
95                    .rev()
96                    .collect()
97            }
98        };
99
100        let betas = match config.beta_schedule {
101            BetaSchedule::ScaledLinear => super::utils::linspace(
102                config.beta_start.sqrt(),
103                config.beta_end.sqrt(),
104                config.train_timesteps,
105            )?
106            .sqr()?,
107            BetaSchedule::Linear => {
108                super::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
109            }
110            BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?,
111        };
112        let betas = betas.to_vec1::<f64>()?;
113        let mut alphas_cumprod = Vec::with_capacity(betas.len());
114        for &beta in betas.iter() {
115            let alpha = 1.0 - beta;
116            alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64))
117        }
118        Ok(Self {
119            alphas_cumprod,
120            timesteps,
121            step_ratio,
122            init_noise_sigma: 1.,
123            config,
124        })
125    }
126}
127
128impl Scheduler for DDIMScheduler {
129    /// Performs a backward step during inference.
130    fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
131        let timestep = if timestep >= self.alphas_cumprod.len() {
132            timestep - 1
133        } else {
134            timestep
135        };
136        // https://github.com/huggingface/diffusers/blob/6e099e2c8ce4c4f5c7318e970a8c093dc5c7046e/src/diffusers/schedulers/scheduling_ddim.py#L195
137        let prev_timestep = if timestep > self.step_ratio {
138            timestep - self.step_ratio
139        } else {
140            0
141        };
142
143        let alpha_prod_t = self.alphas_cumprod[timestep];
144        let alpha_prod_t_prev = self.alphas_cumprod[prev_timestep];
145        let beta_prod_t = 1. - alpha_prod_t;
146        let beta_prod_t_prev = 1. - alpha_prod_t_prev;
147
148        let (pred_original_sample, pred_epsilon) = match self.config.prediction_type {
149            PredictionType::Epsilon => {
150                let pred_original_sample = ((sample - (model_output * beta_prod_t.sqrt())?)?
151                    * (1. / alpha_prod_t.sqrt()))?;
152                (pred_original_sample, model_output.clone())
153            }
154            PredictionType::VPrediction => {
155                let pred_original_sample =
156                    ((sample * alpha_prod_t.sqrt())? - (model_output * beta_prod_t.sqrt())?)?;
157                let pred_epsilon =
158                    ((model_output * alpha_prod_t.sqrt())? + (sample * beta_prod_t.sqrt())?)?;
159                (pred_original_sample, pred_epsilon)
160            }
161            PredictionType::Sample => {
162                let pred_original_sample = model_output.clone();
163                let pred_epsilon = ((sample - &pred_original_sample * alpha_prod_t.sqrt())?
164                    * (1. / beta_prod_t.sqrt()))?;
165                (pred_original_sample, pred_epsilon)
166            }
167        };
168
169        let variance = (beta_prod_t_prev / beta_prod_t) * (1. - alpha_prod_t / alpha_prod_t_prev);
170        let std_dev_t = self.config.eta * variance.sqrt();
171
172        let pred_sample_direction =
173            (pred_epsilon * (1. - alpha_prod_t_prev - std_dev_t * std_dev_t).sqrt())?;
174        let prev_sample =
175            ((pred_original_sample * alpha_prod_t_prev.sqrt())? + pred_sample_direction)?;
176        if self.config.eta > 0. {
177            &prev_sample
178                + Tensor::randn(
179                    0f32,
180                    std_dev_t as f32,
181                    prev_sample.shape(),
182                    prev_sample.device(),
183                )?
184        } else {
185            Ok(prev_sample)
186        }
187    }
188
189    ///  Ensures interchangeability with schedulers that need to scale the denoising model input
190    /// depending on the current timestep.
191    fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> {
192        Ok(sample)
193    }
194
195    fn timesteps(&self) -> &[usize] {
196        self.timesteps.as_slice()
197    }
198
199    fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
200        let timestep = if timestep >= self.alphas_cumprod.len() {
201            timestep - 1
202        } else {
203            timestep
204        };
205        let sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt();
206        let sqrt_one_minus_alpha_prod = (1.0 - self.alphas_cumprod[timestep]).sqrt();
207        (original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)?
208    }
209
210    fn init_noise_sigma(&self) -> f64 {
211        self.init_noise_sigma
212    }
213}