use burn::module::{Module, Param}; use burn::tensor::{backend::Backend, Tensor}; use super::rope::RoPE; use super::config::SmolLMConfig; #[derive(Clone, Debug)] pub struct KVCache { pub k: Tensor, pub v: Tensor, } #[derive(Module, Debug)] pub struct Attention { pub q_proj: Param>, // [hidden, num_heads * head_dim] pub k_proj: Param>, // [hidden, num_kv_heads * head_dim] pub v_proj: Param>, // [hidden, num_kv_heads * head_dim] pub o_proj: Param>, // [num_heads * head_dim, hidden] num_heads: usize, num_kv_heads: usize, head_dim: usize, rope: RoPE, } impl Attention { pub fn new(config: &SmolLMConfig, device: &B::Device) -> Self { let head_dim = config.hidden_size / config.num_attention_heads; Self { q_proj: Param::from_tensor(Tensor::zeros([config.hidden_size, config.num_attention_heads * head_dim], device)), k_proj: Param::from_tensor(Tensor::zeros([config.hidden_size, config.num_key_value_heads * head_dim], device)), v_proj: Param::from_tensor(Tensor::zeros([config.hidden_size, config.num_key_value_heads * head_dim], device)), o_proj: Param::from_tensor(Tensor::zeros([config.num_attention_heads * head_dim, config.hidden_size], device)), num_heads: config.num_attention_heads, num_kv_heads: config.num_key_value_heads, head_dim, rope: RoPE::new(head_dim, config.max_position_embeddings, config.rope_theta, device), } } pub fn forward( &self, x: Tensor, offset: usize, cache: Option> ) -> (Tensor, KVCache) { let [batch, seq_len, hidden_dim] = x.dims(); // Project Q, K, V: x @ W -> [batch, seq, proj_dim] let q = x.clone().matmul(self.q_proj.val().unsqueeze()); let k = x.clone().matmul(self.k_proj.val().unsqueeze()); let v = x.matmul(self.v_proj.val().unsqueeze()); // Reshape: [batch, seq, heads, head_dim] -> [batch, heads, seq, head_dim] let q = q.reshape([batch, seq_len, self.num_heads, self.head_dim]).swap_dims(1, 2); let k = k.reshape([batch, seq_len, self.num_kv_heads, self.head_dim]).swap_dims(1, 2); let v = v.reshape([batch, seq_len, self.num_kv_heads, self.head_dim]).swap_dims(1, 2); // Apply RoPE let q = self.rope.forward(q, offset); let k = self.rope.forward(k, offset); // KV cache let (k, v) = if let Some(c) = cache { (Tensor::cat(vec![c.k, k], 2), Tensor::cat(vec![c.v, v], 2)) } else { (k, v) }; let new_cache = KVCache { k: k.clone(), v: v.clone() }; let kv_len = k.dims()[2]; // GQA: repeat K,V heads — [batch, kv_heads, kv_len, hd] -> [batch, num_heads, kv_len, hd] let num_reps = self.num_heads / self.num_kv_heads; let k = if num_reps > 1 { let [b, kv_h, s, hd] = k.dims(); k.reshape([b, kv_h, 1, s, hd]).repeat_dim(2, num_reps).reshape([b, self.num_heads, s, hd]) } else { k }; let v = if num_reps > 1 { let [b, kv_h, s, hd] = v.dims(); v.reshape([b, kv_h, 1, s, hd]).repeat_dim(2, num_reps).reshape([b, self.num_heads, s, hd]) } else { v }; // Attention: Q @ K^T / sqrt(d) let scale = 1.0 / (self.head_dim as f64).sqrt(); let scores = q.matmul(k.swap_dims(2, 3)).mul_scalar(scale); // scores: [batch, heads, seq_len, kv_len] // Causal mask for prefill (seq_len > 1) let scores = if seq_len > 1 { let mask_data: Vec = (0..seq_len).flat_map(|i| { (0..kv_len).map(move |j| { if j > offset + i { f32::NEG_INFINITY } else { 0.0 } }) }).collect(); let mask = Tensor::::from_data( burn::tensor::TensorData::new(mask_data, [seq_len, kv_len]), &scores.device() ).reshape([1, 1, seq_len, kv_len]); scores + mask } else { scores }; let attn_weights = burn::tensor::activation::softmax(scores, 3); let context = attn_weights.matmul(v); // [batch, heads, seq, hd] -> [batch, seq, heads*hd] let context = context.swap_dims(1, 2).reshape([batch, seq_len, self.num_heads * self.head_dim]); let output = context.matmul(self.o_proj.val().unsqueeze()); (output, new_cache) } }