Files
agentic-studio/network-poc/node/src/burn_smollm/model.rs
2026-04-02 15:47:48 +03:00

97 lines
3.2 KiB
Rust

use burn::module::{Module, Param};
use burn::tensor::{backend::Backend, Tensor, Int};
use super::modules::{RmsNorm, Mlp};
use super::attention::{Attention, KVCache};
use super::config::SmolLMConfig;
#[derive(Module, Debug)]
pub struct LlamaBlock<B: Backend> {
pub self_attn: Attention<B>,
pub mlp: Mlp<B>,
pub input_layernorm: RmsNorm<B>,
pub post_attention_layernorm: RmsNorm<B>,
}
impl<B: Backend> LlamaBlock<B> {
pub fn new(config: &SmolLMConfig, device: &B::Device) -> Self {
Self {
self_attn: Attention::new(config, device),
mlp: Mlp::new(config.hidden_size, config.intermediate_size, device),
input_layernorm: RmsNorm::new(config.hidden_size, config.rms_norm_eps, device),
post_attention_layernorm: RmsNorm::new(config.hidden_size, config.rms_norm_eps, device),
}
}
pub fn forward(
&self,
x: Tensor<B, 3>,
offset: usize,
cache: Option<KVCache<B>>
) -> (Tensor<B, 3>, KVCache<B>) {
let residual = x.clone();
let x_norm = self.input_layernorm.forward(x);
let (attn_out, new_cache) = self.self_attn.forward(x_norm, offset, cache);
let x = residual + attn_out;
let residual = x.clone();
let x_norm = self.post_attention_layernorm.forward(x);
let mlp_out = self.mlp.forward(x_norm);
let x = residual + mlp_out;
(x, new_cache)
}
}
#[derive(Module, Debug)]
pub struct LlamaModel<B: Backend> {
pub embed_tokens: Param<Tensor<B, 2>>,
pub layers: Vec<LlamaBlock<B>>,
pub norm: RmsNorm<B>,
pub lm_head: Param<Tensor<B, 2>>, // For tie_word_embeddings this can point to embed_tokens
}
impl<B: Backend> LlamaModel<B> {
pub fn new(config: &SmolLMConfig, device: &B::Device) -> Self {
let embed = Tensor::zeros([config.vocab_size, config.hidden_size], device);
let lm_head = Tensor::zeros([config.vocab_size, config.hidden_size], device);
let mut layers = Vec::new();
for _ in 0..config.num_hidden_layers {
layers.push(LlamaBlock::new(config, device));
}
Self {
embed_tokens: Param::from_tensor(embed),
layers,
norm: RmsNorm::new(config.hidden_size, config.rms_norm_eps, device),
lm_head: Param::from_tensor(lm_head),
}
}
pub fn forward(
&self,
input_ids: Tensor<B, 2, Int>,
offset: usize,
caches: &mut Vec<Option<KVCache<B>>>
) -> Tensor<B, 3> {
let [_batch, _seq_len] = input_ids.dims();
let mut x = burn::tensor::module::embedding(self.embed_tokens.val(), input_ids);
for (i, layer) in self.layers.iter().enumerate() {
let cache = caches[i].take();
let (out, new_cache) = layer.forward(x, offset, cache);
x = out;
caches[i] = Some(new_cache);
}
x = self.norm.forward(x);
// Matmul with lm_head (or embed_tokens if tied) to get logits
// Notice: lm_head is typically [vocab_size, hidden_size] in HF, so we swap dims
x.matmul(self.lm_head.val().swap_dims(0, 1).unsqueeze())
}
}