1use candle::{DType, Device, Error as E, IndexOp, Module, Result, Tensor, D};
8use candle_nn::{embedding, linear_b, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
9
10pub(crate) fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> {
12 let img = img.unsqueeze(dim + 1)?;
13 let mut dims = img.dims().to_vec();
14 dims[dim + 1] = repeats;
15 img.broadcast_as(dims)?.flatten(dim, dim + 1)
16}
17pub mod speaker_encoder {
18 use super::*;
19
20 #[derive(Debug, Clone, serde::Deserialize)]
21 pub struct Config {
22 pub sampling_rate: usize,
23 pub partial_n_frames: usize,
24 pub model_hidden_size: usize,
25 pub model_embedding_size: usize,
26 pub model_num_layers: usize,
27 pub mel_window_length: usize,
28 pub mel_window_step: usize,
29 pub mel_n_channels: usize,
30 }
31
32 impl Config {
33 pub fn cfg() -> Self {
34 Self {
35 sampling_rate: 16_000,
36 partial_n_frames: 160,
37 model_hidden_size: 256,
38 model_embedding_size: 256,
39 model_num_layers: 3,
40 mel_window_length: 25,
41 mel_window_step: 10,
42 mel_n_channels: 40,
43 }
44 }
45 }
46
47 pub struct Model {
48 lstms: Vec<candle_nn::LSTM>,
49 linear: Linear,
50 cfg: Config,
51 }
52
53 type Slice = (usize, usize);
54
55 impl Model {
56 pub fn new(cfg: Config, vb: VarBuilder) -> Result<Self> {
57 let mut lstms = Vec::with_capacity(cfg.model_num_layers);
58 let vb_l = vb.pp("lstm");
59 for layer_idx in 0..cfg.model_num_layers {
60 let c = candle_nn::LSTMConfig {
61 layer_idx,
62 ..Default::default()
63 };
64 let lstm = candle_nn::lstm(
65 cfg.mel_n_channels,
66 cfg.model_hidden_size,
67 c,
68 vb_l.pp(layer_idx),
69 )?;
70 lstms.push(lstm)
71 }
72 let linear = linear_b(
73 cfg.model_hidden_size,
74 cfg.model_embedding_size,
75 true,
76 vb.pp("linear"),
77 )?;
78 Ok(Self { lstms, linear, cfg })
79 }
80
81 fn compute_partial_slices(
82 &self,
83 n_samples: usize,
84 rate: f64,
85 min_coverage: f64,
86 ) -> (Vec<Slice>, Vec<Slice>) {
87 let c = &self.cfg;
88 let samples_per_frame = c.sampling_rate * c.mel_window_step / 1000;
90 let n_frames = n_samples / samples_per_frame + 1;
91 let frame_step =
92 (c.sampling_rate as f64 / rate / samples_per_frame as f64).round() as usize;
93 let steps = (n_frames + frame_step).saturating_sub(c.partial_n_frames) + 1;
94 let mut wav_slices = vec![];
96 let mut mel_slices = vec![];
97 for i in (0..steps).step_by(frame_step) {
98 let mel_range = (i, i + c.partial_n_frames);
99 let wav_range = (
100 i * samples_per_frame,
101 (i + c.partial_n_frames) * samples_per_frame,
102 );
103 mel_slices.push(mel_range);
104 wav_slices.push(wav_range);
105 }
106 let last_wav_range = match wav_slices.last() {
108 None => return (wav_slices, mel_slices),
109 Some(l) => *l,
110 };
111 let coverage = (n_samples - last_wav_range.0) as f64
112 / (last_wav_range.1 - last_wav_range.0) as f64;
113 if coverage > min_coverage && mel_slices.len() > 1 {
114 mel_slices.pop();
115 wav_slices.pop();
116 }
117 (wav_slices, mel_slices)
118 }
119
120 pub fn embed_utterance(
121 &self,
122 wav: &[f32],
123 mel_filters: &[f32],
124 rate: f64,
125 min_c: f64,
126 device: &Device,
127 ) -> Result<Tensor> {
128 let (wav_slices, mel_slices) = self.compute_partial_slices(wav.len(), rate, min_c);
129 let max_wave_length = match wav_slices.last() {
130 Some(v) => v.1,
131 None => candle::bail!("empty wav slices"),
132 };
133 let wav = if max_wave_length > wav.len() {
134 let mut wav = wav.to_vec();
135 wav.resize(max_wave_length - wav.len(), 0.0);
136 std::borrow::Cow::Owned(wav)
137 } else {
138 std::borrow::Cow::Borrowed(wav)
139 };
140 let mel = crate::models::whisper::audio::log_mel_spectrogram_(
141 wav.as_ref(),
142 mel_filters,
143 self.cfg.mel_window_length,
144 self.cfg.mel_window_step,
145 self.cfg.mel_n_channels,
146 false,
147 );
148 let mels = mel_slices
149 .iter()
150 .flat_map(|s| [mel[s.0], mel[s.1]])
151 .collect::<Vec<_>>();
152 let mels = Tensor::from_vec(mels, (mel_slices.len(), 2), device)?;
153 let partial_embeds = self.forward(&mels)?;
154 let raw_embed = partial_embeds.mean(0)?;
155 let norm = raw_embed.sqr()?.sum_all()?.sqrt()?;
156 raw_embed.broadcast_div(&norm)
157 }
158 }
159
160 impl Module for Model {
161 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
162 use candle_nn::RNN;
163
164 let xs = xs.t()?;
166 let mut xs = xs.clone();
167 for layer in self.lstms.iter() {
168 let states = layer.seq(&xs)?;
169 xs = layer.states_to_tensor(&states)?;
170 }
171 let xs = xs.t()?;
172 let embeds_raw = xs.apply(&self.linear)?.relu()?;
173 let norm = embeds_raw.sqr()?.sum_keepdim(1)?.sqrt()?;
174 embeds_raw.broadcast_div(&norm)
175 }
176 }
177}
178
179type Rank = u32;
180
181pub mod tokenizers {
182 use super::*;
183 use std::collections::HashMap;
184
185 pub struct BPE {
186 pub re: fancy_regex::Regex,
187 pub end_of_text: usize,
188 pub offset: usize,
189 pub ranks: HashMap<Vec<u8>, Rank>,
190 span: tracing::Span,
191 }
192
193 impl BPE {
194 pub fn from_json(json: &serde_json::Value, end_of_text: usize) -> Result<Self> {
195 let json = match json.as_object() {
196 None => candle::bail!("json value is not an object"),
197 Some(json) => json,
198 };
199 let re = match json.get("pat_str") {
200 None => candle::bail!("json object has no pat_str field"),
201 Some(pat_str) => match pat_str.as_str() {
202 None => candle::bail!("pat_str field is not a string"),
203 Some(pat_str) => fancy_regex::Regex::new(pat_str).map_err(E::wrap)?,
204 },
205 };
206 let offset = match json.get("offset") {
207 None => candle::bail!("json object has no offset field"),
208 Some(offset) => match offset.as_u64() {
209 None => candle::bail!("offset field is not a positive int"),
210 Some(offset) => offset as usize,
211 },
212 };
213 let mut ranks = HashMap::new();
214 for id in 0u8..=255 {
215 ranks.insert(vec![id], id as u32);
216 }
217 let mergeable_ranks = match json.get("mergeable_ranks") {
218 None => candle::bail!("json object has no mergeable_ranks field"),
219 Some(mr) => match mr.as_object() {
220 None => candle::bail!("mergeable_ranks is not an object"),
221 Some(mr) => mr,
222 },
223 };
224 for (key, value) in mergeable_ranks.iter() {
225 let value = match value.as_u64() {
226 None => candle::bail!("mergeable_ranks '{key}' is not a u64"),
227 Some(value) => value as u32,
228 };
229 if value < 256 {
230 continue;
231 }
232 let key = key.as_bytes().to_vec();
234 ranks.insert(key, value);
235 }
236 Ok(Self {
237 re,
238 end_of_text,
239 offset,
240 ranks,
241 span: tracing::span!(tracing::Level::TRACE, "bpe"),
242 })
243 }
244
245 fn _byte_pair_merge(&self, piece: &[u8]) -> Vec<(usize, Rank)> {
248 let mut parts = Vec::with_capacity(piece.len() + 1);
251
252 let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
256 for i in 0..piece.len() - 1 {
257 let rank = *self.ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
258 if rank < min_rank.0 {
259 min_rank = (rank, i);
260 }
261 parts.push((i, rank));
262 }
263 parts.push((piece.len() - 1, Rank::MAX));
264 parts.push((piece.len(), Rank::MAX));
265
266 let get_rank = {
267 #[inline(always)]
268 |parts: &Vec<(usize, Rank)>, i: usize| {
269 if (i + 3) < parts.len() {
270 *self
273 .ranks
274 .get(&piece[parts[i].0..parts[i + 3].0])
275 .unwrap_or(&Rank::MAX)
276 } else {
277 Rank::MAX
278 }
279 }
280 };
281
282 while min_rank.0 != Rank::MAX {
287 let i = min_rank.1;
288 if i > 0 {
291 parts[i - 1].1 = get_rank(&parts, i - 1);
292 }
293 parts[i].1 = get_rank(&parts, i);
294 parts.remove(i + 1);
295
296 min_rank = (Rank::MAX, usize::MAX);
297 for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
298 if rank < min_rank.0 {
299 min_rank = (rank, i);
300 }
301 }
302 }
303 parts
304 }
305
306 pub fn byte_pair_encode(&self, piece: &[u8]) -> Vec<Rank> {
307 if piece.is_empty() {
308 return Vec::new();
309 }
310 if piece.len() == 1 {
311 return vec![self.ranks[piece]];
312 }
313 assert!(piece.len() > 1);
314 self._byte_pair_merge(piece)
315 .windows(2)
316 .map(|part| self.ranks[&piece[part[0].0..part[1].0]])
317 .collect()
318 }
319
320 pub fn encode(&self, text: &str) -> Result<Vec<u32>> {
321 let _enter = self.span.enter();
322 let mut bpe_tokens: Vec<u32> = Vec::new();
323 for word in self.re.find_iter(text) {
324 let word = word.map_err(E::wrap)?;
325 let word_tokens = self.byte_pair_encode(word.as_str().as_bytes());
326 for &token in word_tokens.iter() {
327 bpe_tokens.push(token + self.offset as u32)
328 }
329 }
330 bpe_tokens.push((self.end_of_text + self.offset) as u32);
331 Ok(bpe_tokens)
332 }
333 }
334}
335
336pub mod gpt {
337 use super::*;
338
339 #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
340 pub enum NormType {
341 LayerNorm,
342 RMSNorm,
343 }
344
345 #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
346 pub enum AttnKernelType {
347 Fa2,
348 TorchAttn,
349 Hand,
350 }
351
352 #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
353 pub enum NonLinearityType {
354 Gelu,
355 Swiglu,
356 }
357
358 enum Norm {
359 RMSNorm(candle_nn::RmsNorm),
360 LayerNorm(candle_nn::LayerNorm),
361 }
362
363 #[derive(Debug, Clone)]
365 pub struct Config {
366 pub block_size: usize,
367 pub vocab_sizes: Vec<usize>,
368 pub target_vocab_sizes: Vec<usize>,
369 pub n_layer: usize,
370 pub n_head: usize,
371 pub n_embd: usize,
372 pub bias: bool,
373 pub causal: bool,
374 pub spk_emb_on_text: bool,
375 pub norm_type: NormType,
376 pub rmsnorm_eps: f64,
377 pub nonlinearity_type: NonLinearityType,
378 pub swiglu_multiple_of: Option<usize>,
379 pub attn_kernel_type: AttnKernelType,
380 pub kv_cache_enabled: bool,
381 }
382
383 impl Config {
384 pub fn cfg1b_v0_1() -> Self {
385 Self {
386 n_layer: 6,
387 n_head: 6,
388 n_embd: 384,
389 block_size: 1024,
390 bias: false,
391 vocab_sizes: vec![1538, 1025],
392 causal: false,
393 target_vocab_sizes: vec![1025, 1025, 1025, 1025, 1025, 1025],
394 swiglu_multiple_of: Some(256),
395 norm_type: NormType::LayerNorm,
396 kv_cache_enabled: false,
397 attn_kernel_type: AttnKernelType::TorchAttn,
398 spk_emb_on_text: true,
399 nonlinearity_type: NonLinearityType::Gelu,
400 rmsnorm_eps: 1e-5,
401 }
402 }
403 }
404
405 impl Norm {
406 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
407 match cfg.norm_type {
408 NormType::RMSNorm => {
409 let rms_norm = candle_nn::rms_norm(cfg.n_embd, cfg.rmsnorm_eps, vb)?;
410 Ok(Self::RMSNorm(rms_norm))
411 }
412 NormType::LayerNorm => {
413 let ln_cfg = candle_nn::LayerNormConfig {
414 affine: cfg.bias,
415 ..Default::default()
416 };
417 let layer_norm = candle_nn::layer_norm(cfg.n_embd, ln_cfg, vb)?;
418 Ok(Self::LayerNorm(layer_norm))
419 }
420 }
421 }
422 }
423
424 impl Module for Norm {
425 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
426 match self {
427 Self::RMSNorm(m) => m.forward(xs),
428 Self::LayerNorm(m) => m.forward(xs),
429 }
430 }
431 }
432
433 struct SelfAttention {
435 c_attn: Linear,
436 c_proj: Linear,
437 n_head: usize,
438 span: tracing::Span,
439 }
440
441 impl SelfAttention {
442 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
443 if cfg.attn_kernel_type != AttnKernelType::TorchAttn {
446 candle::bail!("only TorchAttn is supported")
447 }
448 if cfg.kv_cache_enabled {
449 candle::bail!("kv_cache_enabled=true is not supported")
450 }
451 let c_attn = linear_b(cfg.n_embd, cfg.n_embd * 3, cfg.bias, vb.pp("c_attn"))?;
452 let c_proj = linear_b(cfg.n_embd, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?;
453 Ok(Self {
454 c_attn,
455 c_proj,
456 n_head: cfg.n_head,
457 span: tracing::span!(tracing::Level::TRACE, "self-attn"),
458 })
459 }
460 }
461
462 impl Module for SelfAttention {
463 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
464 let _enter = self.span.enter();
465 let (b, t, c) = xs.dims3()?;
466 let c_x = xs
467 .apply(&self.c_attn)?
468 .reshape((b, t, 3, self.n_head, c / self.n_head))?;
469 let q = c_x.i((.., .., 0))?;
470 let k = c_x.i((.., .., 1))?;
471 let v = c_x.i((.., .., 2))?;
472 let q = q.transpose(1, 2)?.contiguous()?;
473 let k = k.transpose(1, 2)?.contiguous()?;
474 let v = v.transpose(1, 2)?.contiguous()?;
475 let att = (q.matmul(&k.t()?)? / (k.dim(D::Minus1)? as f64).sqrt())?;
476 let att = candle_nn::ops::softmax_last_dim(&att)?;
478 let att = att.matmul(&v)?.transpose(1, 2)?;
479 att.reshape((b, t, c))?.apply(&self.c_proj)
480 }
481 }
482
483 #[allow(clippy::upper_case_acronyms)]
485 enum MLP {
486 Gelu {
487 c_fc: Linear,
488 c_proj: Linear,
489 span: tracing::Span,
490 },
491 Swiglu {
492 w1: Linear,
493 w3: Linear,
494 c_proj: Linear,
495 span: tracing::Span,
496 },
497 }
498
499 impl MLP {
500 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
501 let hidden_dim = 4 * cfg.n_embd;
502 let slf = match cfg.nonlinearity_type {
503 NonLinearityType::Gelu => {
504 let c_fc = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("c_fc"))?;
505 let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?;
506 Self::Gelu {
507 c_fc,
508 c_proj,
509 span: tracing::span!(tracing::Level::TRACE, "mlp-gelu"),
510 }
511 }
512 NonLinearityType::Swiglu => {
513 let hidden_dim = (2 * hidden_dim) / 3;
514 let swiglu_multiple_of = match cfg.swiglu_multiple_of {
515 None => candle::bail!("swiglu-multiple-of has to be set"),
516 Some(smo) => smo,
517 };
518 let hidden_dim = swiglu_multiple_of * (hidden_dim + swiglu_multiple_of - 1)
519 / swiglu_multiple_of;
520 let w1 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w1"))?;
521 let w3 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w3"))?;
522 let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?;
523 Self::Swiglu {
524 w1,
525 w3,
526 c_proj,
527 span: tracing::span!(tracing::Level::TRACE, "mlp-swiglu"),
528 }
529 }
530 };
531 Ok(slf)
532 }
533 }
534
535 impl Module for MLP {
536 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
537 match self {
538 Self::Gelu { c_fc, c_proj, span } => {
539 let _enter = span.enter();
540 xs.apply(c_fc)?.gelu()?.apply(c_proj)
541 }
542 Self::Swiglu {
543 w1,
544 w3,
545 c_proj,
546 span,
547 } => {
548 let _enter = span.enter();
549 let w1 = xs.apply(w1)?;
550 let w3 = xs.apply(w3)?;
551 (w1.silu()? * w3)?.apply(c_proj)
552 }
553 }
554 }
555 }
556
557 struct Block {
559 ln_1: Norm,
560 ln_2: Norm,
561 attn: SelfAttention,
562 mlp: MLP,
563 span: tracing::Span,
564 }
565
566 impl Block {
567 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
568 let ln_1 = Norm::new(cfg, vb.pp("ln_1"))?;
569 let ln_2 = Norm::new(cfg, vb.pp("ln_2"))?;
570 let attn = SelfAttention::new(cfg, vb.pp("attn"))?;
571 let mlp = MLP::new(cfg, vb.pp("mlp"))?;
572 Ok(Block {
573 ln_1,
574 ln_2,
575 attn,
576 mlp,
577 span: tracing::span!(tracing::Level::TRACE, "gpt-block"),
578 })
579 }
580 }
581
582 impl Module for Block {
583 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
584 let _enter = self.span.enter();
585 let xs = (xs + xs.apply(&self.ln_1)?.apply(&self.attn))?;
586 let xs = (&xs + xs.apply(&self.ln_2)?.apply(&self.mlp))?;
587 Ok(xs)
588 }
589 }
590
591 #[allow(clippy::upper_case_acronyms)]
593 pub struct Model {
594 wtes: Vec<candle_nn::Embedding>,
595 wpe: candle_nn::Embedding,
596 h: Vec<Block>,
597 ln_f: Norm,
598 lm_heads: Vec<Linear>,
599 cfg: Config,
600 dtype: DType,
601 span: tracing::Span,
602 }
603
604 impl Model {
605 pub fn new(cfg: Config, vb: VarBuilder) -> Result<Self> {
606 let vb_t = vb.pp("transformer");
607 let ln_f = Norm::new(&cfg, vb_t.pp("ln_f"))?;
608 let mut wtes = Vec::with_capacity(cfg.vocab_sizes.len());
609 let vb_w = vb_t.pp("wtes");
610 for (idx, vocab_size) in cfg.vocab_sizes.iter().enumerate() {
611 let wte = candle_nn::embedding(*vocab_size, cfg.n_embd, vb_w.pp(idx))?;
612 wtes.push(wte)
613 }
614 let wpe = candle_nn::embedding(cfg.block_size, cfg.n_embd, vb_t.pp("wpe"))?;
615
616 let mut h = Vec::with_capacity(cfg.n_layer);
617 let vb_h = vb_t.pp("h");
618 for idx in 0..cfg.n_layer {
619 let block = Block::new(&cfg, vb_h.pp(idx))?;
620 h.push(block)
621 }
622
623 let mut lm_heads = Vec::with_capacity(cfg.target_vocab_sizes.len());
624 let vb_l = vb.pp("lm_heads");
625 for (idx, vocab_size) in cfg.target_vocab_sizes.iter().enumerate() {
626 let head = linear_b(cfg.n_embd, *vocab_size, false, vb_l.pp(idx))?;
627 lm_heads.push(head)
628 }
629 Ok(Self {
630 wtes,
631 wpe,
632 h,
633 ln_f,
634 lm_heads,
635 cfg,
636 dtype: vb.dtype(),
637 span: tracing::span!(tracing::Level::TRACE, "gpt"),
638 })
639 }
640
641 pub fn config(&self) -> &Config {
642 &self.cfg
643 }
644
645 pub fn forward(&self, idx: &Tensor) -> Result<Vec<Tensor>> {
646 let _enter = self.span.enter();
647 let device = idx.device();
648 let (b, _num_hierarchies, t) = idx.dims3()?;
649 let pos = Tensor::arange(0u32, t as u32, device)?;
650 let pos_emb = pos.apply(&self.wpe)?;
651 let mut tok_emb = Tensor::zeros((b, t, self.cfg.n_embd), self.dtype, device)?;
652 for (wte_idx, wte) in self.wtes.iter().enumerate() {
653 let emb = idx.i((.., wte_idx, ..))?.apply(wte)?;
654 tok_emb = (tok_emb + emb)?;
655 }
656 let spk_emb = 0f64;
658 let mut xs = (pos_emb.broadcast_add(&tok_emb)? + spk_emb)?;
659 for block in self.h.iter() {
660 xs = xs.apply(block)?
661 }
662 let xs = xs.apply(&self.ln_f)?;
663 let mut logits = Vec::with_capacity(self.lm_heads.len());
664 for lm_head in self.lm_heads.iter() {
665 let ys = xs.apply(lm_head)?;
667 logits.push(ys)
668 }
669 Ok(logits)
670 }
671 }
672}
673
674pub mod transformer {
675 use super::*;
676
677 #[derive(Debug, Clone, serde::Deserialize)]
678 pub struct Config {
679 pub block_size: usize,
680 pub vocab_size: usize,
681 pub n_layer: usize,
682 pub n_head: usize,
683 pub dim: usize,
684 pub speaker_emb_dim: usize,
685 pub intermediate_size: Option<usize>,
686 pub n_local_heads: Option<usize>,
687 pub norm_eps: f64,
688 }
689
690 impl Config {
691 pub fn cfg1b_v0_1() -> Self {
692 Self {
693 n_layer: 24,
694 n_head: 16,
695 dim: 2048,
696 vocab_size: 2562,
697 speaker_emb_dim: 256,
698 block_size: 2048,
699 intermediate_size: None,
700 n_local_heads: None,
701 norm_eps: 1e-5,
702 }
703 }
704
705 pub(crate) fn n_local_heads(&self) -> usize {
706 self.n_local_heads.unwrap_or(self.n_head)
707 }
708
709 pub(crate) fn head_dim(&self) -> usize {
710 self.dim / self.n_head
711 }
712
713 pub(crate) fn intermediate_size(&self) -> usize {
714 match self.intermediate_size {
715 Some(intermediate_size) => intermediate_size,
716 None => {
717 let hidden_dim = self.dim * 4;
718 let n_hidden = ((2 * hidden_dim) as f64 / 3.) as usize;
719 (n_hidden + 255) / 256 * 256
720 }
721 }
722 }
723 }
724
725 #[derive(Debug, Clone)]
726 struct FeedForward {
727 w1: Linear,
728 w2: Linear,
729 w3: Linear,
730 span: tracing::Span,
731 }
732
733 impl FeedForward {
734 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
735 let i_size = cfg.intermediate_size();
736 let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?;
737 let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?;
738 let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?;
739 Ok(Self {
740 w1,
741 w2,
742 w3,
743 span: tracing::span!(tracing::Level::TRACE, "feed-forward"),
744 })
745 }
746 }
747
748 impl Module for FeedForward {
749 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
750 let _enter = self.span.enter();
751 let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?;
752 swiglu.apply(&self.w2)
753 }
754 }
755
756 #[derive(Debug, Clone)]
757 struct Attention {
758 wqkv: Linear,
759 wo: Linear,
760 dim: usize,
761 kv_size: usize,
762 n_local_heads: usize,
763 head_dim: usize,
764 n_head: usize,
765 kv_cache: Option<(Tensor, Tensor)>,
766 span: tracing::Span,
767 }
768
769 impl Attention {
770 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
771 let n_local_heads = cfg.n_local_heads();
772 let head_dim = cfg.head_dim();
773 let total_head_dim = (cfg.n_head + 2 * n_local_heads) * head_dim;
774 let wqkv = linear_b(cfg.dim, total_head_dim, false, vb.pp("wqkv"))?;
775 let wo = linear_b(cfg.dim, cfg.dim, false, vb.pp("wo"))?;
776 Ok(Self {
777 wqkv,
778 wo,
779 dim: cfg.dim,
780 kv_size: n_local_heads * head_dim,
781 n_local_heads,
782 head_dim,
783 n_head: cfg.n_head,
784 kv_cache: None,
785 span: tracing::span!(tracing::Level::TRACE, "feed-forward"),
786 })
787 }
788
789 fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result<Tensor> {
790 let _enter = self.span.enter();
791 let (b_sz, seqlen, _) = xs.dims3()?;
792
793 let qkv = xs.apply(&self.wqkv)?;
794 let q = qkv.narrow(D::Minus1, 0, self.dim)?;
795 let k = qkv.narrow(D::Minus1, self.dim, self.kv_size)?;
796 let v = qkv.narrow(D::Minus1, self.dim + self.kv_size, self.kv_size)?;
797 let q = q
798 .reshape((b_sz, seqlen, self.n_head, self.head_dim))?
799 .transpose(1, 2)?
800 .contiguous()?;
801 let k = k
802 .reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))?
803 .transpose(1, 2)?;
804 let v = v
805 .reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))?
806 .transpose(1, 2)?;
807
808 let (k, v) = match &self.kv_cache {
809 None => (k, v),
810 Some((prev_k, prev_v)) => {
811 let k = Tensor::cat(&[prev_k, &k], 2)?;
812 let v = Tensor::cat(&[prev_v, &v], 2)?;
813 (k, v)
814 }
815 };
816 self.kv_cache = Some((k.clone(), v.clone()));
817
818 let k = repeat_interleave(&k, self.n_head / self.n_local_heads, 1)?;
819 let v = repeat_interleave(&v, self.n_head / self.n_local_heads, 1)?;
820
821 let scale = 1f64 / f64::sqrt(self.head_dim as f64);
822 let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
823
824 let attn_weights = attn_weights.broadcast_add(mask)?;
825 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
826 let attn_output = attn_weights.matmul(&v)?;
827 attn_output
828 .transpose(1, 2)?
829 .reshape((b_sz, seqlen, self.dim))?
830 .apply(&self.wo)
831 }
832
833 fn clear_kv_cache(&mut self) {
834 self.kv_cache = None
835 }
836 }
837
838 #[derive(Debug, Clone)]
839 struct Block {
840 attention: Attention,
841 feed_forward: FeedForward,
842 ffn_norm: RmsNorm,
843 attention_norm: RmsNorm,
844 span: tracing::Span,
845 }
846
847 impl Block {
848 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
849 let attention = Attention::new(cfg, vb.pp("attention"))?;
850 let feed_forward = FeedForward::new(cfg, vb.pp("feed_forward"))?;
851 let ffn_norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("ffn_norm"))?;
852 let attention_norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("attention_norm"))?;
853 Ok(Self {
854 attention,
855 feed_forward,
856 ffn_norm,
857 attention_norm,
858 span: tracing::span!(tracing::Level::TRACE, "block"),
859 })
860 }
861
862 fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result<Tensor> {
863 let _enter = self.span.enter();
864 let hs = xs.apply(&self.attention_norm)?;
865 let hs = (xs + self.attention.forward(&hs, pos, mask))?;
866 &hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward)
867 }
868
869 fn clear_kv_cache(&mut self) {
870 self.attention.clear_kv_cache()
871 }
872 }
873
874 #[derive(Debug, Clone)]
875 pub struct Model {
876 tok_embeddings: Embedding,
877 pos_embeddings: Embedding,
878 speaker_cond_pos: Linear,
879 layers: Vec<Block>,
880 norm: RmsNorm,
881 output: Linear,
882 spk_cond_mask: Tensor,
883 span: tracing::Span,
884 }
885
886 impl Model {
887 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
888 let tok_embeddings = embedding(cfg.vocab_size, cfg.dim, vb.pp("tok_embeddings"))?;
889 let pos_embeddings = embedding(cfg.block_size, cfg.dim, vb.pp("pos_embeddings"))?;
890 let speaker_cond_pos = linear_b(
891 cfg.speaker_emb_dim,
892 cfg.dim,
893 false,
894 vb.pp("speaker_cond_pos"),
895 )?;
896 let mut layers = Vec::with_capacity(cfg.n_layer);
897 let vb_l = vb.pp("layers");
898 for layer_idx in 0..cfg.n_layer {
899 let layer = Block::new(cfg, vb_l.pp(layer_idx))?;
900 layers.push(layer)
901 }
902 let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("norm"))?;
903 let output = linear_b(cfg.dim, cfg.vocab_size, false, vb.pp("output"))?;
904 let dtype = vb.dtype();
905 let spk_cond_mask = Tensor::cat(
906 &[
907 Tensor::ones((1, 1, cfg.dim), dtype, vb.device())?,
908 Tensor::zeros((1, 1, cfg.dim), dtype, vb.device())?,
909 ],
910 0,
911 )?;
912 Ok(Self {
913 tok_embeddings,
914 pos_embeddings,
915 speaker_cond_pos,
916 layers,
917 norm,
918 output,
919 spk_cond_mask,
920 span: tracing::span!(tracing::Level::TRACE, "transformer"),
921 })
922 }
923
924 pub fn clear_kv_cache(&mut self) {
925 for layer in self.layers.iter_mut() {
926 layer.clear_kv_cache()
927 }
928 }
929
930 pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result<Tensor> {
931 let _enter = self.span.enter();
932 let (_b_sz, seqlen) = xs.dims2()?;
933 let mask: Vec<_> = (0..seqlen)
934 .flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
935 .collect();
936 let mask = Tensor::from_slice(&mask, (1, 1, seqlen, seqlen), xs.device())?;
937 let input_pos = Tensor::arange(pos as u32, (pos + seqlen) as u32, xs.device())?;
938 let tok_embeddings = xs.apply(&self.tok_embeddings)?;
939 let pos_embeddings = input_pos.apply(&self.pos_embeddings)?;
940 let mut xs = tok_embeddings
941 .broadcast_add(&pos_embeddings)?
942 .broadcast_add(
943 &spk_emb
944 .apply(&self.speaker_cond_pos)?
945 .broadcast_mul(&self.spk_cond_mask)?,
946 )?;
947 let mask = mask.to_dtype(xs.dtype())?;
948 for layer in self.layers.iter_mut() {
949 xs = layer.forward(&xs, pos, &mask)?
950 }
951 xs.narrow(1, seqlen - 1, 1)?
952 .apply(&self.norm)?
953 .apply(&self.output)
954 }
955 }
956}
957
958pub mod adapters {
959 pub struct TiltedEncodec {
961 end_of_audio_token: u32,
962 span: tracing::Span,
963 }
964
965 impl TiltedEncodec {
966 pub fn new(end_of_audio_token: u32) -> Self {
967 Self {
968 end_of_audio_token,
969 span: tracing::span!(tracing::Level::TRACE, "tilted-encodec"),
970 }
971 }
972
973 pub fn decode(&self, tokens: &[Vec<u32>]) -> (Vec<u32>, Vec<Vec<u32>>) {
974 let _enter = self.span.enter();
975 let mut text_ids = vec![];
976 let mut extracted_audio_ids = vec![];
977 let mut min_audio_ids_len = usize::MAX;
978 for (book_id, tokens) in tokens.iter().enumerate() {
979 let mut audio_ids = vec![];
980 for &t in tokens.iter() {
981 #[allow(clippy::comparison_chain)]
982 if t > self.end_of_audio_token {
983 if book_id == 0 {
984 text_ids.push(t)
985 }
986 } else if t < self.end_of_audio_token {
987 audio_ids.push(t)
988 }
989 }
990 min_audio_ids_len = usize::min(min_audio_ids_len, audio_ids.len());
991 extracted_audio_ids.push(audio_ids)
992 }
993 for audio_ids in extracted_audio_ids.iter_mut() {
994 audio_ids.truncate(min_audio_ids_len)
995 }
996 (text_ids, extracted_audio_ids)
997 }
998 }
999
1000 pub struct FlattenedInterleavedEncodec2Codebook {
1002 end_of_audio_token: u32,
1003 span: tracing::Span,
1004 }
1005
1006 impl FlattenedInterleavedEncodec2Codebook {
1007 pub fn new(end_of_audio_token: u32) -> Self {
1008 Self {
1009 end_of_audio_token,
1010 span: tracing::span!(tracing::Level::TRACE, "encodec2codebook"),
1011 }
1012 }
1013
1014 pub fn decode(&self, tokens: &[u32]) -> (Vec<u32>, Vec<u32>, Vec<u32>) {
1015 let _enter = self.span.enter();
1016 let mut text_ids = vec![];
1017 let mut audio_ids1 = vec![];
1018 let mut audio_ids2 = vec![];
1019 for &t in tokens.iter() {
1020 #[allow(clippy::comparison_chain)]
1021 if t < self.end_of_audio_token {
1022 audio_ids1.push(t)
1023 } else if t < 2 * self.end_of_audio_token {
1024 audio_ids2.push(t - self.end_of_audio_token)
1025 } else {
1026 text_ids.push(t)
1027 }
1028 }
1029 (text_ids, audio_ids1, audio_ids2)
1030 }
1031 }
1032}