candle_transformers/models/stable_diffusion/
euler_ancestral_discrete.rs

1//! Ancestral sampling with Euler method steps.
2//!
3//! Based on the original [`k-diffusion` implementation by Katherine Crowson]( https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72).
4//!
5use super::{
6    schedulers::{
7        betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig,
8        TimestepSpacing,
9    },
10    utils::interp,
11};
12use candle::{bail, Error, Result, Tensor};
13
14/// The configuration for the EulerAncestral Discrete scheduler.
15#[derive(Debug, Clone, Copy)]
16pub struct EulerAncestralDiscreteSchedulerConfig {
17    /// The value of beta at the beginning of training.n
18    pub beta_start: f64,
19    /// The value of beta at the end of training.
20    pub beta_end: f64,
21    /// How beta evolved during training.
22    pub beta_schedule: BetaSchedule,
23    /// Adjust the indexes of the inference schedule by this value.
24    pub steps_offset: usize,
25    /// prediction type of the scheduler function, one of `epsilon` (predicting
26    /// the noise of the diffusion process), `sample` (directly predicting the noisy sample`)
27    /// or `v_prediction` (see [section 2.4](https://imagen.research.google/video/paper.pdf))
28    pub prediction_type: PredictionType,
29    /// number of diffusion steps used to train the model
30    pub train_timesteps: usize,
31    /// time step spacing for the diffusion process
32    pub timestep_spacing: TimestepSpacing,
33}
34
35impl Default for EulerAncestralDiscreteSchedulerConfig {
36    fn default() -> Self {
37        Self {
38            beta_start: 0.00085f64,
39            beta_end: 0.012f64,
40            beta_schedule: BetaSchedule::ScaledLinear,
41            steps_offset: 1,
42            prediction_type: PredictionType::Epsilon,
43            train_timesteps: 1000,
44            timestep_spacing: TimestepSpacing::Leading,
45        }
46    }
47}
48
49impl SchedulerConfig for EulerAncestralDiscreteSchedulerConfig {
50    fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>> {
51        Ok(Box::new(EulerAncestralDiscreteScheduler::new(
52            inference_steps,
53            *self,
54        )?))
55    }
56}
57
58/// The EulerAncestral Discrete scheduler.
59#[derive(Debug, Clone)]
60pub struct EulerAncestralDiscreteScheduler {
61    timesteps: Vec<usize>,
62    sigmas: Vec<f64>,
63    init_noise_sigma: f64,
64    pub config: EulerAncestralDiscreteSchedulerConfig,
65}
66
67// clip_sample: False, set_alpha_to_one: False
68impl EulerAncestralDiscreteScheduler {
69    /// Creates a new EulerAncestral Discrete scheduler given the number of steps to be
70    /// used for inference as well as the number of steps that was used
71    /// during training.
72    pub fn new(
73        inference_steps: usize,
74        config: EulerAncestralDiscreteSchedulerConfig,
75    ) -> Result<Self> {
76        let step_ratio = config.train_timesteps / inference_steps;
77        let timesteps: Vec<usize> = match config.timestep_spacing {
78            TimestepSpacing::Leading => (0..(inference_steps))
79                .map(|s| s * step_ratio + config.steps_offset)
80                .rev()
81                .collect(),
82            TimestepSpacing::Trailing => std::iter::successors(Some(config.train_timesteps), |n| {
83                if *n > step_ratio {
84                    Some(n - step_ratio)
85                } else {
86                    None
87                }
88            })
89            .map(|n| n - 1)
90            .collect(),
91            TimestepSpacing::Linspace => {
92                super::utils::linspace(0.0, (config.train_timesteps - 1) as f64, inference_steps)?
93                    .to_vec1::<f64>()?
94                    .iter()
95                    .map(|&f| f as usize)
96                    .rev()
97                    .collect()
98            }
99        };
100
101        let betas = match config.beta_schedule {
102            BetaSchedule::ScaledLinear => super::utils::linspace(
103                config.beta_start.sqrt(),
104                config.beta_end.sqrt(),
105                config.train_timesteps,
106            )?
107            .sqr()?,
108            BetaSchedule::Linear => {
109                super::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
110            }
111            BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?,
112        };
113        let betas = betas.to_vec1::<f64>()?;
114        let mut alphas_cumprod = Vec::with_capacity(betas.len());
115        for &beta in betas.iter() {
116            let alpha = 1.0 - beta;
117            alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64))
118        }
119        let sigmas: Vec<f64> = alphas_cumprod
120            .iter()
121            .map(|&f| ((1. - f) / f).sqrt())
122            .collect();
123
124        let sigmas_xa: Vec<_> = (0..sigmas.len()).map(|i| i as f64).collect();
125
126        let mut sigmas_int = interp(
127            &timesteps.iter().map(|&t| t as f64).collect::<Vec<_>>(),
128            &sigmas_xa,
129            &sigmas,
130        );
131        sigmas_int.push(0.0);
132
133        // standard deviation of the initial noise distribution
134        // f64 does not implement Ord such that there is no `max`, so we need to use this workaround
135        let init_noise_sigma = *sigmas_int
136            .iter()
137            .chain(std::iter::once(&0.0))
138            .reduce(|a, b| if a > b { a } else { b })
139            .expect("init_noise_sigma could not be reduced from sigmas - this should never happen");
140
141        Ok(Self {
142            sigmas: sigmas_int,
143            timesteps,
144            init_noise_sigma,
145            config,
146        })
147    }
148}
149
150impl Scheduler for EulerAncestralDiscreteScheduler {
151    fn timesteps(&self) -> &[usize] {
152        self.timesteps.as_slice()
153    }
154
155    /// Ensures interchangeability with schedulers that need to scale the denoising model input
156    /// depending on the current timestep.
157    ///
158    /// Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm
159    fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Result<Tensor> {
160        let step_index = match self.timesteps.iter().position(|&t| t == timestep) {
161            Some(i) => i,
162            None => bail!("timestep out of this schedulers bounds: {timestep}"),
163        };
164
165        let sigma = self
166            .sigmas
167            .get(step_index)
168            .expect("step_index out of sigma bounds - this shouldn't happen");
169
170        sample / ((sigma.powi(2) + 1.).sqrt())
171    }
172
173    /// Performs a backward step during inference.
174    fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
175        let step_index = self
176            .timesteps
177            .iter()
178            .position(|&p| p == timestep)
179            .ok_or_else(|| Error::Msg("timestep out of this schedulers bounds".to_string()))?;
180
181        let sigma_from = &self.sigmas[step_index];
182        let sigma_to = &self.sigmas[step_index + 1];
183
184        // 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
185        let pred_original_sample = match self.config.prediction_type {
186            PredictionType::Epsilon => (sample - (model_output * *sigma_from))?,
187            PredictionType::VPrediction => {
188                ((model_output * (-sigma_from / (sigma_from.powi(2) + 1.0).sqrt()))?
189                    + (sample / (sigma_from.powi(2) + 1.0))?)?
190            }
191            PredictionType::Sample => bail!("prediction_type not implemented yet: sample"),
192        };
193
194        let sigma_up = (sigma_to.powi(2) * (sigma_from.powi(2) - sigma_to.powi(2))
195            / sigma_from.powi(2))
196        .sqrt();
197        let sigma_down = (sigma_to.powi(2) - sigma_up.powi(2)).sqrt();
198
199        // 2. convert to a ODE derivative
200        let derivative = ((sample - pred_original_sample)? / *sigma_from)?;
201        let dt = sigma_down - *sigma_from;
202        let prev_sample = (sample + derivative * dt)?;
203
204        let noise = prev_sample.randn_like(0.0, 1.0)?;
205
206        prev_sample + noise * sigma_up
207    }
208
209    fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
210        let step_index = self
211            .timesteps
212            .iter()
213            .position(|&p| p == timestep)
214            .ok_or_else(|| Error::Msg("timestep out of this schedulers bounds".to_string()))?;
215
216        let sigma = self
217            .sigmas
218            .get(step_index)
219            .expect("step_index out of sigma bounds - this shouldn't happen");
220
221        original + (noise * *sigma)?
222    }
223
224    fn init_noise_sigma(&self) -> f64 {
225        match self.config.timestep_spacing {
226            TimestepSpacing::Trailing | TimestepSpacing::Linspace => self.init_noise_sigma,
227            TimestepSpacing::Leading => (self.init_noise_sigma.powi(2) + 1.0).sqrt(),
228        }
229    }
230}