1use crate::{
19 quantized_nn::{layer_norm, linear_no_bias as linear, Embedding, Linear},
20 quantized_var_builder::VarBuilder,
21};
22use candle::{IndexOp, Result, Tensor};
23use candle_nn::{GroupNorm, LayerNorm, Module};
24
25pub use crate::models::rwkv_v5::{Config, State, Tokenizer};
26
27#[derive(Debug, Clone)]
28struct SelfAttention {
29 key: Linear,
30 receptance: Linear,
31 value: Linear,
32 gate: Linear,
33 output: Linear,
34 ln_x: candle_nn::GroupNorm,
35 time_mix_key: Tensor,
36 time_mix_value: Tensor,
37 time_mix_receptance: Tensor,
38 time_decay: Tensor,
39 time_faaaa: Tensor,
40 time_mix_gate: Tensor,
41 layer_id: usize,
42 n_attn_heads: usize,
43}
44
45impl SelfAttention {
46 fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
47 let hidden_size = cfg.hidden_size;
48 let attn_hidden_size = cfg.attention_hidden_size;
49 let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?;
50 let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?;
51 let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?;
52 let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?;
53 let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?;
54
55 let vb_x = vb.pp("ln_x");
56 let ln_x_weight = vb_x.get(hidden_size, "weight")?.dequantize(vb.device())?;
57 let ln_x_bias = vb_x.get(hidden_size, "bias")?.dequantize(vb.device())?;
58
59 let ln_x = GroupNorm::new(
60 ln_x_weight,
61 ln_x_bias,
62 hidden_size,
63 hidden_size / cfg.head_size,
64 1e-5,
65 )?;
66
67 let time_mix_key = vb
68 .get((1, 1, cfg.hidden_size), "time_mix_key")?
69 .dequantize(vb.device())?;
70 let time_mix_value = vb
71 .get((1, 1, cfg.hidden_size), "time_mix_value")?
72 .dequantize(vb.device())?;
73 let time_mix_receptance = vb
74 .get((1, 1, cfg.hidden_size), "time_mix_receptance")?
75 .dequantize(vb.device())?;
76 let n_attn_heads = cfg.hidden_size / cfg.head_size;
77 let time_decay = vb
78 .get((n_attn_heads, cfg.head_size), "time_decay")?
79 .dequantize(vb.device())?;
80 let time_faaaa = vb
81 .get((n_attn_heads, cfg.head_size), "time_faaaa")?
82 .dequantize(vb.device())?;
83 let time_mix_gate = vb
84 .get((1, 1, cfg.hidden_size), "time_mix_gate")?
85 .dequantize(vb.device())?;
86 Ok(Self {
87 key,
88 value,
89 receptance,
90 gate,
91 output,
92 ln_x,
93 time_mix_key,
94 time_mix_value,
95 time_mix_receptance,
96 time_decay,
97 time_faaaa,
98 time_mix_gate,
99 layer_id,
100 n_attn_heads,
101 })
102 }
103
104 pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
105 let h = self.time_decay.dim(0)?;
106 let (b, t, s) = xs.dims3()?;
107 let s = s / h;
108 let (receptance, key, value, gate) = {
109 let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
111 let shifted = if shifted.rank() == 2 {
112 shifted.unsqueeze(1)?
113 } else {
114 shifted
115 };
116 let key = ((xs * &self.time_mix_key)? + &shifted * (1.0 - &self.time_mix_key)?)?;
117 let value = ((xs * &self.time_mix_value)? + &shifted * (1.0 - &self.time_mix_value)?)?;
118 let receptance = ((xs * &self.time_mix_receptance)?
119 + &shifted * (1.0 - &self.time_mix_receptance)?)?;
120 let gate = ((xs * &self.time_mix_gate)? + &shifted * (1.0 - &self.time_mix_gate)?)?;
121
122 let key = self.key.forward(&key)?;
123 let value = self.value.forward(&value)?;
124 let receptance = self.receptance.forward(&receptance)?;
125 let gate = candle_nn::ops::silu(&self.gate.forward(&gate)?)?;
126 state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?;
127 (receptance, key, value, gate)
128 };
129 let mut state_ = state.per_layer[self.layer_id].linear_attention.clone();
131 let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?;
132 let value = value.reshape((b, t, h, s))?.transpose(1, 2)?;
133 let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?;
134
135 let time_decay = self
136 .time_decay
137 .exp()?
138 .neg()?
139 .exp()?
140 .reshape(((), 1, 1))?
141 .reshape((self.n_attn_heads, (), 1))?;
142 let time_faaaa =
143 self.time_faaaa
144 .reshape(((), 1, 1))?
145 .reshape((self.n_attn_heads, (), 1))?;
146
147 let mut out: Vec<Tensor> = Vec::with_capacity(t);
148 for t_ in 0..t {
149 let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?;
150 let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?;
151 let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?;
152 let at = kt.matmul(&vt)?;
153 let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
154 let out_ = rt.matmul(&rhs)?.squeeze(2)?;
155 state_ = (&at + time_decay.broadcast_mul(&state_))?;
156 out.push(out_)
157 }
158 let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?;
159 let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?;
160 let out = (out * gate)?.apply(&self.output)?;
161 state.per_layer[self.layer_id].linear_attention = state_;
162 Ok(out)
163 }
164}
165
166#[derive(Debug, Clone)]
167struct FeedForward {
168 time_mix_key: Tensor,
169 time_mix_receptance: Tensor,
170 key: Linear,
171 receptance: Linear,
172 value: Linear,
173 layer_id: usize,
174}
175
176impl FeedForward {
177 fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
178 let int_size = cfg
179 .intermediate_size
180 .unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32);
181 let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?;
182 let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?;
183 let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?;
184 let time_mix_key = vb
185 .get((1, 1, cfg.hidden_size), "time_mix_key")?
186 .dequantize(vb.device())?;
187 let time_mix_receptance = vb
188 .get((1, 1, cfg.hidden_size), "time_mix_receptance")?
189 .dequantize(vb.device())?;
190 Ok(Self {
191 key,
192 receptance,
193 value,
194 time_mix_key,
195 time_mix_receptance,
196 layer_id,
197 })
198 }
199
200 fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
201 let shifted = &state.per_layer[self.layer_id].feed_forward;
202 let key = (xs.broadcast_mul(&self.time_mix_key)?
203 + shifted.broadcast_mul(&(1.0 - &self.time_mix_key)?)?)?;
204 let receptance = (xs.broadcast_mul(&self.time_mix_receptance)?
205 + shifted.broadcast_mul(&(1.0 - &self.time_mix_receptance)?)?)?;
206 let key = key.apply(&self.key)?.relu()?.sqr()?;
207 let value = key.apply(&self.value)?;
208 let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?;
209 state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?;
210 let xs = (receptance * value)?;
211 Ok(xs)
212 }
213}
214
215#[derive(Debug, Clone)]
216struct Block {
217 pre_ln: Option<LayerNorm>,
218 ln1: LayerNorm,
219 ln2: LayerNorm,
220 attention: SelfAttention,
221 feed_forward: FeedForward,
222}
223
224impl Block {
225 fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
226 let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?;
227 let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?;
228 let pre_ln = if layer_id == 0 {
229 let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?;
230 Some(ln)
231 } else {
232 None
233 };
234 let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?;
235 let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?;
236 Ok(Self {
237 pre_ln,
238 ln1,
239 ln2,
240 attention,
241 feed_forward,
242 })
243 }
244
245 fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
246 let xs = match self.pre_ln.as_ref() {
247 None => xs.clone(),
248 Some(pre_ln) => xs.apply(pre_ln)?,
249 };
250 let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?;
251 let xs = (xs + attention)?;
252 let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?;
253 let xs = (xs + feed_forward)?;
254 Ok(xs)
255 }
256}
257
258#[derive(Debug, Clone)]
259pub struct Model {
260 embeddings: Embedding,
261 blocks: Vec<Block>,
262 ln_out: LayerNorm,
263 head: Linear,
264 rescale_every: usize,
265 layers_are_rescaled: bool,
266}
267
268impl Model {
269 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
270 let vb_m = vb.pp("rwkv");
271 let embeddings = Embedding::new(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?;
272 let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
273 let vb_b = vb_m.pp("blocks");
274 for block_index in 0..cfg.num_hidden_layers {
275 let block = Block::new(block_index, cfg, vb_b.pp(block_index))?;
276 blocks.push(block)
277 }
278 let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?;
279 let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?;
280 Ok(Self {
281 embeddings,
282 blocks,
283 ln_out,
284 head,
285 rescale_every: cfg.rescale_every,
286 layers_are_rescaled: false, })
288 }
289
290 pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
291 let (_b_size, _seq_len) = xs.dims2()?;
292 let mut xs = xs.apply(&self.embeddings)?;
293 for (block_idx, block) in self.blocks.iter().enumerate() {
294 xs = block.forward(&xs, state)?;
295 if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 {
296 xs = (xs / 2.)?
297 }
298 }
299 let xs = xs.apply(&self.ln_out)?.apply(&self.head)?;
300 state.pos += 1;
301 Ok(xs)
302 }
303}