1use crate::quantized_nn::{linear_b, Embedding, Linear, RmsNorm};
18pub use crate::quantized_var_builder::VarBuilder;
19
20use crate::models::metavoice::repeat_interleave;
21use candle::{Module, Result, Tensor, D};
22
23pub mod transformer {
24 use super::*;
25
26 type Config = crate::models::metavoice::transformer::Config;
27
28 #[derive(Debug, Clone)]
29 struct FeedForward {
30 w1: Linear,
31 w2: Linear,
32 w3: Linear,
33 span: tracing::Span,
34 }
35
36 impl FeedForward {
37 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
38 let i_size = cfg.intermediate_size();
39 let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?;
40 let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?;
41 let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?;
42 Ok(Self {
43 w1,
44 w2,
45 w3,
46 span: tracing::span!(tracing::Level::TRACE, "feed-forward"),
47 })
48 }
49 }
50
51 impl Module for FeedForward {
52 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
53 let _enter = self.span.enter();
54 let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?;
55 swiglu.apply(&self.w2)
56 }
57 }
58
59 #[derive(Debug, Clone)]
60 struct Attention {
61 wqkv: Linear,
62 wo: Linear,
63 dim: usize,
64 kv_size: usize,
65 n_local_heads: usize,
66 head_dim: usize,
67 n_head: usize,
68 kv_cache: Option<(Tensor, Tensor)>,
69 span: tracing::Span,
70 }
71
72 impl Attention {
73 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
74 let n_local_heads = cfg.n_local_heads();
75 let head_dim = cfg.head_dim();
76 let total_head_dim = (cfg.n_head + 2 * n_local_heads) * head_dim;
77 let wqkv = linear_b(cfg.dim, total_head_dim, false, vb.pp("wqkv"))?;
78 let wo = linear_b(cfg.dim, cfg.dim, false, vb.pp("wo"))?;
79 Ok(Self {
80 wqkv,
81 wo,
82 dim: cfg.dim,
83 kv_size: n_local_heads * head_dim,
84 n_local_heads,
85 head_dim,
86 n_head: cfg.n_head,
87 kv_cache: None,
88 span: tracing::span!(tracing::Level::TRACE, "attention"),
89 })
90 }
91
92 fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result<Tensor> {
93 let _enter = self.span.enter();
94 let (b_sz, seqlen, _) = xs.dims3()?;
95
96 let qkv = xs.apply(&self.wqkv)?;
97 let q = qkv.narrow(D::Minus1, 0, self.dim)?;
98 let k = qkv.narrow(D::Minus1, self.dim, self.kv_size)?;
99 let v = qkv.narrow(D::Minus1, self.dim + self.kv_size, self.kv_size)?;
100 let q = q
101 .reshape((b_sz, seqlen, self.n_head, self.head_dim))?
102 .transpose(1, 2)?
103 .contiguous()?;
104 let k = k
105 .reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))?
106 .transpose(1, 2)?;
107 let v = v
108 .reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))?
109 .transpose(1, 2)?;
110
111 let (k, v) = match &self.kv_cache {
112 None => (k, v),
113 Some((prev_k, prev_v)) => {
114 let k = Tensor::cat(&[prev_k, &k], 2)?;
115 let v = Tensor::cat(&[prev_v, &v], 2)?;
116 (k, v)
117 }
118 };
119 self.kv_cache = Some((k.clone(), v.clone()));
120
121 let k = repeat_interleave(&k, self.n_head / self.n_local_heads, 1)?;
122 let v = repeat_interleave(&v, self.n_head / self.n_local_heads, 1)?;
123
124 let scale = 1f64 / f64::sqrt(self.head_dim as f64);
125 let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
126
127 let attn_weights = attn_weights.broadcast_add(mask)?;
128 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
129 let attn_output = attn_weights.matmul(&v)?;
130 attn_output
131 .transpose(1, 2)?
132 .reshape((b_sz, seqlen, self.dim))?
133 .apply(&self.wo)
134 }
135
136 fn clear_kv_cache(&mut self) {
137 self.kv_cache = None
138 }
139 }
140
141 #[derive(Debug, Clone)]
142 struct Block {
143 attention: Attention,
144 feed_forward: FeedForward,
145 ffn_norm: RmsNorm,
146 attention_norm: RmsNorm,
147 span: tracing::Span,
148 }
149
150 impl Block {
151 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
152 let attention = Attention::new(cfg, vb.pp("attention"))?;
153 let feed_forward = FeedForward::new(cfg, vb.pp("feed_forward"))?;
154 let ffn_norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("ffn_norm"))?;
155 let attention_norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("attention_norm"))?;
156 Ok(Self {
157 attention,
158 feed_forward,
159 ffn_norm,
160 attention_norm,
161 span: tracing::span!(tracing::Level::TRACE, "block"),
162 })
163 }
164
165 fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result<Tensor> {
166 let _enter = self.span.enter();
167 let hs = xs.apply(&self.attention_norm)?;
168 let hs = (xs + self.attention.forward(&hs, pos, mask))?;
169 &hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward)
170 }
171
172 fn clear_kv_cache(&mut self) {
173 self.attention.clear_kv_cache()
174 }
175 }
176
177 #[derive(Debug, Clone)]
178 pub struct Model {
179 tok_embeddings: Embedding,
180 pos_embeddings: Embedding,
181 speaker_cond_pos: Linear,
182 layers: Vec<Block>,
183 norm: RmsNorm,
184 output: Linear,
185 spk_cond_mask: Tensor,
186 span: tracing::Span,
187 }
188
189 impl Model {
190 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
191 let tok_embeddings = Embedding::new(cfg.vocab_size, cfg.dim, vb.pp("tok_embeddings"))?;
192 let pos_embeddings = Embedding::new(cfg.block_size, cfg.dim, vb.pp("pos_embeddings"))?;
193 let speaker_cond_pos = linear_b(
194 cfg.speaker_emb_dim,
195 cfg.dim,
196 false,
197 vb.pp("speaker_cond_pos"),
198 )?;
199 let mut layers = Vec::with_capacity(cfg.n_layer);
200 let vb_l = vb.pp("layers");
201 for layer_idx in 0..cfg.n_layer {
202 let layer = Block::new(cfg, vb_l.pp(layer_idx))?;
203 layers.push(layer)
204 }
205 let norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("norm"))?;
206 let output = linear_b(cfg.dim, cfg.vocab_size, false, vb.pp("output"))?;
207 let spk_cond_mask = Tensor::cat(
208 &[
209 Tensor::ones((1, 1, cfg.dim), candle::DType::F32, vb.device())?,
210 Tensor::zeros((1, 1, cfg.dim), candle::DType::F32, vb.device())?,
211 ],
212 0,
213 )?;
214 Ok(Self {
215 tok_embeddings,
216 pos_embeddings,
217 speaker_cond_pos,
218 layers,
219 norm,
220 output,
221 spk_cond_mask,
222 span: tracing::span!(tracing::Level::TRACE, "qtransformer"),
223 })
224 }
225
226 pub fn clear_kv_cache(&mut self) {
227 for layer in self.layers.iter_mut() {
228 layer.clear_kv_cache()
229 }
230 }
231
232 pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result<Tensor> {
233 let _enter = self.span.enter();
234 let (_b_sz, seqlen) = xs.dims2()?;
235 let mask: Vec<_> = (0..seqlen)
236 .flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
237 .collect();
238 let mask = Tensor::from_slice(&mask, (1, 1, seqlen, seqlen), xs.device())?;
239 let input_pos = Tensor::arange(pos as u32, (pos + seqlen) as u32, xs.device())?;
240 let tok_embeddings = xs.apply(&self.tok_embeddings)?;
241 let pos_embeddings = input_pos.apply(&self.pos_embeddings)?;
242 let mut xs = tok_embeddings
243 .broadcast_add(&pos_embeddings)?
244 .broadcast_add(
245 &spk_emb
246 .apply(&self.speaker_cond_pos)?
247 .broadcast_mul(&self.spk_cond_mask)?,
248 )?;
249 let mask = mask.to_dtype(xs.dtype())?;
250 for layer in self.layers.iter_mut() {
251 xs = layer.forward(&xs, pos, &mask)?
252 }
253 xs.narrow(1, seqlen - 1, 1)?
254 .contiguous()?
255 .apply(&self.norm)?
256 .apply(&self.output)
257 }
258 }
259}