1use crate::{
20 quantized_nn::{layer_norm, linear_no_bias as linear, Embedding, Linear},
21 quantized_var_builder::VarBuilder,
22};
23use candle::{IndexOp, Result, Tensor};
24use candle_nn::{GroupNorm, LayerNorm, Module};
25
26pub use crate::models::rwkv_v5::{Config, State, Tokenizer};
27
28#[derive(Debug, Clone)]
29struct SelfAttention {
30 key: Linear,
31 receptance: Linear,
32 value: Linear,
33 gate: Linear,
34 output: Linear,
35 ln_x: candle_nn::GroupNorm,
36 time_mix_x: Tensor,
37 time_mix_w: Tensor,
38 time_mix_key: Tensor,
39 time_mix_value: Tensor,
40 time_mix_receptance: Tensor,
41 time_decay: Tensor,
42 time_faaaa: Tensor,
43 time_mix_gate: Tensor,
44 time_decay_w1: Tensor,
45 time_decay_w2: Tensor,
46 time_mix_w1: Tensor,
47 time_mix_w2: Tensor,
48 layer_id: usize,
49 n_attn_heads: usize,
50}
51
52impl SelfAttention {
53 fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
54 let hidden_size = cfg.hidden_size;
55 let attn_hidden_size = cfg.attention_hidden_size;
56 let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?;
57 let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?;
58 let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?;
59 let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?;
60 let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?;
61
62 let vb_x = vb.pp("ln_x");
63 let ln_x_weight = vb_x.get(hidden_size, "weight")?.dequantize(vb.device())?;
64 let ln_x_bias = vb_x.get(hidden_size, "bias")?.dequantize(vb.device())?;
65
66 let ln_x = GroupNorm::new(
67 ln_x_weight,
68 ln_x_bias,
69 hidden_size,
70 hidden_size / cfg.head_size,
71 1e-5,
72 )?;
73
74 let time_mix_x = vb
75 .get((1, 1, cfg.hidden_size), "time_mix_x")?
76 .dequantize(vb.device())?;
77 let time_mix_w = vb
78 .get((1, 1, cfg.hidden_size), "time_mix_w")?
79 .dequantize(vb.device())?;
80 let time_mix_key = vb
81 .get((1, 1, cfg.hidden_size), "time_mix_key")?
82 .dequantize(vb.device())?;
83 let time_mix_value = vb
84 .get((1, 1, cfg.hidden_size), "time_mix_value")?
85 .dequantize(vb.device())?;
86 let time_mix_receptance = vb
87 .get((1, 1, cfg.hidden_size), "time_mix_receptance")?
88 .dequantize(vb.device())?;
89 let n_attn_heads = cfg.hidden_size / cfg.head_size;
90 let time_decay = vb
91 .get((1, 1, cfg.hidden_size), "time_decay")?
92 .dequantize(vb.device())?;
93 let time_faaaa = vb
94 .get((n_attn_heads, cfg.head_size), "time_faaaa")?
95 .dequantize(vb.device())?;
96 let time_mix_gate = vb
97 .get((1, 1, cfg.hidden_size), "time_mix_gate")?
98 .dequantize(vb.device())?;
99 let time_decay_w1 = vb
100 .get((cfg.hidden_size, n_attn_heads * 2), "time_decay_w1")?
101 .dequantize(vb.device())?;
102 let time_decay_w2 = vb
103 .get((n_attn_heads * 2, cfg.hidden_size), "time_decay_w2")?
104 .dequantize(vb.device())?;
105 let time_mix_w1 = vb
106 .get((cfg.hidden_size, n_attn_heads * 5), "time_mix_w1")?
107 .dequantize(vb.device())?;
108 let time_mix_w2 = vb
109 .get((5, n_attn_heads, cfg.hidden_size), "time_mix_w2")?
110 .dequantize(vb.device())?;
111 Ok(Self {
112 key,
113 value,
114 receptance,
115 gate,
116 output,
117 ln_x,
118 time_mix_x,
119 time_mix_w,
120 time_mix_key,
121 time_mix_value,
122 time_mix_receptance,
123 time_decay,
124 time_faaaa,
125 time_mix_gate,
126 time_decay_w1,
127 time_decay_w2,
128 time_mix_w1,
129 time_mix_w2,
130 layer_id,
131 n_attn_heads,
132 })
133 }
134
135 pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
136 let h = self.n_attn_heads;
137 let (b, t, s) = xs.dims3()?;
138 let s = s / h;
139 let (receptance, key, value, gate, w) = {
140 let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
142 let shifted = if shifted.rank() == 2 {
143 shifted.unsqueeze(1)?
144 } else {
145 shifted
146 };
147
148 let sx = (&shifted - xs)?;
149 let xxx = (xs + &sx * &self.time_mix_x)?;
150 let xxx = xxx
151 .broadcast_matmul(&self.time_mix_w1)?
152 .tanh()?
153 .reshape((b * t, 5, ()))?
154 .transpose(0, 1)?;
155
156 let xxx = xxx.matmul(&self.time_mix_w2)?.reshape((5, b, t, ()))?;
157
158 let (mw, mk, mv, mr, mg) = (xxx.i(0)?, xxx.i(1)?, xxx.i(2)?, xxx.i(3)?, xxx.i(4)?);
159
160 let xw = (xs + &sx * (&self.time_mix_w + &mw)?)?;
161 let xk = (xs + &sx * (&self.time_mix_key + &mk)?)?;
162 let xv = (xs + &sx * (&self.time_mix_value + &mv)?)?;
163 let xr = (xs + &sx * (&self.time_mix_receptance + &mr)?)?;
164 let xg = (xs + &sx * (&self.time_mix_gate + &mg)?)?;
165
166 let w = (&self.time_decay
167 + xw.broadcast_matmul(&self.time_decay_w1)?
168 .tanh()?
169 .broadcast_matmul(&self.time_decay_w2)?)?
170 .reshape(((), 1, 1))?
171 .reshape((self.n_attn_heads, (), 1))?;
172
173 let key = self.key.forward(&xk)?;
174 let value = self.value.forward(&xv)?;
175 let receptance = self.receptance.forward(&xr)?;
176 let gate = candle_nn::ops::silu(&self.gate.forward(&xg)?)?;
177 state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?;
178 (receptance, key, value, gate, w)
179 };
180
181 let mut state_ = state.per_layer[self.layer_id].linear_attention.clone();
183 let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?;
184 let value = value.reshape((b, t, h, s))?.transpose(1, 2)?;
185 let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?;
186
187 let w = w.exp()?.neg()?.exp()?;
188
189 let time_faaaa =
190 self.time_faaaa
191 .reshape(((), 1, 1))?
192 .reshape((self.n_attn_heads, (), 1))?;
193
194 let mut out: Vec<Tensor> = Vec::with_capacity(t);
195 for t_ in 0..t {
196 let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?;
197 let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?;
198 let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?;
199 let at = kt.matmul(&vt)?;
200 let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
201 let out_ = rt.matmul(&rhs)?.squeeze(2)?;
202 state_ = (&at + w.broadcast_mul(&state_))?;
203 out.push(out_)
204 }
205 let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?;
206 let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?;
207 let out = (out * gate)?.apply(&self.output)?;
208 state.per_layer[self.layer_id].linear_attention = state_;
209 Ok(out)
210 }
211}
212
213#[derive(Debug, Clone)]
214struct FeedForward {
215 time_mix_key: Tensor,
216 time_mix_receptance: Tensor,
217 key: Linear,
218 receptance: Linear,
219 value: Linear,
220 layer_id: usize,
221}
222
223impl FeedForward {
224 fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
225 let int_size = cfg
226 .intermediate_size
227 .unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32);
228 let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?;
229 let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?;
230 let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?;
231 let time_mix_key = vb
232 .get((1, 1, cfg.hidden_size), "time_mix_key")?
233 .dequantize(vb.device())?;
234 let time_mix_receptance = vb
235 .get((1, 1, cfg.hidden_size), "time_mix_receptance")?
236 .dequantize(vb.device())?;
237 Ok(Self {
238 key,
239 receptance,
240 value,
241 time_mix_key,
242 time_mix_receptance,
243 layer_id,
244 })
245 }
246
247 fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
248 let shifted = state.per_layer[self.layer_id]
249 .feed_forward
250 .broadcast_sub(xs)?;
251 let key = (xs + shifted.broadcast_mul(&self.time_mix_key)?)?;
252 let receptance = (xs + shifted.broadcast_mul(&self.time_mix_receptance)?)?;
253 let key = key.apply(&self.key)?.relu()?.sqr()?;
254 let value = key.apply(&self.value)?;
255 let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?;
256 state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?;
257 let xs = (receptance * value)?;
258 Ok(xs)
259 }
260}
261
262#[derive(Debug, Clone)]
263struct Block {
264 pre_ln: Option<LayerNorm>,
265 ln1: LayerNorm,
266 ln2: LayerNorm,
267 attention: SelfAttention,
268 feed_forward: FeedForward,
269}
270
271impl Block {
272 fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
273 let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?;
274 let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?;
275 let pre_ln = if layer_id == 0 {
276 let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?;
277 Some(ln)
278 } else {
279 None
280 };
281 let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?;
282 let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?;
283 Ok(Self {
284 pre_ln,
285 ln1,
286 ln2,
287 attention,
288 feed_forward,
289 })
290 }
291
292 fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
293 let xs = match self.pre_ln.as_ref() {
294 None => xs.clone(),
295 Some(pre_ln) => xs.apply(pre_ln)?,
296 };
297 let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?;
298 let xs = (xs + attention)?;
299 let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?;
300 let xs = (xs + feed_forward)?;
301 Ok(xs)
302 }
303}
304
305#[derive(Debug, Clone)]
306pub struct Model {
307 embeddings: Embedding,
308 blocks: Vec<Block>,
309 ln_out: LayerNorm,
310 head: Linear,
311 rescale_every: usize,
312 layers_are_rescaled: bool,
313}
314
315impl Model {
316 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
317 let vb_m = vb.pp("rwkv");
318 let embeddings = Embedding::new(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?;
319 let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
320 let vb_b = vb_m.pp("blocks");
321 for block_index in 0..cfg.num_hidden_layers {
322 let block = Block::new(block_index, cfg, vb_b.pp(block_index))?;
323 blocks.push(block)
324 }
325 let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?;
326 let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?;
327 Ok(Self {
328 embeddings,
329 blocks,
330 ln_out,
331 head,
332 rescale_every: cfg.rescale_every,
333 layers_are_rescaled: false, })
335 }
336
337 pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
338 let (_b_size, _seq_len) = xs.dims2()?;
339 let mut xs = xs.apply(&self.embeddings)?;
340 for (block_idx, block) in self.blocks.iter().enumerate() {
341 xs = block.forward(&xs, state)?;
342 if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 {
343 xs = (xs / 2.)?
344 }
345 }
346 let xs = xs.apply(&self.ln_out)?.apply(&self.head)?;
347 state.pos += 1;
348 Ok(xs)
349 }
350}