candle_transformers/models/stable_diffusion/
schedulers.rs

1#![allow(dead_code)]
2//! # Diffusion pipelines and models
3//!
4//! Noise schedulers can be used to set the trade-off between
5//! inference speed and quality.
6use candle::{Result, Tensor};
7
8pub trait SchedulerConfig: std::fmt::Debug + Send + Sync {
9    fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>>;
10}
11
12/// This trait represents a scheduler for the diffusion process.
13pub trait Scheduler {
14    fn timesteps(&self) -> &[usize];
15
16    fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor>;
17
18    fn init_noise_sigma(&self) -> f64;
19
20    fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor>;
21
22    fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor>;
23}
24
25/// This represents how beta ranges from its minimum value to the maximum
26/// during training.
27#[derive(Debug, Clone, Copy)]
28pub enum BetaSchedule {
29    /// Linear interpolation.
30    Linear,
31    /// Linear interpolation of the square root of beta.
32    ScaledLinear,
33    /// Glide cosine schedule
34    SquaredcosCapV2,
35}
36
37#[derive(Debug, Clone, Copy)]
38pub enum PredictionType {
39    Epsilon,
40    VPrediction,
41    Sample,
42}
43
44/// Time step spacing for the diffusion process.
45///
46/// "linspace", "leading", "trailing" corresponds to annotation of Table 2. of the [paper](https://arxiv.org/abs/2305.08891)
47#[derive(Debug, Clone, Copy)]
48pub enum TimestepSpacing {
49    Leading,
50    Linspace,
51    Trailing,
52}
53
54impl Default for TimestepSpacing {
55    fn default() -> Self {
56        Self::Leading
57    }
58}
59
60/// Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
61/// `(1-beta)` over time from `t = [0,1]`.
62///
63/// Contains a function `alpha_bar` that takes an argument `t` and transforms it to the cumulative product of `(1-beta)`
64/// up to that part of the diffusion process.
65pub(crate) fn betas_for_alpha_bar(num_diffusion_timesteps: usize, max_beta: f64) -> Result<Tensor> {
66    let alpha_bar = |time_step: usize| {
67        f64::cos((time_step as f64 + 0.008) / 1.008 * std::f64::consts::FRAC_PI_2).powi(2)
68    };
69    let mut betas = Vec::with_capacity(num_diffusion_timesteps);
70    for i in 0..num_diffusion_timesteps {
71        let t1 = i / num_diffusion_timesteps;
72        let t2 = (i + 1) / num_diffusion_timesteps;
73        betas.push((1.0 - alpha_bar(t2) / alpha_bar(t1)).min(max_beta));
74    }
75    let betas_len = betas.len();
76    Tensor::from_vec(betas, betas_len, &candle::Device::Cpu)
77}