candle_transformers/models/
metavoice.rs

1//! MetaVoice Studio ML Models
2//!
3//! See MetaVoice's TTS and voice cloning models:
4//! - [Github](https://github.com/metavoiceio/metavoice-src)
5//! - [Website](https://studio.metavoice.ai/)
6
7use candle::{DType, Device, Error as E, IndexOp, Module, Result, Tensor, D};
8use candle_nn::{embedding, linear_b, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
9
10// Equivalent to torch.repeat_interleave
11pub(crate) fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> {
12    let img = img.unsqueeze(dim + 1)?;
13    let mut dims = img.dims().to_vec();
14    dims[dim + 1] = repeats;
15    img.broadcast_as(dims)?.flatten(dim, dim + 1)
16}
17pub mod speaker_encoder {
18    use super::*;
19
20    #[derive(Debug, Clone, serde::Deserialize)]
21    pub struct Config {
22        pub sampling_rate: usize,
23        pub partial_n_frames: usize,
24        pub model_hidden_size: usize,
25        pub model_embedding_size: usize,
26        pub model_num_layers: usize,
27        pub mel_window_length: usize,
28        pub mel_window_step: usize,
29        pub mel_n_channels: usize,
30    }
31
32    impl Config {
33        pub fn cfg() -> Self {
34            Self {
35                sampling_rate: 16_000,
36                partial_n_frames: 160,
37                model_hidden_size: 256,
38                model_embedding_size: 256,
39                model_num_layers: 3,
40                mel_window_length: 25,
41                mel_window_step: 10,
42                mel_n_channels: 40,
43            }
44        }
45    }
46
47    pub struct Model {
48        lstms: Vec<candle_nn::LSTM>,
49        linear: Linear,
50        cfg: Config,
51    }
52
53    type Slice = (usize, usize);
54
55    impl Model {
56        pub fn new(cfg: Config, vb: VarBuilder) -> Result<Self> {
57            let mut lstms = Vec::with_capacity(cfg.model_num_layers);
58            let vb_l = vb.pp("lstm");
59            for layer_idx in 0..cfg.model_num_layers {
60                let c = candle_nn::LSTMConfig {
61                    layer_idx,
62                    ..Default::default()
63                };
64                let lstm = candle_nn::lstm(
65                    cfg.mel_n_channels,
66                    cfg.model_hidden_size,
67                    c,
68                    vb_l.pp(layer_idx),
69                )?;
70                lstms.push(lstm)
71            }
72            let linear = linear_b(
73                cfg.model_hidden_size,
74                cfg.model_embedding_size,
75                true,
76                vb.pp("linear"),
77            )?;
78            Ok(Self { lstms, linear, cfg })
79        }
80
81        fn compute_partial_slices(
82            &self,
83            n_samples: usize,
84            rate: f64,
85            min_coverage: f64,
86        ) -> (Vec<Slice>, Vec<Slice>) {
87            let c = &self.cfg;
88            // Compute how many frames separate two partial utterances
89            let samples_per_frame = c.sampling_rate * c.mel_window_step / 1000;
90            let n_frames = n_samples / samples_per_frame + 1;
91            let frame_step =
92                (c.sampling_rate as f64 / rate / samples_per_frame as f64).round() as usize;
93            let steps = (n_frames + frame_step).saturating_sub(c.partial_n_frames) + 1;
94            // Compute the slices.
95            let mut wav_slices = vec![];
96            let mut mel_slices = vec![];
97            for i in (0..steps).step_by(frame_step) {
98                let mel_range = (i, i + c.partial_n_frames);
99                let wav_range = (
100                    i * samples_per_frame,
101                    (i + c.partial_n_frames) * samples_per_frame,
102                );
103                mel_slices.push(mel_range);
104                wav_slices.push(wav_range);
105            }
106            // Evaluate whether extra padding is warranted or not.
107            let last_wav_range = match wav_slices.last() {
108                None => return (wav_slices, mel_slices),
109                Some(l) => *l,
110            };
111            let coverage = (n_samples - last_wav_range.0) as f64
112                / (last_wav_range.1 - last_wav_range.0) as f64;
113            if coverage > min_coverage && mel_slices.len() > 1 {
114                mel_slices.pop();
115                wav_slices.pop();
116            }
117            (wav_slices, mel_slices)
118        }
119
120        pub fn embed_utterance(
121            &self,
122            wav: &[f32],
123            mel_filters: &[f32],
124            rate: f64,
125            min_c: f64,
126            device: &Device,
127        ) -> Result<Tensor> {
128            let (wav_slices, mel_slices) = self.compute_partial_slices(wav.len(), rate, min_c);
129            let max_wave_length = match wav_slices.last() {
130                Some(v) => v.1,
131                None => candle::bail!("empty wav slices"),
132            };
133            let wav = if max_wave_length > wav.len() {
134                let mut wav = wav.to_vec();
135                wav.resize(max_wave_length - wav.len(), 0.0);
136                std::borrow::Cow::Owned(wav)
137            } else {
138                std::borrow::Cow::Borrowed(wav)
139            };
140            let mel = crate::models::whisper::audio::log_mel_spectrogram_(
141                wav.as_ref(),
142                mel_filters,
143                /* fft_size */ self.cfg.mel_window_length,
144                /* fft_step */ self.cfg.mel_window_step,
145                self.cfg.mel_n_channels,
146                false,
147            );
148            let mels = mel_slices
149                .iter()
150                .flat_map(|s| [mel[s.0], mel[s.1]])
151                .collect::<Vec<_>>();
152            let mels = Tensor::from_vec(mels, (mel_slices.len(), 2), device)?;
153            let partial_embeds = self.forward(&mels)?;
154            let raw_embed = partial_embeds.mean(0)?;
155            let norm = raw_embed.sqr()?.sum_all()?.sqrt()?;
156            raw_embed.broadcast_div(&norm)
157        }
158    }
159
160    impl Module for Model {
161        fn forward(&self, xs: &Tensor) -> Result<Tensor> {
162            use candle_nn::RNN;
163
164            // This is different from the Python transformers version as candle LSTM is batch first.
165            let xs = xs.t()?;
166            let mut xs = xs.clone();
167            for layer in self.lstms.iter() {
168                let states = layer.seq(&xs)?;
169                xs = layer.states_to_tensor(&states)?;
170            }
171            let xs = xs.t()?;
172            let embeds_raw = xs.apply(&self.linear)?.relu()?;
173            let norm = embeds_raw.sqr()?.sum_keepdim(1)?.sqrt()?;
174            embeds_raw.broadcast_div(&norm)
175        }
176    }
177}
178
179type Rank = u32;
180
181pub mod tokenizers {
182    use super::*;
183    use std::collections::HashMap;
184
185    pub struct BPE {
186        pub re: fancy_regex::Regex,
187        pub end_of_text: usize,
188        pub offset: usize,
189        pub ranks: HashMap<Vec<u8>, Rank>,
190        span: tracing::Span,
191    }
192
193    impl BPE {
194        pub fn from_json(json: &serde_json::Value, end_of_text: usize) -> Result<Self> {
195            let json = match json.as_object() {
196                None => candle::bail!("json value is not an object"),
197                Some(json) => json,
198            };
199            let re = match json.get("pat_str") {
200                None => candle::bail!("json object has no pat_str field"),
201                Some(pat_str) => match pat_str.as_str() {
202                    None => candle::bail!("pat_str field is not a string"),
203                    Some(pat_str) => fancy_regex::Regex::new(pat_str).map_err(E::wrap)?,
204                },
205            };
206            let offset = match json.get("offset") {
207                None => candle::bail!("json object has no offset field"),
208                Some(offset) => match offset.as_u64() {
209                    None => candle::bail!("offset field is not a positive int"),
210                    Some(offset) => offset as usize,
211                },
212            };
213            let mut ranks = HashMap::new();
214            for id in 0u8..=255 {
215                ranks.insert(vec![id], id as u32);
216            }
217            let mergeable_ranks = match json.get("mergeable_ranks") {
218                None => candle::bail!("json object has no mergeable_ranks field"),
219                Some(mr) => match mr.as_object() {
220                    None => candle::bail!("mergeable_ranks is not an object"),
221                    Some(mr) => mr,
222                },
223            };
224            for (key, value) in mergeable_ranks.iter() {
225                let value = match value.as_u64() {
226                    None => candle::bail!("mergeable_ranks '{key}' is not a u64"),
227                    Some(value) => value as u32,
228                };
229                if value < 256 {
230                    continue;
231                }
232                // No escaping for other keys.
233                let key = key.as_bytes().to_vec();
234                ranks.insert(key, value);
235            }
236            Ok(Self {
237                re,
238                end_of_text,
239                offset,
240                ranks,
241                span: tracing::span!(tracing::Level::TRACE, "bpe"),
242            })
243        }
244
245        // Taken from:
246        // https://github.com/openai/tiktoken/blob/1b9faf2779855124f05174adf1383e53689ed94b/src/lib.rs#L16C1-L82C2
247        fn _byte_pair_merge(&self, piece: &[u8]) -> Vec<(usize, Rank)> {
248            // This is a vector of (start, rank).
249            // The rank is of the pair starting at position start.
250            let mut parts = Vec::with_capacity(piece.len() + 1);
251
252            // Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE
253            // the way we currently do, this is equivalent. An easy way to break this would be to decouple
254            // merge priority from token index or to prevent specific token merges.
255            let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
256            for i in 0..piece.len() - 1 {
257                let rank = *self.ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
258                if rank < min_rank.0 {
259                    min_rank = (rank, i);
260                }
261                parts.push((i, rank));
262            }
263            parts.push((piece.len() - 1, Rank::MAX));
264            parts.push((piece.len(), Rank::MAX));
265
266            let get_rank = {
267                #[inline(always)]
268                |parts: &Vec<(usize, Rank)>, i: usize| {
269                    if (i + 3) < parts.len() {
270                        // Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted
271                        // parts[i + 1], see comment in the main loop.
272                        *self
273                            .ranks
274                            .get(&piece[parts[i].0..parts[i + 3].0])
275                            .unwrap_or(&Rank::MAX)
276                    } else {
277                        Rank::MAX
278                    }
279                }
280            };
281
282            // If you have n parts and m merges, this does O(mn) work.
283            // We could do something with a heap and do O(m log n) work.
284            // n is often very small so considerations like cache-locality outweigh the algorithmic
285            // complexity downsides of the `parts` vector.
286            while min_rank.0 != Rank::MAX {
287                let i = min_rank.1;
288                // Update parts[i] and parts[i - 1] before removing parts[i + 1], since
289                // `parts.remove(i + 1)` will thrash the cache.
290                if i > 0 {
291                    parts[i - 1].1 = get_rank(&parts, i - 1);
292                }
293                parts[i].1 = get_rank(&parts, i);
294                parts.remove(i + 1);
295
296                min_rank = (Rank::MAX, usize::MAX);
297                for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
298                    if rank < min_rank.0 {
299                        min_rank = (rank, i);
300                    }
301                }
302            }
303            parts
304        }
305
306        pub fn byte_pair_encode(&self, piece: &[u8]) -> Vec<Rank> {
307            if piece.is_empty() {
308                return Vec::new();
309            }
310            if piece.len() == 1 {
311                return vec![self.ranks[piece]];
312            }
313            assert!(piece.len() > 1);
314            self._byte_pair_merge(piece)
315                .windows(2)
316                .map(|part| self.ranks[&piece[part[0].0..part[1].0]])
317                .collect()
318        }
319
320        pub fn encode(&self, text: &str) -> Result<Vec<u32>> {
321            let _enter = self.span.enter();
322            let mut bpe_tokens: Vec<u32> = Vec::new();
323            for word in self.re.find_iter(text) {
324                let word = word.map_err(E::wrap)?;
325                let word_tokens = self.byte_pair_encode(word.as_str().as_bytes());
326                for &token in word_tokens.iter() {
327                    bpe_tokens.push(token + self.offset as u32)
328                }
329            }
330            bpe_tokens.push((self.end_of_text + self.offset) as u32);
331            Ok(bpe_tokens)
332        }
333    }
334}
335
336pub mod gpt {
337    use super::*;
338
339    #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
340    pub enum NormType {
341        LayerNorm,
342        RMSNorm,
343    }
344
345    #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
346    pub enum AttnKernelType {
347        Fa2,
348        TorchAttn,
349        Hand,
350    }
351
352    #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
353    pub enum NonLinearityType {
354        Gelu,
355        Swiglu,
356    }
357
358    enum Norm {
359        RMSNorm(candle_nn::RmsNorm),
360        LayerNorm(candle_nn::LayerNorm),
361    }
362
363    // https://github.com/metavoiceio/metavoice-src/blob/11550bb4e8a1ad032cc1556cc924f7a4e767cbfa/fam/llm/model.py#L27
364    #[derive(Debug, Clone)]
365    pub struct Config {
366        pub block_size: usize,
367        pub vocab_sizes: Vec<usize>,
368        pub target_vocab_sizes: Vec<usize>,
369        pub n_layer: usize,
370        pub n_head: usize,
371        pub n_embd: usize,
372        pub bias: bool,
373        pub causal: bool,
374        pub spk_emb_on_text: bool,
375        pub norm_type: NormType,
376        pub rmsnorm_eps: f64,
377        pub nonlinearity_type: NonLinearityType,
378        pub swiglu_multiple_of: Option<usize>,
379        pub attn_kernel_type: AttnKernelType,
380        pub kv_cache_enabled: bool,
381    }
382
383    impl Config {
384        pub fn cfg1b_v0_1() -> Self {
385            Self {
386                n_layer: 6,
387                n_head: 6,
388                n_embd: 384,
389                block_size: 1024,
390                bias: false,
391                vocab_sizes: vec![1538, 1025],
392                causal: false,
393                target_vocab_sizes: vec![1025, 1025, 1025, 1025, 1025, 1025],
394                swiglu_multiple_of: Some(256),
395                norm_type: NormType::LayerNorm,
396                kv_cache_enabled: false,
397                attn_kernel_type: AttnKernelType::TorchAttn,
398                spk_emb_on_text: true,
399                nonlinearity_type: NonLinearityType::Gelu,
400                rmsnorm_eps: 1e-5,
401            }
402        }
403    }
404
405    impl Norm {
406        fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
407            match cfg.norm_type {
408                NormType::RMSNorm => {
409                    let rms_norm = candle_nn::rms_norm(cfg.n_embd, cfg.rmsnorm_eps, vb)?;
410                    Ok(Self::RMSNorm(rms_norm))
411                }
412                NormType::LayerNorm => {
413                    let ln_cfg = candle_nn::LayerNormConfig {
414                        affine: cfg.bias,
415                        ..Default::default()
416                    };
417                    let layer_norm = candle_nn::layer_norm(cfg.n_embd, ln_cfg, vb)?;
418                    Ok(Self::LayerNorm(layer_norm))
419                }
420            }
421        }
422    }
423
424    impl Module for Norm {
425        fn forward(&self, xs: &Tensor) -> Result<Tensor> {
426            match self {
427                Self::RMSNorm(m) => m.forward(xs),
428                Self::LayerNorm(m) => m.forward(xs),
429            }
430        }
431    }
432
433    // https://github.com/metavoiceio/metavoice-src/blob/11550bb4e8a1ad032cc1556cc924f7a4e767cbfa/fam/llm/layers/attn.py#L18
434    struct SelfAttention {
435        c_attn: Linear,
436        c_proj: Linear,
437        n_head: usize,
438        span: tracing::Span,
439    }
440
441    impl SelfAttention {
442        fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
443            // The different attention variants are likely to be identical but still we only accept
444            // TorchAttn for now.
445            if cfg.attn_kernel_type != AttnKernelType::TorchAttn {
446                candle::bail!("only TorchAttn is supported")
447            }
448            if cfg.kv_cache_enabled {
449                candle::bail!("kv_cache_enabled=true is not supported")
450            }
451            let c_attn = linear_b(cfg.n_embd, cfg.n_embd * 3, cfg.bias, vb.pp("c_attn"))?;
452            let c_proj = linear_b(cfg.n_embd, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?;
453            Ok(Self {
454                c_attn,
455                c_proj,
456                n_head: cfg.n_head,
457                span: tracing::span!(tracing::Level::TRACE, "self-attn"),
458            })
459        }
460    }
461
462    impl Module for SelfAttention {
463        fn forward(&self, xs: &Tensor) -> Result<Tensor> {
464            let _enter = self.span.enter();
465            let (b, t, c) = xs.dims3()?;
466            let c_x = xs
467                .apply(&self.c_attn)?
468                .reshape((b, t, 3, self.n_head, c / self.n_head))?;
469            let q = c_x.i((.., .., 0))?;
470            let k = c_x.i((.., .., 1))?;
471            let v = c_x.i((.., .., 2))?;
472            let q = q.transpose(1, 2)?.contiguous()?;
473            let k = k.transpose(1, 2)?.contiguous()?;
474            let v = v.transpose(1, 2)?.contiguous()?;
475            let att = (q.matmul(&k.t()?)? / (k.dim(D::Minus1)? as f64).sqrt())?;
476            // TODO: causal mask
477            let att = candle_nn::ops::softmax_last_dim(&att)?;
478            let att = att.matmul(&v)?.transpose(1, 2)?;
479            att.reshape((b, t, c))?.apply(&self.c_proj)
480        }
481    }
482
483    // https://github.com/metavoiceio/metavoice-src/blob/11550bb4e8a1ad032cc1556cc924f7a4e767cbfa/fam/llm/layers/layers.py#L43
484    #[allow(clippy::upper_case_acronyms)]
485    enum MLP {
486        Gelu {
487            c_fc: Linear,
488            c_proj: Linear,
489            span: tracing::Span,
490        },
491        Swiglu {
492            w1: Linear,
493            w3: Linear,
494            c_proj: Linear,
495            span: tracing::Span,
496        },
497    }
498
499    impl MLP {
500        fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
501            let hidden_dim = 4 * cfg.n_embd;
502            let slf = match cfg.nonlinearity_type {
503                NonLinearityType::Gelu => {
504                    let c_fc = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("c_fc"))?;
505                    let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?;
506                    Self::Gelu {
507                        c_fc,
508                        c_proj,
509                        span: tracing::span!(tracing::Level::TRACE, "mlp-gelu"),
510                    }
511                }
512                NonLinearityType::Swiglu => {
513                    let hidden_dim = (2 * hidden_dim) / 3;
514                    let swiglu_multiple_of = match cfg.swiglu_multiple_of {
515                        None => candle::bail!("swiglu-multiple-of has to be set"),
516                        Some(smo) => smo,
517                    };
518                    let hidden_dim = swiglu_multiple_of * (hidden_dim + swiglu_multiple_of - 1)
519                        / swiglu_multiple_of;
520                    let w1 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w1"))?;
521                    let w3 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w3"))?;
522                    let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?;
523                    Self::Swiglu {
524                        w1,
525                        w3,
526                        c_proj,
527                        span: tracing::span!(tracing::Level::TRACE, "mlp-swiglu"),
528                    }
529                }
530            };
531            Ok(slf)
532        }
533    }
534
535    impl Module for MLP {
536        fn forward(&self, xs: &Tensor) -> Result<Tensor> {
537            match self {
538                Self::Gelu { c_fc, c_proj, span } => {
539                    let _enter = span.enter();
540                    xs.apply(c_fc)?.gelu()?.apply(c_proj)
541                }
542                Self::Swiglu {
543                    w1,
544                    w3,
545                    c_proj,
546                    span,
547                } => {
548                    let _enter = span.enter();
549                    let w1 = xs.apply(w1)?;
550                    let w3 = xs.apply(w3)?;
551                    (w1.silu()? * w3)?.apply(c_proj)
552                }
553            }
554        }
555    }
556
557    // https://github.com/metavoiceio/metavoice-src/blob/11550bb4e8a1ad032cc1556cc924f7a4e767cbfa/fam/llm/layers/combined.py#L7
558    struct Block {
559        ln_1: Norm,
560        ln_2: Norm,
561        attn: SelfAttention,
562        mlp: MLP,
563        span: tracing::Span,
564    }
565
566    impl Block {
567        fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
568            let ln_1 = Norm::new(cfg, vb.pp("ln_1"))?;
569            let ln_2 = Norm::new(cfg, vb.pp("ln_2"))?;
570            let attn = SelfAttention::new(cfg, vb.pp("attn"))?;
571            let mlp = MLP::new(cfg, vb.pp("mlp"))?;
572            Ok(Block {
573                ln_1,
574                ln_2,
575                attn,
576                mlp,
577                span: tracing::span!(tracing::Level::TRACE, "gpt-block"),
578            })
579        }
580    }
581
582    impl Module for Block {
583        fn forward(&self, xs: &Tensor) -> Result<Tensor> {
584            let _enter = self.span.enter();
585            let xs = (xs + xs.apply(&self.ln_1)?.apply(&self.attn))?;
586            let xs = (&xs + xs.apply(&self.ln_2)?.apply(&self.mlp))?;
587            Ok(xs)
588        }
589    }
590
591    // https://github.com/metavoiceio/metavoice-src/blob/11550bb4e8a1ad032cc1556cc924f7a4e767cbfa/fam/llm/model.py#L79
592    #[allow(clippy::upper_case_acronyms)]
593    pub struct Model {
594        wtes: Vec<candle_nn::Embedding>,
595        wpe: candle_nn::Embedding,
596        h: Vec<Block>,
597        ln_f: Norm,
598        lm_heads: Vec<Linear>,
599        cfg: Config,
600        dtype: DType,
601        span: tracing::Span,
602    }
603
604    impl Model {
605        pub fn new(cfg: Config, vb: VarBuilder) -> Result<Self> {
606            let vb_t = vb.pp("transformer");
607            let ln_f = Norm::new(&cfg, vb_t.pp("ln_f"))?;
608            let mut wtes = Vec::with_capacity(cfg.vocab_sizes.len());
609            let vb_w = vb_t.pp("wtes");
610            for (idx, vocab_size) in cfg.vocab_sizes.iter().enumerate() {
611                let wte = candle_nn::embedding(*vocab_size, cfg.n_embd, vb_w.pp(idx))?;
612                wtes.push(wte)
613            }
614            let wpe = candle_nn::embedding(cfg.block_size, cfg.n_embd, vb_t.pp("wpe"))?;
615
616            let mut h = Vec::with_capacity(cfg.n_layer);
617            let vb_h = vb_t.pp("h");
618            for idx in 0..cfg.n_layer {
619                let block = Block::new(&cfg, vb_h.pp(idx))?;
620                h.push(block)
621            }
622
623            let mut lm_heads = Vec::with_capacity(cfg.target_vocab_sizes.len());
624            let vb_l = vb.pp("lm_heads");
625            for (idx, vocab_size) in cfg.target_vocab_sizes.iter().enumerate() {
626                let head = linear_b(cfg.n_embd, *vocab_size, false, vb_l.pp(idx))?;
627                lm_heads.push(head)
628            }
629            Ok(Self {
630                wtes,
631                wpe,
632                h,
633                ln_f,
634                lm_heads,
635                cfg,
636                dtype: vb.dtype(),
637                span: tracing::span!(tracing::Level::TRACE, "gpt"),
638            })
639        }
640
641        pub fn config(&self) -> &Config {
642            &self.cfg
643        }
644
645        pub fn forward(&self, idx: &Tensor) -> Result<Vec<Tensor>> {
646            let _enter = self.span.enter();
647            let device = idx.device();
648            let (b, _num_hierarchies, t) = idx.dims3()?;
649            let pos = Tensor::arange(0u32, t as u32, device)?;
650            let pos_emb = pos.apply(&self.wpe)?;
651            let mut tok_emb = Tensor::zeros((b, t, self.cfg.n_embd), self.dtype, device)?;
652            for (wte_idx, wte) in self.wtes.iter().enumerate() {
653                let emb = idx.i((.., wte_idx, ..))?.apply(wte)?;
654                tok_emb = (tok_emb + emb)?;
655            }
656            // TODO: speaker embs.
657            let spk_emb = 0f64;
658            let mut xs = (pos_emb.broadcast_add(&tok_emb)? + spk_emb)?;
659            for block in self.h.iter() {
660                xs = xs.apply(block)?
661            }
662            let xs = xs.apply(&self.ln_f)?;
663            let mut logits = Vec::with_capacity(self.lm_heads.len());
664            for lm_head in self.lm_heads.iter() {
665                // non-causal mode only.
666                let ys = xs.apply(lm_head)?;
667                logits.push(ys)
668            }
669            Ok(logits)
670        }
671    }
672}
673
674pub mod transformer {
675    use super::*;
676
677    #[derive(Debug, Clone, serde::Deserialize)]
678    pub struct Config {
679        pub block_size: usize,
680        pub vocab_size: usize,
681        pub n_layer: usize,
682        pub n_head: usize,
683        pub dim: usize,
684        pub speaker_emb_dim: usize,
685        pub intermediate_size: Option<usize>,
686        pub n_local_heads: Option<usize>,
687        pub norm_eps: f64,
688    }
689
690    impl Config {
691        pub fn cfg1b_v0_1() -> Self {
692            Self {
693                n_layer: 24,
694                n_head: 16,
695                dim: 2048,
696                vocab_size: 2562,
697                speaker_emb_dim: 256,
698                block_size: 2048,
699                intermediate_size: None,
700                n_local_heads: None,
701                norm_eps: 1e-5,
702            }
703        }
704
705        pub(crate) fn n_local_heads(&self) -> usize {
706            self.n_local_heads.unwrap_or(self.n_head)
707        }
708
709        pub(crate) fn head_dim(&self) -> usize {
710            self.dim / self.n_head
711        }
712
713        pub(crate) fn intermediate_size(&self) -> usize {
714            match self.intermediate_size {
715                Some(intermediate_size) => intermediate_size,
716                None => {
717                    let hidden_dim = self.dim * 4;
718                    let n_hidden = ((2 * hidden_dim) as f64 / 3.) as usize;
719                    (n_hidden + 255) / 256 * 256
720                }
721            }
722        }
723    }
724
725    #[derive(Debug, Clone)]
726    struct FeedForward {
727        w1: Linear,
728        w2: Linear,
729        w3: Linear,
730        span: tracing::Span,
731    }
732
733    impl FeedForward {
734        fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
735            let i_size = cfg.intermediate_size();
736            let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?;
737            let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?;
738            let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?;
739            Ok(Self {
740                w1,
741                w2,
742                w3,
743                span: tracing::span!(tracing::Level::TRACE, "feed-forward"),
744            })
745        }
746    }
747
748    impl Module for FeedForward {
749        fn forward(&self, xs: &Tensor) -> Result<Tensor> {
750            let _enter = self.span.enter();
751            let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?;
752            swiglu.apply(&self.w2)
753        }
754    }
755
756    #[derive(Debug, Clone)]
757    struct Attention {
758        wqkv: Linear,
759        wo: Linear,
760        dim: usize,
761        kv_size: usize,
762        n_local_heads: usize,
763        head_dim: usize,
764        n_head: usize,
765        kv_cache: Option<(Tensor, Tensor)>,
766        span: tracing::Span,
767    }
768
769    impl Attention {
770        fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
771            let n_local_heads = cfg.n_local_heads();
772            let head_dim = cfg.head_dim();
773            let total_head_dim = (cfg.n_head + 2 * n_local_heads) * head_dim;
774            let wqkv = linear_b(cfg.dim, total_head_dim, false, vb.pp("wqkv"))?;
775            let wo = linear_b(cfg.dim, cfg.dim, false, vb.pp("wo"))?;
776            Ok(Self {
777                wqkv,
778                wo,
779                dim: cfg.dim,
780                kv_size: n_local_heads * head_dim,
781                n_local_heads,
782                head_dim,
783                n_head: cfg.n_head,
784                kv_cache: None,
785                span: tracing::span!(tracing::Level::TRACE, "feed-forward"),
786            })
787        }
788
789        fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result<Tensor> {
790            let _enter = self.span.enter();
791            let (b_sz, seqlen, _) = xs.dims3()?;
792
793            let qkv = xs.apply(&self.wqkv)?;
794            let q = qkv.narrow(D::Minus1, 0, self.dim)?;
795            let k = qkv.narrow(D::Minus1, self.dim, self.kv_size)?;
796            let v = qkv.narrow(D::Minus1, self.dim + self.kv_size, self.kv_size)?;
797            let q = q
798                .reshape((b_sz, seqlen, self.n_head, self.head_dim))?
799                .transpose(1, 2)?
800                .contiguous()?;
801            let k = k
802                .reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))?
803                .transpose(1, 2)?;
804            let v = v
805                .reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))?
806                .transpose(1, 2)?;
807
808            let (k, v) = match &self.kv_cache {
809                None => (k, v),
810                Some((prev_k, prev_v)) => {
811                    let k = Tensor::cat(&[prev_k, &k], 2)?;
812                    let v = Tensor::cat(&[prev_v, &v], 2)?;
813                    (k, v)
814                }
815            };
816            self.kv_cache = Some((k.clone(), v.clone()));
817
818            let k = repeat_interleave(&k, self.n_head / self.n_local_heads, 1)?;
819            let v = repeat_interleave(&v, self.n_head / self.n_local_heads, 1)?;
820
821            let scale = 1f64 / f64::sqrt(self.head_dim as f64);
822            let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
823
824            let attn_weights = attn_weights.broadcast_add(mask)?;
825            let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
826            let attn_output = attn_weights.matmul(&v)?;
827            attn_output
828                .transpose(1, 2)?
829                .reshape((b_sz, seqlen, self.dim))?
830                .apply(&self.wo)
831        }
832
833        fn clear_kv_cache(&mut self) {
834            self.kv_cache = None
835        }
836    }
837
838    #[derive(Debug, Clone)]
839    struct Block {
840        attention: Attention,
841        feed_forward: FeedForward,
842        ffn_norm: RmsNorm,
843        attention_norm: RmsNorm,
844        span: tracing::Span,
845    }
846
847    impl Block {
848        fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
849            let attention = Attention::new(cfg, vb.pp("attention"))?;
850            let feed_forward = FeedForward::new(cfg, vb.pp("feed_forward"))?;
851            let ffn_norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("ffn_norm"))?;
852            let attention_norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("attention_norm"))?;
853            Ok(Self {
854                attention,
855                feed_forward,
856                ffn_norm,
857                attention_norm,
858                span: tracing::span!(tracing::Level::TRACE, "block"),
859            })
860        }
861
862        fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result<Tensor> {
863            let _enter = self.span.enter();
864            let hs = xs.apply(&self.attention_norm)?;
865            let hs = (xs + self.attention.forward(&hs, pos, mask))?;
866            &hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward)
867        }
868
869        fn clear_kv_cache(&mut self) {
870            self.attention.clear_kv_cache()
871        }
872    }
873
874    #[derive(Debug, Clone)]
875    pub struct Model {
876        tok_embeddings: Embedding,
877        pos_embeddings: Embedding,
878        speaker_cond_pos: Linear,
879        layers: Vec<Block>,
880        norm: RmsNorm,
881        output: Linear,
882        spk_cond_mask: Tensor,
883        span: tracing::Span,
884    }
885
886    impl Model {
887        pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
888            let tok_embeddings = embedding(cfg.vocab_size, cfg.dim, vb.pp("tok_embeddings"))?;
889            let pos_embeddings = embedding(cfg.block_size, cfg.dim, vb.pp("pos_embeddings"))?;
890            let speaker_cond_pos = linear_b(
891                cfg.speaker_emb_dim,
892                cfg.dim,
893                false,
894                vb.pp("speaker_cond_pos"),
895            )?;
896            let mut layers = Vec::with_capacity(cfg.n_layer);
897            let vb_l = vb.pp("layers");
898            for layer_idx in 0..cfg.n_layer {
899                let layer = Block::new(cfg, vb_l.pp(layer_idx))?;
900                layers.push(layer)
901            }
902            let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("norm"))?;
903            let output = linear_b(cfg.dim, cfg.vocab_size, false, vb.pp("output"))?;
904            let dtype = vb.dtype();
905            let spk_cond_mask = Tensor::cat(
906                &[
907                    Tensor::ones((1, 1, cfg.dim), dtype, vb.device())?,
908                    Tensor::zeros((1, 1, cfg.dim), dtype, vb.device())?,
909                ],
910                0,
911            )?;
912            Ok(Self {
913                tok_embeddings,
914                pos_embeddings,
915                speaker_cond_pos,
916                layers,
917                norm,
918                output,
919                spk_cond_mask,
920                span: tracing::span!(tracing::Level::TRACE, "transformer"),
921            })
922        }
923
924        pub fn clear_kv_cache(&mut self) {
925            for layer in self.layers.iter_mut() {
926                layer.clear_kv_cache()
927            }
928        }
929
930        pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result<Tensor> {
931            let _enter = self.span.enter();
932            let (_b_sz, seqlen) = xs.dims2()?;
933            let mask: Vec<_> = (0..seqlen)
934                .flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
935                .collect();
936            let mask = Tensor::from_slice(&mask, (1, 1, seqlen, seqlen), xs.device())?;
937            let input_pos = Tensor::arange(pos as u32, (pos + seqlen) as u32, xs.device())?;
938            let tok_embeddings = xs.apply(&self.tok_embeddings)?;
939            let pos_embeddings = input_pos.apply(&self.pos_embeddings)?;
940            let mut xs = tok_embeddings
941                .broadcast_add(&pos_embeddings)?
942                .broadcast_add(
943                    &spk_emb
944                        .apply(&self.speaker_cond_pos)?
945                        .broadcast_mul(&self.spk_cond_mask)?,
946                )?;
947            let mask = mask.to_dtype(xs.dtype())?;
948            for layer in self.layers.iter_mut() {
949                xs = layer.forward(&xs, pos, &mask)?
950            }
951            xs.narrow(1, seqlen - 1, 1)?
952                .apply(&self.norm)?
953                .apply(&self.output)
954        }
955    }
956}
957
958pub mod adapters {
959    // https://github.com/metavoiceio/metavoice-src/blob/9078234c496d76adbec06df789b6b04b1875f129/fam/llm/adapters/tilted_encodec.py
960    pub struct TiltedEncodec {
961        end_of_audio_token: u32,
962        span: tracing::Span,
963    }
964
965    impl TiltedEncodec {
966        pub fn new(end_of_audio_token: u32) -> Self {
967            Self {
968                end_of_audio_token,
969                span: tracing::span!(tracing::Level::TRACE, "tilted-encodec"),
970            }
971        }
972
973        pub fn decode(&self, tokens: &[Vec<u32>]) -> (Vec<u32>, Vec<Vec<u32>>) {
974            let _enter = self.span.enter();
975            let mut text_ids = vec![];
976            let mut extracted_audio_ids = vec![];
977            let mut min_audio_ids_len = usize::MAX;
978            for (book_id, tokens) in tokens.iter().enumerate() {
979                let mut audio_ids = vec![];
980                for &t in tokens.iter() {
981                    #[allow(clippy::comparison_chain)]
982                    if t > self.end_of_audio_token {
983                        if book_id == 0 {
984                            text_ids.push(t)
985                        }
986                    } else if t < self.end_of_audio_token {
987                        audio_ids.push(t)
988                    }
989                }
990                min_audio_ids_len = usize::min(min_audio_ids_len, audio_ids.len());
991                extracted_audio_ids.push(audio_ids)
992            }
993            for audio_ids in extracted_audio_ids.iter_mut() {
994                audio_ids.truncate(min_audio_ids_len)
995            }
996            (text_ids, extracted_audio_ids)
997        }
998    }
999
1000    // https://github.com/metavoiceio/metavoice-src/blob/9078234c496d76adbec06df789b6b04b1875f129/fam/llm/adapters/flattened_encodec.py#L4
1001    pub struct FlattenedInterleavedEncodec2Codebook {
1002        end_of_audio_token: u32,
1003        span: tracing::Span,
1004    }
1005
1006    impl FlattenedInterleavedEncodec2Codebook {
1007        pub fn new(end_of_audio_token: u32) -> Self {
1008            Self {
1009                end_of_audio_token,
1010                span: tracing::span!(tracing::Level::TRACE, "encodec2codebook"),
1011            }
1012        }
1013
1014        pub fn decode(&self, tokens: &[u32]) -> (Vec<u32>, Vec<u32>, Vec<u32>) {
1015            let _enter = self.span.enter();
1016            let mut text_ids = vec![];
1017            let mut audio_ids1 = vec![];
1018            let mut audio_ids2 = vec![];
1019            for &t in tokens.iter() {
1020                #[allow(clippy::comparison_chain)]
1021                if t < self.end_of_audio_token {
1022                    audio_ids1.push(t)
1023                } else if t < 2 * self.end_of_audio_token {
1024                    audio_ids2.push(t - self.end_of_audio_token)
1025                } else {
1026                    text_ids.push(t)
1027                }
1028            }
1029            (text_ids, audio_ids1, audio_ids2)
1030        }
1031    }
1032}