Files
agentic-studio/network-poc/node/src/burn_smollm/attention.rs
2026-04-02 15:47:48 +03:00

119 lines
4.5 KiB
Rust

use burn::module::{Module, Param};
use burn::tensor::{backend::Backend, Tensor};
use super::rope::RoPE;
use super::config::SmolLMConfig;
#[derive(Clone, Debug)]
pub struct KVCache<B: Backend> {
pub k: Tensor<B, 4>,
pub v: Tensor<B, 4>,
}
#[derive(Module, Debug)]
pub struct Attention<B: Backend> {
pub q_proj: Param<Tensor<B, 2>>, // [hidden, num_heads * head_dim]
pub k_proj: Param<Tensor<B, 2>>, // [hidden, num_kv_heads * head_dim]
pub v_proj: Param<Tensor<B, 2>>, // [hidden, num_kv_heads * head_dim]
pub o_proj: Param<Tensor<B, 2>>, // [num_heads * head_dim, hidden]
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
rope: RoPE<B>,
}
impl<B: Backend> Attention<B> {
pub fn new(config: &SmolLMConfig, device: &B::Device) -> Self {
let head_dim = config.hidden_size / config.num_attention_heads;
Self {
q_proj: Param::from_tensor(Tensor::zeros([config.hidden_size, config.num_attention_heads * head_dim], device)),
k_proj: Param::from_tensor(Tensor::zeros([config.hidden_size, config.num_key_value_heads * head_dim], device)),
v_proj: Param::from_tensor(Tensor::zeros([config.hidden_size, config.num_key_value_heads * head_dim], device)),
o_proj: Param::from_tensor(Tensor::zeros([config.num_attention_heads * head_dim, config.hidden_size], device)),
num_heads: config.num_attention_heads,
num_kv_heads: config.num_key_value_heads,
head_dim,
rope: RoPE::new(head_dim, config.max_position_embeddings, config.rope_theta, device),
}
}
pub fn forward(
&self,
x: Tensor<B, 3>,
offset: usize,
cache: Option<KVCache<B>>
) -> (Tensor<B, 3>, KVCache<B>) {
let [batch, seq_len, hidden_dim] = x.dims();
// Project Q, K, V: x @ W -> [batch, seq, proj_dim]
let q = x.clone().matmul(self.q_proj.val().unsqueeze());
let k = x.clone().matmul(self.k_proj.val().unsqueeze());
let v = x.matmul(self.v_proj.val().unsqueeze());
// Reshape: [batch, seq, heads, head_dim] -> [batch, heads, seq, head_dim]
let q = q.reshape([batch, seq_len, self.num_heads, self.head_dim]).swap_dims(1, 2);
let k = k.reshape([batch, seq_len, self.num_kv_heads, self.head_dim]).swap_dims(1, 2);
let v = v.reshape([batch, seq_len, self.num_kv_heads, self.head_dim]).swap_dims(1, 2);
// Apply RoPE
let q = self.rope.forward(q, offset);
let k = self.rope.forward(k, offset);
// KV cache
let (k, v) = if let Some(c) = cache {
(Tensor::cat(vec![c.k, k], 2), Tensor::cat(vec![c.v, v], 2))
} else {
(k, v)
};
let new_cache = KVCache { k: k.clone(), v: v.clone() };
let kv_len = k.dims()[2];
// GQA: repeat K,V heads — [batch, kv_heads, kv_len, hd] -> [batch, num_heads, kv_len, hd]
let num_reps = self.num_heads / self.num_kv_heads;
let k = if num_reps > 1 {
let [b, kv_h, s, hd] = k.dims();
k.reshape([b, kv_h, 1, s, hd]).repeat_dim(2, num_reps).reshape([b, self.num_heads, s, hd])
} else { k };
let v = if num_reps > 1 {
let [b, kv_h, s, hd] = v.dims();
v.reshape([b, kv_h, 1, s, hd]).repeat_dim(2, num_reps).reshape([b, self.num_heads, s, hd])
} else { v };
// Attention: Q @ K^T / sqrt(d)
let scale = 1.0 / (self.head_dim as f64).sqrt();
let scores = q.matmul(k.swap_dims(2, 3)).mul_scalar(scale);
// scores: [batch, heads, seq_len, kv_len]
// Causal mask for prefill (seq_len > 1)
let scores = if seq_len > 1 {
let mask_data: Vec<f32> = (0..seq_len).flat_map(|i| {
(0..kv_len).map(move |j| {
if j > offset + i { f32::NEG_INFINITY } else { 0.0 }
})
}).collect();
let mask = Tensor::<B, 2>::from_data(
burn::tensor::TensorData::new(mask_data, [seq_len, kv_len]),
&scores.device()
).reshape([1, 1, seq_len, kv_len]);
scores + mask
} else {
scores
};
let attn_weights = burn::tensor::activation::softmax(scores, 3);
let context = attn_weights.matmul(v);
// [batch, heads, seq, hd] -> [batch, seq, heads*hd]
let context = context.swap_dims(1, 2).reshape([batch, seq_len, self.num_heads * self.head_dim]);
let output = context.matmul(self.o_proj.val().unsqueeze());
(output, new_cache)
}
}