Files
agentic-studio/network-poc/node/src/burn_smollm/modules.rs
2026-04-02 15:47:48 +03:00

60 lines
2.1 KiB
Rust

use burn::module::{Module, Param};
use burn::tensor::{backend::Backend, Tensor};
#[derive(Module, Debug)]
pub struct RmsNorm<B: Backend> {
pub weight: Param<Tensor<B, 1>>,
epsilon: f64,
}
impl<B: Backend> RmsNorm<B> {
pub fn new(size: usize, epsilon: f64, device: &B::Device) -> Self {
let weight = Param::from_tensor(Tensor::ones([size], device));
Self { weight, epsilon }
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
// x: [batch, seq_len, dim]
// RMSNorm: x * weight / sqrt(mean(x^2) + eps)
let x_sq = x.clone().powf_scalar(2.0);
// mean over last dim, keeping dims for broadcast
let [b, s, d] = x_sq.dims();
let variance = x_sq.sum_dim(2).div_scalar(d as f32);
let norm = x.div(variance.add_scalar(self.epsilon).sqrt());
let w = self.weight.val().unsqueeze::<2>().unsqueeze::<3>().reshape([1, 1, d]);
norm * w
}
}
#[derive(Module, Debug)]
pub struct Mlp<B: Backend> {
pub gate_proj: Param<Tensor<B, 2>>, // [in, intermediate]
pub up_proj: Param<Tensor<B, 2>>, // [in, intermediate]
pub down_proj: Param<Tensor<B, 2>>, // [intermediate, out]
}
impl<B: Backend> Mlp<B> {
pub fn new(hidden_size: usize, intermediate_size: usize, device: &B::Device) -> Self {
Self {
gate_proj: Param::from_tensor(Tensor::zeros([hidden_size, intermediate_size], device)),
up_proj: Param::from_tensor(Tensor::zeros([hidden_size, intermediate_size], device)),
down_proj: Param::from_tensor(Tensor::zeros([intermediate_size, hidden_size], device)),
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
// x: [batch, seq, hidden]
// gate = x @ gate_proj -> [batch, seq, intermediate]
let gate = x.clone().matmul(self.gate_proj.val().unsqueeze());
let up = x.matmul(self.up_proj.val().unsqueeze());
// SiLU(gate) * up
let silu = gate.clone() * burn::tensor::activation::sigmoid(gate);
let intermediate = silu * up;
// intermediate @ down_proj -> [batch, seq, hidden]
intermediate.matmul(self.down_proj.val().unsqueeze())
}
}