candle_transformers/models/stable_diffusion/
ddpm.rs1use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType};
2use candle::{Result, Tensor};
3
4#[derive(Debug, Clone, PartialEq, Eq)]
5pub enum DDPMVarianceType {
6 FixedSmall,
7 FixedSmallLog,
8 FixedLarge,
9 FixedLargeLog,
10 Learned,
11}
12
13impl Default for DDPMVarianceType {
14 fn default() -> Self {
15 Self::FixedSmall
16 }
17}
18
19#[derive(Debug, Clone)]
20pub struct DDPMSchedulerConfig {
21 pub beta_start: f64,
23 pub beta_end: f64,
25 pub beta_schedule: BetaSchedule,
27 pub clip_sample: bool,
29 pub variance_type: DDPMVarianceType,
31 pub prediction_type: PredictionType,
33 pub train_timesteps: usize,
35}
36
37impl Default for DDPMSchedulerConfig {
38 fn default() -> Self {
39 Self {
40 beta_start: 0.00085,
41 beta_end: 0.012,
42 beta_schedule: BetaSchedule::ScaledLinear,
43 clip_sample: false,
44 variance_type: DDPMVarianceType::FixedSmall,
45 prediction_type: PredictionType::Epsilon,
46 train_timesteps: 1000,
47 }
48 }
49}
50
51pub struct DDPMScheduler {
52 alphas_cumprod: Vec<f64>,
53 init_noise_sigma: f64,
54 timesteps: Vec<usize>,
55 step_ratio: usize,
56 pub config: DDPMSchedulerConfig,
57}
58
59impl DDPMScheduler {
60 pub fn new(inference_steps: usize, config: DDPMSchedulerConfig) -> Result<Self> {
61 let betas = match config.beta_schedule {
62 BetaSchedule::ScaledLinear => super::utils::linspace(
63 config.beta_start.sqrt(),
64 config.beta_end.sqrt(),
65 config.train_timesteps,
66 )?
67 .sqr()?,
68 BetaSchedule::Linear => {
69 super::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
70 }
71 BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?,
72 };
73
74 let betas = betas.to_vec1::<f64>()?;
75 let mut alphas_cumprod = Vec::with_capacity(betas.len());
76 for &beta in betas.iter() {
77 let alpha = 1.0 - beta;
78 alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64))
79 }
80
81 let inference_steps = inference_steps.min(config.train_timesteps);
84 let step_ratio = config.train_timesteps / inference_steps;
86 let timesteps: Vec<usize> = (0..inference_steps).map(|s| s * step_ratio).rev().collect();
87
88 Ok(Self {
89 alphas_cumprod,
90 init_noise_sigma: 1.0,
91 timesteps,
92 step_ratio,
93 config,
94 })
95 }
96
97 fn get_variance(&self, timestep: usize) -> f64 {
98 let prev_t = timestep as isize - self.step_ratio as isize;
99 let alpha_prod_t = self.alphas_cumprod[timestep];
100 let alpha_prod_t_prev = if prev_t >= 0 {
101 self.alphas_cumprod[prev_t as usize]
102 } else {
103 1.0
104 };
105 let current_beta_t = 1. - alpha_prod_t / alpha_prod_t_prev;
106
107 let variance = (1. - alpha_prod_t_prev) / (1. - alpha_prod_t) * current_beta_t;
111
112 match self.config.variance_type {
114 DDPMVarianceType::FixedSmall => variance.max(1e-20),
115 DDPMVarianceType::FixedSmallLog => {
117 let variance = variance.max(1e-20).ln();
118 (variance * 0.5).exp()
119 }
120 DDPMVarianceType::FixedLarge => current_beta_t,
121 DDPMVarianceType::FixedLargeLog => current_beta_t.ln(),
122 DDPMVarianceType::Learned => variance,
123 }
124 }
125
126 pub fn timesteps(&self) -> &[usize] {
127 self.timesteps.as_slice()
128 }
129
130 pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
133 sample
134 }
135
136 pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
137 let prev_t = timestep as isize - self.step_ratio as isize;
138
139 let alpha_prod_t = self.alphas_cumprod[timestep];
142 let alpha_prod_t_prev = if prev_t >= 0 {
143 self.alphas_cumprod[prev_t as usize]
144 } else {
145 1.0
146 };
147 let beta_prod_t = 1. - alpha_prod_t;
148 let beta_prod_t_prev = 1. - alpha_prod_t_prev;
149 let current_alpha_t = alpha_prod_t / alpha_prod_t_prev;
150 let current_beta_t = 1. - current_alpha_t;
151
152 let mut pred_original_sample = match self.config.prediction_type {
154 PredictionType::Epsilon => {
155 ((sample - model_output * beta_prod_t.sqrt())? / alpha_prod_t.sqrt())?
156 }
157 PredictionType::Sample => model_output.clone(),
158 PredictionType::VPrediction => {
159 ((sample * alpha_prod_t.sqrt())? - model_output * beta_prod_t.sqrt())?
160 }
161 };
162
163 if self.config.clip_sample {
165 pred_original_sample = pred_original_sample.clamp(-1f32, 1f32)?;
166 }
167
168 let pred_original_sample_coeff = (alpha_prod_t_prev.sqrt() * current_beta_t) / beta_prod_t;
171 let current_sample_coeff = current_alpha_t.sqrt() * beta_prod_t_prev / beta_prod_t;
172
173 let pred_prev_sample = ((&pred_original_sample * pred_original_sample_coeff)?
176 + sample * current_sample_coeff)?;
177
178 let mut variance = model_output.zeros_like()?;
181 if timestep > 0 {
182 let variance_noise = model_output.randn_like(0., 1.)?;
183 if self.config.variance_type == DDPMVarianceType::FixedSmallLog {
184 variance = (variance_noise * self.get_variance(timestep))?;
185 } else {
186 variance = (variance_noise * self.get_variance(timestep).sqrt())?;
187 }
188 }
189 &pred_prev_sample + variance
190 }
191
192 pub fn add_noise(
193 &self,
194 original_samples: &Tensor,
195 noise: Tensor,
196 timestep: usize,
197 ) -> Result<Tensor> {
198 (original_samples * self.alphas_cumprod[timestep].sqrt())?
199 + noise * (1. - self.alphas_cumprod[timestep]).sqrt()
200 }
201
202 pub fn init_noise_sigma(&self) -> f64 {
203 self.init_noise_sigma
204 }
205}