candle_transformers/models/
quantized_rwkv_v6.rs

1//! RWKV v6 model implementation with quantization support.
2//!
3//! RWKV is a linear attention model that combines the efficiency of RNNs
4//! with the parallelizable training of Transformers. Version 6 builds on previous
5//! versions with further optimizations.
6//!
7//! Key characteristics:
8//! - Linear attention mechanism
9//! - Time mixing layers
10//! - Channel mixing layers
11//! - RMSNorm for normalization
12//! - Support for 8-bit quantization
13//!
14//! References:
15//! - [RWKV Architecture](https://github.com/BlinkDL/RWKV-LM)
16//! - [RWKV v6 Release](https://huggingface.co/BlinkDL/rwkv-6)
17//!
18
19use crate::{
20    quantized_nn::{layer_norm, linear_no_bias as linear, Embedding, Linear},
21    quantized_var_builder::VarBuilder,
22};
23use candle::{IndexOp, Result, Tensor};
24use candle_nn::{GroupNorm, LayerNorm, Module};
25
26pub use crate::models::rwkv_v5::{Config, State, Tokenizer};
27
28#[derive(Debug, Clone)]
29struct SelfAttention {
30    key: Linear,
31    receptance: Linear,
32    value: Linear,
33    gate: Linear,
34    output: Linear,
35    ln_x: candle_nn::GroupNorm,
36    time_mix_x: Tensor,
37    time_mix_w: Tensor,
38    time_mix_key: Tensor,
39    time_mix_value: Tensor,
40    time_mix_receptance: Tensor,
41    time_decay: Tensor,
42    time_faaaa: Tensor,
43    time_mix_gate: Tensor,
44    time_decay_w1: Tensor,
45    time_decay_w2: Tensor,
46    time_mix_w1: Tensor,
47    time_mix_w2: Tensor,
48    layer_id: usize,
49    n_attn_heads: usize,
50}
51
52impl SelfAttention {
53    fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
54        let hidden_size = cfg.hidden_size;
55        let attn_hidden_size = cfg.attention_hidden_size;
56        let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?;
57        let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?;
58        let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?;
59        let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?;
60        let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?;
61
62        let vb_x = vb.pp("ln_x");
63        let ln_x_weight = vb_x.get(hidden_size, "weight")?.dequantize(vb.device())?;
64        let ln_x_bias = vb_x.get(hidden_size, "bias")?.dequantize(vb.device())?;
65
66        let ln_x = GroupNorm::new(
67            ln_x_weight,
68            ln_x_bias,
69            hidden_size,
70            hidden_size / cfg.head_size,
71            1e-5,
72        )?;
73
74        let time_mix_x = vb
75            .get((1, 1, cfg.hidden_size), "time_mix_x")?
76            .dequantize(vb.device())?;
77        let time_mix_w = vb
78            .get((1, 1, cfg.hidden_size), "time_mix_w")?
79            .dequantize(vb.device())?;
80        let time_mix_key = vb
81            .get((1, 1, cfg.hidden_size), "time_mix_key")?
82            .dequantize(vb.device())?;
83        let time_mix_value = vb
84            .get((1, 1, cfg.hidden_size), "time_mix_value")?
85            .dequantize(vb.device())?;
86        let time_mix_receptance = vb
87            .get((1, 1, cfg.hidden_size), "time_mix_receptance")?
88            .dequantize(vb.device())?;
89        let n_attn_heads = cfg.hidden_size / cfg.head_size;
90        let time_decay = vb
91            .get((1, 1, cfg.hidden_size), "time_decay")?
92            .dequantize(vb.device())?;
93        let time_faaaa = vb
94            .get((n_attn_heads, cfg.head_size), "time_faaaa")?
95            .dequantize(vb.device())?;
96        let time_mix_gate = vb
97            .get((1, 1, cfg.hidden_size), "time_mix_gate")?
98            .dequantize(vb.device())?;
99        let time_decay_w1 = vb
100            .get((cfg.hidden_size, n_attn_heads * 2), "time_decay_w1")?
101            .dequantize(vb.device())?;
102        let time_decay_w2 = vb
103            .get((n_attn_heads * 2, cfg.hidden_size), "time_decay_w2")?
104            .dequantize(vb.device())?;
105        let time_mix_w1 = vb
106            .get((cfg.hidden_size, n_attn_heads * 5), "time_mix_w1")?
107            .dequantize(vb.device())?;
108        let time_mix_w2 = vb
109            .get((5, n_attn_heads, cfg.hidden_size), "time_mix_w2")?
110            .dequantize(vb.device())?;
111        Ok(Self {
112            key,
113            value,
114            receptance,
115            gate,
116            output,
117            ln_x,
118            time_mix_x,
119            time_mix_w,
120            time_mix_key,
121            time_mix_value,
122            time_mix_receptance,
123            time_decay,
124            time_faaaa,
125            time_mix_gate,
126            time_decay_w1,
127            time_decay_w2,
128            time_mix_w1,
129            time_mix_w2,
130            layer_id,
131            n_attn_heads,
132        })
133    }
134
135    pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
136        let h = self.n_attn_heads;
137        let (b, t, s) = xs.dims3()?;
138        let s = s / h;
139        let (receptance, key, value, gate, w) = {
140            // extract key-value
141            let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
142            let shifted = if shifted.rank() == 2 {
143                shifted.unsqueeze(1)?
144            } else {
145                shifted
146            };
147
148            let sx = (&shifted - xs)?;
149            let xxx = (xs + &sx * &self.time_mix_x)?;
150            let xxx = xxx
151                .broadcast_matmul(&self.time_mix_w1)?
152                .tanh()?
153                .reshape((b * t, 5, ()))?
154                .transpose(0, 1)?;
155
156            let xxx = xxx.matmul(&self.time_mix_w2)?.reshape((5, b, t, ()))?;
157
158            let (mw, mk, mv, mr, mg) = (xxx.i(0)?, xxx.i(1)?, xxx.i(2)?, xxx.i(3)?, xxx.i(4)?);
159
160            let xw = (xs + &sx * (&self.time_mix_w + &mw)?)?;
161            let xk = (xs + &sx * (&self.time_mix_key + &mk)?)?;
162            let xv = (xs + &sx * (&self.time_mix_value + &mv)?)?;
163            let xr = (xs + &sx * (&self.time_mix_receptance + &mr)?)?;
164            let xg = (xs + &sx * (&self.time_mix_gate + &mg)?)?;
165
166            let w = (&self.time_decay
167                + xw.broadcast_matmul(&self.time_decay_w1)?
168                    .tanh()?
169                    .broadcast_matmul(&self.time_decay_w2)?)?
170            .reshape(((), 1, 1))?
171            .reshape((self.n_attn_heads, (), 1))?;
172
173            let key = self.key.forward(&xk)?;
174            let value = self.value.forward(&xv)?;
175            let receptance = self.receptance.forward(&xr)?;
176            let gate = candle_nn::ops::silu(&self.gate.forward(&xg)?)?;
177            state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?;
178            (receptance, key, value, gate, w)
179        };
180
181        // linear attention
182        let mut state_ = state.per_layer[self.layer_id].linear_attention.clone();
183        let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?;
184        let value = value.reshape((b, t, h, s))?.transpose(1, 2)?;
185        let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?;
186
187        let w = w.exp()?.neg()?.exp()?;
188
189        let time_faaaa =
190            self.time_faaaa
191                .reshape(((), 1, 1))?
192                .reshape((self.n_attn_heads, (), 1))?;
193
194        let mut out: Vec<Tensor> = Vec::with_capacity(t);
195        for t_ in 0..t {
196            let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?;
197            let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?;
198            let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?;
199            let at = kt.matmul(&vt)?;
200            let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
201            let out_ = rt.matmul(&rhs)?.squeeze(2)?;
202            state_ = (&at + w.broadcast_mul(&state_))?;
203            out.push(out_)
204        }
205        let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?;
206        let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?;
207        let out = (out * gate)?.apply(&self.output)?;
208        state.per_layer[self.layer_id].linear_attention = state_;
209        Ok(out)
210    }
211}
212
213#[derive(Debug, Clone)]
214struct FeedForward {
215    time_mix_key: Tensor,
216    time_mix_receptance: Tensor,
217    key: Linear,
218    receptance: Linear,
219    value: Linear,
220    layer_id: usize,
221}
222
223impl FeedForward {
224    fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
225        let int_size = cfg
226            .intermediate_size
227            .unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32);
228        let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?;
229        let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?;
230        let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?;
231        let time_mix_key = vb
232            .get((1, 1, cfg.hidden_size), "time_mix_key")?
233            .dequantize(vb.device())?;
234        let time_mix_receptance = vb
235            .get((1, 1, cfg.hidden_size), "time_mix_receptance")?
236            .dequantize(vb.device())?;
237        Ok(Self {
238            key,
239            receptance,
240            value,
241            time_mix_key,
242            time_mix_receptance,
243            layer_id,
244        })
245    }
246
247    fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
248        let shifted = state.per_layer[self.layer_id]
249            .feed_forward
250            .broadcast_sub(xs)?;
251        let key = (xs + shifted.broadcast_mul(&self.time_mix_key)?)?;
252        let receptance = (xs + shifted.broadcast_mul(&self.time_mix_receptance)?)?;
253        let key = key.apply(&self.key)?.relu()?.sqr()?;
254        let value = key.apply(&self.value)?;
255        let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?;
256        state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?;
257        let xs = (receptance * value)?;
258        Ok(xs)
259    }
260}
261
262#[derive(Debug, Clone)]
263struct Block {
264    pre_ln: Option<LayerNorm>,
265    ln1: LayerNorm,
266    ln2: LayerNorm,
267    attention: SelfAttention,
268    feed_forward: FeedForward,
269}
270
271impl Block {
272    fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
273        let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?;
274        let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?;
275        let pre_ln = if layer_id == 0 {
276            let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?;
277            Some(ln)
278        } else {
279            None
280        };
281        let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?;
282        let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?;
283        Ok(Self {
284            pre_ln,
285            ln1,
286            ln2,
287            attention,
288            feed_forward,
289        })
290    }
291
292    fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
293        let xs = match self.pre_ln.as_ref() {
294            None => xs.clone(),
295            Some(pre_ln) => xs.apply(pre_ln)?,
296        };
297        let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?;
298        let xs = (xs + attention)?;
299        let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?;
300        let xs = (xs + feed_forward)?;
301        Ok(xs)
302    }
303}
304
305#[derive(Debug, Clone)]
306pub struct Model {
307    embeddings: Embedding,
308    blocks: Vec<Block>,
309    ln_out: LayerNorm,
310    head: Linear,
311    rescale_every: usize,
312    layers_are_rescaled: bool,
313}
314
315impl Model {
316    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
317        let vb_m = vb.pp("rwkv");
318        let embeddings = Embedding::new(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?;
319        let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
320        let vb_b = vb_m.pp("blocks");
321        for block_index in 0..cfg.num_hidden_layers {
322            let block = Block::new(block_index, cfg, vb_b.pp(block_index))?;
323            blocks.push(block)
324        }
325        let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?;
326        let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?;
327        Ok(Self {
328            embeddings,
329            blocks,
330            ln_out,
331            head,
332            rescale_every: cfg.rescale_every,
333            layers_are_rescaled: false, // This seem to only happen for the f16/bf16 dtypes.
334        })
335    }
336
337    pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
338        let (_b_size, _seq_len) = xs.dims2()?;
339        let mut xs = xs.apply(&self.embeddings)?;
340        for (block_idx, block) in self.blocks.iter().enumerate() {
341            xs = block.forward(&xs, state)?;
342            if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 {
343                xs = (xs / 2.)?
344            }
345        }
346        let xs = xs.apply(&self.ln_out)?.apply(&self.head)?;
347        state.pos += 1;
348        Ok(xs)
349    }
350}