1use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear};
35use candle::{DType, Device, IndexOp, Result, Tensor};
36use candle_nn::{embedding, Embedding, Module, VarBuilder};
37use std::collections::{HashMap, HashSet};
38
39fn default_num_attention_heads() -> usize {
40 64
41}
42
43#[derive(Debug, Clone, serde::Deserialize)]
45pub struct Config {
46 pub vocab_size: usize,
47 pub hidden_size: usize,
48 pub num_hidden_layers: usize,
49 pub attention_hidden_size: usize,
50 #[serde(default = "default_num_attention_heads")]
51 pub num_attention_heads: usize,
52 pub head_size: usize,
53 pub intermediate_size: Option<usize>,
54 pub layer_norm_epsilon: f64,
55 pub rescale_every: usize,
56}
57
58pub struct StatePerLayer {
59 pub extract_key_value: Tensor,
60 pub linear_attention: Tensor,
61 pub feed_forward: Tensor,
62}
63
64pub struct State {
65 pub per_layer: Vec<StatePerLayer>,
66 pub pos: usize,
67}
68
69impl State {
70 pub fn new(batch_size: usize, cfg: &Config, dev: &Device) -> Result<Self> {
71 let mut per_layer = Vec::with_capacity(cfg.num_hidden_layers);
72 let num_attention_heads = cfg.hidden_size / cfg.num_attention_heads;
74 for _layer_idx in 0..cfg.num_hidden_layers {
75 let extract_key_value = Tensor::zeros((batch_size, cfg.hidden_size), DType::F32, dev)?;
76 let linear_attention = Tensor::zeros(
77 (
78 batch_size,
79 num_attention_heads,
80 cfg.hidden_size / num_attention_heads,
81 cfg.hidden_size / num_attention_heads,
82 ),
83 DType::F32,
84 dev,
85 )?;
86 let feed_forward = Tensor::zeros((batch_size, cfg.hidden_size), DType::F32, dev)?;
87 per_layer.push(StatePerLayer {
88 extract_key_value,
89 linear_attention,
90 feed_forward,
91 });
92 }
93 Ok(Self { per_layer, pos: 0 })
94 }
95}
96
97#[derive(Debug, Clone)]
98struct SelfAttention {
99 key: Linear,
100 receptance: Linear,
101 value: Linear,
102 gate: Linear,
103 output: Linear,
104 ln_x: candle_nn::GroupNorm,
105 time_mix_key: Tensor,
106 time_mix_value: Tensor,
107 time_mix_receptance: Tensor,
108 time_decay: Tensor,
109 time_faaaa: Tensor,
110 time_mix_gate: Tensor,
111 layer_id: usize,
112 n_attn_heads: usize,
113}
114
115impl SelfAttention {
116 pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
117 let hidden_size = cfg.hidden_size;
118 let attn_hidden_size = cfg.attention_hidden_size;
119 let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?;
120 let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?;
121 let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?;
122 let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?;
123 let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?;
124 let ln_x = candle_nn::group_norm(
125 hidden_size / cfg.head_size,
126 hidden_size,
127 1e-5,
128 vb.pp("ln_x"),
129 )?;
130 let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
131 let time_mix_value = vb.get((1, 1, cfg.hidden_size), "time_mix_value")?;
132 let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
133 let n_attn_heads = cfg.hidden_size / cfg.head_size;
134 let time_decay = vb.get((n_attn_heads, cfg.head_size), "time_decay")?;
135 let time_faaaa = vb.get((n_attn_heads, cfg.head_size), "time_faaaa")?;
136 let time_mix_gate = vb.get((1, 1, cfg.hidden_size), "time_mix_gate")?;
137 Ok(Self {
138 key,
139 value,
140 receptance,
141 gate,
142 output,
143 ln_x,
144 time_mix_key,
145 time_mix_value,
146 time_mix_receptance,
147 time_decay,
148 time_faaaa,
149 time_mix_gate,
150 layer_id,
151 n_attn_heads,
152 })
153 }
154
155 pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
156 let h = self.time_decay.dim(0)?;
157 let (b, t, s) = xs.dims3()?;
158 let s = s / h;
159 let (receptance, key, value, gate) = {
160 let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
162 let shifted = if shifted.rank() == 2 {
163 shifted.unsqueeze(1)?
164 } else {
165 shifted
166 };
167 let key = ((xs * &self.time_mix_key)? + &shifted * (1.0 - &self.time_mix_key)?)?;
168 let value = ((xs * &self.time_mix_value)? + &shifted * (1.0 - &self.time_mix_value)?)?;
169 let receptance = ((xs * &self.time_mix_receptance)?
170 + &shifted * (1.0 - &self.time_mix_receptance)?)?;
171 let gate = ((xs * &self.time_mix_gate)? + &shifted * (1.0 - &self.time_mix_gate)?)?;
172
173 let key = self.key.forward(&key)?;
174 let value = self.value.forward(&value)?;
175 let receptance = self.receptance.forward(&receptance)?;
176 let gate = candle_nn::ops::silu(&self.gate.forward(&gate)?)?;
177 state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?;
178 (receptance, key, value, gate)
179 };
180 let mut state_ = state.per_layer[self.layer_id].linear_attention.clone();
182 let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?;
183 let value = value.reshape((b, t, h, s))?.transpose(1, 2)?;
184 let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?;
185
186 let time_decay = self
187 .time_decay
188 .exp()?
189 .neg()?
190 .exp()?
191 .reshape(((), 1, 1))?
192 .reshape((self.n_attn_heads, (), 1))?;
193 let time_faaaa =
194 self.time_faaaa
195 .reshape(((), 1, 1))?
196 .reshape((self.n_attn_heads, (), 1))?;
197
198 let mut out: Vec<Tensor> = Vec::with_capacity(t);
199 for t_ in 0..t {
200 let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?;
201 let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?;
202 let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?;
203 let at = kt.matmul(&vt)?;
204 let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
205 let out_ = rt.matmul(&rhs)?.squeeze(2)?;
206 state_ = (&at + time_decay.broadcast_mul(&state_))?;
207 out.push(out_)
208 }
209 let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?;
210 let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?;
211 let out = (out * gate)?.apply(&self.output)?;
212 state.per_layer[self.layer_id].linear_attention = state_;
213 Ok(out)
214 }
215}
216
217#[derive(Debug, Clone)]
218struct FeedForward {
219 time_mix_key: Tensor,
220 time_mix_receptance: Tensor,
221 key: Linear,
222 receptance: Linear,
223 value: Linear,
224 layer_id: usize,
225}
226
227impl FeedForward {
228 pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
229 let int_size = cfg
230 .intermediate_size
231 .unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32);
232 let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?;
233 let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?;
234 let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?;
235 let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
236 let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
237 Ok(Self {
238 key,
239 receptance,
240 value,
241 time_mix_key,
242 time_mix_receptance,
243 layer_id,
244 })
245 }
246
247 pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
248 let shifted = &state.per_layer[self.layer_id].feed_forward;
249 let key = (xs.broadcast_mul(&self.time_mix_key)?
250 + shifted.broadcast_mul(&(1.0 - &self.time_mix_key)?)?)?;
251 let receptance = (xs.broadcast_mul(&self.time_mix_receptance)?
252 + shifted.broadcast_mul(&(1.0 - &self.time_mix_receptance)?)?)?;
253 let key = key.apply(&self.key)?.relu()?.sqr()?;
254 let value = key.apply(&self.value)?;
255 let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?;
256 state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?;
257 let xs = (receptance * value)?;
258 Ok(xs)
259 }
260}
261
262#[derive(Debug, Clone)]
263struct Block {
264 pre_ln: Option<LayerNorm>,
265 ln1: LayerNorm,
266 ln2: LayerNorm,
267 attention: SelfAttention,
268 feed_forward: FeedForward,
269}
270
271impl Block {
272 pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
273 let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?;
274 let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?;
275 let pre_ln = if layer_id == 0 {
276 let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?;
277 Some(ln)
278 } else {
279 None
280 };
281 let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?;
282 let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?;
283 Ok(Self {
284 pre_ln,
285 ln1,
286 ln2,
287 attention,
288 feed_forward,
289 })
290 }
291
292 pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
293 let xs = match self.pre_ln.as_ref() {
294 None => xs.clone(),
295 Some(pre_ln) => xs.apply(pre_ln)?,
296 };
297 let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?;
298 let xs = (xs + attention)?;
299 let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?;
300 let xs = (xs + feed_forward)?;
301 Ok(xs)
302 }
303}
304
305#[derive(Debug, Clone)]
306pub struct Model {
307 embeddings: Embedding,
308 blocks: Vec<Block>,
309 ln_out: LayerNorm,
310 head: Linear,
311 rescale_every: usize,
312 layers_are_rescaled: bool,
313}
314
315impl Model {
316 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
317 let vb_m = vb.pp("rwkv");
318 let embeddings = embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?;
319 let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
320 let vb_b = vb_m.pp("blocks");
321 for block_index in 0..cfg.num_hidden_layers {
322 let block = Block::new(block_index, cfg, vb_b.pp(block_index))?;
323 blocks.push(block)
324 }
325 let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?;
326 let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?;
327 Ok(Self {
328 embeddings,
329 blocks,
330 ln_out,
331 head,
332 rescale_every: cfg.rescale_every,
333 layers_are_rescaled: false, })
335 }
336
337 pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
338 let (_b_size, _seq_len) = xs.dims2()?;
339 let mut xs = xs.apply(&self.embeddings)?;
340 for (block_idx, block) in self.blocks.iter().enumerate() {
341 xs = block.forward(&xs, state)?;
342 if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 {
343 xs = (xs / 2.)?
344 }
345 }
346 let xs = xs.apply(&self.ln_out)?.apply(&self.head)?;
347 state.pos += 1;
348 Ok(xs)
349 }
350}
351
352type Bytes = Vec<u8>;
353
354pub struct Tokenizer {
356 table: Vec<Vec<Vec<Bytes>>>,
357 good: Vec<HashSet<u8>>,
358 idx2token: HashMap<u32, Vec<u8>>,
359 token2idx: HashMap<Vec<u8>, u32>,
360}
361
362impl Tokenizer {
363 pub fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
364 let file = std::fs::File::open(p)?;
365 let token2idx: HashMap<String, u32> =
366 serde_json::from_reader(file).map_err(candle::Error::wrap)?;
367 let token2idx = token2idx
368 .into_iter()
369 .map(|(key, value)| (key.into_bytes(), value))
370 .collect::<HashMap<_, _>>();
371 let idx2token = token2idx
372 .iter()
373 .map(|(key, value)| (*value, key.to_vec()))
374 .collect::<HashMap<_, _>>();
375
376 let max_idx = token2idx.values().copied().max().unwrap_or(0);
377
378 let mut table = vec![vec![vec![]; 256]; 256];
379 let mut good = vec![HashSet::new(); 256];
380 for idx in (0..(1 + max_idx)).rev() {
381 let s = match idx2token.get(&idx) {
382 None => continue,
383 Some(s) => s,
384 };
385 if s.len() >= 2 {
386 let (s0, s1) = (s[0], s[1]);
387 table[s0 as usize][s1 as usize].push(s.to_vec());
388 good[s0 as usize].insert(s1);
389 }
390 }
391 Ok(Self {
392 table,
393 good,
394 idx2token,
395 token2idx,
396 })
397 }
398
399 pub fn decode_bytes(&self, tokens: &[u32]) -> Vec<u8> {
400 let mut v = Vec::new();
401 for token_id in tokens.iter() {
402 if let Some(token) = self.idx2token.get(token_id) {
403 v.extend_from_slice(token.as_slice())
404 }
405 }
406 v
407 }
408
409 pub fn decode(&self, tokens: &[u32]) -> Result<String> {
410 let bytes = self.decode_bytes(tokens);
411 String::from_utf8(bytes).map_err(candle::Error::wrap)
412 }
413
414 pub fn encode_bytes(&self, bytes: &[u8]) -> Result<Vec<u32>> {
415 let mut tokens = Vec::new();
416 let mut i = 0;
417 while i < bytes.len() {
418 let mut s = vec![bytes[i]];
419 if i + 1 < bytes.len() && self.good[bytes[i] as usize].contains(&bytes[i + 1]) {
420 let table = &self.table[bytes[i] as usize][bytes[i + 1] as usize];
421 for table_elem in table.iter() {
422 if bytes[i..].starts_with(table_elem) {
423 s = table_elem.to_vec();
424 break;
425 }
426 }
427 }
428 i += s.len();
429 let token = match self.token2idx.get(&s) {
430 None => candle::bail!("unexpected token '{}' {s:?}", String::from_utf8_lossy(&s)),
431 Some(token) => *token,
432 };
433 tokens.push(token)
434 }
435 Ok(tokens)
436 }
437
438 pub fn encode(&self, str: &str) -> Result<Vec<u32>> {
439 self.encode_bytes(str.as_bytes())
440 }
441}