1use candle::{IndexOp, Layout, Result, Shape, Tensor, D};
6use candle_nn::{linear, Linear, VarBuilder};
7
8struct CodebookEncode;
9
10impl candle::CustomOp2 for CodebookEncode {
11 fn name(&self) -> &'static str {
12 "cb"
13 }
14
15 fn cpu_fwd(
16 &self,
17 lhs_storage: &candle::CpuStorage,
18 lhs_layout: &Layout,
19 rhs_storage: &candle::CpuStorage,
20 rhs_layout: &Layout,
21 ) -> Result<(candle::CpuStorage, Shape)> {
22 use rayon::prelude::*;
23
24 let (lhs_dim1, lhs_dim2) = lhs_layout.shape().dims2()?;
25 let (rhs_dim1, rhs_dim2) = rhs_layout.shape().dims2()?;
26 if lhs_dim2 != rhs_dim2 {
27 candle::bail!("CodebookEncode, mismatch on last dim, {lhs_layout:?} {rhs_layout:?}");
28 }
29 if lhs_dim2 == 0 {
30 candle::bail!("CodebookEncode, empty last dim {lhs_layout:?}")
31 }
32 let lhs = match lhs_layout.contiguous_offsets() {
33 None => candle::bail!("CodebookEncode, lhs has to be contiguous, got {lhs_layout:?}"),
34 Some((o1, o2)) => {
35 let slice = lhs_storage.as_slice::<f32>()?;
36 &slice[o1..o2]
37 }
38 };
39 let rhs = match rhs_layout.contiguous_offsets() {
40 None => candle::bail!("CodebookEncode, rhs has to be contiguous, got {rhs_layout:?}"),
41 Some((o1, o2)) => {
42 let slice = rhs_storage.as_slice::<f32>()?;
43 &slice[o1..o2]
44 }
45 };
46 let dst = (0..lhs_dim1)
47 .into_par_iter()
48 .map(|idx1| {
49 let mut where_min = 0;
50 let mut min_dist = f32::INFINITY;
51 let lhs = &lhs[idx1 * lhs_dim2..(idx1 + 1) * lhs_dim2];
52 for idx2 in 0..rhs_dim1 {
53 let rhs = &rhs[idx2 * rhs_dim2..(idx2 + 1) * rhs_dim2];
54 let mut dist = 0f32;
55 for (a, b) in lhs.iter().zip(rhs.iter()) {
56 dist += (a - b) * (a - b)
57 }
58 if dist < min_dist {
59 min_dist = dist;
60 where_min = idx2;
61 }
62 }
63 where_min as u32
64 })
65 .collect();
66 let storage = candle::WithDType::to_cpu_storage_owned(dst);
67 Ok((storage, (lhs_dim1,).into()))
68 }
69}
70
71#[allow(unused)]
72#[derive(Debug, Clone)]
73pub struct EuclideanCodebook {
74 initialized: Tensor,
75 cluster_usage: Tensor,
76 embedding_sum: Tensor,
77 embedding: Tensor,
78 c2: Tensor,
79 epsilon: f64,
80 dim: usize,
81 span_encode: tracing::Span,
82 span_decode: tracing::Span,
83}
84
85impl EuclideanCodebook {
86 pub fn new(dim: usize, codebook_size: usize, vb: VarBuilder) -> Result<Self> {
87 let epsilon = 1e-5;
88 let initialized = vb.get(1, "initialized")?;
89 let cluster_usage = vb.get(codebook_size, "cluster_usage")?;
90 let embedding_sum = vb.get((codebook_size, dim), "embed_sum")?;
91 let embedding = {
92 let cluster_usage = cluster_usage.maximum(epsilon)?.unsqueeze(1)?;
93 embedding_sum.broadcast_div(&cluster_usage)?
94 };
95 let c2 = ((&embedding * &embedding)?.sum(D::Minus1)? / 2.0)?;
96 Ok(Self {
97 initialized,
98 cluster_usage,
99 embedding_sum,
100 embedding,
101 c2,
102 epsilon,
103 dim,
104 span_encode: tracing::span!(tracing::Level::TRACE, "euclidean-encode"),
105 span_decode: tracing::span!(tracing::Level::TRACE, "euclidean-encode"),
106 })
107 }
108
109 pub fn encode_very_slow(&self, xs: &Tensor) -> Result<Tensor> {
110 let _enter = self.span_encode.enter();
111 let mut target_shape = xs.dims().to_vec();
112 target_shape.pop();
113 let xs = xs.flatten_to(D::Minus2)?;
114 let _ = xs.dims2()?;
115 let cluster_usage = self.cluster_usage.maximum(self.epsilon)?.unsqueeze(1)?;
117 let embedding = self.embedding_sum.broadcast_div(&cluster_usage)?;
118 let diff = xs.unsqueeze(1)?.broadcast_sub(&embedding.unsqueeze(0)?)?;
120 let dists = diff.sqr()?.sum(D::Minus1)?;
121 let codes = dists.argmin(D::Minus1)?;
122 codes.reshape(target_shape)
123 }
124
125 pub fn encode_slow(&self, xs: &Tensor) -> Result<Tensor> {
126 let _enter = self.span_encode.enter();
127 let mut target_shape = xs.dims().to_vec();
128 target_shape.pop();
129 let xs = xs.flatten_to(D::Minus2)?;
130 let _ = xs.dims2()?;
131 let dot_prod = xs.matmul(&self.embedding.t()?)?;
132 let codes = self.c2.broadcast_sub(&dot_prod)?.argmin(D::Minus1)?;
133 codes.reshape(target_shape)
134 }
135
136 pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
137 let _enter = self.span_encode.enter();
138 let mut target_shape = xs.dims().to_vec();
139 target_shape.pop();
140 let xs = xs.flatten_to(D::Minus2)?;
141 let _ = xs.dims2()?;
142 let codes = Tensor::apply_op2(&xs, &self.embedding, CodebookEncode)?;
143 codes.reshape(target_shape)
144 }
145
146 pub fn decode(&self, indexes: &Tensor) -> Result<Tensor> {
147 let _enter = self.span_decode.enter();
148 let mut final_dims = indexes.dims().to_vec();
150 final_dims.push(self.dim);
151 let indexes = indexes.flatten_all()?;
152 let values = self.embedding.index_select(&indexes, 0)?;
153 let values = values.reshape(final_dims)?;
154 Ok(values)
155 }
156}
157
158#[allow(unused)]
159#[derive(Debug, Clone)]
160pub struct VectorQuantization {
161 project_in: Option<Linear>,
162 project_out: Option<Linear>,
163 codebook: EuclideanCodebook,
164}
165
166impl VectorQuantization {
167 pub fn new(
168 dim: usize,
169 codebook_size: usize,
170 codebook_dim: Option<usize>,
171 vb: VarBuilder,
172 ) -> Result<Self> {
173 let codebook_dim = codebook_dim.unwrap_or(dim);
174 let (project_in, project_out) = if codebook_dim == dim {
175 (None, None)
176 } else {
177 let p_in = linear(dim, codebook_dim, vb.pp("project_in"))?;
178 let p_out = linear(codebook_dim, dim, vb.pp("project_out"))?;
179 (Some(p_in), Some(p_out))
180 };
181 let codebook = EuclideanCodebook::new(codebook_dim, codebook_size, vb.pp("codebook"))?;
182 Ok(Self {
183 project_in,
184 project_out,
185 codebook,
186 })
187 }
188
189 pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
190 let xs = xs.t()?.apply(&self.project_in.as_ref())?;
191 self.codebook.encode_slow(&xs)
192 }
193
194 pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
195 let quantized = self.codebook.decode(codes)?;
196 let quantized = match &self.project_out {
197 None => quantized,
198 Some(p) => quantized.apply(p)?,
199 };
200 quantized.t()
201 }
202}
203
204#[derive(Debug, Clone)]
205pub struct ResidualVectorQuantization {
206 layers: Vec<VectorQuantization>,
207}
208
209impl ResidualVectorQuantization {
210 pub fn new(
211 n_q: usize,
212 dim: usize,
213 codebook_size: usize,
214 codebook_dim: Option<usize>,
215 vb: VarBuilder,
216 ) -> Result<Self> {
217 let vb = vb.pp("layers");
218 let mut layers = Vec::with_capacity(n_q);
219 for i in 0..n_q {
220 let layer = VectorQuantization::new(dim, codebook_size, codebook_dim, vb.pp(i))?;
221 layers.push(layer)
222 }
223 Ok(Self { layers })
224 }
225
226 pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
227 let mut codes = Vec::with_capacity(self.layers.len());
228 let mut residual = xs.clone();
229 for layer in self.layers.iter() {
230 let indices = layer.encode(&residual)?;
231 let quantized = layer.decode(&indices)?;
232 residual = (residual - quantized)?;
233 codes.push(indices)
234 }
235 Tensor::stack(&codes, 0)
236 }
237
238 pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
239 if self.layers.is_empty() {
240 candle::bail!("empty layers in ResidualVectorQuantization")
241 }
242 if self.layers.len() != xs.dim(0)? {
243 candle::bail!(
244 "mismatch between the number of layers {} and the code shape {:?}",
245 self.layers.len(),
246 xs.shape()
247 )
248 }
249 let mut quantized = self.layers[0].decode(&xs.i(0)?)?;
250 for (i, layer) in self.layers.iter().enumerate().skip(1) {
251 let xs = xs.i(i)?;
252 quantized = (quantized + layer.decode(&xs))?
253 }
254 Ok(quantized)
255 }
256}
257
258#[allow(unused)]
259#[derive(Debug, Clone)]
260pub struct ResidualVectorQuantizer {
261 vq: ResidualVectorQuantization,
262 input_proj: Option<candle_nn::Conv1d>,
263 output_proj: Option<candle_nn::Conv1d>,
264}
265
266impl ResidualVectorQuantizer {
267 pub fn new(
268 dim: usize,
269 input_dim: Option<usize>,
270 output_dim: Option<usize>,
271 n_q: usize,
272 bins: usize,
273 force_projection: bool,
274 vb: VarBuilder,
275 ) -> Result<Self> {
276 let input_dim = input_dim.unwrap_or(dim);
277 let output_dim = output_dim.unwrap_or(dim);
278
279 let input_proj = if input_dim == dim && !force_projection {
280 None
281 } else {
282 let c = candle_nn::conv1d_no_bias(
283 input_dim,
284 dim,
285 1,
286 Default::default(),
287 vb.pp("input_proj"),
288 )?;
289 Some(c)
290 };
291 let output_proj = if output_dim == dim && !force_projection {
292 None
293 } else {
294 let c = candle_nn::conv1d_no_bias(
295 dim,
296 output_dim,
297 1,
298 Default::default(),
299 vb.pp("output_proj"),
300 )?;
301 Some(c)
302 };
303
304 let vq = ResidualVectorQuantization::new(
305 n_q, dim, bins, None, vb,
306 )?;
307 Ok(Self {
308 vq,
309 input_proj,
310 output_proj,
311 })
312 }
313
314 pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
315 let codes = self.vq.encode(&xs.apply(&self.input_proj.as_ref())?)?;
316 codes.transpose(0, 1)
317 }
318
319 pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
320 let codes = codes.transpose(0, 1)?;
322 let quantized = self.vq.decode(&codes)?;
323 match &self.output_proj {
324 None => Ok(quantized),
325 Some(p) => quantized.apply(p),
326 }
327 }
328}
329
330#[derive(Debug, Clone)]
333pub struct SplitResidualVectorQuantizer {
334 rvq_first: ResidualVectorQuantizer,
335 rvq_rest: ResidualVectorQuantizer,
336 n_q: usize,
337 span_encode: tracing::Span,
338 span_decode: tracing::Span,
339}
340
341impl SplitResidualVectorQuantizer {
342 pub fn new(
343 dim: usize,
344 input_dim: Option<usize>,
345 output_dim: Option<usize>,
346 n_q: usize,
347 bins: usize,
348 vb: VarBuilder,
349 ) -> Result<Self> {
350 let rvq_first = ResidualVectorQuantizer::new(
351 dim,
352 input_dim,
353 output_dim,
354 1,
355 bins,
356 true,
357 vb.pp("semantic_residual_vector_quantizer"),
358 )?;
359 let rvq_rest = ResidualVectorQuantizer::new(
360 dim,
361 input_dim,
362 output_dim,
363 n_q - 1,
364 bins,
365 true,
366 vb.pp("acoustic_residual_vector_quantizer"),
367 )?;
368 let span_encode = tracing::span!(tracing::Level::TRACE, "split-rvq-encode");
369 let span_decode = tracing::span!(tracing::Level::TRACE, "split-rvq-decode");
370 Ok(Self {
371 rvq_first,
372 rvq_rest,
373 n_q,
374 span_encode,
375 span_decode,
376 })
377 }
378
379 pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
380 let _enter = self.span_encode.enter();
381 let codes = self.rvq_first.encode(xs)?;
382 if self.n_q > 1 {
383 let rest_codes = self.rvq_rest.encode(xs)?;
387 Tensor::cat(&[codes, rest_codes], 1)
388 } else {
389 Ok(codes)
390 }
391 }
392
393 pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
394 let _enter = self.span_decode.enter();
396 let quantized = self.rvq_first.decode(&codes.i((.., ..1))?)?;
397 let quantized = if self.n_q > 1 {
398 (quantized + self.rvq_rest.decode(&codes.i((.., 1..))?))?
399 } else {
400 quantized
401 };
402 Ok(quantized)
403 }
404}