candle_transformers/models/stable_diffusion/
schedulers.rs1#![allow(dead_code)]
2use candle::{Result, Tensor};
7
8pub trait SchedulerConfig: std::fmt::Debug + Send + Sync {
9 fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>>;
10}
11
12pub trait Scheduler {
14 fn timesteps(&self) -> &[usize];
15
16 fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor>;
17
18 fn init_noise_sigma(&self) -> f64;
19
20 fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor>;
21
22 fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor>;
23}
24
25#[derive(Debug, Clone, Copy)]
28pub enum BetaSchedule {
29 Linear,
31 ScaledLinear,
33 SquaredcosCapV2,
35}
36
37#[derive(Debug, Clone, Copy)]
38pub enum PredictionType {
39 Epsilon,
40 VPrediction,
41 Sample,
42}
43
44#[derive(Debug, Clone, Copy)]
48pub enum TimestepSpacing {
49 Leading,
50 Linspace,
51 Trailing,
52}
53
54impl Default for TimestepSpacing {
55 fn default() -> Self {
56 Self::Leading
57 }
58}
59
60pub(crate) fn betas_for_alpha_bar(num_diffusion_timesteps: usize, max_beta: f64) -> Result<Tensor> {
66 let alpha_bar = |time_step: usize| {
67 f64::cos((time_step as f64 + 0.008) / 1.008 * std::f64::consts::FRAC_PI_2).powi(2)
68 };
69 let mut betas = Vec::with_capacity(num_diffusion_timesteps);
70 for i in 0..num_diffusion_timesteps {
71 let t1 = i / num_diffusion_timesteps;
72 let t2 = (i + 1) / num_diffusion_timesteps;
73 betas.push((1.0 - alpha_bar(t2) / alpha_bar(t1)).min(max_beta));
74 }
75 let betas_len = betas.len();
76 Tensor::from_vec(betas, betas_len, &candle::Device::Cpu)
77}