candle_transformers/models/
rwkv_v6.rs

1//! RWKV v6 model implementation.
2//!
3//! The [RWKV model](https://wiki.rwkv.com/) is a recurrent neural network model
4//! with performance on par with transformer architectures. Several variants are
5//! available, candle implements the v5 and v6 versions and can be used with
6//! Eagle 7B([blog post](https://blog.rwkv.com/p/eagle-7b-soaring-past-transformers)).
7//!
8//! Key characteristics:
9//! - Linear attention mechanism
10//! - Time-mixing for temporal dependencies
11//! - Group normalization
12//! - Feed forward gating
13//! - State recycling for efficient inference
14//!
15//! # Example
16//!
17//! ```bash
18//! cargo run --example rwkv --release -- \
19//!   --prompt "The smallest prime is "
20//!
21//! > avx: true, neon: false, simd128: false, f16c: true
22//! > temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
23//! > The smallest prime is ϕ(2) = 2.
24//! > The smallest composite is ϕ(3) = 3.
25//! > The smallest perfect number is ϕ(5) = 5.
26//! > The smallest perfect square is ϕ(4) = 4.
27//! > The smallest perfect cube is ϕ(6) = 6.
28//! ```
29
30use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear};
31use candle::{IndexOp, Result, Tensor};
32use candle_nn::{embedding, Embedding, Module, VarBuilder};
33
34pub use crate::models::rwkv_v5::{Config, State, Tokenizer};
35
36#[derive(Debug, Clone)]
37struct SelfAttention {
38    key: Linear,
39    receptance: Linear,
40    value: Linear,
41    gate: Linear,
42    output: Linear,
43    ln_x: candle_nn::GroupNorm,
44    time_mix_x: Tensor,
45    time_mix_w: Tensor,
46    time_mix_key: Tensor,
47    time_mix_value: Tensor,
48    time_mix_receptance: Tensor,
49    time_decay: Tensor,
50    time_faaaa: Tensor,
51    time_mix_gate: Tensor,
52    time_decay_w1: Tensor,
53    time_decay_w2: Tensor,
54    time_mix_w1: Tensor,
55    time_mix_w2: Tensor,
56    layer_id: usize,
57    n_attn_heads: usize,
58}
59
60impl SelfAttention {
61    fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
62        let hidden_size = cfg.hidden_size;
63        let attn_hidden_size = cfg.attention_hidden_size;
64        let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?;
65        let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?;
66        let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?;
67        let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?;
68        let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?;
69        let ln_x = candle_nn::group_norm(
70            hidden_size / cfg.head_size,
71            hidden_size,
72            1e-5,
73            vb.pp("ln_x"),
74        )?;
75
76        let time_mix_x = vb.get((1, 1, cfg.hidden_size), "time_mix_x")?;
77        let time_mix_w = vb.get((1, 1, cfg.hidden_size), "time_mix_w")?;
78        let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
79        let time_mix_value = vb.get((1, 1, cfg.hidden_size), "time_mix_value")?;
80        let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
81        let n_attn_heads = cfg.hidden_size / cfg.head_size;
82        let time_decay = vb.get((1, 1, cfg.hidden_size), "time_decay")?;
83        let time_faaaa = vb.get((n_attn_heads, cfg.head_size), "time_faaaa")?;
84        let time_mix_gate = vb.get((1, 1, cfg.hidden_size), "time_mix_gate")?;
85        let time_decay_w1 = vb.get((cfg.hidden_size, n_attn_heads * 2), "time_decay_w1")?;
86        let time_decay_w2 = vb.get((n_attn_heads * 2, cfg.hidden_size), "time_decay_w2")?;
87        let time_mix_w1 = vb.get((cfg.hidden_size, n_attn_heads * 5), "time_mix_w1")?;
88        let time_mix_w2 = vb.get((5, n_attn_heads, cfg.hidden_size), "time_mix_w2")?;
89        Ok(Self {
90            key,
91            value,
92            receptance,
93            gate,
94            output,
95            ln_x,
96            time_mix_x,
97            time_mix_w,
98            time_mix_key,
99            time_mix_value,
100            time_mix_receptance,
101            time_decay,
102            time_faaaa,
103            time_mix_gate,
104            time_decay_w1,
105            time_decay_w2,
106            time_mix_w1,
107            time_mix_w2,
108            layer_id,
109            n_attn_heads,
110        })
111    }
112
113    pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
114        let h = self.n_attn_heads;
115        let (b, t, s) = xs.dims3()?;
116        let s = s / h;
117        let (receptance, key, value, gate, w) = {
118            // extract key-value
119            let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
120            let shifted = if shifted.rank() == 2 {
121                shifted.unsqueeze(1)?
122            } else {
123                shifted
124            };
125
126            let sx = (&shifted - xs)?;
127            let xxx = (xs + &sx * &self.time_mix_x)?;
128            let xxx = xxx
129                .broadcast_matmul(&self.time_mix_w1)?
130                .tanh()?
131                .reshape((b * t, 5, ()))?
132                .transpose(0, 1)?;
133
134            let xxx = xxx.matmul(&self.time_mix_w2)?.reshape((5, b, t, ()))?;
135
136            let (mw, mk, mv, mr, mg) = (xxx.i(0)?, xxx.i(1)?, xxx.i(2)?, xxx.i(3)?, xxx.i(4)?);
137
138            let xw = (xs + &sx * (&self.time_mix_w + &mw)?)?;
139            let xk = (xs + &sx * (&self.time_mix_key + &mk)?)?;
140            let xv = (xs + &sx * (&self.time_mix_value + &mv)?)?;
141            let xr = (xs + &sx * (&self.time_mix_receptance + &mr)?)?;
142            let xg = (xs + &sx * (&self.time_mix_gate + &mg)?)?;
143
144            let w = (&self.time_decay
145                + xw.broadcast_matmul(&self.time_decay_w1)?
146                    .tanh()?
147                    .broadcast_matmul(&self.time_decay_w2)?)?
148            .reshape(((), 1, 1))?
149            .reshape((self.n_attn_heads, (), 1))?;
150
151            let key = self.key.forward(&xk)?;
152            let value = self.value.forward(&xv)?;
153            let receptance = self.receptance.forward(&xr)?;
154            let gate = candle_nn::ops::silu(&self.gate.forward(&xg)?)?;
155            state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?;
156            (receptance, key, value, gate, w)
157        };
158
159        // linear attention
160        let mut state_ = state.per_layer[self.layer_id].linear_attention.clone();
161        let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?;
162        let value = value.reshape((b, t, h, s))?.transpose(1, 2)?;
163        let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?;
164
165        let w = w.exp()?.neg()?.exp()?;
166
167        let time_faaaa =
168            self.time_faaaa
169                .reshape(((), 1, 1))?
170                .reshape((self.n_attn_heads, (), 1))?;
171
172        let mut out: Vec<Tensor> = Vec::with_capacity(t);
173        for t_ in 0..t {
174            let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?;
175            let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?;
176            let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?;
177            let at = kt.matmul(&vt)?;
178            let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
179            let out_ = rt.matmul(&rhs)?.squeeze(2)?;
180            state_ = (&at + w.broadcast_mul(&state_))?;
181            out.push(out_)
182        }
183        let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?;
184        let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?;
185        let out = (out * gate)?.apply(&self.output)?;
186        state.per_layer[self.layer_id].linear_attention = state_;
187        Ok(out)
188    }
189}
190
191#[derive(Debug, Clone)]
192struct FeedForward {
193    time_mix_key: Tensor,
194    time_mix_receptance: Tensor,
195    key: Linear,
196    receptance: Linear,
197    value: Linear,
198    layer_id: usize,
199}
200
201impl FeedForward {
202    fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
203        let int_size = cfg
204            .intermediate_size
205            .unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32);
206        let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?;
207        let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?;
208        let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?;
209        let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
210        let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
211        Ok(Self {
212            key,
213            receptance,
214            value,
215            time_mix_key,
216            time_mix_receptance,
217            layer_id,
218        })
219    }
220
221    fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
222        let shifted = state.per_layer[self.layer_id]
223            .feed_forward
224            .broadcast_sub(xs)?;
225        let key = (xs + shifted.broadcast_mul(&self.time_mix_key)?)?;
226        let receptance = (xs + shifted.broadcast_mul(&self.time_mix_receptance)?)?;
227        let key = key.apply(&self.key)?.relu()?.sqr()?;
228        let value = key.apply(&self.value)?;
229        let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?;
230        state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?;
231        let xs = (receptance * value)?;
232        Ok(xs)
233    }
234}
235
236#[derive(Debug, Clone)]
237struct Block {
238    pre_ln: Option<LayerNorm>,
239    ln1: LayerNorm,
240    ln2: LayerNorm,
241    attention: SelfAttention,
242    feed_forward: FeedForward,
243}
244
245impl Block {
246    fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
247        let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?;
248        let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?;
249        let pre_ln = if layer_id == 0 {
250            let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?;
251            Some(ln)
252        } else {
253            None
254        };
255        let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?;
256        let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?;
257        Ok(Self {
258            pre_ln,
259            ln1,
260            ln2,
261            attention,
262            feed_forward,
263        })
264    }
265
266    fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
267        let xs = match self.pre_ln.as_ref() {
268            None => xs.clone(),
269            Some(pre_ln) => xs.apply(pre_ln)?,
270        };
271        let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?;
272        let xs = (xs + attention)?;
273        let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?;
274        let xs = (xs + feed_forward)?;
275        Ok(xs)
276    }
277}
278
279#[derive(Debug, Clone)]
280pub struct Model {
281    embeddings: Embedding,
282    blocks: Vec<Block>,
283    ln_out: LayerNorm,
284    head: Linear,
285    rescale_every: usize,
286    layers_are_rescaled: bool,
287}
288
289impl Model {
290    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
291        let vb_m = vb.pp("rwkv");
292        let embeddings = embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?;
293        let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
294        let vb_b = vb_m.pp("blocks");
295        for block_index in 0..cfg.num_hidden_layers {
296            let block = Block::new(block_index, cfg, vb_b.pp(block_index))?;
297            blocks.push(block)
298        }
299        let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?;
300        let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?;
301        Ok(Self {
302            embeddings,
303            blocks,
304            ln_out,
305            head,
306            rescale_every: cfg.rescale_every,
307            layers_are_rescaled: false, // This seem to only happen for the f16/bf16 dtypes.
308        })
309    }
310
311    pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
312        let (_b_size, _seq_len) = xs.dims2()?;
313        let mut xs = xs.apply(&self.embeddings)?;
314        for (block_idx, block) in self.blocks.iter().enumerate() {
315            xs = block.forward(&xs, state)?;
316            if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 {
317                xs = (xs / 2.)?
318            }
319        }
320        let xs = xs.apply(&self.ln_out)?.apply(&self.head)?;
321        state.pos += 1;
322        Ok(xs)
323    }
324}