1use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear};
31use candle::{IndexOp, Result, Tensor};
32use candle_nn::{embedding, Embedding, Module, VarBuilder};
33
34pub use crate::models::rwkv_v5::{Config, State, Tokenizer};
35
36#[derive(Debug, Clone)]
37struct SelfAttention {
38 key: Linear,
39 receptance: Linear,
40 value: Linear,
41 gate: Linear,
42 output: Linear,
43 ln_x: candle_nn::GroupNorm,
44 time_mix_x: Tensor,
45 time_mix_w: Tensor,
46 time_mix_key: Tensor,
47 time_mix_value: Tensor,
48 time_mix_receptance: Tensor,
49 time_decay: Tensor,
50 time_faaaa: Tensor,
51 time_mix_gate: Tensor,
52 time_decay_w1: Tensor,
53 time_decay_w2: Tensor,
54 time_mix_w1: Tensor,
55 time_mix_w2: Tensor,
56 layer_id: usize,
57 n_attn_heads: usize,
58}
59
60impl SelfAttention {
61 fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
62 let hidden_size = cfg.hidden_size;
63 let attn_hidden_size = cfg.attention_hidden_size;
64 let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?;
65 let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?;
66 let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?;
67 let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?;
68 let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?;
69 let ln_x = candle_nn::group_norm(
70 hidden_size / cfg.head_size,
71 hidden_size,
72 1e-5,
73 vb.pp("ln_x"),
74 )?;
75
76 let time_mix_x = vb.get((1, 1, cfg.hidden_size), "time_mix_x")?;
77 let time_mix_w = vb.get((1, 1, cfg.hidden_size), "time_mix_w")?;
78 let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
79 let time_mix_value = vb.get((1, 1, cfg.hidden_size), "time_mix_value")?;
80 let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
81 let n_attn_heads = cfg.hidden_size / cfg.head_size;
82 let time_decay = vb.get((1, 1, cfg.hidden_size), "time_decay")?;
83 let time_faaaa = vb.get((n_attn_heads, cfg.head_size), "time_faaaa")?;
84 let time_mix_gate = vb.get((1, 1, cfg.hidden_size), "time_mix_gate")?;
85 let time_decay_w1 = vb.get((cfg.hidden_size, n_attn_heads * 2), "time_decay_w1")?;
86 let time_decay_w2 = vb.get((n_attn_heads * 2, cfg.hidden_size), "time_decay_w2")?;
87 let time_mix_w1 = vb.get((cfg.hidden_size, n_attn_heads * 5), "time_mix_w1")?;
88 let time_mix_w2 = vb.get((5, n_attn_heads, cfg.hidden_size), "time_mix_w2")?;
89 Ok(Self {
90 key,
91 value,
92 receptance,
93 gate,
94 output,
95 ln_x,
96 time_mix_x,
97 time_mix_w,
98 time_mix_key,
99 time_mix_value,
100 time_mix_receptance,
101 time_decay,
102 time_faaaa,
103 time_mix_gate,
104 time_decay_w1,
105 time_decay_w2,
106 time_mix_w1,
107 time_mix_w2,
108 layer_id,
109 n_attn_heads,
110 })
111 }
112
113 pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
114 let h = self.n_attn_heads;
115 let (b, t, s) = xs.dims3()?;
116 let s = s / h;
117 let (receptance, key, value, gate, w) = {
118 let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
120 let shifted = if shifted.rank() == 2 {
121 shifted.unsqueeze(1)?
122 } else {
123 shifted
124 };
125
126 let sx = (&shifted - xs)?;
127 let xxx = (xs + &sx * &self.time_mix_x)?;
128 let xxx = xxx
129 .broadcast_matmul(&self.time_mix_w1)?
130 .tanh()?
131 .reshape((b * t, 5, ()))?
132 .transpose(0, 1)?;
133
134 let xxx = xxx.matmul(&self.time_mix_w2)?.reshape((5, b, t, ()))?;
135
136 let (mw, mk, mv, mr, mg) = (xxx.i(0)?, xxx.i(1)?, xxx.i(2)?, xxx.i(3)?, xxx.i(4)?);
137
138 let xw = (xs + &sx * (&self.time_mix_w + &mw)?)?;
139 let xk = (xs + &sx * (&self.time_mix_key + &mk)?)?;
140 let xv = (xs + &sx * (&self.time_mix_value + &mv)?)?;
141 let xr = (xs + &sx * (&self.time_mix_receptance + &mr)?)?;
142 let xg = (xs + &sx * (&self.time_mix_gate + &mg)?)?;
143
144 let w = (&self.time_decay
145 + xw.broadcast_matmul(&self.time_decay_w1)?
146 .tanh()?
147 .broadcast_matmul(&self.time_decay_w2)?)?
148 .reshape(((), 1, 1))?
149 .reshape((self.n_attn_heads, (), 1))?;
150
151 let key = self.key.forward(&xk)?;
152 let value = self.value.forward(&xv)?;
153 let receptance = self.receptance.forward(&xr)?;
154 let gate = candle_nn::ops::silu(&self.gate.forward(&xg)?)?;
155 state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?;
156 (receptance, key, value, gate, w)
157 };
158
159 let mut state_ = state.per_layer[self.layer_id].linear_attention.clone();
161 let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?;
162 let value = value.reshape((b, t, h, s))?.transpose(1, 2)?;
163 let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?;
164
165 let w = w.exp()?.neg()?.exp()?;
166
167 let time_faaaa =
168 self.time_faaaa
169 .reshape(((), 1, 1))?
170 .reshape((self.n_attn_heads, (), 1))?;
171
172 let mut out: Vec<Tensor> = Vec::with_capacity(t);
173 for t_ in 0..t {
174 let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?;
175 let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?;
176 let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?;
177 let at = kt.matmul(&vt)?;
178 let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
179 let out_ = rt.matmul(&rhs)?.squeeze(2)?;
180 state_ = (&at + w.broadcast_mul(&state_))?;
181 out.push(out_)
182 }
183 let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?;
184 let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?;
185 let out = (out * gate)?.apply(&self.output)?;
186 state.per_layer[self.layer_id].linear_attention = state_;
187 Ok(out)
188 }
189}
190
191#[derive(Debug, Clone)]
192struct FeedForward {
193 time_mix_key: Tensor,
194 time_mix_receptance: Tensor,
195 key: Linear,
196 receptance: Linear,
197 value: Linear,
198 layer_id: usize,
199}
200
201impl FeedForward {
202 fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
203 let int_size = cfg
204 .intermediate_size
205 .unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32);
206 let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?;
207 let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?;
208 let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?;
209 let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
210 let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
211 Ok(Self {
212 key,
213 receptance,
214 value,
215 time_mix_key,
216 time_mix_receptance,
217 layer_id,
218 })
219 }
220
221 fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
222 let shifted = state.per_layer[self.layer_id]
223 .feed_forward
224 .broadcast_sub(xs)?;
225 let key = (xs + shifted.broadcast_mul(&self.time_mix_key)?)?;
226 let receptance = (xs + shifted.broadcast_mul(&self.time_mix_receptance)?)?;
227 let key = key.apply(&self.key)?.relu()?.sqr()?;
228 let value = key.apply(&self.value)?;
229 let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?;
230 state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?;
231 let xs = (receptance * value)?;
232 Ok(xs)
233 }
234}
235
236#[derive(Debug, Clone)]
237struct Block {
238 pre_ln: Option<LayerNorm>,
239 ln1: LayerNorm,
240 ln2: LayerNorm,
241 attention: SelfAttention,
242 feed_forward: FeedForward,
243}
244
245impl Block {
246 fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
247 let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?;
248 let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?;
249 let pre_ln = if layer_id == 0 {
250 let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?;
251 Some(ln)
252 } else {
253 None
254 };
255 let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?;
256 let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?;
257 Ok(Self {
258 pre_ln,
259 ln1,
260 ln2,
261 attention,
262 feed_forward,
263 })
264 }
265
266 fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
267 let xs = match self.pre_ln.as_ref() {
268 None => xs.clone(),
269 Some(pre_ln) => xs.apply(pre_ln)?,
270 };
271 let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?;
272 let xs = (xs + attention)?;
273 let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?;
274 let xs = (xs + feed_forward)?;
275 Ok(xs)
276 }
277}
278
279#[derive(Debug, Clone)]
280pub struct Model {
281 embeddings: Embedding,
282 blocks: Vec<Block>,
283 ln_out: LayerNorm,
284 head: Linear,
285 rescale_every: usize,
286 layers_are_rescaled: bool,
287}
288
289impl Model {
290 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
291 let vb_m = vb.pp("rwkv");
292 let embeddings = embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?;
293 let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
294 let vb_b = vb_m.pp("blocks");
295 for block_index in 0..cfg.num_hidden_layers {
296 let block = Block::new(block_index, cfg, vb_b.pp(block_index))?;
297 blocks.push(block)
298 }
299 let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?;
300 let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?;
301 Ok(Self {
302 embeddings,
303 blocks,
304 ln_out,
305 head,
306 rescale_every: cfg.rescale_every,
307 layers_are_rescaled: false, })
309 }
310
311 pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
312 let (_b_size, _seq_len) = xs.dims2()?;
313 let mut xs = xs.apply(&self.embeddings)?;
314 for (block_idx, block) in self.blocks.iter().enumerate() {
315 xs = block.forward(&xs, state)?;
316 if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 {
317 xs = (xs / 2.)?
318 }
319 }
320 let xs = xs.apply(&self.ln_out)?.apply(&self.head)?;
321 state.pos += 1;
322 Ok(xs)
323 }
324}