candle_transformers/models/
quantized_metavoice.rs

1//! Quantized MetaVoice model implementation.
2//!
3//! MetaVoice is a conditional text-to-speech model based on a transformer architecture.
4//! This implementation provides quantization for reduced memory and compute.
5//!
6//! Key characteristics:
7//! - Transformer-based autoregressive decoder
8//! - Speaker conditioning
9//! - Support for 8-bit quantization
10//! - Key-value caching for efficient inference
11//! - RMS normalization layers
12//!
13//! References:
14//! - [MetaVoice Code](https://github.com/metavoiceio/metavoice)
15//!
16
17use crate::quantized_nn::{linear_b, Embedding, Linear, RmsNorm};
18pub use crate::quantized_var_builder::VarBuilder;
19
20use crate::models::metavoice::repeat_interleave;
21use candle::{Module, Result, Tensor, D};
22
23pub mod transformer {
24    use super::*;
25
26    type Config = crate::models::metavoice::transformer::Config;
27
28    #[derive(Debug, Clone)]
29    struct FeedForward {
30        w1: Linear,
31        w2: Linear,
32        w3: Linear,
33        span: tracing::Span,
34    }
35
36    impl FeedForward {
37        fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
38            let i_size = cfg.intermediate_size();
39            let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?;
40            let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?;
41            let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?;
42            Ok(Self {
43                w1,
44                w2,
45                w3,
46                span: tracing::span!(tracing::Level::TRACE, "feed-forward"),
47            })
48        }
49    }
50
51    impl Module for FeedForward {
52        fn forward(&self, xs: &Tensor) -> Result<Tensor> {
53            let _enter = self.span.enter();
54            let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?;
55            swiglu.apply(&self.w2)
56        }
57    }
58
59    #[derive(Debug, Clone)]
60    struct Attention {
61        wqkv: Linear,
62        wo: Linear,
63        dim: usize,
64        kv_size: usize,
65        n_local_heads: usize,
66        head_dim: usize,
67        n_head: usize,
68        kv_cache: Option<(Tensor, Tensor)>,
69        span: tracing::Span,
70    }
71
72    impl Attention {
73        fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
74            let n_local_heads = cfg.n_local_heads();
75            let head_dim = cfg.head_dim();
76            let total_head_dim = (cfg.n_head + 2 * n_local_heads) * head_dim;
77            let wqkv = linear_b(cfg.dim, total_head_dim, false, vb.pp("wqkv"))?;
78            let wo = linear_b(cfg.dim, cfg.dim, false, vb.pp("wo"))?;
79            Ok(Self {
80                wqkv,
81                wo,
82                dim: cfg.dim,
83                kv_size: n_local_heads * head_dim,
84                n_local_heads,
85                head_dim,
86                n_head: cfg.n_head,
87                kv_cache: None,
88                span: tracing::span!(tracing::Level::TRACE, "attention"),
89            })
90        }
91
92        fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result<Tensor> {
93            let _enter = self.span.enter();
94            let (b_sz, seqlen, _) = xs.dims3()?;
95
96            let qkv = xs.apply(&self.wqkv)?;
97            let q = qkv.narrow(D::Minus1, 0, self.dim)?;
98            let k = qkv.narrow(D::Minus1, self.dim, self.kv_size)?;
99            let v = qkv.narrow(D::Minus1, self.dim + self.kv_size, self.kv_size)?;
100            let q = q
101                .reshape((b_sz, seqlen, self.n_head, self.head_dim))?
102                .transpose(1, 2)?
103                .contiguous()?;
104            let k = k
105                .reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))?
106                .transpose(1, 2)?;
107            let v = v
108                .reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))?
109                .transpose(1, 2)?;
110
111            let (k, v) = match &self.kv_cache {
112                None => (k, v),
113                Some((prev_k, prev_v)) => {
114                    let k = Tensor::cat(&[prev_k, &k], 2)?;
115                    let v = Tensor::cat(&[prev_v, &v], 2)?;
116                    (k, v)
117                }
118            };
119            self.kv_cache = Some((k.clone(), v.clone()));
120
121            let k = repeat_interleave(&k, self.n_head / self.n_local_heads, 1)?;
122            let v = repeat_interleave(&v, self.n_head / self.n_local_heads, 1)?;
123
124            let scale = 1f64 / f64::sqrt(self.head_dim as f64);
125            let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
126
127            let attn_weights = attn_weights.broadcast_add(mask)?;
128            let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
129            let attn_output = attn_weights.matmul(&v)?;
130            attn_output
131                .transpose(1, 2)?
132                .reshape((b_sz, seqlen, self.dim))?
133                .apply(&self.wo)
134        }
135
136        fn clear_kv_cache(&mut self) {
137            self.kv_cache = None
138        }
139    }
140
141    #[derive(Debug, Clone)]
142    struct Block {
143        attention: Attention,
144        feed_forward: FeedForward,
145        ffn_norm: RmsNorm,
146        attention_norm: RmsNorm,
147        span: tracing::Span,
148    }
149
150    impl Block {
151        fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
152            let attention = Attention::new(cfg, vb.pp("attention"))?;
153            let feed_forward = FeedForward::new(cfg, vb.pp("feed_forward"))?;
154            let ffn_norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("ffn_norm"))?;
155            let attention_norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("attention_norm"))?;
156            Ok(Self {
157                attention,
158                feed_forward,
159                ffn_norm,
160                attention_norm,
161                span: tracing::span!(tracing::Level::TRACE, "block"),
162            })
163        }
164
165        fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result<Tensor> {
166            let _enter = self.span.enter();
167            let hs = xs.apply(&self.attention_norm)?;
168            let hs = (xs + self.attention.forward(&hs, pos, mask))?;
169            &hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward)
170        }
171
172        fn clear_kv_cache(&mut self) {
173            self.attention.clear_kv_cache()
174        }
175    }
176
177    #[derive(Debug, Clone)]
178    pub struct Model {
179        tok_embeddings: Embedding,
180        pos_embeddings: Embedding,
181        speaker_cond_pos: Linear,
182        layers: Vec<Block>,
183        norm: RmsNorm,
184        output: Linear,
185        spk_cond_mask: Tensor,
186        span: tracing::Span,
187    }
188
189    impl Model {
190        pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
191            let tok_embeddings = Embedding::new(cfg.vocab_size, cfg.dim, vb.pp("tok_embeddings"))?;
192            let pos_embeddings = Embedding::new(cfg.block_size, cfg.dim, vb.pp("pos_embeddings"))?;
193            let speaker_cond_pos = linear_b(
194                cfg.speaker_emb_dim,
195                cfg.dim,
196                false,
197                vb.pp("speaker_cond_pos"),
198            )?;
199            let mut layers = Vec::with_capacity(cfg.n_layer);
200            let vb_l = vb.pp("layers");
201            for layer_idx in 0..cfg.n_layer {
202                let layer = Block::new(cfg, vb_l.pp(layer_idx))?;
203                layers.push(layer)
204            }
205            let norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("norm"))?;
206            let output = linear_b(cfg.dim, cfg.vocab_size, false, vb.pp("output"))?;
207            let spk_cond_mask = Tensor::cat(
208                &[
209                    Tensor::ones((1, 1, cfg.dim), candle::DType::F32, vb.device())?,
210                    Tensor::zeros((1, 1, cfg.dim), candle::DType::F32, vb.device())?,
211                ],
212                0,
213            )?;
214            Ok(Self {
215                tok_embeddings,
216                pos_embeddings,
217                speaker_cond_pos,
218                layers,
219                norm,
220                output,
221                spk_cond_mask,
222                span: tracing::span!(tracing::Level::TRACE, "qtransformer"),
223            })
224        }
225
226        pub fn clear_kv_cache(&mut self) {
227            for layer in self.layers.iter_mut() {
228                layer.clear_kv_cache()
229            }
230        }
231
232        pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result<Tensor> {
233            let _enter = self.span.enter();
234            let (_b_sz, seqlen) = xs.dims2()?;
235            let mask: Vec<_> = (0..seqlen)
236                .flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
237                .collect();
238            let mask = Tensor::from_slice(&mask, (1, 1, seqlen, seqlen), xs.device())?;
239            let input_pos = Tensor::arange(pos as u32, (pos + seqlen) as u32, xs.device())?;
240            let tok_embeddings = xs.apply(&self.tok_embeddings)?;
241            let pos_embeddings = input_pos.apply(&self.pos_embeddings)?;
242            let mut xs = tok_embeddings
243                .broadcast_add(&pos_embeddings)?
244                .broadcast_add(
245                    &spk_emb
246                        .apply(&self.speaker_cond_pos)?
247                        .broadcast_mul(&self.spk_cond_mask)?,
248                )?;
249            let mask = mask.to_dtype(xs.dtype())?;
250            for layer in self.layers.iter_mut() {
251                xs = layer.forward(&xs, pos, &mask)?
252            }
253            xs.narrow(1, seqlen - 1, 1)?
254                .contiguous()?
255                .apply(&self.norm)?
256                .apply(&self.output)
257        }
258    }
259}