use burn::module::{Module, Param}; use burn::tensor::{backend::Backend, Tensor}; #[derive(Module, Debug)] pub struct RmsNorm { pub weight: Param>, epsilon: f64, } impl RmsNorm { pub fn new(size: usize, epsilon: f64, device: &B::Device) -> Self { let weight = Param::from_tensor(Tensor::ones([size], device)); Self { weight, epsilon } } pub fn forward(&self, x: Tensor) -> Tensor { // x: [batch, seq_len, dim] // RMSNorm: x * weight / sqrt(mean(x^2) + eps) let x_sq = x.clone().powf_scalar(2.0); // mean over last dim, keeping dims for broadcast let [b, s, d] = x_sq.dims(); let variance = x_sq.sum_dim(2).div_scalar(d as f32); let norm = x.div(variance.add_scalar(self.epsilon).sqrt()); let w = self.weight.val().unsqueeze::<2>().unsqueeze::<3>().reshape([1, 1, d]); norm * w } } #[derive(Module, Debug)] pub struct Mlp { pub gate_proj: Param>, // [in, intermediate] pub up_proj: Param>, // [in, intermediate] pub down_proj: Param>, // [intermediate, out] } impl Mlp { pub fn new(hidden_size: usize, intermediate_size: usize, device: &B::Device) -> Self { Self { gate_proj: Param::from_tensor(Tensor::zeros([hidden_size, intermediate_size], device)), up_proj: Param::from_tensor(Tensor::zeros([hidden_size, intermediate_size], device)), down_proj: Param::from_tensor(Tensor::zeros([intermediate_size, hidden_size], device)), } } pub fn forward(&self, x: Tensor) -> Tensor { // x: [batch, seq, hidden] // gate = x @ gate_proj -> [batch, seq, intermediate] let gate = x.clone().matmul(self.gate_proj.val().unsqueeze()); let up = x.matmul(self.up_proj.val().unsqueeze()); // SiLU(gate) * up let silu = gate.clone() * burn::tensor::activation::sigmoid(gate); let intermediate = silu * up; // intermediate @ down_proj -> [batch, seq, hidden] intermediate.matmul(self.down_proj.val().unsqueeze()) } }