candle_transformers/models/
rwkv_v5.rs

1//! RWKV v5 model implementation.
2//!
3//! The [RWKV model](https://wiki.rwkv.com/) is a recurrent neural network model
4//! with performance on par with transformer architectures. Several variants are
5//! available, candle implements the v5 and v6 versions and can be used with
6//! Eagle 7B([blog post](https://blog.rwkv.com/p/eagle-7b-soaring-past-transformers)).
7//!
8//! Key characteristics:
9//! - Time-mix attention mechanism
10//! - Channel-mix feed-forward network
11//! - Linear attention
12//! - Group normalization
13//! - Token shift mechanism
14//!
15//! References:
16//! - [RWKV Language Model](https://github.com/BlinkDL/RWKV-LM)
17//! - [RWKV v5 Release](https://github.com/BlinkDL/ChatRWKV/tree/main)
18//!
19//! # Example
20//!
21//! ```bash
22//! cargo run --example rwkv --release -- \
23//!   --prompt "The smallest prime is "
24//!
25//! > avx: true, neon: false, simd128: false, f16c: true
26//! > temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
27//! > The smallest prime is ϕ(2) = 2.
28//! > The smallest composite is ϕ(3) = 3.
29//! > The smallest perfect number is ϕ(5) = 5.
30//! > The smallest perfect square is ϕ(4) = 4.
31//! > The smallest perfect cube is ϕ(6) = 6.
32//! ```
33
34use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear};
35use candle::{DType, Device, IndexOp, Result, Tensor};
36use candle_nn::{embedding, Embedding, Module, VarBuilder};
37use std::collections::{HashMap, HashSet};
38
39fn default_num_attention_heads() -> usize {
40    64
41}
42
43// https://huggingface.co/RWKV/HF_v5-Eagle-7B/blob/main/configuration_rwkv5.py
44#[derive(Debug, Clone, serde::Deserialize)]
45pub struct Config {
46    pub vocab_size: usize,
47    pub hidden_size: usize,
48    pub num_hidden_layers: usize,
49    pub attention_hidden_size: usize,
50    #[serde(default = "default_num_attention_heads")]
51    pub num_attention_heads: usize,
52    pub head_size: usize,
53    pub intermediate_size: Option<usize>,
54    pub layer_norm_epsilon: f64,
55    pub rescale_every: usize,
56}
57
58pub struct StatePerLayer {
59    pub extract_key_value: Tensor,
60    pub linear_attention: Tensor,
61    pub feed_forward: Tensor,
62}
63
64pub struct State {
65    pub per_layer: Vec<StatePerLayer>,
66    pub pos: usize,
67}
68
69impl State {
70    pub fn new(batch_size: usize, cfg: &Config, dev: &Device) -> Result<Self> {
71        let mut per_layer = Vec::with_capacity(cfg.num_hidden_layers);
72        // Certainly a weird convention but taken from modeling_rwkv5.py
73        let num_attention_heads = cfg.hidden_size / cfg.num_attention_heads;
74        for _layer_idx in 0..cfg.num_hidden_layers {
75            let extract_key_value = Tensor::zeros((batch_size, cfg.hidden_size), DType::F32, dev)?;
76            let linear_attention = Tensor::zeros(
77                (
78                    batch_size,
79                    num_attention_heads,
80                    cfg.hidden_size / num_attention_heads,
81                    cfg.hidden_size / num_attention_heads,
82                ),
83                DType::F32,
84                dev,
85            )?;
86            let feed_forward = Tensor::zeros((batch_size, cfg.hidden_size), DType::F32, dev)?;
87            per_layer.push(StatePerLayer {
88                extract_key_value,
89                linear_attention,
90                feed_forward,
91            });
92        }
93        Ok(Self { per_layer, pos: 0 })
94    }
95}
96
97#[derive(Debug, Clone)]
98struct SelfAttention {
99    key: Linear,
100    receptance: Linear,
101    value: Linear,
102    gate: Linear,
103    output: Linear,
104    ln_x: candle_nn::GroupNorm,
105    time_mix_key: Tensor,
106    time_mix_value: Tensor,
107    time_mix_receptance: Tensor,
108    time_decay: Tensor,
109    time_faaaa: Tensor,
110    time_mix_gate: Tensor,
111    layer_id: usize,
112    n_attn_heads: usize,
113}
114
115impl SelfAttention {
116    pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
117        let hidden_size = cfg.hidden_size;
118        let attn_hidden_size = cfg.attention_hidden_size;
119        let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?;
120        let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?;
121        let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?;
122        let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?;
123        let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?;
124        let ln_x = candle_nn::group_norm(
125            hidden_size / cfg.head_size,
126            hidden_size,
127            1e-5,
128            vb.pp("ln_x"),
129        )?;
130        let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
131        let time_mix_value = vb.get((1, 1, cfg.hidden_size), "time_mix_value")?;
132        let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
133        let n_attn_heads = cfg.hidden_size / cfg.head_size;
134        let time_decay = vb.get((n_attn_heads, cfg.head_size), "time_decay")?;
135        let time_faaaa = vb.get((n_attn_heads, cfg.head_size), "time_faaaa")?;
136        let time_mix_gate = vb.get((1, 1, cfg.hidden_size), "time_mix_gate")?;
137        Ok(Self {
138            key,
139            value,
140            receptance,
141            gate,
142            output,
143            ln_x,
144            time_mix_key,
145            time_mix_value,
146            time_mix_receptance,
147            time_decay,
148            time_faaaa,
149            time_mix_gate,
150            layer_id,
151            n_attn_heads,
152        })
153    }
154
155    pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
156        let h = self.time_decay.dim(0)?;
157        let (b, t, s) = xs.dims3()?;
158        let s = s / h;
159        let (receptance, key, value, gate) = {
160            // extract key-value
161            let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
162            let shifted = if shifted.rank() == 2 {
163                shifted.unsqueeze(1)?
164            } else {
165                shifted
166            };
167            let key = ((xs * &self.time_mix_key)? + &shifted * (1.0 - &self.time_mix_key)?)?;
168            let value = ((xs * &self.time_mix_value)? + &shifted * (1.0 - &self.time_mix_value)?)?;
169            let receptance = ((xs * &self.time_mix_receptance)?
170                + &shifted * (1.0 - &self.time_mix_receptance)?)?;
171            let gate = ((xs * &self.time_mix_gate)? + &shifted * (1.0 - &self.time_mix_gate)?)?;
172
173            let key = self.key.forward(&key)?;
174            let value = self.value.forward(&value)?;
175            let receptance = self.receptance.forward(&receptance)?;
176            let gate = candle_nn::ops::silu(&self.gate.forward(&gate)?)?;
177            state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?;
178            (receptance, key, value, gate)
179        };
180        // linear attention
181        let mut state_ = state.per_layer[self.layer_id].linear_attention.clone();
182        let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?;
183        let value = value.reshape((b, t, h, s))?.transpose(1, 2)?;
184        let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?;
185
186        let time_decay = self
187            .time_decay
188            .exp()?
189            .neg()?
190            .exp()?
191            .reshape(((), 1, 1))?
192            .reshape((self.n_attn_heads, (), 1))?;
193        let time_faaaa =
194            self.time_faaaa
195                .reshape(((), 1, 1))?
196                .reshape((self.n_attn_heads, (), 1))?;
197
198        let mut out: Vec<Tensor> = Vec::with_capacity(t);
199        for t_ in 0..t {
200            let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?;
201            let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?;
202            let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?;
203            let at = kt.matmul(&vt)?;
204            let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
205            let out_ = rt.matmul(&rhs)?.squeeze(2)?;
206            state_ = (&at + time_decay.broadcast_mul(&state_))?;
207            out.push(out_)
208        }
209        let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?;
210        let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?;
211        let out = (out * gate)?.apply(&self.output)?;
212        state.per_layer[self.layer_id].linear_attention = state_;
213        Ok(out)
214    }
215}
216
217#[derive(Debug, Clone)]
218struct FeedForward {
219    time_mix_key: Tensor,
220    time_mix_receptance: Tensor,
221    key: Linear,
222    receptance: Linear,
223    value: Linear,
224    layer_id: usize,
225}
226
227impl FeedForward {
228    pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
229        let int_size = cfg
230            .intermediate_size
231            .unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32);
232        let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?;
233        let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?;
234        let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?;
235        let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
236        let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
237        Ok(Self {
238            key,
239            receptance,
240            value,
241            time_mix_key,
242            time_mix_receptance,
243            layer_id,
244        })
245    }
246
247    pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
248        let shifted = &state.per_layer[self.layer_id].feed_forward;
249        let key = (xs.broadcast_mul(&self.time_mix_key)?
250            + shifted.broadcast_mul(&(1.0 - &self.time_mix_key)?)?)?;
251        let receptance = (xs.broadcast_mul(&self.time_mix_receptance)?
252            + shifted.broadcast_mul(&(1.0 - &self.time_mix_receptance)?)?)?;
253        let key = key.apply(&self.key)?.relu()?.sqr()?;
254        let value = key.apply(&self.value)?;
255        let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?;
256        state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?;
257        let xs = (receptance * value)?;
258        Ok(xs)
259    }
260}
261
262#[derive(Debug, Clone)]
263struct Block {
264    pre_ln: Option<LayerNorm>,
265    ln1: LayerNorm,
266    ln2: LayerNorm,
267    attention: SelfAttention,
268    feed_forward: FeedForward,
269}
270
271impl Block {
272    pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
273        let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?;
274        let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?;
275        let pre_ln = if layer_id == 0 {
276            let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?;
277            Some(ln)
278        } else {
279            None
280        };
281        let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?;
282        let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?;
283        Ok(Self {
284            pre_ln,
285            ln1,
286            ln2,
287            attention,
288            feed_forward,
289        })
290    }
291
292    pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
293        let xs = match self.pre_ln.as_ref() {
294            None => xs.clone(),
295            Some(pre_ln) => xs.apply(pre_ln)?,
296        };
297        let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?;
298        let xs = (xs + attention)?;
299        let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?;
300        let xs = (xs + feed_forward)?;
301        Ok(xs)
302    }
303}
304
305#[derive(Debug, Clone)]
306pub struct Model {
307    embeddings: Embedding,
308    blocks: Vec<Block>,
309    ln_out: LayerNorm,
310    head: Linear,
311    rescale_every: usize,
312    layers_are_rescaled: bool,
313}
314
315impl Model {
316    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
317        let vb_m = vb.pp("rwkv");
318        let embeddings = embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?;
319        let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
320        let vb_b = vb_m.pp("blocks");
321        for block_index in 0..cfg.num_hidden_layers {
322            let block = Block::new(block_index, cfg, vb_b.pp(block_index))?;
323            blocks.push(block)
324        }
325        let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?;
326        let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?;
327        Ok(Self {
328            embeddings,
329            blocks,
330            ln_out,
331            head,
332            rescale_every: cfg.rescale_every,
333            layers_are_rescaled: false, // This seem to only happen for the f16/bf16 dtypes.
334        })
335    }
336
337    pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
338        let (_b_size, _seq_len) = xs.dims2()?;
339        let mut xs = xs.apply(&self.embeddings)?;
340        for (block_idx, block) in self.blocks.iter().enumerate() {
341            xs = block.forward(&xs, state)?;
342            if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 {
343                xs = (xs / 2.)?
344            }
345        }
346        let xs = xs.apply(&self.ln_out)?.apply(&self.head)?;
347        state.pos += 1;
348        Ok(xs)
349    }
350}
351
352type Bytes = Vec<u8>;
353
354// https://github.com/BlinkDL/ChatRWKV/blob/095e812aef15a1f74107f6c39d13578a2412dc46/RWKV_v5_demo.py#L14
355pub struct Tokenizer {
356    table: Vec<Vec<Vec<Bytes>>>,
357    good: Vec<HashSet<u8>>,
358    idx2token: HashMap<u32, Vec<u8>>,
359    token2idx: HashMap<Vec<u8>, u32>,
360}
361
362impl Tokenizer {
363    pub fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
364        let file = std::fs::File::open(p)?;
365        let token2idx: HashMap<String, u32> =
366            serde_json::from_reader(file).map_err(candle::Error::wrap)?;
367        let token2idx = token2idx
368            .into_iter()
369            .map(|(key, value)| (key.into_bytes(), value))
370            .collect::<HashMap<_, _>>();
371        let idx2token = token2idx
372            .iter()
373            .map(|(key, value)| (*value, key.to_vec()))
374            .collect::<HashMap<_, _>>();
375
376        let max_idx = token2idx.values().copied().max().unwrap_or(0);
377
378        let mut table = vec![vec![vec![]; 256]; 256];
379        let mut good = vec![HashSet::new(); 256];
380        for idx in (0..(1 + max_idx)).rev() {
381            let s = match idx2token.get(&idx) {
382                None => continue,
383                Some(s) => s,
384            };
385            if s.len() >= 2 {
386                let (s0, s1) = (s[0], s[1]);
387                table[s0 as usize][s1 as usize].push(s.to_vec());
388                good[s0 as usize].insert(s1);
389            }
390        }
391        Ok(Self {
392            table,
393            good,
394            idx2token,
395            token2idx,
396        })
397    }
398
399    pub fn decode_bytes(&self, tokens: &[u32]) -> Vec<u8> {
400        let mut v = Vec::new();
401        for token_id in tokens.iter() {
402            if let Some(token) = self.idx2token.get(token_id) {
403                v.extend_from_slice(token.as_slice())
404            }
405        }
406        v
407    }
408
409    pub fn decode(&self, tokens: &[u32]) -> Result<String> {
410        let bytes = self.decode_bytes(tokens);
411        String::from_utf8(bytes).map_err(candle::Error::wrap)
412    }
413
414    pub fn encode_bytes(&self, bytes: &[u8]) -> Result<Vec<u32>> {
415        let mut tokens = Vec::new();
416        let mut i = 0;
417        while i < bytes.len() {
418            let mut s = vec![bytes[i]];
419            if i + 1 < bytes.len() && self.good[bytes[i] as usize].contains(&bytes[i + 1]) {
420                let table = &self.table[bytes[i] as usize][bytes[i + 1] as usize];
421                for table_elem in table.iter() {
422                    if bytes[i..].starts_with(table_elem) {
423                        s = table_elem.to_vec();
424                        break;
425                    }
426                }
427            }
428            i += s.len();
429            let token = match self.token2idx.get(&s) {
430                None => candle::bail!("unexpected token '{}' {s:?}", String::from_utf8_lossy(&s)),
431                Some(token) => *token,
432            };
433            tokens.push(token)
434        }
435        Ok(tokens)
436    }
437
438    pub fn encode(&self, str: &str) -> Result<Vec<u32>> {
439        self.encode_bytes(str.as_bytes())
440    }
441}