candle_transformers/models/mimi/
quantization.rs

1// Copyright (c) Kyutai, all rights reserved.
2// This source code is licensed under the license found in the
3// LICENSE file in the root directory of this source tree.
4
5use candle::{IndexOp, Layout, Result, Shape, Tensor, D};
6use candle_nn::{linear, Linear, VarBuilder};
7
8struct CodebookEncode;
9
10impl candle::CustomOp2 for CodebookEncode {
11    fn name(&self) -> &'static str {
12        "cb"
13    }
14
15    fn cpu_fwd(
16        &self,
17        lhs_storage: &candle::CpuStorage,
18        lhs_layout: &Layout,
19        rhs_storage: &candle::CpuStorage,
20        rhs_layout: &Layout,
21    ) -> Result<(candle::CpuStorage, Shape)> {
22        use rayon::prelude::*;
23
24        let (lhs_dim1, lhs_dim2) = lhs_layout.shape().dims2()?;
25        let (rhs_dim1, rhs_dim2) = rhs_layout.shape().dims2()?;
26        if lhs_dim2 != rhs_dim2 {
27            candle::bail!("CodebookEncode, mismatch on last dim, {lhs_layout:?} {rhs_layout:?}");
28        }
29        if lhs_dim2 == 0 {
30            candle::bail!("CodebookEncode, empty last dim {lhs_layout:?}")
31        }
32        let lhs = match lhs_layout.contiguous_offsets() {
33            None => candle::bail!("CodebookEncode, lhs has to be contiguous, got {lhs_layout:?}"),
34            Some((o1, o2)) => {
35                let slice = lhs_storage.as_slice::<f32>()?;
36                &slice[o1..o2]
37            }
38        };
39        let rhs = match rhs_layout.contiguous_offsets() {
40            None => candle::bail!("CodebookEncode, rhs has to be contiguous, got {rhs_layout:?}"),
41            Some((o1, o2)) => {
42                let slice = rhs_storage.as_slice::<f32>()?;
43                &slice[o1..o2]
44            }
45        };
46        let dst = (0..lhs_dim1)
47            .into_par_iter()
48            .map(|idx1| {
49                let mut where_min = 0;
50                let mut min_dist = f32::INFINITY;
51                let lhs = &lhs[idx1 * lhs_dim2..(idx1 + 1) * lhs_dim2];
52                for idx2 in 0..rhs_dim1 {
53                    let rhs = &rhs[idx2 * rhs_dim2..(idx2 + 1) * rhs_dim2];
54                    let mut dist = 0f32;
55                    for (a, b) in lhs.iter().zip(rhs.iter()) {
56                        dist += (a - b) * (a - b)
57                    }
58                    if dist < min_dist {
59                        min_dist = dist;
60                        where_min = idx2;
61                    }
62                }
63                where_min as u32
64            })
65            .collect();
66        let storage = candle::WithDType::to_cpu_storage_owned(dst);
67        Ok((storage, (lhs_dim1,).into()))
68    }
69}
70
71#[allow(unused)]
72#[derive(Debug, Clone)]
73pub struct EuclideanCodebook {
74    initialized: Tensor,
75    cluster_usage: Tensor,
76    embedding_sum: Tensor,
77    embedding: Tensor,
78    c2: Tensor,
79    epsilon: f64,
80    dim: usize,
81    span_encode: tracing::Span,
82    span_decode: tracing::Span,
83}
84
85impl EuclideanCodebook {
86    pub fn new(dim: usize, codebook_size: usize, vb: VarBuilder) -> Result<Self> {
87        let epsilon = 1e-5;
88        let initialized = vb.get(1, "initialized")?;
89        let cluster_usage = vb.get(codebook_size, "cluster_usage")?;
90        let embedding_sum = vb.get((codebook_size, dim), "embed_sum")?;
91        let embedding = {
92            let cluster_usage = cluster_usage.maximum(epsilon)?.unsqueeze(1)?;
93            embedding_sum.broadcast_div(&cluster_usage)?
94        };
95        let c2 = ((&embedding * &embedding)?.sum(D::Minus1)? / 2.0)?;
96        Ok(Self {
97            initialized,
98            cluster_usage,
99            embedding_sum,
100            embedding,
101            c2,
102            epsilon,
103            dim,
104            span_encode: tracing::span!(tracing::Level::TRACE, "euclidean-encode"),
105            span_decode: tracing::span!(tracing::Level::TRACE, "euclidean-encode"),
106        })
107    }
108
109    pub fn encode_very_slow(&self, xs: &Tensor) -> Result<Tensor> {
110        let _enter = self.span_encode.enter();
111        let mut target_shape = xs.dims().to_vec();
112        target_shape.pop();
113        let xs = xs.flatten_to(D::Minus2)?;
114        let _ = xs.dims2()?;
115        // TODO: avoid repeating this.
116        let cluster_usage = self.cluster_usage.maximum(self.epsilon)?.unsqueeze(1)?;
117        let embedding = self.embedding_sum.broadcast_div(&cluster_usage)?;
118        // Manual cdist implementation.
119        let diff = xs.unsqueeze(1)?.broadcast_sub(&embedding.unsqueeze(0)?)?;
120        let dists = diff.sqr()?.sum(D::Minus1)?;
121        let codes = dists.argmin(D::Minus1)?;
122        codes.reshape(target_shape)
123    }
124
125    pub fn encode_slow(&self, xs: &Tensor) -> Result<Tensor> {
126        let _enter = self.span_encode.enter();
127        let mut target_shape = xs.dims().to_vec();
128        target_shape.pop();
129        let xs = xs.flatten_to(D::Minus2)?;
130        let _ = xs.dims2()?;
131        let dot_prod = xs.matmul(&self.embedding.t()?)?;
132        let codes = self.c2.broadcast_sub(&dot_prod)?.argmin(D::Minus1)?;
133        codes.reshape(target_shape)
134    }
135
136    pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
137        let _enter = self.span_encode.enter();
138        let mut target_shape = xs.dims().to_vec();
139        target_shape.pop();
140        let xs = xs.flatten_to(D::Minus2)?;
141        let _ = xs.dims2()?;
142        let codes = Tensor::apply_op2(&xs, &self.embedding, CodebookEncode)?;
143        codes.reshape(target_shape)
144    }
145
146    pub fn decode(&self, indexes: &Tensor) -> Result<Tensor> {
147        let _enter = self.span_decode.enter();
148        // let ys = candle_nn::Embedding::new(self.embedding.clone(), self.dim).forward(xs)?;
149        let mut final_dims = indexes.dims().to_vec();
150        final_dims.push(self.dim);
151        let indexes = indexes.flatten_all()?;
152        let values = self.embedding.index_select(&indexes, 0)?;
153        let values = values.reshape(final_dims)?;
154        Ok(values)
155    }
156}
157
158#[allow(unused)]
159#[derive(Debug, Clone)]
160pub struct VectorQuantization {
161    project_in: Option<Linear>,
162    project_out: Option<Linear>,
163    codebook: EuclideanCodebook,
164}
165
166impl VectorQuantization {
167    pub fn new(
168        dim: usize,
169        codebook_size: usize,
170        codebook_dim: Option<usize>,
171        vb: VarBuilder,
172    ) -> Result<Self> {
173        let codebook_dim = codebook_dim.unwrap_or(dim);
174        let (project_in, project_out) = if codebook_dim == dim {
175            (None, None)
176        } else {
177            let p_in = linear(dim, codebook_dim, vb.pp("project_in"))?;
178            let p_out = linear(codebook_dim, dim, vb.pp("project_out"))?;
179            (Some(p_in), Some(p_out))
180        };
181        let codebook = EuclideanCodebook::new(codebook_dim, codebook_size, vb.pp("codebook"))?;
182        Ok(Self {
183            project_in,
184            project_out,
185            codebook,
186        })
187    }
188
189    pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
190        let xs = xs.t()?.apply(&self.project_in.as_ref())?;
191        self.codebook.encode_slow(&xs)
192    }
193
194    pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
195        let quantized = self.codebook.decode(codes)?;
196        let quantized = match &self.project_out {
197            None => quantized,
198            Some(p) => quantized.apply(p)?,
199        };
200        quantized.t()
201    }
202}
203
204#[derive(Debug, Clone)]
205pub struct ResidualVectorQuantization {
206    layers: Vec<VectorQuantization>,
207}
208
209impl ResidualVectorQuantization {
210    pub fn new(
211        n_q: usize,
212        dim: usize,
213        codebook_size: usize,
214        codebook_dim: Option<usize>,
215        vb: VarBuilder,
216    ) -> Result<Self> {
217        let vb = vb.pp("layers");
218        let mut layers = Vec::with_capacity(n_q);
219        for i in 0..n_q {
220            let layer = VectorQuantization::new(dim, codebook_size, codebook_dim, vb.pp(i))?;
221            layers.push(layer)
222        }
223        Ok(Self { layers })
224    }
225
226    pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
227        let mut codes = Vec::with_capacity(self.layers.len());
228        let mut residual = xs.clone();
229        for layer in self.layers.iter() {
230            let indices = layer.encode(&residual)?;
231            let quantized = layer.decode(&indices)?;
232            residual = (residual - quantized)?;
233            codes.push(indices)
234        }
235        Tensor::stack(&codes, 0)
236    }
237
238    pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
239        if self.layers.is_empty() {
240            candle::bail!("empty layers in ResidualVectorQuantization")
241        }
242        if self.layers.len() != xs.dim(0)? {
243            candle::bail!(
244                "mismatch between the number of layers {} and the code shape {:?}",
245                self.layers.len(),
246                xs.shape()
247            )
248        }
249        let mut quantized = self.layers[0].decode(&xs.i(0)?)?;
250        for (i, layer) in self.layers.iter().enumerate().skip(1) {
251            let xs = xs.i(i)?;
252            quantized = (quantized + layer.decode(&xs))?
253        }
254        Ok(quantized)
255    }
256}
257
258#[allow(unused)]
259#[derive(Debug, Clone)]
260pub struct ResidualVectorQuantizer {
261    vq: ResidualVectorQuantization,
262    input_proj: Option<candle_nn::Conv1d>,
263    output_proj: Option<candle_nn::Conv1d>,
264}
265
266impl ResidualVectorQuantizer {
267    pub fn new(
268        dim: usize,
269        input_dim: Option<usize>,
270        output_dim: Option<usize>,
271        n_q: usize,
272        bins: usize,
273        force_projection: bool,
274        vb: VarBuilder,
275    ) -> Result<Self> {
276        let input_dim = input_dim.unwrap_or(dim);
277        let output_dim = output_dim.unwrap_or(dim);
278
279        let input_proj = if input_dim == dim && !force_projection {
280            None
281        } else {
282            let c = candle_nn::conv1d_no_bias(
283                input_dim,
284                dim,
285                1,
286                Default::default(),
287                vb.pp("input_proj"),
288            )?;
289            Some(c)
290        };
291        let output_proj = if output_dim == dim && !force_projection {
292            None
293        } else {
294            let c = candle_nn::conv1d_no_bias(
295                dim,
296                output_dim,
297                1,
298                Default::default(),
299                vb.pp("output_proj"),
300            )?;
301            Some(c)
302        };
303
304        let vq = ResidualVectorQuantization::new(
305            n_q, dim, /* codebook_size */ bins, /* codebook_dim */ None, vb,
306        )?;
307        Ok(Self {
308            vq,
309            input_proj,
310            output_proj,
311        })
312    }
313
314    pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
315        let codes = self.vq.encode(&xs.apply(&self.input_proj.as_ref())?)?;
316        codes.transpose(0, 1)
317    }
318
319    pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
320        // codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
321        let codes = codes.transpose(0, 1)?;
322        let quantized = self.vq.decode(&codes)?;
323        match &self.output_proj {
324            None => Ok(quantized),
325            Some(p) => quantized.apply(p),
326        }
327    }
328}
329
330// we do not use any codebook_offset at the moment. When reconstructing the codes, we could just
331// concatenate the indexes.
332#[derive(Debug, Clone)]
333pub struct SplitResidualVectorQuantizer {
334    rvq_first: ResidualVectorQuantizer,
335    rvq_rest: ResidualVectorQuantizer,
336    n_q: usize,
337    span_encode: tracing::Span,
338    span_decode: tracing::Span,
339}
340
341impl SplitResidualVectorQuantizer {
342    pub fn new(
343        dim: usize,
344        input_dim: Option<usize>,
345        output_dim: Option<usize>,
346        n_q: usize,
347        bins: usize,
348        vb: VarBuilder,
349    ) -> Result<Self> {
350        let rvq_first = ResidualVectorQuantizer::new(
351            dim,
352            input_dim,
353            output_dim,
354            1,
355            bins,
356            true,
357            vb.pp("semantic_residual_vector_quantizer"),
358        )?;
359        let rvq_rest = ResidualVectorQuantizer::new(
360            dim,
361            input_dim,
362            output_dim,
363            n_q - 1,
364            bins,
365            true,
366            vb.pp("acoustic_residual_vector_quantizer"),
367        )?;
368        let span_encode = tracing::span!(tracing::Level::TRACE, "split-rvq-encode");
369        let span_decode = tracing::span!(tracing::Level::TRACE, "split-rvq-decode");
370        Ok(Self {
371            rvq_first,
372            rvq_rest,
373            n_q,
374            span_encode,
375            span_decode,
376        })
377    }
378
379    pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
380        let _enter = self.span_encode.enter();
381        let codes = self.rvq_first.encode(xs)?;
382        if self.n_q > 1 {
383            // We encode xs again here rather than the residual. The decomposition is not
384            // hierarchical but rather having semantic tokens for rvq_first and the acoustic tokens
385            // for rvq_rest.
386            let rest_codes = self.rvq_rest.encode(xs)?;
387            Tensor::cat(&[codes, rest_codes], 1)
388        } else {
389            Ok(codes)
390        }
391    }
392
393    pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
394        // codes is [B, K, T], with T frames, K nb of codebooks.
395        let _enter = self.span_decode.enter();
396        let quantized = self.rvq_first.decode(&codes.i((.., ..1))?)?;
397        let quantized = if self.n_q > 1 {
398            (quantized + self.rvq_rest.decode(&codes.i((.., 1..))?))?
399        } else {
400            quantized
401        };
402        Ok(quantized)
403    }
404}