candle_transformers/generation/
mod.rs1use candle::{Context, DType, Error, Result, Tensor};
7use rand::{distr::Distribution, SeedableRng};
8
9#[derive(Clone, PartialEq, Debug)]
10pub enum Sampling {
11 ArgMax,
12 All { temperature: f64 },
13 TopK { k: usize, temperature: f64 },
14 TopP { p: f64, temperature: f64 },
15 TopKThenTopP { k: usize, p: f64, temperature: f64 },
16}
17
18pub struct LogitsProcessor {
19 rng: rand::rngs::StdRng,
20 sampling: Sampling,
21}
22
23impl LogitsProcessor {
24 pub fn from_sampling(seed: u64, sampling: Sampling) -> Self {
25 let rng = rand::rngs::StdRng::seed_from_u64(seed);
26 Self { rng, sampling }
27 }
28
29 pub fn new(seed: u64, temperature: Option<f64>, top_p: Option<f64>) -> Self {
30 let temperature = temperature.and_then(|v| if v < 1e-7 { None } else { Some(v) });
31 let sampling = match temperature {
32 None => Sampling::ArgMax,
33 Some(temperature) => match top_p {
34 None => Sampling::All { temperature },
35 Some(p) => Sampling::TopP { p, temperature },
36 },
37 };
38 Self::from_sampling(seed, sampling)
39 }
40
41 fn sample_argmax(&mut self, logits: Tensor) -> Result<u32> {
42 let logits_v: Vec<f32> = logits.to_vec1()?;
43 let next_token = logits_v
44 .iter()
45 .enumerate()
46 .max_by(|(_, u), (_, v)| u.total_cmp(v))
47 .map(|(i, _)| i as u32)
48 .context("empty logits")?;
49 Ok(next_token)
50 }
51
52 fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
53 let distr = rand::distr::weighted::WeightedIndex::new(prs).map_err(Error::wrap)?;
54 let next_token = distr.sample(&mut self.rng) as u32;
55 Ok(next_token)
56 }
57
58 fn sample_topp(&mut self, prs: &mut Vec<f32>, top_p: f32) -> Result<u32> {
62 let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
63
64 argsort_indices.sort_by(|&i, &j| prs[j].total_cmp(&prs[i]));
66
67 let mut cumsum = 0.;
69 for index in &argsort_indices {
70 if cumsum >= top_p {
71 prs[*index] = 0.0;
72 } else {
73 cumsum += prs[*index];
74 }
75 }
76 self.sample_multinomial(prs)
78 }
79
80 fn sample_topk(&mut self, prs: &mut Vec<f32>, top_k: usize) -> Result<u32> {
82 if top_k >= prs.len() {
83 self.sample_multinomial(prs)
84 } else {
85 let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
86 let (indices, _, _) =
87 argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
88 let prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
89 let index = self.sample_multinomial(&prs)?;
90 Ok(indices[index as usize] as u32)
91 }
92 }
93
94 fn sample_topk_topp(&mut self, prs: &mut Vec<f32>, top_k: usize, top_p: f32) -> Result<u32> {
97 if top_k >= prs.len() {
98 self.sample_topp(prs, top_p)
99 } else {
100 let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
101 let (indices, _, _) =
102 argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
103 let mut prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
104 let sum_p = prs.iter().sum::<f32>();
105 let index = if top_p <= 0.0 || top_p >= sum_p {
106 self.sample_multinomial(&prs)?
107 } else {
108 self.sample_topp(&mut prs, top_p)?
109 };
110 Ok(indices[index as usize] as u32)
111 }
112 }
113
114 pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
115 self.sample_f(logits, |_| {})
116 }
117
118 pub fn sample_f(&mut self, logits: &Tensor, f: impl FnOnce(&mut [f32])) -> Result<u32> {
119 let logits = logits.to_dtype(DType::F32)?;
120 let prs = |temperature: f64| -> Result<Vec<f32>> {
121 let logits = (&logits / temperature)?;
122 let prs = candle_nn::ops::softmax_last_dim(&logits)?;
123 let mut prs = prs.to_vec1()?;
124 f(&mut prs);
125 Ok(prs)
126 };
127
128 let next_token = match &self.sampling {
129 Sampling::ArgMax => self.sample_argmax(logits)?,
130 Sampling::All { temperature } => {
131 let prs = prs(*temperature)?;
132 self.sample_multinomial(&prs)?
133 }
134 Sampling::TopP { p, temperature } => {
135 let mut prs = prs(*temperature)?;
136 if *p <= 0.0 || *p >= 1.0 {
137 self.sample_multinomial(&prs)?
139 } else {
140 self.sample_topp(&mut prs, *p as f32)?
142 }
143 }
144 Sampling::TopK { k, temperature } => {
145 let mut prs = prs(*temperature)?;
146 self.sample_topk(&mut prs, *k)?
147 }
148 Sampling::TopKThenTopP { k, p, temperature } => {
149 let mut prs = prs(*temperature)?;
150 self.sample_topk_topp(&mut prs, *k, *p as f32)?
151 }
152 };
153 Ok(next_token)
154 }
155}