candle_transformers/models/stable_diffusion/
euler_ancestral_discrete.rs1use super::{
6 schedulers::{
7 betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig,
8 TimestepSpacing,
9 },
10 utils::interp,
11};
12use candle::{bail, Error, Result, Tensor};
13
14#[derive(Debug, Clone, Copy)]
16pub struct EulerAncestralDiscreteSchedulerConfig {
17 pub beta_start: f64,
19 pub beta_end: f64,
21 pub beta_schedule: BetaSchedule,
23 pub steps_offset: usize,
25 pub prediction_type: PredictionType,
29 pub train_timesteps: usize,
31 pub timestep_spacing: TimestepSpacing,
33}
34
35impl Default for EulerAncestralDiscreteSchedulerConfig {
36 fn default() -> Self {
37 Self {
38 beta_start: 0.00085f64,
39 beta_end: 0.012f64,
40 beta_schedule: BetaSchedule::ScaledLinear,
41 steps_offset: 1,
42 prediction_type: PredictionType::Epsilon,
43 train_timesteps: 1000,
44 timestep_spacing: TimestepSpacing::Leading,
45 }
46 }
47}
48
49impl SchedulerConfig for EulerAncestralDiscreteSchedulerConfig {
50 fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>> {
51 Ok(Box::new(EulerAncestralDiscreteScheduler::new(
52 inference_steps,
53 *self,
54 )?))
55 }
56}
57
58#[derive(Debug, Clone)]
60pub struct EulerAncestralDiscreteScheduler {
61 timesteps: Vec<usize>,
62 sigmas: Vec<f64>,
63 init_noise_sigma: f64,
64 pub config: EulerAncestralDiscreteSchedulerConfig,
65}
66
67impl EulerAncestralDiscreteScheduler {
69 pub fn new(
73 inference_steps: usize,
74 config: EulerAncestralDiscreteSchedulerConfig,
75 ) -> Result<Self> {
76 let step_ratio = config.train_timesteps / inference_steps;
77 let timesteps: Vec<usize> = match config.timestep_spacing {
78 TimestepSpacing::Leading => (0..(inference_steps))
79 .map(|s| s * step_ratio + config.steps_offset)
80 .rev()
81 .collect(),
82 TimestepSpacing::Trailing => std::iter::successors(Some(config.train_timesteps), |n| {
83 if *n > step_ratio {
84 Some(n - step_ratio)
85 } else {
86 None
87 }
88 })
89 .map(|n| n - 1)
90 .collect(),
91 TimestepSpacing::Linspace => {
92 super::utils::linspace(0.0, (config.train_timesteps - 1) as f64, inference_steps)?
93 .to_vec1::<f64>()?
94 .iter()
95 .map(|&f| f as usize)
96 .rev()
97 .collect()
98 }
99 };
100
101 let betas = match config.beta_schedule {
102 BetaSchedule::ScaledLinear => super::utils::linspace(
103 config.beta_start.sqrt(),
104 config.beta_end.sqrt(),
105 config.train_timesteps,
106 )?
107 .sqr()?,
108 BetaSchedule::Linear => {
109 super::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
110 }
111 BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?,
112 };
113 let betas = betas.to_vec1::<f64>()?;
114 let mut alphas_cumprod = Vec::with_capacity(betas.len());
115 for &beta in betas.iter() {
116 let alpha = 1.0 - beta;
117 alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64))
118 }
119 let sigmas: Vec<f64> = alphas_cumprod
120 .iter()
121 .map(|&f| ((1. - f) / f).sqrt())
122 .collect();
123
124 let sigmas_xa: Vec<_> = (0..sigmas.len()).map(|i| i as f64).collect();
125
126 let mut sigmas_int = interp(
127 ×teps.iter().map(|&t| t as f64).collect::<Vec<_>>(),
128 &sigmas_xa,
129 &sigmas,
130 );
131 sigmas_int.push(0.0);
132
133 let init_noise_sigma = *sigmas_int
136 .iter()
137 .chain(std::iter::once(&0.0))
138 .reduce(|a, b| if a > b { a } else { b })
139 .expect("init_noise_sigma could not be reduced from sigmas - this should never happen");
140
141 Ok(Self {
142 sigmas: sigmas_int,
143 timesteps,
144 init_noise_sigma,
145 config,
146 })
147 }
148}
149
150impl Scheduler for EulerAncestralDiscreteScheduler {
151 fn timesteps(&self) -> &[usize] {
152 self.timesteps.as_slice()
153 }
154
155 fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Result<Tensor> {
160 let step_index = match self.timesteps.iter().position(|&t| t == timestep) {
161 Some(i) => i,
162 None => bail!("timestep out of this schedulers bounds: {timestep}"),
163 };
164
165 let sigma = self
166 .sigmas
167 .get(step_index)
168 .expect("step_index out of sigma bounds - this shouldn't happen");
169
170 sample / ((sigma.powi(2) + 1.).sqrt())
171 }
172
173 fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
175 let step_index = self
176 .timesteps
177 .iter()
178 .position(|&p| p == timestep)
179 .ok_or_else(|| Error::Msg("timestep out of this schedulers bounds".to_string()))?;
180
181 let sigma_from = &self.sigmas[step_index];
182 let sigma_to = &self.sigmas[step_index + 1];
183
184 let pred_original_sample = match self.config.prediction_type {
186 PredictionType::Epsilon => (sample - (model_output * *sigma_from))?,
187 PredictionType::VPrediction => {
188 ((model_output * (-sigma_from / (sigma_from.powi(2) + 1.0).sqrt()))?
189 + (sample / (sigma_from.powi(2) + 1.0))?)?
190 }
191 PredictionType::Sample => bail!("prediction_type not implemented yet: sample"),
192 };
193
194 let sigma_up = (sigma_to.powi(2) * (sigma_from.powi(2) - sigma_to.powi(2))
195 / sigma_from.powi(2))
196 .sqrt();
197 let sigma_down = (sigma_to.powi(2) - sigma_up.powi(2)).sqrt();
198
199 let derivative = ((sample - pred_original_sample)? / *sigma_from)?;
201 let dt = sigma_down - *sigma_from;
202 let prev_sample = (sample + derivative * dt)?;
203
204 let noise = prev_sample.randn_like(0.0, 1.0)?;
205
206 prev_sample + noise * sigma_up
207 }
208
209 fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
210 let step_index = self
211 .timesteps
212 .iter()
213 .position(|&p| p == timestep)
214 .ok_or_else(|| Error::Msg("timestep out of this schedulers bounds".to_string()))?;
215
216 let sigma = self
217 .sigmas
218 .get(step_index)
219 .expect("step_index out of sigma bounds - this shouldn't happen");
220
221 original + (noise * *sigma)?
222 }
223
224 fn init_noise_sigma(&self) -> f64 {
225 match self.config.timestep_spacing {
226 TimestepSpacing::Trailing | TimestepSpacing::Linspace => self.init_noise_sigma,
227 TimestepSpacing::Leading => (self.init_noise_sigma.powi(2) + 1.0).sqrt(),
228 }
229 }
230}