candle_transformers/models/
quantized_rwkv_v5.rs

1//! RWKV v5 model implementation with quantization support.
2//!
3//! RWKV v5 is an attention-free language model optimized for efficiency.
4//! This implementation provides quantization for reduced memory and compute.
5//!
6//! Key characteristics:
7//! - Linear attention mechanism
8//! - GroupNorm layer normalization
9//! - Time-mixing layers
10//! - State-based sequential processing
11//! - Support for 8-bit quantization
12//!
13//! References:
14//! - [RWKV Model](https://github.com/BlinkDL/RWKV-LM)
15//! - [RWKV v5 Architecture](https://www.rwkv.com/v5)
16//!
17
18use crate::{
19    quantized_nn::{layer_norm, linear_no_bias as linear, Embedding, Linear},
20    quantized_var_builder::VarBuilder,
21};
22use candle::{IndexOp, Result, Tensor};
23use candle_nn::{GroupNorm, LayerNorm, Module};
24
25pub use crate::models::rwkv_v5::{Config, State, Tokenizer};
26
27#[derive(Debug, Clone)]
28struct SelfAttention {
29    key: Linear,
30    receptance: Linear,
31    value: Linear,
32    gate: Linear,
33    output: Linear,
34    ln_x: candle_nn::GroupNorm,
35    time_mix_key: Tensor,
36    time_mix_value: Tensor,
37    time_mix_receptance: Tensor,
38    time_decay: Tensor,
39    time_faaaa: Tensor,
40    time_mix_gate: Tensor,
41    layer_id: usize,
42    n_attn_heads: usize,
43}
44
45impl SelfAttention {
46    fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
47        let hidden_size = cfg.hidden_size;
48        let attn_hidden_size = cfg.attention_hidden_size;
49        let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?;
50        let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?;
51        let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?;
52        let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?;
53        let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?;
54
55        let vb_x = vb.pp("ln_x");
56        let ln_x_weight = vb_x.get(hidden_size, "weight")?.dequantize(vb.device())?;
57        let ln_x_bias = vb_x.get(hidden_size, "bias")?.dequantize(vb.device())?;
58
59        let ln_x = GroupNorm::new(
60            ln_x_weight,
61            ln_x_bias,
62            hidden_size,
63            hidden_size / cfg.head_size,
64            1e-5,
65        )?;
66
67        let time_mix_key = vb
68            .get((1, 1, cfg.hidden_size), "time_mix_key")?
69            .dequantize(vb.device())?;
70        let time_mix_value = vb
71            .get((1, 1, cfg.hidden_size), "time_mix_value")?
72            .dequantize(vb.device())?;
73        let time_mix_receptance = vb
74            .get((1, 1, cfg.hidden_size), "time_mix_receptance")?
75            .dequantize(vb.device())?;
76        let n_attn_heads = cfg.hidden_size / cfg.head_size;
77        let time_decay = vb
78            .get((n_attn_heads, cfg.head_size), "time_decay")?
79            .dequantize(vb.device())?;
80        let time_faaaa = vb
81            .get((n_attn_heads, cfg.head_size), "time_faaaa")?
82            .dequantize(vb.device())?;
83        let time_mix_gate = vb
84            .get((1, 1, cfg.hidden_size), "time_mix_gate")?
85            .dequantize(vb.device())?;
86        Ok(Self {
87            key,
88            value,
89            receptance,
90            gate,
91            output,
92            ln_x,
93            time_mix_key,
94            time_mix_value,
95            time_mix_receptance,
96            time_decay,
97            time_faaaa,
98            time_mix_gate,
99            layer_id,
100            n_attn_heads,
101        })
102    }
103
104    pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
105        let h = self.time_decay.dim(0)?;
106        let (b, t, s) = xs.dims3()?;
107        let s = s / h;
108        let (receptance, key, value, gate) = {
109            // extract key-value
110            let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
111            let shifted = if shifted.rank() == 2 {
112                shifted.unsqueeze(1)?
113            } else {
114                shifted
115            };
116            let key = ((xs * &self.time_mix_key)? + &shifted * (1.0 - &self.time_mix_key)?)?;
117            let value = ((xs * &self.time_mix_value)? + &shifted * (1.0 - &self.time_mix_value)?)?;
118            let receptance = ((xs * &self.time_mix_receptance)?
119                + &shifted * (1.0 - &self.time_mix_receptance)?)?;
120            let gate = ((xs * &self.time_mix_gate)? + &shifted * (1.0 - &self.time_mix_gate)?)?;
121
122            let key = self.key.forward(&key)?;
123            let value = self.value.forward(&value)?;
124            let receptance = self.receptance.forward(&receptance)?;
125            let gate = candle_nn::ops::silu(&self.gate.forward(&gate)?)?;
126            state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?;
127            (receptance, key, value, gate)
128        };
129        // linear attention
130        let mut state_ = state.per_layer[self.layer_id].linear_attention.clone();
131        let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?;
132        let value = value.reshape((b, t, h, s))?.transpose(1, 2)?;
133        let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?;
134
135        let time_decay = self
136            .time_decay
137            .exp()?
138            .neg()?
139            .exp()?
140            .reshape(((), 1, 1))?
141            .reshape((self.n_attn_heads, (), 1))?;
142        let time_faaaa =
143            self.time_faaaa
144                .reshape(((), 1, 1))?
145                .reshape((self.n_attn_heads, (), 1))?;
146
147        let mut out: Vec<Tensor> = Vec::with_capacity(t);
148        for t_ in 0..t {
149            let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?;
150            let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?;
151            let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?;
152            let at = kt.matmul(&vt)?;
153            let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
154            let out_ = rt.matmul(&rhs)?.squeeze(2)?;
155            state_ = (&at + time_decay.broadcast_mul(&state_))?;
156            out.push(out_)
157        }
158        let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?;
159        let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?;
160        let out = (out * gate)?.apply(&self.output)?;
161        state.per_layer[self.layer_id].linear_attention = state_;
162        Ok(out)
163    }
164}
165
166#[derive(Debug, Clone)]
167struct FeedForward {
168    time_mix_key: Tensor,
169    time_mix_receptance: Tensor,
170    key: Linear,
171    receptance: Linear,
172    value: Linear,
173    layer_id: usize,
174}
175
176impl FeedForward {
177    fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
178        let int_size = cfg
179            .intermediate_size
180            .unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32);
181        let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?;
182        let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?;
183        let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?;
184        let time_mix_key = vb
185            .get((1, 1, cfg.hidden_size), "time_mix_key")?
186            .dequantize(vb.device())?;
187        let time_mix_receptance = vb
188            .get((1, 1, cfg.hidden_size), "time_mix_receptance")?
189            .dequantize(vb.device())?;
190        Ok(Self {
191            key,
192            receptance,
193            value,
194            time_mix_key,
195            time_mix_receptance,
196            layer_id,
197        })
198    }
199
200    fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
201        let shifted = &state.per_layer[self.layer_id].feed_forward;
202        let key = (xs.broadcast_mul(&self.time_mix_key)?
203            + shifted.broadcast_mul(&(1.0 - &self.time_mix_key)?)?)?;
204        let receptance = (xs.broadcast_mul(&self.time_mix_receptance)?
205            + shifted.broadcast_mul(&(1.0 - &self.time_mix_receptance)?)?)?;
206        let key = key.apply(&self.key)?.relu()?.sqr()?;
207        let value = key.apply(&self.value)?;
208        let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?;
209        state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?;
210        let xs = (receptance * value)?;
211        Ok(xs)
212    }
213}
214
215#[derive(Debug, Clone)]
216struct Block {
217    pre_ln: Option<LayerNorm>,
218    ln1: LayerNorm,
219    ln2: LayerNorm,
220    attention: SelfAttention,
221    feed_forward: FeedForward,
222}
223
224impl Block {
225    fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
226        let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?;
227        let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?;
228        let pre_ln = if layer_id == 0 {
229            let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?;
230            Some(ln)
231        } else {
232            None
233        };
234        let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?;
235        let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?;
236        Ok(Self {
237            pre_ln,
238            ln1,
239            ln2,
240            attention,
241            feed_forward,
242        })
243    }
244
245    fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
246        let xs = match self.pre_ln.as_ref() {
247            None => xs.clone(),
248            Some(pre_ln) => xs.apply(pre_ln)?,
249        };
250        let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?;
251        let xs = (xs + attention)?;
252        let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?;
253        let xs = (xs + feed_forward)?;
254        Ok(xs)
255    }
256}
257
258#[derive(Debug, Clone)]
259pub struct Model {
260    embeddings: Embedding,
261    blocks: Vec<Block>,
262    ln_out: LayerNorm,
263    head: Linear,
264    rescale_every: usize,
265    layers_are_rescaled: bool,
266}
267
268impl Model {
269    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
270        let vb_m = vb.pp("rwkv");
271        let embeddings = Embedding::new(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?;
272        let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
273        let vb_b = vb_m.pp("blocks");
274        for block_index in 0..cfg.num_hidden_layers {
275            let block = Block::new(block_index, cfg, vb_b.pp(block_index))?;
276            blocks.push(block)
277        }
278        let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?;
279        let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?;
280        Ok(Self {
281            embeddings,
282            blocks,
283            ln_out,
284            head,
285            rescale_every: cfg.rescale_every,
286            layers_are_rescaled: false, // This seem to only happen for the f16/bf16 dtypes.
287        })
288    }
289
290    pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
291        let (_b_size, _seq_len) = xs.dims2()?;
292        let mut xs = xs.apply(&self.embeddings)?;
293        for (block_idx, block) in self.blocks.iter().enumerate() {
294            xs = block.forward(&xs, state)?;
295            if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 {
296                xs = (xs / 2.)?
297            }
298        }
299        let xs = xs.apply(&self.ln_out)?.apply(&self.head)?;
300        state.pos += 1;
301        Ok(xs)
302    }
303}