candle_transformers/generation/
mod.rs

1//! Logit Processing and Sampling
2//!
3//! Functionality for modeling sampling strategies and logits processing in text generation
4//! with support for temperature-based sampling, top-k filtering, nucleus sampling (top-p),
5//! and combinations thereof.
6use candle::{Context, DType, Error, Result, Tensor};
7use rand::{distr::Distribution, SeedableRng};
8
9#[derive(Clone, PartialEq, Debug)]
10pub enum Sampling {
11    ArgMax,
12    All { temperature: f64 },
13    TopK { k: usize, temperature: f64 },
14    TopP { p: f64, temperature: f64 },
15    TopKThenTopP { k: usize, p: f64, temperature: f64 },
16}
17
18pub struct LogitsProcessor {
19    rng: rand::rngs::StdRng,
20    sampling: Sampling,
21}
22
23impl LogitsProcessor {
24    pub fn from_sampling(seed: u64, sampling: Sampling) -> Self {
25        let rng = rand::rngs::StdRng::seed_from_u64(seed);
26        Self { rng, sampling }
27    }
28
29    pub fn new(seed: u64, temperature: Option<f64>, top_p: Option<f64>) -> Self {
30        let temperature = temperature.and_then(|v| if v < 1e-7 { None } else { Some(v) });
31        let sampling = match temperature {
32            None => Sampling::ArgMax,
33            Some(temperature) => match top_p {
34                None => Sampling::All { temperature },
35                Some(p) => Sampling::TopP { p, temperature },
36            },
37        };
38        Self::from_sampling(seed, sampling)
39    }
40
41    fn sample_argmax(&mut self, logits: Tensor) -> Result<u32> {
42        let logits_v: Vec<f32> = logits.to_vec1()?;
43        let next_token = logits_v
44            .iter()
45            .enumerate()
46            .max_by(|(_, u), (_, v)| u.total_cmp(v))
47            .map(|(i, _)| i as u32)
48            .context("empty logits")?;
49        Ok(next_token)
50    }
51
52    fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
53        let distr = rand::distr::weighted::WeightedIndex::new(prs).map_err(Error::wrap)?;
54        let next_token = distr.sample(&mut self.rng) as u32;
55        Ok(next_token)
56    }
57
58    /// top-p sampling (or "nucleus sampling") samples from the smallest set of tokens that exceed
59    /// probability top_p. This way we never sample tokens that have very low probabilities and are
60    /// less likely to go "off the rails".
61    fn sample_topp(&mut self, prs: &mut Vec<f32>, top_p: f32) -> Result<u32> {
62        let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
63
64        // Sort by descending probability.
65        argsort_indices.sort_by(|&i, &j| prs[j].total_cmp(&prs[i]));
66
67        // Clamp smaller probabilities to zero.
68        let mut cumsum = 0.;
69        for index in &argsort_indices {
70            if cumsum >= top_p {
71                prs[*index] = 0.0;
72            } else {
73                cumsum += prs[*index];
74            }
75        }
76        // Sample with clamped probabilities.
77        self.sample_multinomial(prs)
78    }
79
80    // top-k sampling samples from the k tokens with the largest probabilities.
81    fn sample_topk(&mut self, prs: &mut Vec<f32>, top_k: usize) -> Result<u32> {
82        if top_k >= prs.len() {
83            self.sample_multinomial(prs)
84        } else {
85            let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
86            let (indices, _, _) =
87                argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
88            let prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
89            let index = self.sample_multinomial(&prs)?;
90            Ok(indices[index as usize] as u32)
91        }
92    }
93
94    // top-k sampling samples from the k tokens with the largest probabilities.
95    // then top-p sampling.
96    fn sample_topk_topp(&mut self, prs: &mut Vec<f32>, top_k: usize, top_p: f32) -> Result<u32> {
97        if top_k >= prs.len() {
98            self.sample_topp(prs, top_p)
99        } else {
100            let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
101            let (indices, _, _) =
102                argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
103            let mut prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
104            let sum_p = prs.iter().sum::<f32>();
105            let index = if top_p <= 0.0 || top_p >= sum_p {
106                self.sample_multinomial(&prs)?
107            } else {
108                self.sample_topp(&mut prs, top_p)?
109            };
110            Ok(indices[index as usize] as u32)
111        }
112    }
113
114    pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
115        self.sample_f(logits, |_| {})
116    }
117
118    pub fn sample_f(&mut self, logits: &Tensor, f: impl FnOnce(&mut [f32])) -> Result<u32> {
119        let logits = logits.to_dtype(DType::F32)?;
120        let prs = |temperature: f64| -> Result<Vec<f32>> {
121            let logits = (&logits / temperature)?;
122            let prs = candle_nn::ops::softmax_last_dim(&logits)?;
123            let mut prs = prs.to_vec1()?;
124            f(&mut prs);
125            Ok(prs)
126        };
127
128        let next_token = match &self.sampling {
129            Sampling::ArgMax => self.sample_argmax(logits)?,
130            Sampling::All { temperature } => {
131                let prs = prs(*temperature)?;
132                self.sample_multinomial(&prs)?
133            }
134            Sampling::TopP { p, temperature } => {
135                let mut prs = prs(*temperature)?;
136                if *p <= 0.0 || *p >= 1.0 {
137                    // simply sample from the predicted probability distribution
138                    self.sample_multinomial(&prs)?
139                } else {
140                    // top-p (nucleus) sampling, clamping the least likely tokens to zero
141                    self.sample_topp(&mut prs, *p as f32)?
142                }
143            }
144            Sampling::TopK { k, temperature } => {
145                let mut prs = prs(*temperature)?;
146                self.sample_topk(&mut prs, *k)?
147            }
148            Sampling::TopKThenTopP { k, p, temperature } => {
149                let mut prs = prs(*temperature)?;
150                self.sample_topk_topp(&mut prs, *k, *p as f32)?
151            }
152        };
153        Ok(next_token)
154    }
155}