candle_transformers/models/stable_diffusion/
ddim.rs1use super::schedulers::{
11 betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig, TimestepSpacing,
12};
13use candle::{Result, Tensor};
14
15#[derive(Debug, Clone, Copy)]
17pub struct DDIMSchedulerConfig {
18 pub beta_start: f64,
20 pub beta_end: f64,
22 pub beta_schedule: BetaSchedule,
24 pub eta: f64,
26 pub steps_offset: usize,
28 pub prediction_type: PredictionType,
32 pub train_timesteps: usize,
34 pub timestep_spacing: TimestepSpacing,
36}
37
38impl Default for DDIMSchedulerConfig {
39 fn default() -> Self {
40 Self {
41 beta_start: 0.00085f64,
42 beta_end: 0.012f64,
43 beta_schedule: BetaSchedule::ScaledLinear,
44 eta: 0.,
45 steps_offset: 1,
46 prediction_type: PredictionType::Epsilon,
47 train_timesteps: 1000,
48 timestep_spacing: TimestepSpacing::Leading,
49 }
50 }
51}
52
53impl SchedulerConfig for DDIMSchedulerConfig {
54 fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>> {
55 Ok(Box::new(DDIMScheduler::new(inference_steps, *self)?))
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct DDIMScheduler {
62 timesteps: Vec<usize>,
63 alphas_cumprod: Vec<f64>,
64 step_ratio: usize,
65 init_noise_sigma: f64,
66 pub config: DDIMSchedulerConfig,
67}
68
69impl DDIMScheduler {
71 fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> {
75 let step_ratio = config.train_timesteps / inference_steps;
76 let timesteps: Vec<usize> = match config.timestep_spacing {
77 TimestepSpacing::Leading => (0..(inference_steps))
78 .map(|s| s * step_ratio + config.steps_offset)
79 .rev()
80 .collect(),
81 TimestepSpacing::Trailing => std::iter::successors(Some(config.train_timesteps), |n| {
82 if *n > step_ratio {
83 Some(n - step_ratio)
84 } else {
85 None
86 }
87 })
88 .map(|n| n - 1)
89 .collect(),
90 TimestepSpacing::Linspace => {
91 super::utils::linspace(0.0, (config.train_timesteps - 1) as f64, inference_steps)?
92 .to_vec1::<f64>()?
93 .iter()
94 .map(|&f| f as usize)
95 .rev()
96 .collect()
97 }
98 };
99
100 let betas = match config.beta_schedule {
101 BetaSchedule::ScaledLinear => super::utils::linspace(
102 config.beta_start.sqrt(),
103 config.beta_end.sqrt(),
104 config.train_timesteps,
105 )?
106 .sqr()?,
107 BetaSchedule::Linear => {
108 super::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
109 }
110 BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?,
111 };
112 let betas = betas.to_vec1::<f64>()?;
113 let mut alphas_cumprod = Vec::with_capacity(betas.len());
114 for &beta in betas.iter() {
115 let alpha = 1.0 - beta;
116 alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64))
117 }
118 Ok(Self {
119 alphas_cumprod,
120 timesteps,
121 step_ratio,
122 init_noise_sigma: 1.,
123 config,
124 })
125 }
126}
127
128impl Scheduler for DDIMScheduler {
129 fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
131 let timestep = if timestep >= self.alphas_cumprod.len() {
132 timestep - 1
133 } else {
134 timestep
135 };
136 let prev_timestep = if timestep > self.step_ratio {
138 timestep - self.step_ratio
139 } else {
140 0
141 };
142
143 let alpha_prod_t = self.alphas_cumprod[timestep];
144 let alpha_prod_t_prev = self.alphas_cumprod[prev_timestep];
145 let beta_prod_t = 1. - alpha_prod_t;
146 let beta_prod_t_prev = 1. - alpha_prod_t_prev;
147
148 let (pred_original_sample, pred_epsilon) = match self.config.prediction_type {
149 PredictionType::Epsilon => {
150 let pred_original_sample = ((sample - (model_output * beta_prod_t.sqrt())?)?
151 * (1. / alpha_prod_t.sqrt()))?;
152 (pred_original_sample, model_output.clone())
153 }
154 PredictionType::VPrediction => {
155 let pred_original_sample =
156 ((sample * alpha_prod_t.sqrt())? - (model_output * beta_prod_t.sqrt())?)?;
157 let pred_epsilon =
158 ((model_output * alpha_prod_t.sqrt())? + (sample * beta_prod_t.sqrt())?)?;
159 (pred_original_sample, pred_epsilon)
160 }
161 PredictionType::Sample => {
162 let pred_original_sample = model_output.clone();
163 let pred_epsilon = ((sample - &pred_original_sample * alpha_prod_t.sqrt())?
164 * (1. / beta_prod_t.sqrt()))?;
165 (pred_original_sample, pred_epsilon)
166 }
167 };
168
169 let variance = (beta_prod_t_prev / beta_prod_t) * (1. - alpha_prod_t / alpha_prod_t_prev);
170 let std_dev_t = self.config.eta * variance.sqrt();
171
172 let pred_sample_direction =
173 (pred_epsilon * (1. - alpha_prod_t_prev - std_dev_t * std_dev_t).sqrt())?;
174 let prev_sample =
175 ((pred_original_sample * alpha_prod_t_prev.sqrt())? + pred_sample_direction)?;
176 if self.config.eta > 0. {
177 &prev_sample
178 + Tensor::randn(
179 0f32,
180 std_dev_t as f32,
181 prev_sample.shape(),
182 prev_sample.device(),
183 )?
184 } else {
185 Ok(prev_sample)
186 }
187 }
188
189 fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> {
192 Ok(sample)
193 }
194
195 fn timesteps(&self) -> &[usize] {
196 self.timesteps.as_slice()
197 }
198
199 fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
200 let timestep = if timestep >= self.alphas_cumprod.len() {
201 timestep - 1
202 } else {
203 timestep
204 };
205 let sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt();
206 let sqrt_one_minus_alpha_prod = (1.0 - self.alphas_cumprod[timestep]).sqrt();
207 (original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)?
208 }
209
210 fn init_noise_sigma(&self) -> f64 {
211 self.init_noise_sigma
212 }
213}