candle_transformers/models/
parler_tts.rs

1//! Parler Model implementation for parler_tts text-to-speech synthesis
2//!
3//! Implements a transformer-based decoder architecture for generating audio tokens
4//! from text using discrete tokens. The model converts text into audio segments
5//! using multiple codebooks of quantized audio tokens.
6//!
7//! The model architecture includes:
8//! - Multi-head attention layers for text and audio processing
9//! - Feed-forward networks
10//! - Layer normalization
11//! - Positional embeddings
12//! - Multiple codebook prediction heads
13//!
14//! The implementation follows the original parler_tts architecture while focusing
15//! on audio token generation for text-to-speech synthesis.
16//!
17
18use crate::generation::LogitsProcessor;
19use crate::models::t5;
20use candle::{IndexOp, Result, Tensor};
21use candle_nn::{layer_norm, linear_b as linear, Activation, LayerNorm, Linear, VarBuilder};
22
23#[derive(serde::Deserialize, Debug, Clone)]
24pub struct DecoderConfig {
25    pub vocab_size: usize,
26    pub max_position_embeddings: usize,
27    pub num_hidden_layers: usize,
28    pub ffn_dim: usize,
29    pub num_attention_heads: usize,
30    pub num_key_value_heads: Option<usize>,
31    pub num_cross_attention_key_value_heads: Option<usize>,
32    pub activation_function: Activation,
33    pub hidden_size: usize,
34    pub scale_embedding: bool,
35    pub num_codebooks: usize,
36    pub pad_token_id: usize,
37    pub bos_token_id: usize,
38    pub eos_token_id: usize,
39    pub tie_word_embeddings: bool,
40    pub rope_embeddings: bool,
41    pub rope_theta: f64,
42}
43
44#[derive(serde::Deserialize, Debug, Clone)]
45pub struct Config {
46    pub decoder_start_token_id: u32,
47    pub pad_token_id: u32,
48    pub decoder: DecoderConfig,
49    pub text_encoder: t5::Config,
50    pub vocab_size: usize,
51    pub audio_encoder: crate::models::dac::Config,
52}
53
54#[derive(Debug, Clone)]
55pub struct Attention {
56    k_proj: Linear,
57    v_proj: Linear,
58    q_proj: Linear,
59    out_proj: Linear,
60    is_causal: bool,
61    kv_cache: Option<(Tensor, Tensor)>,
62    scaling: f64,
63    num_heads: usize,
64    num_kv_heads: usize,
65    num_kv_groups: usize,
66    head_dim: usize,
67}
68
69impl Attention {
70    fn new(
71        num_kv_heads: usize,
72        is_causal: bool,
73        cfg: &DecoderConfig,
74        vb: VarBuilder,
75    ) -> Result<Self> {
76        if cfg.rope_embeddings {
77            candle::bail!("rope embeddings are not supported");
78        }
79        let embed_dim = cfg.hidden_size;
80        let head_dim = embed_dim / cfg.num_attention_heads;
81        let kv_out_dim = num_kv_heads * head_dim;
82        let k_proj = linear(embed_dim, kv_out_dim, false, vb.pp("k_proj"))?;
83        let v_proj = linear(embed_dim, kv_out_dim, false, vb.pp("v_proj"))?;
84        let q_proj = linear(embed_dim, embed_dim, false, vb.pp("q_proj"))?;
85        let out_proj = linear(embed_dim, embed_dim, false, vb.pp("out_proj"))?;
86        Ok(Self {
87            k_proj,
88            v_proj,
89            q_proj,
90            out_proj,
91            is_causal,
92            kv_cache: None,
93            scaling: (head_dim as f64).powf(-0.5),
94            num_heads: cfg.num_attention_heads,
95            num_kv_heads,
96            num_kv_groups: cfg.num_attention_heads / num_kv_heads,
97            head_dim,
98        })
99    }
100
101    fn forward(
102        &mut self,
103        xs: &Tensor,
104        key_value_states: Option<&Tensor>,
105        attention_mask: Option<&Tensor>,
106    ) -> Result<Tensor> {
107        let (b_sz, tgt_len, _) = xs.dims3()?;
108        let query_states = (xs.apply(&self.q_proj)? * self.scaling)?
109            .reshape((b_sz, tgt_len, self.num_heads, self.head_dim))?
110            .transpose(1, 2)?
111            .contiguous()?;
112        let key_states = match key_value_states {
113            Some(states) => states.apply(&self.k_proj)?,
114            None => xs.apply(&self.k_proj)?,
115        };
116        let key_states = key_states
117            .reshape((b_sz, (), self.num_kv_heads, self.head_dim))?
118            .transpose(1, 2)?
119            .contiguous()?;
120        let value_states = match key_value_states {
121            Some(states) => states.apply(&self.v_proj)?,
122            None => xs.apply(&self.v_proj)?,
123        };
124        let value_states = value_states
125            .reshape((b_sz, (), self.num_kv_heads, self.head_dim))?
126            .transpose(1, 2)?
127            .contiguous()?;
128
129        let (key_states, value_states) = match &self.kv_cache {
130            None => (key_states, value_states),
131            Some((prev_k, prev_v)) => {
132                let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
133                let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
134                (key_states, value_states)
135            }
136        };
137        if self.is_causal {
138            self.kv_cache = Some((key_states.clone(), value_states.clone()));
139        }
140
141        let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
142        let value_states =
143            crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
144
145        let attn_weights = query_states.matmul(&key_states.transpose(2, 3)?)?;
146        let attn_weights = match attention_mask {
147            None => attn_weights,
148            Some(mask) => attn_weights.broadcast_add(mask)?,
149        };
150        let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
151        let attn_output = attn_weights.matmul(&value_states)?;
152        attn_output
153            .transpose(1, 2)?
154            .reshape((b_sz, tgt_len, ()))?
155            .apply(&self.out_proj)
156    }
157
158    fn clear_kv_cache(&mut self) {
159        self.kv_cache = None
160    }
161}
162
163#[derive(Debug, Clone)]
164pub struct DecoderLayer {
165    self_attn: Attention,
166    self_attn_layer_norm: LayerNorm,
167    encoder_attn: Attention,
168    encoder_attn_layer_norm: LayerNorm,
169    fc1: Linear,
170    fc2: Linear,
171    final_layer_norm: LayerNorm,
172    activation: Activation,
173}
174
175impl DecoderLayer {
176    fn new(cfg: &DecoderConfig, vb: VarBuilder) -> Result<Self> {
177        let kv_heads = cfg.num_key_value_heads.unwrap_or(cfg.num_attention_heads);
178        let kv_heads_cross = cfg.num_cross_attention_key_value_heads.unwrap_or(kv_heads);
179
180        let self_attn = Attention::new(kv_heads, true, cfg, vb.pp("self_attn"))?;
181        let encoder_attn = Attention::new(kv_heads_cross, false, cfg, vb.pp("encoder_attn"))?;
182        let self_attn_layer_norm =
183            layer_norm(cfg.hidden_size, 1e-5, vb.pp("self_attn_layer_norm"))?;
184        let encoder_attn_layer_norm =
185            layer_norm(cfg.hidden_size, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
186        let fc1 = linear(cfg.hidden_size, cfg.ffn_dim, false, vb.pp("fc1"))?;
187        let fc2 = linear(cfg.ffn_dim, cfg.hidden_size, false, vb.pp("fc2"))?;
188        let final_layer_norm = layer_norm(cfg.hidden_size, 1e-5, vb.pp("final_layer_norm"))?;
189        Ok(Self {
190            self_attn,
191            self_attn_layer_norm,
192            encoder_attn,
193            encoder_attn_layer_norm,
194            fc1,
195            fc2,
196            final_layer_norm,
197            activation: cfg.activation_function,
198        })
199    }
200
201    fn forward(
202        &mut self,
203        xs: &Tensor,
204        attention_mask: Option<&Tensor>,
205        encoder_xs: &Tensor,
206        encoder_attention_mask: Option<&Tensor>,
207    ) -> Result<Tensor> {
208        // Self attention
209        let residual = xs;
210        let xs = xs.apply(&self.self_attn_layer_norm)?;
211        let xs = self.self_attn.forward(&xs, None, attention_mask)?;
212        let xs = (residual + xs)?;
213
214        // Cross attention
215        let residual = &xs;
216        let xs = xs.apply(&self.encoder_attn_layer_norm)?;
217        let xs = self
218            .encoder_attn
219            .forward(&xs, Some(encoder_xs), encoder_attention_mask)?;
220        let xs = (residual + xs)?;
221
222        // Fully connected
223        let residual = &xs;
224        let xs = xs
225            .apply(&self.final_layer_norm)?
226            .apply(&self.fc1)?
227            .apply(&self.activation)?
228            .apply(&self.fc2)?;
229        residual + xs
230    }
231
232    fn clear_kv_cache(&mut self) {
233        self.self_attn.clear_kv_cache();
234        self.encoder_attn.clear_kv_cache();
235    }
236}
237
238#[derive(Debug, Clone)]
239pub struct Decoder {
240    embed_tokens: Vec<candle_nn::Embedding>,
241    embed_positions: Tensor,
242    layers: Vec<DecoderLayer>,
243    layer_norm: LayerNorm,
244    num_codebooks: usize,
245    hidden_size: usize,
246    lm_heads: Vec<Linear>,
247    dtype: candle::DType,
248}
249
250impl Decoder {
251    pub fn new(cfg: &DecoderConfig, vb: VarBuilder) -> Result<Self> {
252        let vb_d = vb.pp("model.decoder");
253        let mut embed_tokens = Vec::with_capacity(cfg.num_codebooks);
254        let vb_e = vb_d.pp("embed_tokens");
255        for embed_idx in 0..cfg.num_codebooks {
256            let e = candle_nn::embedding(cfg.vocab_size + 1, cfg.hidden_size, vb_e.pp(embed_idx))?;
257            embed_tokens.push(e)
258        }
259        let embed_positions = vb_d.get(
260            (cfg.max_position_embeddings, cfg.hidden_size),
261            "embed_positions.weights",
262        )?;
263        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
264        let vb_l = vb_d.pp("layers");
265        for layer_idx in 0..cfg.num_hidden_layers {
266            let layer = DecoderLayer::new(cfg, vb_l.pp(layer_idx))?;
267            layers.push(layer)
268        }
269        let layer_norm = layer_norm(cfg.hidden_size, 1e-5, vb_d.pp("layer_norm"))?;
270
271        let mut lm_heads = Vec::with_capacity(cfg.num_codebooks);
272        let vb_l = vb.pp("lm_heads");
273        for lm_idx in 0..cfg.num_codebooks {
274            let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, vb_l.pp(lm_idx))?;
275            lm_heads.push(lm_head)
276        }
277        Ok(Self {
278            embed_tokens,
279            embed_positions,
280            layers,
281            layer_norm,
282            num_codebooks: cfg.num_codebooks,
283            lm_heads,
284            hidden_size: cfg.hidden_size,
285            dtype: vb.dtype(),
286        })
287    }
288
289    pub fn forward(
290        &mut self,
291        input_ids: &Tensor,
292        prompt_hidden_states: Option<&Tensor>,
293        attention_mask: Option<&Tensor>,
294        encoder_xs: &Tensor,
295        encoder_attention_mask: Option<&Tensor>,
296        seqlen_offset: usize,
297    ) -> Result<Vec<Tensor>> {
298        let (b_sz, num_codebooks, seq_len) = input_ids.dims3()?;
299        if num_codebooks != self.num_codebooks {
300            candle::bail!("unexpected num codebooks in input {:?}", input_ids.shape())
301        }
302        let mut inputs_embeds = Tensor::zeros(
303            (b_sz, seq_len, self.hidden_size),
304            self.dtype,
305            input_ids.device(),
306        )?;
307        for (idx, embs) in self.embed_tokens.iter().enumerate() {
308            let e = input_ids.i((.., idx))?.apply(embs)?;
309            inputs_embeds = (inputs_embeds + e)?
310        }
311        let inputs_embeds = match prompt_hidden_states {
312            None => inputs_embeds,
313            Some(pis) => Tensor::cat(&[pis, &inputs_embeds], 1)?,
314        };
315        let embed_positions = self
316            .embed_positions
317            .i(seqlen_offset..seqlen_offset + inputs_embeds.dim(1)?)?;
318        let mut xs = (inputs_embeds + embed_positions.unsqueeze(0))?;
319        for layer in self.layers.iter_mut() {
320            xs = layer.forward(&xs, attention_mask, encoder_xs, encoder_attention_mask)?;
321        }
322        let xs = xs.apply(&self.layer_norm)?;
323        let mut lm_logits = Vec::with_capacity(self.num_codebooks);
324        for lm_head in self.lm_heads.iter() {
325            let logits = xs.apply(lm_head)?;
326            lm_logits.push(logits)
327        }
328        Ok(lm_logits)
329    }
330
331    pub fn clear_kv_cache(&mut self) {
332        for layer in self.layers.iter_mut() {
333            layer.clear_kv_cache()
334        }
335    }
336}
337
338#[derive(Debug, Clone)]
339pub struct Model {
340    pub embed_prompts: candle_nn::Embedding,
341    pub enc_to_dec_proj: Option<Linear>,
342    pub decoder: Decoder,
343    pub text_encoder: t5::T5EncoderModel,
344    pub decoder_start_token_id: u32,
345    pub pad_token_id: u32,
346    pub audio_encoder: crate::models::dac::Model,
347}
348
349impl Model {
350    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
351        let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.text_encoder)?;
352        let decoder = Decoder::new(&cfg.decoder, vb.pp("decoder"))?;
353        let embed_prompts = candle_nn::embedding(
354            cfg.vocab_size,
355            cfg.decoder.hidden_size,
356            vb.pp("embed_prompts"),
357        )?;
358        let enc_to_dec_proj = if cfg.text_encoder.d_model != cfg.decoder.hidden_size {
359            let proj = linear(
360                cfg.text_encoder.d_model,
361                cfg.decoder.hidden_size,
362                true,
363                vb.pp("enc_to_dec_proj"),
364            )?;
365            Some(proj)
366        } else {
367            None
368        };
369        let audio_encoder =
370            crate::models::dac::Model::new(&cfg.audio_encoder, vb.pp("audio_encoder"))?;
371        Ok(Self {
372            decoder,
373            text_encoder,
374            embed_prompts,
375            enc_to_dec_proj,
376            decoder_start_token_id: cfg.decoder_start_token_id,
377            pad_token_id: cfg.pad_token_id,
378            audio_encoder,
379        })
380    }
381
382    /// Note that the returned tensor uses the CPU device.
383    pub fn generate(
384        &mut self,
385        prompt_tokens: &Tensor,
386        description_tokens: &Tensor,
387        mut lp: LogitsProcessor,
388        max_steps: usize,
389    ) -> Result<Tensor> {
390        self.decoder.clear_kv_cache();
391        self.text_encoder.clear_kv_cache();
392        let encoded = self.text_encoder.forward(description_tokens)?;
393        let encoded = match self.enc_to_dec_proj.as_ref() {
394            None => encoded,
395            Some(proj) => encoded.apply(proj)?,
396        };
397        let prompt_hidden_states = prompt_tokens.apply(&self.embed_prompts)?;
398        let num_codebooks = self.decoder.num_codebooks;
399        let mut audio_tokens = vec![self.decoder_start_token_id; num_codebooks];
400        let mut all_audio_tokens = vec![vec![]; num_codebooks];
401        let prompt_len = prompt_hidden_states.dim(1)?;
402        for step in 0..max_steps {
403            let input_ids = Tensor::from_slice(
404                audio_tokens.as_slice(),
405                (1, num_codebooks, 1),
406                prompt_tokens.device(),
407            )?;
408            let (prompt_hidden_states, pos) = if step == 0 {
409                (Some(&prompt_hidden_states), 0)
410            } else {
411                (None, step + prompt_len)
412            };
413            let causal_mask = if pos == 0 {
414                self.prepare_causal_mask(prompt_len + 1, prompt_len + 1, input_ids.device())?
415            } else {
416                self.prepare_causal_mask(1, pos + 1, input_ids.device())?
417            };
418            let logits = self.decoder.forward(
419                &input_ids,
420                prompt_hidden_states,
421                Some(&causal_mask),
422                &encoded,
423                None,
424                pos,
425            )?;
426            for (logit_idx, logit) in logits.iter().enumerate() {
427                if logit_idx > step {
428                    break;
429                }
430                if audio_tokens[logit_idx] != self.pad_token_id {
431                    let logit = logit.i((0, logit.dim(1)? - 1))?;
432                    let token = lp.sample(&logit)?;
433                    audio_tokens[logit_idx] = token
434                }
435            }
436            if audio_tokens.iter().all(|v| v == &self.pad_token_id) {
437                break;
438            }
439            for (cb_idx, &token) in audio_tokens.iter().enumerate() {
440                if token != self.decoder_start_token_id && token != self.pad_token_id {
441                    all_audio_tokens[cb_idx].push(token)
442                }
443            }
444        }
445
446        let min_len = all_audio_tokens.iter().map(|v| v.len()).min().unwrap_or(0);
447        all_audio_tokens.iter_mut().for_each(|v| {
448            v.resize(min_len, 0);
449        });
450        let all_audio_tokens = Tensor::new(all_audio_tokens, &candle::Device::Cpu)?;
451        Ok(all_audio_tokens)
452    }
453
454    fn prepare_causal_mask(
455        &self,
456        q_len: usize,
457        kv_len: usize,
458        device: &candle::Device,
459    ) -> Result<Tensor> {
460        let mask: Vec<_> = (0..q_len)
461            .flat_map(|i| {
462                (0..kv_len).map(move |j| {
463                    if i + kv_len < j + q_len {
464                        f32::NEG_INFINITY
465                    } else {
466                        0.
467                    }
468                })
469            })
470            .collect();
471        Tensor::from_slice(&mask, (q_len, kv_len), device)
472    }
473}