1use crate::generation::LogitsProcessor;
19use crate::models::t5;
20use candle::{IndexOp, Result, Tensor};
21use candle_nn::{layer_norm, linear_b as linear, Activation, LayerNorm, Linear, VarBuilder};
22
23#[derive(serde::Deserialize, Debug, Clone)]
24pub struct DecoderConfig {
25 pub vocab_size: usize,
26 pub max_position_embeddings: usize,
27 pub num_hidden_layers: usize,
28 pub ffn_dim: usize,
29 pub num_attention_heads: usize,
30 pub num_key_value_heads: Option<usize>,
31 pub num_cross_attention_key_value_heads: Option<usize>,
32 pub activation_function: Activation,
33 pub hidden_size: usize,
34 pub scale_embedding: bool,
35 pub num_codebooks: usize,
36 pub pad_token_id: usize,
37 pub bos_token_id: usize,
38 pub eos_token_id: usize,
39 pub tie_word_embeddings: bool,
40 pub rope_embeddings: bool,
41 pub rope_theta: f64,
42}
43
44#[derive(serde::Deserialize, Debug, Clone)]
45pub struct Config {
46 pub decoder_start_token_id: u32,
47 pub pad_token_id: u32,
48 pub decoder: DecoderConfig,
49 pub text_encoder: t5::Config,
50 pub vocab_size: usize,
51 pub audio_encoder: crate::models::dac::Config,
52}
53
54#[derive(Debug, Clone)]
55pub struct Attention {
56 k_proj: Linear,
57 v_proj: Linear,
58 q_proj: Linear,
59 out_proj: Linear,
60 is_causal: bool,
61 kv_cache: Option<(Tensor, Tensor)>,
62 scaling: f64,
63 num_heads: usize,
64 num_kv_heads: usize,
65 num_kv_groups: usize,
66 head_dim: usize,
67}
68
69impl Attention {
70 fn new(
71 num_kv_heads: usize,
72 is_causal: bool,
73 cfg: &DecoderConfig,
74 vb: VarBuilder,
75 ) -> Result<Self> {
76 if cfg.rope_embeddings {
77 candle::bail!("rope embeddings are not supported");
78 }
79 let embed_dim = cfg.hidden_size;
80 let head_dim = embed_dim / cfg.num_attention_heads;
81 let kv_out_dim = num_kv_heads * head_dim;
82 let k_proj = linear(embed_dim, kv_out_dim, false, vb.pp("k_proj"))?;
83 let v_proj = linear(embed_dim, kv_out_dim, false, vb.pp("v_proj"))?;
84 let q_proj = linear(embed_dim, embed_dim, false, vb.pp("q_proj"))?;
85 let out_proj = linear(embed_dim, embed_dim, false, vb.pp("out_proj"))?;
86 Ok(Self {
87 k_proj,
88 v_proj,
89 q_proj,
90 out_proj,
91 is_causal,
92 kv_cache: None,
93 scaling: (head_dim as f64).powf(-0.5),
94 num_heads: cfg.num_attention_heads,
95 num_kv_heads,
96 num_kv_groups: cfg.num_attention_heads / num_kv_heads,
97 head_dim,
98 })
99 }
100
101 fn forward(
102 &mut self,
103 xs: &Tensor,
104 key_value_states: Option<&Tensor>,
105 attention_mask: Option<&Tensor>,
106 ) -> Result<Tensor> {
107 let (b_sz, tgt_len, _) = xs.dims3()?;
108 let query_states = (xs.apply(&self.q_proj)? * self.scaling)?
109 .reshape((b_sz, tgt_len, self.num_heads, self.head_dim))?
110 .transpose(1, 2)?
111 .contiguous()?;
112 let key_states = match key_value_states {
113 Some(states) => states.apply(&self.k_proj)?,
114 None => xs.apply(&self.k_proj)?,
115 };
116 let key_states = key_states
117 .reshape((b_sz, (), self.num_kv_heads, self.head_dim))?
118 .transpose(1, 2)?
119 .contiguous()?;
120 let value_states = match key_value_states {
121 Some(states) => states.apply(&self.v_proj)?,
122 None => xs.apply(&self.v_proj)?,
123 };
124 let value_states = value_states
125 .reshape((b_sz, (), self.num_kv_heads, self.head_dim))?
126 .transpose(1, 2)?
127 .contiguous()?;
128
129 let (key_states, value_states) = match &self.kv_cache {
130 None => (key_states, value_states),
131 Some((prev_k, prev_v)) => {
132 let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
133 let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
134 (key_states, value_states)
135 }
136 };
137 if self.is_causal {
138 self.kv_cache = Some((key_states.clone(), value_states.clone()));
139 }
140
141 let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
142 let value_states =
143 crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
144
145 let attn_weights = query_states.matmul(&key_states.transpose(2, 3)?)?;
146 let attn_weights = match attention_mask {
147 None => attn_weights,
148 Some(mask) => attn_weights.broadcast_add(mask)?,
149 };
150 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
151 let attn_output = attn_weights.matmul(&value_states)?;
152 attn_output
153 .transpose(1, 2)?
154 .reshape((b_sz, tgt_len, ()))?
155 .apply(&self.out_proj)
156 }
157
158 fn clear_kv_cache(&mut self) {
159 self.kv_cache = None
160 }
161}
162
163#[derive(Debug, Clone)]
164pub struct DecoderLayer {
165 self_attn: Attention,
166 self_attn_layer_norm: LayerNorm,
167 encoder_attn: Attention,
168 encoder_attn_layer_norm: LayerNorm,
169 fc1: Linear,
170 fc2: Linear,
171 final_layer_norm: LayerNorm,
172 activation: Activation,
173}
174
175impl DecoderLayer {
176 fn new(cfg: &DecoderConfig, vb: VarBuilder) -> Result<Self> {
177 let kv_heads = cfg.num_key_value_heads.unwrap_or(cfg.num_attention_heads);
178 let kv_heads_cross = cfg.num_cross_attention_key_value_heads.unwrap_or(kv_heads);
179
180 let self_attn = Attention::new(kv_heads, true, cfg, vb.pp("self_attn"))?;
181 let encoder_attn = Attention::new(kv_heads_cross, false, cfg, vb.pp("encoder_attn"))?;
182 let self_attn_layer_norm =
183 layer_norm(cfg.hidden_size, 1e-5, vb.pp("self_attn_layer_norm"))?;
184 let encoder_attn_layer_norm =
185 layer_norm(cfg.hidden_size, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
186 let fc1 = linear(cfg.hidden_size, cfg.ffn_dim, false, vb.pp("fc1"))?;
187 let fc2 = linear(cfg.ffn_dim, cfg.hidden_size, false, vb.pp("fc2"))?;
188 let final_layer_norm = layer_norm(cfg.hidden_size, 1e-5, vb.pp("final_layer_norm"))?;
189 Ok(Self {
190 self_attn,
191 self_attn_layer_norm,
192 encoder_attn,
193 encoder_attn_layer_norm,
194 fc1,
195 fc2,
196 final_layer_norm,
197 activation: cfg.activation_function,
198 })
199 }
200
201 fn forward(
202 &mut self,
203 xs: &Tensor,
204 attention_mask: Option<&Tensor>,
205 encoder_xs: &Tensor,
206 encoder_attention_mask: Option<&Tensor>,
207 ) -> Result<Tensor> {
208 let residual = xs;
210 let xs = xs.apply(&self.self_attn_layer_norm)?;
211 let xs = self.self_attn.forward(&xs, None, attention_mask)?;
212 let xs = (residual + xs)?;
213
214 let residual = &xs;
216 let xs = xs.apply(&self.encoder_attn_layer_norm)?;
217 let xs = self
218 .encoder_attn
219 .forward(&xs, Some(encoder_xs), encoder_attention_mask)?;
220 let xs = (residual + xs)?;
221
222 let residual = &xs;
224 let xs = xs
225 .apply(&self.final_layer_norm)?
226 .apply(&self.fc1)?
227 .apply(&self.activation)?
228 .apply(&self.fc2)?;
229 residual + xs
230 }
231
232 fn clear_kv_cache(&mut self) {
233 self.self_attn.clear_kv_cache();
234 self.encoder_attn.clear_kv_cache();
235 }
236}
237
238#[derive(Debug, Clone)]
239pub struct Decoder {
240 embed_tokens: Vec<candle_nn::Embedding>,
241 embed_positions: Tensor,
242 layers: Vec<DecoderLayer>,
243 layer_norm: LayerNorm,
244 num_codebooks: usize,
245 hidden_size: usize,
246 lm_heads: Vec<Linear>,
247 dtype: candle::DType,
248}
249
250impl Decoder {
251 pub fn new(cfg: &DecoderConfig, vb: VarBuilder) -> Result<Self> {
252 let vb_d = vb.pp("model.decoder");
253 let mut embed_tokens = Vec::with_capacity(cfg.num_codebooks);
254 let vb_e = vb_d.pp("embed_tokens");
255 for embed_idx in 0..cfg.num_codebooks {
256 let e = candle_nn::embedding(cfg.vocab_size + 1, cfg.hidden_size, vb_e.pp(embed_idx))?;
257 embed_tokens.push(e)
258 }
259 let embed_positions = vb_d.get(
260 (cfg.max_position_embeddings, cfg.hidden_size),
261 "embed_positions.weights",
262 )?;
263 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
264 let vb_l = vb_d.pp("layers");
265 for layer_idx in 0..cfg.num_hidden_layers {
266 let layer = DecoderLayer::new(cfg, vb_l.pp(layer_idx))?;
267 layers.push(layer)
268 }
269 let layer_norm = layer_norm(cfg.hidden_size, 1e-5, vb_d.pp("layer_norm"))?;
270
271 let mut lm_heads = Vec::with_capacity(cfg.num_codebooks);
272 let vb_l = vb.pp("lm_heads");
273 for lm_idx in 0..cfg.num_codebooks {
274 let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, vb_l.pp(lm_idx))?;
275 lm_heads.push(lm_head)
276 }
277 Ok(Self {
278 embed_tokens,
279 embed_positions,
280 layers,
281 layer_norm,
282 num_codebooks: cfg.num_codebooks,
283 lm_heads,
284 hidden_size: cfg.hidden_size,
285 dtype: vb.dtype(),
286 })
287 }
288
289 pub fn forward(
290 &mut self,
291 input_ids: &Tensor,
292 prompt_hidden_states: Option<&Tensor>,
293 attention_mask: Option<&Tensor>,
294 encoder_xs: &Tensor,
295 encoder_attention_mask: Option<&Tensor>,
296 seqlen_offset: usize,
297 ) -> Result<Vec<Tensor>> {
298 let (b_sz, num_codebooks, seq_len) = input_ids.dims3()?;
299 if num_codebooks != self.num_codebooks {
300 candle::bail!("unexpected num codebooks in input {:?}", input_ids.shape())
301 }
302 let mut inputs_embeds = Tensor::zeros(
303 (b_sz, seq_len, self.hidden_size),
304 self.dtype,
305 input_ids.device(),
306 )?;
307 for (idx, embs) in self.embed_tokens.iter().enumerate() {
308 let e = input_ids.i((.., idx))?.apply(embs)?;
309 inputs_embeds = (inputs_embeds + e)?
310 }
311 let inputs_embeds = match prompt_hidden_states {
312 None => inputs_embeds,
313 Some(pis) => Tensor::cat(&[pis, &inputs_embeds], 1)?,
314 };
315 let embed_positions = self
316 .embed_positions
317 .i(seqlen_offset..seqlen_offset + inputs_embeds.dim(1)?)?;
318 let mut xs = (inputs_embeds + embed_positions.unsqueeze(0))?;
319 for layer in self.layers.iter_mut() {
320 xs = layer.forward(&xs, attention_mask, encoder_xs, encoder_attention_mask)?;
321 }
322 let xs = xs.apply(&self.layer_norm)?;
323 let mut lm_logits = Vec::with_capacity(self.num_codebooks);
324 for lm_head in self.lm_heads.iter() {
325 let logits = xs.apply(lm_head)?;
326 lm_logits.push(logits)
327 }
328 Ok(lm_logits)
329 }
330
331 pub fn clear_kv_cache(&mut self) {
332 for layer in self.layers.iter_mut() {
333 layer.clear_kv_cache()
334 }
335 }
336}
337
338#[derive(Debug, Clone)]
339pub struct Model {
340 pub embed_prompts: candle_nn::Embedding,
341 pub enc_to_dec_proj: Option<Linear>,
342 pub decoder: Decoder,
343 pub text_encoder: t5::T5EncoderModel,
344 pub decoder_start_token_id: u32,
345 pub pad_token_id: u32,
346 pub audio_encoder: crate::models::dac::Model,
347}
348
349impl Model {
350 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
351 let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.text_encoder)?;
352 let decoder = Decoder::new(&cfg.decoder, vb.pp("decoder"))?;
353 let embed_prompts = candle_nn::embedding(
354 cfg.vocab_size,
355 cfg.decoder.hidden_size,
356 vb.pp("embed_prompts"),
357 )?;
358 let enc_to_dec_proj = if cfg.text_encoder.d_model != cfg.decoder.hidden_size {
359 let proj = linear(
360 cfg.text_encoder.d_model,
361 cfg.decoder.hidden_size,
362 true,
363 vb.pp("enc_to_dec_proj"),
364 )?;
365 Some(proj)
366 } else {
367 None
368 };
369 let audio_encoder =
370 crate::models::dac::Model::new(&cfg.audio_encoder, vb.pp("audio_encoder"))?;
371 Ok(Self {
372 decoder,
373 text_encoder,
374 embed_prompts,
375 enc_to_dec_proj,
376 decoder_start_token_id: cfg.decoder_start_token_id,
377 pad_token_id: cfg.pad_token_id,
378 audio_encoder,
379 })
380 }
381
382 pub fn generate(
384 &mut self,
385 prompt_tokens: &Tensor,
386 description_tokens: &Tensor,
387 mut lp: LogitsProcessor,
388 max_steps: usize,
389 ) -> Result<Tensor> {
390 self.decoder.clear_kv_cache();
391 self.text_encoder.clear_kv_cache();
392 let encoded = self.text_encoder.forward(description_tokens)?;
393 let encoded = match self.enc_to_dec_proj.as_ref() {
394 None => encoded,
395 Some(proj) => encoded.apply(proj)?,
396 };
397 let prompt_hidden_states = prompt_tokens.apply(&self.embed_prompts)?;
398 let num_codebooks = self.decoder.num_codebooks;
399 let mut audio_tokens = vec![self.decoder_start_token_id; num_codebooks];
400 let mut all_audio_tokens = vec![vec![]; num_codebooks];
401 let prompt_len = prompt_hidden_states.dim(1)?;
402 for step in 0..max_steps {
403 let input_ids = Tensor::from_slice(
404 audio_tokens.as_slice(),
405 (1, num_codebooks, 1),
406 prompt_tokens.device(),
407 )?;
408 let (prompt_hidden_states, pos) = if step == 0 {
409 (Some(&prompt_hidden_states), 0)
410 } else {
411 (None, step + prompt_len)
412 };
413 let causal_mask = if pos == 0 {
414 self.prepare_causal_mask(prompt_len + 1, prompt_len + 1, input_ids.device())?
415 } else {
416 self.prepare_causal_mask(1, pos + 1, input_ids.device())?
417 };
418 let logits = self.decoder.forward(
419 &input_ids,
420 prompt_hidden_states,
421 Some(&causal_mask),
422 &encoded,
423 None,
424 pos,
425 )?;
426 for (logit_idx, logit) in logits.iter().enumerate() {
427 if logit_idx > step {
428 break;
429 }
430 if audio_tokens[logit_idx] != self.pad_token_id {
431 let logit = logit.i((0, logit.dim(1)? - 1))?;
432 let token = lp.sample(&logit)?;
433 audio_tokens[logit_idx] = token
434 }
435 }
436 if audio_tokens.iter().all(|v| v == &self.pad_token_id) {
437 break;
438 }
439 for (cb_idx, &token) in audio_tokens.iter().enumerate() {
440 if token != self.decoder_start_token_id && token != self.pad_token_id {
441 all_audio_tokens[cb_idx].push(token)
442 }
443 }
444 }
445
446 let min_len = all_audio_tokens.iter().map(|v| v.len()).min().unwrap_or(0);
447 all_audio_tokens.iter_mut().for_each(|v| {
448 v.resize(min_len, 0);
449 });
450 let all_audio_tokens = Tensor::new(all_audio_tokens, &candle::Device::Cpu)?;
451 Ok(all_audio_tokens)
452 }
453
454 fn prepare_causal_mask(
455 &self,
456 q_len: usize,
457 kv_len: usize,
458 device: &candle::Device,
459 ) -> Result<Tensor> {
460 let mask: Vec<_> = (0..q_len)
461 .flat_map(|i| {
462 (0..kv_len).map(move |j| {
463 if i + kv_len < j + q_len {
464 f32::NEG_INFINITY
465 } else {
466 0.
467 }
468 })
469 })
470 .collect();
471 Tensor::from_slice(&mask, (q_len, kv_len), device)
472 }
473}