candle_transformers/models/stable_diffusion/
uni_pc.rs

1//! # UniPC Scheduler
2//!
3//! UniPC is a training-free framework designed for the fast sampling of diffusion models, which consists of a
4//! corrector (UniC) and a predictor (UniP) that share a unified analytical form and support arbitrary orders.
5//!
6//! UniPC is by design model-agnostic, supporting pixel-space/latent-space DPMs on unconditional/conditional
7//! sampling. It can also be applied to both noise prediction and data prediction models. Compared with prior
8//! methods, UniPC converges faster thanks to the increased order of accuracy. Both quantitative and qualitative
9//! results show UniPC can improve sampling quality, especially at very low step counts (5~10).
10//!
11//! For more information, see the original publication:
12//! UniPC: A Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models, W. Zhao et al, 2023.
13//! https://arxiv.org/abs/2302.04867
14//!
15//! This work is based largely on UniPC implementation from the diffusers python package:
16//! https://raw.githubusercontent.com/huggingface/diffusers/e8aacda762e311505ba05ae340af23b149e37af3/src/diffusers/schedulers/scheduling_unipc_multistep.py
17use std::collections::HashSet;
18use std::ops::Neg;
19
20use super::schedulers::PredictionType;
21use super::{
22    schedulers::{Scheduler, SchedulerConfig},
23    utils::{interp, linspace},
24};
25use candle::{Error, IndexOp, Result, Tensor};
26
27#[derive(Debug, Clone, Copy)]
28pub enum SigmaSchedule {
29    Karras(KarrasSigmaSchedule),
30    Exponential(ExponentialSigmaSchedule),
31}
32
33impl SigmaSchedule {
34    fn sigma_t(&self, t: f64) -> f64 {
35        match self {
36            Self::Karras(x) => x.sigma_t(t),
37            Self::Exponential(x) => x.sigma_t(t),
38        }
39    }
40}
41
42impl Default for SigmaSchedule {
43    fn default() -> Self {
44        Self::Karras(KarrasSigmaSchedule::default())
45    }
46}
47
48#[derive(Debug, Clone, Copy)]
49pub struct KarrasSigmaSchedule {
50    pub sigma_min: f64,
51    pub sigma_max: f64,
52    pub rho: f64,
53}
54
55impl KarrasSigmaSchedule {
56    fn sigma_t(&self, t: f64) -> f64 {
57        let (min_inv_rho, max_inv_rho) = (
58            self.sigma_min.powf(1.0 / self.rho),
59            self.sigma_max.powf(1.0 / self.rho),
60        );
61
62        (max_inv_rho + ((1.0 - t) * (min_inv_rho - max_inv_rho))).powf(self.rho)
63    }
64}
65
66impl Default for KarrasSigmaSchedule {
67    fn default() -> Self {
68        Self {
69            sigma_max: 10.0,
70            sigma_min: 0.1,
71            rho: 4.0,
72        }
73    }
74}
75
76#[derive(Debug, Clone, Copy)]
77pub struct ExponentialSigmaSchedule {
78    sigma_min: f64,
79    sigma_max: f64,
80}
81
82impl ExponentialSigmaSchedule {
83    fn sigma_t(&self, t: f64) -> f64 {
84        (t * (self.sigma_max.ln() - self.sigma_min.ln()) + self.sigma_min.ln()).exp()
85    }
86}
87
88impl Default for ExponentialSigmaSchedule {
89    fn default() -> Self {
90        Self {
91            sigma_max: 80.0,
92            sigma_min: 0.1,
93        }
94    }
95}
96
97#[derive(Debug, Default, Clone, Copy)]
98pub enum SolverType {
99    #[default]
100    Bh1,
101    Bh2,
102}
103
104#[derive(Debug, Default, Clone, Copy)]
105pub enum AlgorithmType {
106    #[default]
107    DpmSolverPlusPlus,
108    SdeDpmSolverPlusPlus,
109}
110
111#[derive(Debug, Default, Clone, Copy)]
112pub enum FinalSigmasType {
113    #[default]
114    Zero,
115    SigmaMin,
116}
117
118#[derive(Debug, Clone)]
119pub enum TimestepSchedule {
120    /// Timesteps will be determined by interpolation of sigmas
121    FromSigmas,
122    /// Timesteps will be separated by regular intervals
123    Linspace,
124}
125
126impl TimestepSchedule {
127    fn timesteps(
128        &self,
129        sigma_schedule: &SigmaSchedule,
130        num_inference_steps: usize,
131        num_training_steps: usize,
132    ) -> Result<Vec<usize>> {
133        match self {
134            Self::FromSigmas => {
135                let sigmas: Tensor = linspace(1., 0., num_inference_steps)?
136                    .to_vec1()?
137                    .into_iter()
138                    .map(|t| sigma_schedule.sigma_t(t))
139                    .collect::<Vec<f64>>()
140                    .try_into()?;
141                let log_sigmas = sigmas.log()?.to_vec1::<f64>()?;
142                let timesteps = interp(
143                    &log_sigmas.iter().copied().rev().collect::<Vec<_>>(),
144                    &linspace(
145                        log_sigmas[log_sigmas.len() - 1] - 0.001,
146                        log_sigmas[0] + 0.001,
147                        num_inference_steps,
148                    )?
149                    .to_vec1::<f64>()?,
150                    &linspace(0., num_training_steps as f64, num_inference_steps)?
151                        .to_vec1::<f64>()?,
152                )
153                .into_iter()
154                .map(|f| (num_training_steps - 1) - (f as usize))
155                .collect::<Vec<_>>();
156
157                Ok(timesteps)
158            }
159
160            Self::Linspace => {
161                Ok(
162                    linspace((num_training_steps - 1) as f64, 0., num_inference_steps)?
163                        .to_vec1::<f64>()?
164                        .into_iter()
165                        .map(|f| f as usize)
166                        .collect(),
167                )
168            }
169        }
170    }
171}
172
173#[derive(Debug, Clone)]
174pub enum CorrectorConfiguration {
175    Disabled,
176    Enabled { skip_steps: HashSet<usize> },
177}
178
179impl Default for CorrectorConfiguration {
180    fn default() -> Self {
181        Self::Enabled {
182            skip_steps: [0, 1, 2].into_iter().collect(),
183        }
184    }
185}
186
187impl CorrectorConfiguration {
188    pub fn new(disabled_steps: impl IntoIterator<Item = usize>) -> Self {
189        Self::Enabled {
190            skip_steps: disabled_steps.into_iter().collect(),
191        }
192    }
193}
194
195#[derive(Debug, Clone)]
196pub struct UniPCSchedulerConfig {
197    /// Configure the UNIC corrector. By default it is disabled
198    pub corrector: CorrectorConfiguration,
199    /// Determines how sigma relates to a given timestep
200    pub sigma_schedule: SigmaSchedule,
201    /// Determines the points
202    pub timestep_schedule: TimestepSchedule,
203    /// The solver order which can be `1` or higher. It is recommended to use `solver_order=2` for guided
204    /// sampling, and `solver_order=3` for unconditional sampling.
205    pub solver_order: usize,
206    /// Prediction type of the scheduler function
207    pub prediction_type: PredictionType,
208    pub num_training_timesteps: usize,
209    /// Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
210    /// as Stable Diffusion.
211    pub thresholding: bool,
212    /// The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
213    pub dynamic_thresholding_ratio: f64,
214    /// The threshold value for dynamic thresholding.
215    pub sample_max_value: f64,
216    pub solver_type: SolverType,
217    /// Whether to use lower-order solvers in the final steps.
218    pub lower_order_final: bool,
219}
220
221impl Default for UniPCSchedulerConfig {
222    fn default() -> Self {
223        Self {
224            corrector: Default::default(),
225            timestep_schedule: TimestepSchedule::FromSigmas,
226            sigma_schedule: SigmaSchedule::Karras(Default::default()),
227            prediction_type: PredictionType::Epsilon,
228            num_training_timesteps: 1000,
229            solver_order: 2,
230            thresholding: false,
231            dynamic_thresholding_ratio: 0.995,
232            sample_max_value: 1.0,
233            solver_type: SolverType::Bh1,
234            lower_order_final: true,
235        }
236    }
237}
238
239impl SchedulerConfig for UniPCSchedulerConfig {
240    fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>> {
241        Ok(Box::new(EdmDpmMultistepScheduler::new(
242            self.clone(),
243            inference_steps,
244        )?))
245    }
246}
247
248struct State {
249    model_outputs: Vec<Option<Tensor>>,
250    lower_order_nums: usize,
251    order: usize,
252    last_sample: Option<Tensor>,
253}
254
255impl State {
256    fn new(solver_order: usize) -> Self {
257        Self {
258            model_outputs: vec![None; solver_order],
259            lower_order_nums: 0,
260            order: 0,
261            last_sample: None,
262        }
263    }
264
265    fn lower_order_nums(&self) -> usize {
266        self.lower_order_nums
267    }
268
269    fn update_lower_order_nums(&mut self, n: usize) {
270        self.lower_order_nums = n;
271    }
272
273    fn model_outputs(&self) -> &[Option<Tensor>] {
274        self.model_outputs.as_slice()
275    }
276
277    fn update_model_output(&mut self, idx: usize, output: Option<Tensor>) {
278        self.model_outputs[idx] = output;
279    }
280
281    fn last_sample(&self) -> Option<&Tensor> {
282        self.last_sample.as_ref()
283    }
284
285    fn update_last_sample(&mut self, sample: Tensor) {
286        let _ = self.last_sample.replace(sample);
287    }
288
289    fn order(&self) -> usize {
290        self.order
291    }
292
293    fn update_order(&mut self, order: usize) {
294        self.order = order;
295    }
296}
297
298pub struct EdmDpmMultistepScheduler {
299    schedule: Schedule,
300    config: UniPCSchedulerConfig,
301    state: State,
302}
303
304impl EdmDpmMultistepScheduler {
305    pub fn new(config: UniPCSchedulerConfig, num_inference_steps: usize) -> Result<Self> {
306        let schedule = Schedule::new(
307            config.timestep_schedule.clone(),
308            config.sigma_schedule,
309            num_inference_steps,
310            config.num_training_timesteps,
311        )?;
312
313        Ok(Self {
314            schedule,
315            state: State::new(config.solver_order),
316            config,
317        })
318    }
319
320    fn step_index(&self, timestep: usize) -> usize {
321        let index_candidates = self
322            .schedule
323            .timesteps()
324            .iter()
325            .enumerate()
326            .filter(|(_, t)| (*t == &timestep))
327            .map(|(i, _)| i)
328            .collect::<Vec<_>>();
329
330        match index_candidates.len() {
331            0 => 0,
332            1 => index_candidates[0],
333            _ => index_candidates[1],
334        }
335    }
336
337    fn timestep(&self, step_idx: usize) -> usize {
338        self.schedule
339            .timesteps()
340            .get(step_idx)
341            .copied()
342            .unwrap_or(0)
343    }
344
345    fn convert_model_output(
346        &self,
347        model_output: &Tensor,
348        sample: &Tensor,
349        timestep: usize,
350    ) -> Result<Tensor> {
351        let (alpha_t, sigma_t) = (
352            self.schedule.alpha_t(timestep),
353            self.schedule.sigma_t(timestep),
354        );
355
356        let x0_pred = match self.config.prediction_type {
357            PredictionType::Epsilon => ((sample - (model_output * sigma_t))? / alpha_t)?,
358            PredictionType::Sample => model_output.clone(),
359            PredictionType::VPrediction => ((alpha_t * sample)? - (sigma_t * model_output)?)?,
360        };
361
362        if self.config.thresholding {
363            self.threshold_sample(x0_pred)
364        } else {
365            Ok(x0_pred)
366        }
367    }
368
369    fn threshold_sample(&self, sample: Tensor) -> Result<Tensor> {
370        let shape = sample.shape().clone().into_dims();
371        let v = sample
372            .abs()?
373            .reshape((shape[0], shape[1] * shape[2..].iter().product::<usize>()))?
374            .to_dtype(candle::DType::F64)?
375            .to_vec2::<f64>()?;
376        let q = stats::Quantile::new(self.config.dynamic_thresholding_ratio)
377            .with_samples(v.into_iter().flatten());
378        let (threshold, max) = (q.quantile().max(self.config.sample_max_value), q.max());
379
380        sample.clamp(-threshold, threshold)? / (threshold / max).sqrt().min(1.)
381    }
382
383    fn multistep_uni_p_bh_update(&self, sample: &Tensor, timestep: usize) -> Result<Tensor> {
384        let step_index = self.step_index(timestep);
385        let ns = &self.schedule;
386        let model_outputs = self.state.model_outputs();
387        let Some(m0) = &model_outputs[model_outputs.len() - 1] else {
388            return Err(Error::Msg(
389                "Expected model output for predictor update".to_string(),
390            ));
391        };
392
393        let (t0, tt) = (timestep, self.timestep(self.step_index(timestep) + 1));
394        let (sigma_t, sigma_s0) = (ns.sigma_t(tt), ns.sigma_t(t0));
395        let (alpha_t, _alpha_s0) = (ns.alpha_t(tt), ns.alpha_t(t0));
396        let (lambda_t, lambda_s0) = (ns.lambda_t(tt), ns.lambda_t(t0));
397
398        let h = lambda_t - lambda_s0;
399        let device = sample.device();
400
401        let (mut rks, mut d1s) = (vec![], vec![]);
402        for i in 1..self.state.order() {
403            let ti = self.timestep(step_index.saturating_sub(i + 1));
404            let Some(mi) = model_outputs
405                .get(model_outputs.len().saturating_sub(i + 1))
406                .into_iter()
407                .flatten()
408                .next()
409            else {
410                return Err(Error::Msg(
411                    "Expected model output for predictor update".to_string(),
412                ));
413            };
414            let (alpha_si, sigma_si) = (ns.alpha_t(ti), ns.sigma_t(ti));
415            let lambda_si = alpha_si.ln() - sigma_si.ln();
416            let rk = (lambda_si - lambda_s0) / h;
417            rks.push(rk);
418            d1s.push(((mi - m0)? / rk)?);
419        }
420        rks.push(1.0);
421        let rks = Tensor::new(rks, device)?;
422        let (mut r, mut b) = (vec![], vec![]);
423
424        let hh = h.neg();
425        let h_phi_1 = hh.exp_m1();
426        let mut h_phi_k = h_phi_1 / hh - 1.;
427        let mut factorial_i = 1.;
428
429        let b_h = match self.config.solver_type {
430            SolverType::Bh1 => hh,
431            SolverType::Bh2 => hh.exp_m1(),
432        };
433
434        for i in 1..self.state.order() + 1 {
435            r.push(rks.powf(i as f64 - 1.)?);
436            b.push(h_phi_k * factorial_i / b_h);
437            factorial_i = i as f64 + 1.;
438            h_phi_k = h_phi_k / hh - 1. / factorial_i;
439        }
440
441        let (r, b) = (Tensor::stack(&r, 0)?, Tensor::new(b, device)?);
442        let (d1s, rhos_p) = match d1s.len() {
443            0 => (None, None),
444            _ => {
445                let rhos_p = match self.state.order() {
446                    2 => Tensor::new(&[0.5f64], m0.device())?.to_dtype(m0.dtype())?,
447                    _ => {
448                        let ((r1, r2), b1) = (r.dims2()?, b.dims1()?);
449                        let inverse = linalg::inverse(&r.i((..(r1 - 1), ..(r2 - 1)))?)?;
450                        let b = b.i(..(b1 - 1))?;
451                        b.broadcast_mul(&inverse)?.sum(1)?.to_dtype(m0.dtype())?
452                    }
453                };
454
455                (Some(Tensor::stack(&d1s, 1)?), Some(rhos_p))
456            }
457        };
458
459        let x_t_ = ((sigma_t / sigma_s0 * sample)? - (alpha_t * h_phi_1 * m0)?)?;
460        if let (Some(d1s), Some(rhos_p)) = (d1s, rhos_p) {
461            use linalg::{Permutation, TensordotFixedPosition, TensordotGeneral};
462            let output_shape = m0.shape().clone();
463            let pred_res = TensordotGeneral {
464                lhs_permutation: Permutation { dims: vec![0] },
465                rhs_permutation: Permutation {
466                    dims: vec![1, 0, 2, 3, 4],
467                },
468                tensordot_fixed_position: TensordotFixedPosition {
469                    len_uncontracted_lhs: 1,
470                    len_uncontracted_rhs: output_shape.dims().iter().product::<usize>(),
471                    len_contracted_axes: d1s.dim(1)?,
472                    output_shape,
473                },
474                output_permutation: Permutation {
475                    dims: vec![0, 1, 2, 3],
476                },
477            }
478            .eval(&rhos_p, &d1s)?;
479            x_t_ - (alpha_t * b_h * pred_res)?
480        } else {
481            Ok(x_t_)
482        }
483    }
484
485    fn multistep_uni_c_bh_update(
486        &self,
487        model_output: &Tensor,
488        model_outputs: &[Option<Tensor>],
489        last_sample: &Tensor,
490        sample: &Tensor,
491        timestep: usize,
492    ) -> Result<Tensor> {
493        let step_index = self.step_index(timestep);
494        let Some(m0) = model_outputs.last().into_iter().flatten().next() else {
495            return Err(Error::Msg(
496                "Expected model output for corrector update".to_string(),
497            ));
498        };
499        let model_t = model_output;
500        let (x, _xt) = (last_sample, sample);
501
502        let (t0, tt, ns) = (
503            self.timestep(self.step_index(timestep) - 1),
504            timestep,
505            &self.schedule,
506        );
507        let (sigma_t, sigma_s0) = (ns.sigma_t(tt), ns.sigma_t(t0));
508        let (alpha_t, _alpha_s0) = (ns.alpha_t(tt), ns.alpha_t(t0));
509        let (lambda_t, lambda_s0) = (ns.lambda_t(tt), ns.lambda_t(t0));
510
511        let h = lambda_t - lambda_s0;
512        let device = sample.device();
513
514        let (mut rks, mut d1s) = (vec![], vec![]);
515        for i in 1..self.state.order() {
516            let ti = self.timestep(step_index.saturating_sub(i + 1));
517            let Some(mi) = model_outputs
518                .get(model_outputs.len().saturating_sub(i + 1))
519                .into_iter()
520                .flatten()
521                .next()
522            else {
523                return Err(Error::Msg(
524                    "Expected model output for corrector update".to_string(),
525                ));
526            };
527            let (alpha_si, sigma_si) = (ns.alpha_t(ti), ns.sigma_t(ti));
528            let lambda_si = alpha_si.ln() - sigma_si.ln();
529            let rk = (lambda_si - lambda_s0) / h;
530            rks.push(rk);
531            d1s.push(((mi - m0)? / rk)?);
532        }
533        rks.push(1.0);
534        let rks = Tensor::new(rks, device)?;
535        let (mut r, mut b) = (vec![], vec![]);
536
537        let hh = h.neg();
538        let h_phi_1 = hh.exp_m1();
539        let mut h_phi_k = h_phi_1 / hh - 1.;
540        let mut factorial_i = 1.;
541
542        let b_h = match self.config.solver_type {
543            SolverType::Bh1 => hh,
544            SolverType::Bh2 => hh.exp_m1(),
545        };
546
547        for i in 1..self.state.order() + 1 {
548            r.push(rks.powf(i as f64 - 1.)?);
549            b.push(h_phi_k * factorial_i / b_h);
550            factorial_i = i as f64 + 1.;
551            h_phi_k = h_phi_k / hh - 1. / factorial_i;
552        }
553
554        let (r, b) = (Tensor::stack(&r, 0)?, Tensor::new(b, device)?);
555        let d1s = match d1s.len() {
556            0 => None,
557            _ => Some(Tensor::stack(&d1s, 1)?),
558        };
559        let rhos_c = match self.state.order() {
560            1 => Tensor::new(&[0.5f64], m0.device())?.to_dtype(m0.dtype())?,
561            _ => {
562                let inverse = linalg::inverse(&r)?;
563                b.broadcast_mul(&inverse)?.sum(1)?.to_dtype(m0.dtype())?
564            }
565        };
566
567        let x_t_ = ((sigma_t / sigma_s0 * x)? - (alpha_t * h_phi_1 * m0)?)?;
568        let corr_res = d1s
569            .map(|d1s| {
570                use linalg::{Permutation, TensordotFixedPosition, TensordotGeneral};
571                let output_shape = x_t_.shape().clone();
572                TensordotGeneral {
573                    lhs_permutation: Permutation { dims: vec![0] },
574                    rhs_permutation: Permutation {
575                        dims: vec![1, 0, 2, 3, 4],
576                    },
577                    tensordot_fixed_position: TensordotFixedPosition {
578                        len_uncontracted_lhs: 1,
579                        len_uncontracted_rhs: output_shape.dims().iter().product::<usize>(),
580                        len_contracted_axes: d1s.dim(1)?,
581                        output_shape,
582                    },
583                    output_permutation: Permutation {
584                        dims: vec![0, 1, 2, 3],
585                    },
586                }
587                .eval(&rhos_c.i(..rhos_c.dims()[0] - 1)?, &d1s)
588            })
589            .unwrap_or_else(|| Tensor::zeros_like(m0))?;
590
591        let d1_t = (model_t - m0)?;
592        let x_t = (x_t_
593            - (alpha_t
594                * b_h
595                * (corr_res + rhos_c.i(rhos_c.dims()[0] - 1)?.broadcast_mul(&d1_t)?)?)?)?;
596
597        Ok(x_t)
598    }
599}
600
601impl Scheduler for EdmDpmMultistepScheduler {
602    fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
603        let step_index = self.step_index(timestep);
604        let model_output_converted = &self.convert_model_output(model_output, sample, timestep)?;
605        let sample = match (&self.config.corrector, self.state.last_sample()) {
606            (CorrectorConfiguration::Enabled { skip_steps: s }, Some(last_sample))
607                if !s.contains(&step_index) && step_index > 0 =>
608            {
609                &self.multistep_uni_c_bh_update(
610                    model_output_converted,
611                    self.state.model_outputs(),
612                    last_sample,
613                    sample,
614                    timestep,
615                )?
616            }
617            (CorrectorConfiguration::Enabled { .. }, _) | (CorrectorConfiguration::Disabled, _) => {
618                sample
619            }
620        };
621
622        let mut model_outputs = self.state.model_outputs().to_vec();
623        for i in 0..self.config.solver_order.saturating_sub(1) {
624            self.state
625                .update_model_output(i, model_outputs[i + 1].take());
626        }
627        self.state.update_model_output(
628            model_outputs.len() - 1,
629            Some(model_output_converted.clone()),
630        );
631
632        let mut this_order = self.config.solver_order;
633        if self.config.lower_order_final {
634            this_order = self
635                .config
636                .solver_order
637                .min(self.schedule.timesteps.len() - step_index);
638        }
639        self.state
640            .update_order(this_order.min(self.state.lower_order_nums() + 1));
641
642        self.state.update_last_sample(sample.clone());
643        let prev_sample = self.multistep_uni_p_bh_update(sample, timestep)?;
644
645        let lower_order_nums = self.state.lower_order_nums();
646        if lower_order_nums < self.config.solver_order {
647            self.state.update_lower_order_nums(lower_order_nums + 1);
648        }
649
650        Ok(prev_sample)
651    }
652
653    fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> {
654        Ok(sample)
655    }
656
657    fn timesteps(&self) -> &[usize] {
658        &self.schedule.timesteps
659    }
660
661    fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
662        let (alpha_t, sigma_t) = (
663            self.schedule.alpha_t(timestep),
664            self.schedule.sigma_t(timestep),
665        );
666
667        (alpha_t * original)? + (sigma_t * noise)?
668    }
669
670    fn init_noise_sigma(&self) -> f64 {
671        self.schedule.sigma_t(self.schedule.num_training_steps())
672    }
673}
674
675#[derive(Debug, Clone)]
676struct Schedule {
677    timesteps: Vec<usize>,
678    num_training_steps: usize,
679    sigma_schedule: SigmaSchedule,
680    #[allow(unused)]
681    timestep_schedule: TimestepSchedule,
682}
683
684impl Schedule {
685    fn new(
686        timestep_schedule: TimestepSchedule,
687        sigma_schedule: SigmaSchedule,
688        num_inference_steps: usize,
689        num_training_steps: usize,
690    ) -> Result<Self> {
691        Ok(Self {
692            timesteps: timestep_schedule.timesteps(
693                &sigma_schedule,
694                num_inference_steps,
695                num_training_steps,
696            )?,
697            timestep_schedule,
698            sigma_schedule,
699            num_training_steps,
700        })
701    }
702
703    fn timesteps(&self) -> &[usize] {
704        &self.timesteps
705    }
706
707    fn num_training_steps(&self) -> usize {
708        self.num_training_steps
709    }
710
711    fn t(&self, step: usize) -> f64 {
712        (step as f64 + 1.) / self.num_training_steps as f64
713    }
714
715    fn alpha_t(&self, t: usize) -> f64 {
716        (1. / (self.sigma_schedule.sigma_t(self.t(t)).powi(2) + 1.)).sqrt()
717    }
718
719    fn sigma_t(&self, t: usize) -> f64 {
720        self.sigma_schedule.sigma_t(self.t(t)) * self.alpha_t(t)
721    }
722
723    fn lambda_t(&self, t: usize) -> f64 {
724        self.alpha_t(t).ln() - self.sigma_t(t).ln()
725    }
726}
727
728mod stats {
729    //! This is a slightly modified form of the P² quantile implementation from https://github.com/vks/average.
730    //! Also see: http://www.cs.wustl.edu/~jain/papers/ftp/psqr.pdf
731    use num_traits::{Float, ToPrimitive};
732
733    #[derive(Debug, Clone)]
734    pub struct Quantile {
735        q: [f64; 5],
736        n: [i64; 5],
737        m: [f64; 5],
738        dm: [f64; 5],
739        max: Option<f64>,
740    }
741
742    impl Quantile {
743        pub fn new(p: f64) -> Quantile {
744            assert!((0. ..=1.).contains(&p));
745            Quantile {
746                q: [0.; 5],
747                n: [1, 2, 3, 4, 0],
748                m: [1., 1. + 2. * p, 1. + 4. * p, 3. + 2. * p, 5.],
749                dm: [0., p / 2., p, (1. + p) / 2., 1.],
750                max: None,
751            }
752        }
753
754        pub fn max(&self) -> f64 {
755            self.max.unwrap_or(f64::NAN)
756        }
757
758        fn p(&self) -> f64 {
759            self.dm[2]
760        }
761
762        fn parabolic(&self, i: usize, d: f64) -> f64 {
763            let s = d.round() as i64;
764            self.q[i]
765                + d / (self.n[i + 1] - self.n[i - 1]).to_f64().unwrap()
766                    * ((self.n[i] - self.n[i - 1] + s).to_f64().unwrap()
767                        * (self.q[i + 1] - self.q[i])
768                        / (self.n[i + 1] - self.n[i]).to_f64().unwrap()
769                        + (self.n[i + 1] - self.n[i] - s).to_f64().unwrap()
770                            * (self.q[i] - self.q[i - 1])
771                            / (self.n[i] - self.n[i - 1]).to_f64().unwrap())
772        }
773
774        fn linear(&self, i: usize, d: f64) -> f64 {
775            let sum = if d < 0. { i - 1 } else { i + 1 };
776            self.q[i] + d * (self.q[sum] - self.q[i]) / (self.n[sum] - self.n[i]).to_f64().unwrap()
777        }
778
779        pub fn quantile(&self) -> f64 {
780            if self.len() >= 5 {
781                return self.q[2];
782            }
783
784            if self.is_empty() {
785                return f64::NAN;
786            }
787            let mut heights: [f64; 4] = [self.q[0], self.q[1], self.q[2], self.q[3]];
788            let len = self.len() as usize;
789            debug_assert!(len < 5);
790            sort_floats(&mut heights[..len]);
791            let desired_index = (len as f64) * self.p() - 1.;
792            let mut index = desired_index.ceil();
793            if desired_index == index && index >= 0. {
794                let index = index.round() as usize;
795                debug_assert!(index < 5);
796                if index < len - 1 {
797                    return 0.5 * self.q[index] + 0.5 * self.q[index + 1];
798                }
799            }
800            index = index.max(0.);
801            let mut index = index.round() as usize;
802            debug_assert!(index < 5);
803            index = index.min(len - 1);
804            self.q[index]
805        }
806
807        fn len(&self) -> u64 {
808            self.n[4] as u64
809        }
810
811        fn is_empty(&self) -> bool {
812            self.len() == 0
813        }
814
815        pub fn add(&mut self, x: f64) {
816            self.max = self.max.map(|y| y.max(x)).or(Some(x));
817
818            if self.n[4] < 5 {
819                self.q[self.n[4] as usize] = x;
820                self.n[4] += 1;
821                if self.n[4] == 5 {
822                    sort_floats(&mut self.q);
823                }
824                return;
825            }
826
827            let mut k: usize;
828            if x < self.q[0] {
829                self.q[0] = x;
830                k = 0;
831            } else {
832                k = 4;
833                for i in 1..5 {
834                    if x < self.q[i] {
835                        k = i;
836                        break;
837                    }
838                }
839                if self.q[4] < x {
840                    self.q[4] = x;
841                }
842            };
843
844            for i in k..5 {
845                self.n[i] += 1;
846            }
847            for i in 0..5 {
848                self.m[i] += self.dm[i];
849            }
850
851            for i in 1..4 {
852                let d = self.m[i] - self.n[i].to_f64().unwrap();
853                if d >= 1. && self.n[i + 1] - self.n[i] > 1
854                    || d <= -1. && self.n[i - 1] - self.n[i] < -1
855                {
856                    let d = Float::signum(d);
857                    let q_new = self.parabolic(i, d);
858                    if self.q[i - 1] < q_new && q_new < self.q[i + 1] {
859                        self.q[i] = q_new;
860                    } else {
861                        self.q[i] = self.linear(i, d);
862                    }
863                    let delta = d.round() as i64;
864                    debug_assert_eq!(delta.abs(), 1);
865                    self.n[i] += delta;
866                }
867            }
868        }
869
870        pub fn with_samples(mut self, samples: impl IntoIterator<Item = f64>) -> Self {
871            for sample in samples {
872                self.add(sample);
873            }
874
875            self
876        }
877    }
878
879    fn sort_floats(v: &mut [f64]) {
880        v.sort_unstable_by(|a, b| a.total_cmp(b));
881    }
882}
883
884mod linalg {
885    use candle::{IndexOp, Result, Shape, Tensor};
886
887    pub fn inverse(m: &Tensor) -> Result<Tensor> {
888        adjoint(m)? / determinant(m)?.to_scalar::<f64>()?
889    }
890
891    pub fn adjoint(m: &Tensor) -> Result<Tensor> {
892        cofactor(m)?.transpose(0, 1)
893    }
894
895    pub fn cofactor(m: &Tensor) -> Result<Tensor> {
896        let s = m.shape().dim(0)?;
897        if s == 2 {
898            let mut v = vec![];
899            for i in 0..2 {
900                let mut x = vec![];
901                for j in 0..2 {
902                    x.push((m.i((i, j))? * (-1.0f64).powi(i as i32 + j as i32))?)
903                }
904                v.push(Tensor::stack(&x, 0)?.unsqueeze(0)?);
905            }
906            return Tensor::stack(&v, 1)?.squeeze(0);
907        }
908
909        let minors = minors(m)?;
910        let mut v = vec![];
911        for i in 0..s {
912            let mut x = vec![];
913            for j in 0..s {
914                let det = (determinant(&minors.i((i, j))?)?
915                    * ((-1.0f64).powi(i as i32) * (-1.0f64).powi(j as i32)))?;
916                x.push(det);
917            }
918            v.push(Tensor::stack(&x, 0)?.unsqueeze(0)?);
919        }
920
921        Tensor::stack(&v, 1)?.squeeze(0)
922    }
923
924    pub fn determinant(m: &Tensor) -> Result<Tensor> {
925        let s = m.shape().dim(0)?;
926        if s == 2 {
927            return (m.i((0, 0))? * m.i((1, 1))?)? - (m.i((0, 1))? * m.i((1, 0))?);
928        }
929
930        let cofactor = cofactor(m)?;
931        let m0 = m.i((0, 0))?;
932        let det = (0..s)
933            .map(|i| (m.i((0, i))? * cofactor.i((0, i))?))
934            .try_fold(m0.zeros_like()?, |acc, cur| (acc + cur?))?;
935
936        Ok(det)
937    }
938
939    pub fn minors(m: &Tensor) -> Result<Tensor> {
940        let s = m.shape().dim(0)?;
941        if s == 1 {
942            return m.i((0, 0));
943        }
944
945        let mut v = vec![];
946        for i in 0..s {
947            let msub = Tensor::cat(&[m.i((..i, ..))?, m.i(((i + 1).., ..))?], 0)?;
948            let mut x = vec![];
949            for j in 0..s {
950                let t = Tensor::cat(&[msub.i((.., ..j))?, msub.i((.., (j + 1)..))?], 1)?;
951                x.push(t);
952            }
953            v.push(Tensor::stack(&x, 0)?.unsqueeze(0)?);
954        }
955
956        Tensor::stack(&v, 1)?.squeeze(0)
957    }
958
959    #[derive(Debug)]
960    pub struct TensordotGeneral {
961        pub lhs_permutation: Permutation,
962        pub rhs_permutation: Permutation,
963        pub tensordot_fixed_position: TensordotFixedPosition,
964        pub output_permutation: Permutation,
965    }
966
967    impl TensordotGeneral {
968        pub fn eval(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor> {
969            let permuted_lhs = self.lhs_permutation.eval(lhs)?;
970            let permuted_rhs = self.rhs_permutation.eval(rhs)?;
971            let tensordotted = self
972                .tensordot_fixed_position
973                .eval(&permuted_lhs, &permuted_rhs)?;
974            self.output_permutation.eval(&tensordotted)
975        }
976    }
977
978    #[derive(Debug)]
979    pub struct TensordotFixedPosition {
980        pub len_uncontracted_lhs: usize,
981        pub len_uncontracted_rhs: usize,
982        pub len_contracted_axes: usize,
983        pub output_shape: Shape,
984    }
985
986    impl TensordotFixedPosition {
987        fn eval(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor> {
988            let lhs_view = lhs.reshape((self.len_uncontracted_lhs, self.len_contracted_axes))?;
989            let rhs_view = rhs.reshape((self.len_contracted_axes, self.len_uncontracted_rhs))?;
990
991            lhs_view.matmul(&rhs_view)?.reshape(&self.output_shape)
992        }
993    }
994
995    #[derive(Debug)]
996    pub struct Permutation {
997        pub dims: Vec<usize>,
998    }
999
1000    impl Permutation {
1001        fn eval(&self, tensor: &Tensor) -> Result<Tensor> {
1002            tensor.permute(self.dims.as_slice())
1003        }
1004    }
1005}