119 lines
4.5 KiB
Rust
119 lines
4.5 KiB
Rust
use burn::module::{Module, Param};
|
|
use burn::tensor::{backend::Backend, Tensor};
|
|
use super::rope::RoPE;
|
|
use super::config::SmolLMConfig;
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct KVCache<B: Backend> {
|
|
pub k: Tensor<B, 4>,
|
|
pub v: Tensor<B, 4>,
|
|
}
|
|
|
|
#[derive(Module, Debug)]
|
|
pub struct Attention<B: Backend> {
|
|
pub q_proj: Param<Tensor<B, 2>>, // [hidden, num_heads * head_dim]
|
|
pub k_proj: Param<Tensor<B, 2>>, // [hidden, num_kv_heads * head_dim]
|
|
pub v_proj: Param<Tensor<B, 2>>, // [hidden, num_kv_heads * head_dim]
|
|
pub o_proj: Param<Tensor<B, 2>>, // [num_heads * head_dim, hidden]
|
|
|
|
num_heads: usize,
|
|
num_kv_heads: usize,
|
|
head_dim: usize,
|
|
|
|
rope: RoPE<B>,
|
|
}
|
|
|
|
impl<B: Backend> Attention<B> {
|
|
pub fn new(config: &SmolLMConfig, device: &B::Device) -> Self {
|
|
let head_dim = config.hidden_size / config.num_attention_heads;
|
|
|
|
Self {
|
|
q_proj: Param::from_tensor(Tensor::zeros([config.hidden_size, config.num_attention_heads * head_dim], device)),
|
|
k_proj: Param::from_tensor(Tensor::zeros([config.hidden_size, config.num_key_value_heads * head_dim], device)),
|
|
v_proj: Param::from_tensor(Tensor::zeros([config.hidden_size, config.num_key_value_heads * head_dim], device)),
|
|
o_proj: Param::from_tensor(Tensor::zeros([config.num_attention_heads * head_dim, config.hidden_size], device)),
|
|
|
|
num_heads: config.num_attention_heads,
|
|
num_kv_heads: config.num_key_value_heads,
|
|
head_dim,
|
|
|
|
rope: RoPE::new(head_dim, config.max_position_embeddings, config.rope_theta, device),
|
|
}
|
|
}
|
|
|
|
pub fn forward(
|
|
&self,
|
|
x: Tensor<B, 3>,
|
|
offset: usize,
|
|
cache: Option<KVCache<B>>
|
|
) -> (Tensor<B, 3>, KVCache<B>) {
|
|
let [batch, seq_len, hidden_dim] = x.dims();
|
|
|
|
// Project Q, K, V: x @ W -> [batch, seq, proj_dim]
|
|
let q = x.clone().matmul(self.q_proj.val().unsqueeze());
|
|
let k = x.clone().matmul(self.k_proj.val().unsqueeze());
|
|
let v = x.matmul(self.v_proj.val().unsqueeze());
|
|
|
|
// Reshape: [batch, seq, heads, head_dim] -> [batch, heads, seq, head_dim]
|
|
let q = q.reshape([batch, seq_len, self.num_heads, self.head_dim]).swap_dims(1, 2);
|
|
let k = k.reshape([batch, seq_len, self.num_kv_heads, self.head_dim]).swap_dims(1, 2);
|
|
let v = v.reshape([batch, seq_len, self.num_kv_heads, self.head_dim]).swap_dims(1, 2);
|
|
|
|
// Apply RoPE
|
|
let q = self.rope.forward(q, offset);
|
|
let k = self.rope.forward(k, offset);
|
|
|
|
// KV cache
|
|
let (k, v) = if let Some(c) = cache {
|
|
(Tensor::cat(vec![c.k, k], 2), Tensor::cat(vec![c.v, v], 2))
|
|
} else {
|
|
(k, v)
|
|
};
|
|
|
|
let new_cache = KVCache { k: k.clone(), v: v.clone() };
|
|
let kv_len = k.dims()[2];
|
|
|
|
// GQA: repeat K,V heads — [batch, kv_heads, kv_len, hd] -> [batch, num_heads, kv_len, hd]
|
|
let num_reps = self.num_heads / self.num_kv_heads;
|
|
let k = if num_reps > 1 {
|
|
let [b, kv_h, s, hd] = k.dims();
|
|
k.reshape([b, kv_h, 1, s, hd]).repeat_dim(2, num_reps).reshape([b, self.num_heads, s, hd])
|
|
} else { k };
|
|
let v = if num_reps > 1 {
|
|
let [b, kv_h, s, hd] = v.dims();
|
|
v.reshape([b, kv_h, 1, s, hd]).repeat_dim(2, num_reps).reshape([b, self.num_heads, s, hd])
|
|
} else { v };
|
|
|
|
// Attention: Q @ K^T / sqrt(d)
|
|
let scale = 1.0 / (self.head_dim as f64).sqrt();
|
|
let scores = q.matmul(k.swap_dims(2, 3)).mul_scalar(scale);
|
|
// scores: [batch, heads, seq_len, kv_len]
|
|
|
|
// Causal mask for prefill (seq_len > 1)
|
|
let scores = if seq_len > 1 {
|
|
let mask_data: Vec<f32> = (0..seq_len).flat_map(|i| {
|
|
(0..kv_len).map(move |j| {
|
|
if j > offset + i { f32::NEG_INFINITY } else { 0.0 }
|
|
})
|
|
}).collect();
|
|
let mask = Tensor::<B, 2>::from_data(
|
|
burn::tensor::TensorData::new(mask_data, [seq_len, kv_len]),
|
|
&scores.device()
|
|
).reshape([1, 1, seq_len, kv_len]);
|
|
scores + mask
|
|
} else {
|
|
scores
|
|
};
|
|
|
|
let attn_weights = burn::tensor::activation::softmax(scores, 3);
|
|
|
|
let context = attn_weights.matmul(v);
|
|
// [batch, heads, seq, hd] -> [batch, seq, heads*hd]
|
|
let context = context.swap_dims(1, 2).reshape([batch, seq_len, self.num_heads * self.head_dim]);
|
|
|
|
let output = context.matmul(self.o_proj.val().unsqueeze());
|
|
|
|
(output, new_cache)
|
|
}
|
|
}
|