1use std::collections::HashSet;
18use std::ops::Neg;
19
20use super::schedulers::PredictionType;
21use super::{
22 schedulers::{Scheduler, SchedulerConfig},
23 utils::{interp, linspace},
24};
25use candle::{Error, IndexOp, Result, Tensor};
26
27#[derive(Debug, Clone, Copy)]
28pub enum SigmaSchedule {
29 Karras(KarrasSigmaSchedule),
30 Exponential(ExponentialSigmaSchedule),
31}
32
33impl SigmaSchedule {
34 fn sigma_t(&self, t: f64) -> f64 {
35 match self {
36 Self::Karras(x) => x.sigma_t(t),
37 Self::Exponential(x) => x.sigma_t(t),
38 }
39 }
40}
41
42impl Default for SigmaSchedule {
43 fn default() -> Self {
44 Self::Karras(KarrasSigmaSchedule::default())
45 }
46}
47
48#[derive(Debug, Clone, Copy)]
49pub struct KarrasSigmaSchedule {
50 pub sigma_min: f64,
51 pub sigma_max: f64,
52 pub rho: f64,
53}
54
55impl KarrasSigmaSchedule {
56 fn sigma_t(&self, t: f64) -> f64 {
57 let (min_inv_rho, max_inv_rho) = (
58 self.sigma_min.powf(1.0 / self.rho),
59 self.sigma_max.powf(1.0 / self.rho),
60 );
61
62 (max_inv_rho + ((1.0 - t) * (min_inv_rho - max_inv_rho))).powf(self.rho)
63 }
64}
65
66impl Default for KarrasSigmaSchedule {
67 fn default() -> Self {
68 Self {
69 sigma_max: 10.0,
70 sigma_min: 0.1,
71 rho: 4.0,
72 }
73 }
74}
75
76#[derive(Debug, Clone, Copy)]
77pub struct ExponentialSigmaSchedule {
78 sigma_min: f64,
79 sigma_max: f64,
80}
81
82impl ExponentialSigmaSchedule {
83 fn sigma_t(&self, t: f64) -> f64 {
84 (t * (self.sigma_max.ln() - self.sigma_min.ln()) + self.sigma_min.ln()).exp()
85 }
86}
87
88impl Default for ExponentialSigmaSchedule {
89 fn default() -> Self {
90 Self {
91 sigma_max: 80.0,
92 sigma_min: 0.1,
93 }
94 }
95}
96
97#[derive(Debug, Default, Clone, Copy)]
98pub enum SolverType {
99 #[default]
100 Bh1,
101 Bh2,
102}
103
104#[derive(Debug, Default, Clone, Copy)]
105pub enum AlgorithmType {
106 #[default]
107 DpmSolverPlusPlus,
108 SdeDpmSolverPlusPlus,
109}
110
111#[derive(Debug, Default, Clone, Copy)]
112pub enum FinalSigmasType {
113 #[default]
114 Zero,
115 SigmaMin,
116}
117
118#[derive(Debug, Clone)]
119pub enum TimestepSchedule {
120 FromSigmas,
122 Linspace,
124}
125
126impl TimestepSchedule {
127 fn timesteps(
128 &self,
129 sigma_schedule: &SigmaSchedule,
130 num_inference_steps: usize,
131 num_training_steps: usize,
132 ) -> Result<Vec<usize>> {
133 match self {
134 Self::FromSigmas => {
135 let sigmas: Tensor = linspace(1., 0., num_inference_steps)?
136 .to_vec1()?
137 .into_iter()
138 .map(|t| sigma_schedule.sigma_t(t))
139 .collect::<Vec<f64>>()
140 .try_into()?;
141 let log_sigmas = sigmas.log()?.to_vec1::<f64>()?;
142 let timesteps = interp(
143 &log_sigmas.iter().copied().rev().collect::<Vec<_>>(),
144 &linspace(
145 log_sigmas[log_sigmas.len() - 1] - 0.001,
146 log_sigmas[0] + 0.001,
147 num_inference_steps,
148 )?
149 .to_vec1::<f64>()?,
150 &linspace(0., num_training_steps as f64, num_inference_steps)?
151 .to_vec1::<f64>()?,
152 )
153 .into_iter()
154 .map(|f| (num_training_steps - 1) - (f as usize))
155 .collect::<Vec<_>>();
156
157 Ok(timesteps)
158 }
159
160 Self::Linspace => {
161 Ok(
162 linspace((num_training_steps - 1) as f64, 0., num_inference_steps)?
163 .to_vec1::<f64>()?
164 .into_iter()
165 .map(|f| f as usize)
166 .collect(),
167 )
168 }
169 }
170 }
171}
172
173#[derive(Debug, Clone)]
174pub enum CorrectorConfiguration {
175 Disabled,
176 Enabled { skip_steps: HashSet<usize> },
177}
178
179impl Default for CorrectorConfiguration {
180 fn default() -> Self {
181 Self::Enabled {
182 skip_steps: [0, 1, 2].into_iter().collect(),
183 }
184 }
185}
186
187impl CorrectorConfiguration {
188 pub fn new(disabled_steps: impl IntoIterator<Item = usize>) -> Self {
189 Self::Enabled {
190 skip_steps: disabled_steps.into_iter().collect(),
191 }
192 }
193}
194
195#[derive(Debug, Clone)]
196pub struct UniPCSchedulerConfig {
197 pub corrector: CorrectorConfiguration,
199 pub sigma_schedule: SigmaSchedule,
201 pub timestep_schedule: TimestepSchedule,
203 pub solver_order: usize,
206 pub prediction_type: PredictionType,
208 pub num_training_timesteps: usize,
209 pub thresholding: bool,
212 pub dynamic_thresholding_ratio: f64,
214 pub sample_max_value: f64,
216 pub solver_type: SolverType,
217 pub lower_order_final: bool,
219}
220
221impl Default for UniPCSchedulerConfig {
222 fn default() -> Self {
223 Self {
224 corrector: Default::default(),
225 timestep_schedule: TimestepSchedule::FromSigmas,
226 sigma_schedule: SigmaSchedule::Karras(Default::default()),
227 prediction_type: PredictionType::Epsilon,
228 num_training_timesteps: 1000,
229 solver_order: 2,
230 thresholding: false,
231 dynamic_thresholding_ratio: 0.995,
232 sample_max_value: 1.0,
233 solver_type: SolverType::Bh1,
234 lower_order_final: true,
235 }
236 }
237}
238
239impl SchedulerConfig for UniPCSchedulerConfig {
240 fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>> {
241 Ok(Box::new(EdmDpmMultistepScheduler::new(
242 self.clone(),
243 inference_steps,
244 )?))
245 }
246}
247
248struct State {
249 model_outputs: Vec<Option<Tensor>>,
250 lower_order_nums: usize,
251 order: usize,
252 last_sample: Option<Tensor>,
253}
254
255impl State {
256 fn new(solver_order: usize) -> Self {
257 Self {
258 model_outputs: vec![None; solver_order],
259 lower_order_nums: 0,
260 order: 0,
261 last_sample: None,
262 }
263 }
264
265 fn lower_order_nums(&self) -> usize {
266 self.lower_order_nums
267 }
268
269 fn update_lower_order_nums(&mut self, n: usize) {
270 self.lower_order_nums = n;
271 }
272
273 fn model_outputs(&self) -> &[Option<Tensor>] {
274 self.model_outputs.as_slice()
275 }
276
277 fn update_model_output(&mut self, idx: usize, output: Option<Tensor>) {
278 self.model_outputs[idx] = output;
279 }
280
281 fn last_sample(&self) -> Option<&Tensor> {
282 self.last_sample.as_ref()
283 }
284
285 fn update_last_sample(&mut self, sample: Tensor) {
286 let _ = self.last_sample.replace(sample);
287 }
288
289 fn order(&self) -> usize {
290 self.order
291 }
292
293 fn update_order(&mut self, order: usize) {
294 self.order = order;
295 }
296}
297
298pub struct EdmDpmMultistepScheduler {
299 schedule: Schedule,
300 config: UniPCSchedulerConfig,
301 state: State,
302}
303
304impl EdmDpmMultistepScheduler {
305 pub fn new(config: UniPCSchedulerConfig, num_inference_steps: usize) -> Result<Self> {
306 let schedule = Schedule::new(
307 config.timestep_schedule.clone(),
308 config.sigma_schedule,
309 num_inference_steps,
310 config.num_training_timesteps,
311 )?;
312
313 Ok(Self {
314 schedule,
315 state: State::new(config.solver_order),
316 config,
317 })
318 }
319
320 fn step_index(&self, timestep: usize) -> usize {
321 let index_candidates = self
322 .schedule
323 .timesteps()
324 .iter()
325 .enumerate()
326 .filter(|(_, t)| (*t == ×tep))
327 .map(|(i, _)| i)
328 .collect::<Vec<_>>();
329
330 match index_candidates.len() {
331 0 => 0,
332 1 => index_candidates[0],
333 _ => index_candidates[1],
334 }
335 }
336
337 fn timestep(&self, step_idx: usize) -> usize {
338 self.schedule
339 .timesteps()
340 .get(step_idx)
341 .copied()
342 .unwrap_or(0)
343 }
344
345 fn convert_model_output(
346 &self,
347 model_output: &Tensor,
348 sample: &Tensor,
349 timestep: usize,
350 ) -> Result<Tensor> {
351 let (alpha_t, sigma_t) = (
352 self.schedule.alpha_t(timestep),
353 self.schedule.sigma_t(timestep),
354 );
355
356 let x0_pred = match self.config.prediction_type {
357 PredictionType::Epsilon => ((sample - (model_output * sigma_t))? / alpha_t)?,
358 PredictionType::Sample => model_output.clone(),
359 PredictionType::VPrediction => ((alpha_t * sample)? - (sigma_t * model_output)?)?,
360 };
361
362 if self.config.thresholding {
363 self.threshold_sample(x0_pred)
364 } else {
365 Ok(x0_pred)
366 }
367 }
368
369 fn threshold_sample(&self, sample: Tensor) -> Result<Tensor> {
370 let shape = sample.shape().clone().into_dims();
371 let v = sample
372 .abs()?
373 .reshape((shape[0], shape[1] * shape[2..].iter().product::<usize>()))?
374 .to_dtype(candle::DType::F64)?
375 .to_vec2::<f64>()?;
376 let q = stats::Quantile::new(self.config.dynamic_thresholding_ratio)
377 .with_samples(v.into_iter().flatten());
378 let (threshold, max) = (q.quantile().max(self.config.sample_max_value), q.max());
379
380 sample.clamp(-threshold, threshold)? / (threshold / max).sqrt().min(1.)
381 }
382
383 fn multistep_uni_p_bh_update(&self, sample: &Tensor, timestep: usize) -> Result<Tensor> {
384 let step_index = self.step_index(timestep);
385 let ns = &self.schedule;
386 let model_outputs = self.state.model_outputs();
387 let Some(m0) = &model_outputs[model_outputs.len() - 1] else {
388 return Err(Error::Msg(
389 "Expected model output for predictor update".to_string(),
390 ));
391 };
392
393 let (t0, tt) = (timestep, self.timestep(self.step_index(timestep) + 1));
394 let (sigma_t, sigma_s0) = (ns.sigma_t(tt), ns.sigma_t(t0));
395 let (alpha_t, _alpha_s0) = (ns.alpha_t(tt), ns.alpha_t(t0));
396 let (lambda_t, lambda_s0) = (ns.lambda_t(tt), ns.lambda_t(t0));
397
398 let h = lambda_t - lambda_s0;
399 let device = sample.device();
400
401 let (mut rks, mut d1s) = (vec![], vec![]);
402 for i in 1..self.state.order() {
403 let ti = self.timestep(step_index.saturating_sub(i + 1));
404 let Some(mi) = model_outputs
405 .get(model_outputs.len().saturating_sub(i + 1))
406 .into_iter()
407 .flatten()
408 .next()
409 else {
410 return Err(Error::Msg(
411 "Expected model output for predictor update".to_string(),
412 ));
413 };
414 let (alpha_si, sigma_si) = (ns.alpha_t(ti), ns.sigma_t(ti));
415 let lambda_si = alpha_si.ln() - sigma_si.ln();
416 let rk = (lambda_si - lambda_s0) / h;
417 rks.push(rk);
418 d1s.push(((mi - m0)? / rk)?);
419 }
420 rks.push(1.0);
421 let rks = Tensor::new(rks, device)?;
422 let (mut r, mut b) = (vec![], vec![]);
423
424 let hh = h.neg();
425 let h_phi_1 = hh.exp_m1();
426 let mut h_phi_k = h_phi_1 / hh - 1.;
427 let mut factorial_i = 1.;
428
429 let b_h = match self.config.solver_type {
430 SolverType::Bh1 => hh,
431 SolverType::Bh2 => hh.exp_m1(),
432 };
433
434 for i in 1..self.state.order() + 1 {
435 r.push(rks.powf(i as f64 - 1.)?);
436 b.push(h_phi_k * factorial_i / b_h);
437 factorial_i = i as f64 + 1.;
438 h_phi_k = h_phi_k / hh - 1. / factorial_i;
439 }
440
441 let (r, b) = (Tensor::stack(&r, 0)?, Tensor::new(b, device)?);
442 let (d1s, rhos_p) = match d1s.len() {
443 0 => (None, None),
444 _ => {
445 let rhos_p = match self.state.order() {
446 2 => Tensor::new(&[0.5f64], m0.device())?.to_dtype(m0.dtype())?,
447 _ => {
448 let ((r1, r2), b1) = (r.dims2()?, b.dims1()?);
449 let inverse = linalg::inverse(&r.i((..(r1 - 1), ..(r2 - 1)))?)?;
450 let b = b.i(..(b1 - 1))?;
451 b.broadcast_mul(&inverse)?.sum(1)?.to_dtype(m0.dtype())?
452 }
453 };
454
455 (Some(Tensor::stack(&d1s, 1)?), Some(rhos_p))
456 }
457 };
458
459 let x_t_ = ((sigma_t / sigma_s0 * sample)? - (alpha_t * h_phi_1 * m0)?)?;
460 if let (Some(d1s), Some(rhos_p)) = (d1s, rhos_p) {
461 use linalg::{Permutation, TensordotFixedPosition, TensordotGeneral};
462 let output_shape = m0.shape().clone();
463 let pred_res = TensordotGeneral {
464 lhs_permutation: Permutation { dims: vec![0] },
465 rhs_permutation: Permutation {
466 dims: vec![1, 0, 2, 3, 4],
467 },
468 tensordot_fixed_position: TensordotFixedPosition {
469 len_uncontracted_lhs: 1,
470 len_uncontracted_rhs: output_shape.dims().iter().product::<usize>(),
471 len_contracted_axes: d1s.dim(1)?,
472 output_shape,
473 },
474 output_permutation: Permutation {
475 dims: vec![0, 1, 2, 3],
476 },
477 }
478 .eval(&rhos_p, &d1s)?;
479 x_t_ - (alpha_t * b_h * pred_res)?
480 } else {
481 Ok(x_t_)
482 }
483 }
484
485 fn multistep_uni_c_bh_update(
486 &self,
487 model_output: &Tensor,
488 model_outputs: &[Option<Tensor>],
489 last_sample: &Tensor,
490 sample: &Tensor,
491 timestep: usize,
492 ) -> Result<Tensor> {
493 let step_index = self.step_index(timestep);
494 let Some(m0) = model_outputs.last().into_iter().flatten().next() else {
495 return Err(Error::Msg(
496 "Expected model output for corrector update".to_string(),
497 ));
498 };
499 let model_t = model_output;
500 let (x, _xt) = (last_sample, sample);
501
502 let (t0, tt, ns) = (
503 self.timestep(self.step_index(timestep) - 1),
504 timestep,
505 &self.schedule,
506 );
507 let (sigma_t, sigma_s0) = (ns.sigma_t(tt), ns.sigma_t(t0));
508 let (alpha_t, _alpha_s0) = (ns.alpha_t(tt), ns.alpha_t(t0));
509 let (lambda_t, lambda_s0) = (ns.lambda_t(tt), ns.lambda_t(t0));
510
511 let h = lambda_t - lambda_s0;
512 let device = sample.device();
513
514 let (mut rks, mut d1s) = (vec![], vec![]);
515 for i in 1..self.state.order() {
516 let ti = self.timestep(step_index.saturating_sub(i + 1));
517 let Some(mi) = model_outputs
518 .get(model_outputs.len().saturating_sub(i + 1))
519 .into_iter()
520 .flatten()
521 .next()
522 else {
523 return Err(Error::Msg(
524 "Expected model output for corrector update".to_string(),
525 ));
526 };
527 let (alpha_si, sigma_si) = (ns.alpha_t(ti), ns.sigma_t(ti));
528 let lambda_si = alpha_si.ln() - sigma_si.ln();
529 let rk = (lambda_si - lambda_s0) / h;
530 rks.push(rk);
531 d1s.push(((mi - m0)? / rk)?);
532 }
533 rks.push(1.0);
534 let rks = Tensor::new(rks, device)?;
535 let (mut r, mut b) = (vec![], vec![]);
536
537 let hh = h.neg();
538 let h_phi_1 = hh.exp_m1();
539 let mut h_phi_k = h_phi_1 / hh - 1.;
540 let mut factorial_i = 1.;
541
542 let b_h = match self.config.solver_type {
543 SolverType::Bh1 => hh,
544 SolverType::Bh2 => hh.exp_m1(),
545 };
546
547 for i in 1..self.state.order() + 1 {
548 r.push(rks.powf(i as f64 - 1.)?);
549 b.push(h_phi_k * factorial_i / b_h);
550 factorial_i = i as f64 + 1.;
551 h_phi_k = h_phi_k / hh - 1. / factorial_i;
552 }
553
554 let (r, b) = (Tensor::stack(&r, 0)?, Tensor::new(b, device)?);
555 let d1s = match d1s.len() {
556 0 => None,
557 _ => Some(Tensor::stack(&d1s, 1)?),
558 };
559 let rhos_c = match self.state.order() {
560 1 => Tensor::new(&[0.5f64], m0.device())?.to_dtype(m0.dtype())?,
561 _ => {
562 let inverse = linalg::inverse(&r)?;
563 b.broadcast_mul(&inverse)?.sum(1)?.to_dtype(m0.dtype())?
564 }
565 };
566
567 let x_t_ = ((sigma_t / sigma_s0 * x)? - (alpha_t * h_phi_1 * m0)?)?;
568 let corr_res = d1s
569 .map(|d1s| {
570 use linalg::{Permutation, TensordotFixedPosition, TensordotGeneral};
571 let output_shape = x_t_.shape().clone();
572 TensordotGeneral {
573 lhs_permutation: Permutation { dims: vec![0] },
574 rhs_permutation: Permutation {
575 dims: vec![1, 0, 2, 3, 4],
576 },
577 tensordot_fixed_position: TensordotFixedPosition {
578 len_uncontracted_lhs: 1,
579 len_uncontracted_rhs: output_shape.dims().iter().product::<usize>(),
580 len_contracted_axes: d1s.dim(1)?,
581 output_shape,
582 },
583 output_permutation: Permutation {
584 dims: vec![0, 1, 2, 3],
585 },
586 }
587 .eval(&rhos_c.i(..rhos_c.dims()[0] - 1)?, &d1s)
588 })
589 .unwrap_or_else(|| Tensor::zeros_like(m0))?;
590
591 let d1_t = (model_t - m0)?;
592 let x_t = (x_t_
593 - (alpha_t
594 * b_h
595 * (corr_res + rhos_c.i(rhos_c.dims()[0] - 1)?.broadcast_mul(&d1_t)?)?)?)?;
596
597 Ok(x_t)
598 }
599}
600
601impl Scheduler for EdmDpmMultistepScheduler {
602 fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
603 let step_index = self.step_index(timestep);
604 let model_output_converted = &self.convert_model_output(model_output, sample, timestep)?;
605 let sample = match (&self.config.corrector, self.state.last_sample()) {
606 (CorrectorConfiguration::Enabled { skip_steps: s }, Some(last_sample))
607 if !s.contains(&step_index) && step_index > 0 =>
608 {
609 &self.multistep_uni_c_bh_update(
610 model_output_converted,
611 self.state.model_outputs(),
612 last_sample,
613 sample,
614 timestep,
615 )?
616 }
617 (CorrectorConfiguration::Enabled { .. }, _) | (CorrectorConfiguration::Disabled, _) => {
618 sample
619 }
620 };
621
622 let mut model_outputs = self.state.model_outputs().to_vec();
623 for i in 0..self.config.solver_order.saturating_sub(1) {
624 self.state
625 .update_model_output(i, model_outputs[i + 1].take());
626 }
627 self.state.update_model_output(
628 model_outputs.len() - 1,
629 Some(model_output_converted.clone()),
630 );
631
632 let mut this_order = self.config.solver_order;
633 if self.config.lower_order_final {
634 this_order = self
635 .config
636 .solver_order
637 .min(self.schedule.timesteps.len() - step_index);
638 }
639 self.state
640 .update_order(this_order.min(self.state.lower_order_nums() + 1));
641
642 self.state.update_last_sample(sample.clone());
643 let prev_sample = self.multistep_uni_p_bh_update(sample, timestep)?;
644
645 let lower_order_nums = self.state.lower_order_nums();
646 if lower_order_nums < self.config.solver_order {
647 self.state.update_lower_order_nums(lower_order_nums + 1);
648 }
649
650 Ok(prev_sample)
651 }
652
653 fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> {
654 Ok(sample)
655 }
656
657 fn timesteps(&self) -> &[usize] {
658 &self.schedule.timesteps
659 }
660
661 fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
662 let (alpha_t, sigma_t) = (
663 self.schedule.alpha_t(timestep),
664 self.schedule.sigma_t(timestep),
665 );
666
667 (alpha_t * original)? + (sigma_t * noise)?
668 }
669
670 fn init_noise_sigma(&self) -> f64 {
671 self.schedule.sigma_t(self.schedule.num_training_steps())
672 }
673}
674
675#[derive(Debug, Clone)]
676struct Schedule {
677 timesteps: Vec<usize>,
678 num_training_steps: usize,
679 sigma_schedule: SigmaSchedule,
680 #[allow(unused)]
681 timestep_schedule: TimestepSchedule,
682}
683
684impl Schedule {
685 fn new(
686 timestep_schedule: TimestepSchedule,
687 sigma_schedule: SigmaSchedule,
688 num_inference_steps: usize,
689 num_training_steps: usize,
690 ) -> Result<Self> {
691 Ok(Self {
692 timesteps: timestep_schedule.timesteps(
693 &sigma_schedule,
694 num_inference_steps,
695 num_training_steps,
696 )?,
697 timestep_schedule,
698 sigma_schedule,
699 num_training_steps,
700 })
701 }
702
703 fn timesteps(&self) -> &[usize] {
704 &self.timesteps
705 }
706
707 fn num_training_steps(&self) -> usize {
708 self.num_training_steps
709 }
710
711 fn t(&self, step: usize) -> f64 {
712 (step as f64 + 1.) / self.num_training_steps as f64
713 }
714
715 fn alpha_t(&self, t: usize) -> f64 {
716 (1. / (self.sigma_schedule.sigma_t(self.t(t)).powi(2) + 1.)).sqrt()
717 }
718
719 fn sigma_t(&self, t: usize) -> f64 {
720 self.sigma_schedule.sigma_t(self.t(t)) * self.alpha_t(t)
721 }
722
723 fn lambda_t(&self, t: usize) -> f64 {
724 self.alpha_t(t).ln() - self.sigma_t(t).ln()
725 }
726}
727
728mod stats {
729 use num_traits::{Float, ToPrimitive};
732
733 #[derive(Debug, Clone)]
734 pub struct Quantile {
735 q: [f64; 5],
736 n: [i64; 5],
737 m: [f64; 5],
738 dm: [f64; 5],
739 max: Option<f64>,
740 }
741
742 impl Quantile {
743 pub fn new(p: f64) -> Quantile {
744 assert!((0. ..=1.).contains(&p));
745 Quantile {
746 q: [0.; 5],
747 n: [1, 2, 3, 4, 0],
748 m: [1., 1. + 2. * p, 1. + 4. * p, 3. + 2. * p, 5.],
749 dm: [0., p / 2., p, (1. + p) / 2., 1.],
750 max: None,
751 }
752 }
753
754 pub fn max(&self) -> f64 {
755 self.max.unwrap_or(f64::NAN)
756 }
757
758 fn p(&self) -> f64 {
759 self.dm[2]
760 }
761
762 fn parabolic(&self, i: usize, d: f64) -> f64 {
763 let s = d.round() as i64;
764 self.q[i]
765 + d / (self.n[i + 1] - self.n[i - 1]).to_f64().unwrap()
766 * ((self.n[i] - self.n[i - 1] + s).to_f64().unwrap()
767 * (self.q[i + 1] - self.q[i])
768 / (self.n[i + 1] - self.n[i]).to_f64().unwrap()
769 + (self.n[i + 1] - self.n[i] - s).to_f64().unwrap()
770 * (self.q[i] - self.q[i - 1])
771 / (self.n[i] - self.n[i - 1]).to_f64().unwrap())
772 }
773
774 fn linear(&self, i: usize, d: f64) -> f64 {
775 let sum = if d < 0. { i - 1 } else { i + 1 };
776 self.q[i] + d * (self.q[sum] - self.q[i]) / (self.n[sum] - self.n[i]).to_f64().unwrap()
777 }
778
779 pub fn quantile(&self) -> f64 {
780 if self.len() >= 5 {
781 return self.q[2];
782 }
783
784 if self.is_empty() {
785 return f64::NAN;
786 }
787 let mut heights: [f64; 4] = [self.q[0], self.q[1], self.q[2], self.q[3]];
788 let len = self.len() as usize;
789 debug_assert!(len < 5);
790 sort_floats(&mut heights[..len]);
791 let desired_index = (len as f64) * self.p() - 1.;
792 let mut index = desired_index.ceil();
793 if desired_index == index && index >= 0. {
794 let index = index.round() as usize;
795 debug_assert!(index < 5);
796 if index < len - 1 {
797 return 0.5 * self.q[index] + 0.5 * self.q[index + 1];
798 }
799 }
800 index = index.max(0.);
801 let mut index = index.round() as usize;
802 debug_assert!(index < 5);
803 index = index.min(len - 1);
804 self.q[index]
805 }
806
807 fn len(&self) -> u64 {
808 self.n[4] as u64
809 }
810
811 fn is_empty(&self) -> bool {
812 self.len() == 0
813 }
814
815 pub fn add(&mut self, x: f64) {
816 self.max = self.max.map(|y| y.max(x)).or(Some(x));
817
818 if self.n[4] < 5 {
819 self.q[self.n[4] as usize] = x;
820 self.n[4] += 1;
821 if self.n[4] == 5 {
822 sort_floats(&mut self.q);
823 }
824 return;
825 }
826
827 let mut k: usize;
828 if x < self.q[0] {
829 self.q[0] = x;
830 k = 0;
831 } else {
832 k = 4;
833 for i in 1..5 {
834 if x < self.q[i] {
835 k = i;
836 break;
837 }
838 }
839 if self.q[4] < x {
840 self.q[4] = x;
841 }
842 };
843
844 for i in k..5 {
845 self.n[i] += 1;
846 }
847 for i in 0..5 {
848 self.m[i] += self.dm[i];
849 }
850
851 for i in 1..4 {
852 let d = self.m[i] - self.n[i].to_f64().unwrap();
853 if d >= 1. && self.n[i + 1] - self.n[i] > 1
854 || d <= -1. && self.n[i - 1] - self.n[i] < -1
855 {
856 let d = Float::signum(d);
857 let q_new = self.parabolic(i, d);
858 if self.q[i - 1] < q_new && q_new < self.q[i + 1] {
859 self.q[i] = q_new;
860 } else {
861 self.q[i] = self.linear(i, d);
862 }
863 let delta = d.round() as i64;
864 debug_assert_eq!(delta.abs(), 1);
865 self.n[i] += delta;
866 }
867 }
868 }
869
870 pub fn with_samples(mut self, samples: impl IntoIterator<Item = f64>) -> Self {
871 for sample in samples {
872 self.add(sample);
873 }
874
875 self
876 }
877 }
878
879 fn sort_floats(v: &mut [f64]) {
880 v.sort_unstable_by(|a, b| a.total_cmp(b));
881 }
882}
883
884mod linalg {
885 use candle::{IndexOp, Result, Shape, Tensor};
886
887 pub fn inverse(m: &Tensor) -> Result<Tensor> {
888 adjoint(m)? / determinant(m)?.to_scalar::<f64>()?
889 }
890
891 pub fn adjoint(m: &Tensor) -> Result<Tensor> {
892 cofactor(m)?.transpose(0, 1)
893 }
894
895 pub fn cofactor(m: &Tensor) -> Result<Tensor> {
896 let s = m.shape().dim(0)?;
897 if s == 2 {
898 let mut v = vec![];
899 for i in 0..2 {
900 let mut x = vec![];
901 for j in 0..2 {
902 x.push((m.i((i, j))? * (-1.0f64).powi(i as i32 + j as i32))?)
903 }
904 v.push(Tensor::stack(&x, 0)?.unsqueeze(0)?);
905 }
906 return Tensor::stack(&v, 1)?.squeeze(0);
907 }
908
909 let minors = minors(m)?;
910 let mut v = vec![];
911 for i in 0..s {
912 let mut x = vec![];
913 for j in 0..s {
914 let det = (determinant(&minors.i((i, j))?)?
915 * ((-1.0f64).powi(i as i32) * (-1.0f64).powi(j as i32)))?;
916 x.push(det);
917 }
918 v.push(Tensor::stack(&x, 0)?.unsqueeze(0)?);
919 }
920
921 Tensor::stack(&v, 1)?.squeeze(0)
922 }
923
924 pub fn determinant(m: &Tensor) -> Result<Tensor> {
925 let s = m.shape().dim(0)?;
926 if s == 2 {
927 return (m.i((0, 0))? * m.i((1, 1))?)? - (m.i((0, 1))? * m.i((1, 0))?);
928 }
929
930 let cofactor = cofactor(m)?;
931 let m0 = m.i((0, 0))?;
932 let det = (0..s)
933 .map(|i| (m.i((0, i))? * cofactor.i((0, i))?))
934 .try_fold(m0.zeros_like()?, |acc, cur| (acc + cur?))?;
935
936 Ok(det)
937 }
938
939 pub fn minors(m: &Tensor) -> Result<Tensor> {
940 let s = m.shape().dim(0)?;
941 if s == 1 {
942 return m.i((0, 0));
943 }
944
945 let mut v = vec![];
946 for i in 0..s {
947 let msub = Tensor::cat(&[m.i((..i, ..))?, m.i(((i + 1).., ..))?], 0)?;
948 let mut x = vec![];
949 for j in 0..s {
950 let t = Tensor::cat(&[msub.i((.., ..j))?, msub.i((.., (j + 1)..))?], 1)?;
951 x.push(t);
952 }
953 v.push(Tensor::stack(&x, 0)?.unsqueeze(0)?);
954 }
955
956 Tensor::stack(&v, 1)?.squeeze(0)
957 }
958
959 #[derive(Debug)]
960 pub struct TensordotGeneral {
961 pub lhs_permutation: Permutation,
962 pub rhs_permutation: Permutation,
963 pub tensordot_fixed_position: TensordotFixedPosition,
964 pub output_permutation: Permutation,
965 }
966
967 impl TensordotGeneral {
968 pub fn eval(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor> {
969 let permuted_lhs = self.lhs_permutation.eval(lhs)?;
970 let permuted_rhs = self.rhs_permutation.eval(rhs)?;
971 let tensordotted = self
972 .tensordot_fixed_position
973 .eval(&permuted_lhs, &permuted_rhs)?;
974 self.output_permutation.eval(&tensordotted)
975 }
976 }
977
978 #[derive(Debug)]
979 pub struct TensordotFixedPosition {
980 pub len_uncontracted_lhs: usize,
981 pub len_uncontracted_rhs: usize,
982 pub len_contracted_axes: usize,
983 pub output_shape: Shape,
984 }
985
986 impl TensordotFixedPosition {
987 fn eval(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor> {
988 let lhs_view = lhs.reshape((self.len_uncontracted_lhs, self.len_contracted_axes))?;
989 let rhs_view = rhs.reshape((self.len_contracted_axes, self.len_uncontracted_rhs))?;
990
991 lhs_view.matmul(&rhs_view)?.reshape(&self.output_shape)
992 }
993 }
994
995 #[derive(Debug)]
996 pub struct Permutation {
997 pub dims: Vec<usize>,
998 }
999
1000 impl Permutation {
1001 fn eval(&self, tensor: &Tensor) -> Result<Tensor> {
1002 tensor.permute(self.dims.as_slice())
1003 }
1004 }
1005}