60 lines
2.1 KiB
Rust
60 lines
2.1 KiB
Rust
use burn::module::{Module, Param};
|
|
use burn::tensor::{backend::Backend, Tensor};
|
|
|
|
#[derive(Module, Debug)]
|
|
pub struct RmsNorm<B: Backend> {
|
|
pub weight: Param<Tensor<B, 1>>,
|
|
epsilon: f64,
|
|
}
|
|
|
|
impl<B: Backend> RmsNorm<B> {
|
|
pub fn new(size: usize, epsilon: f64, device: &B::Device) -> Self {
|
|
let weight = Param::from_tensor(Tensor::ones([size], device));
|
|
Self { weight, epsilon }
|
|
}
|
|
|
|
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
|
|
// x: [batch, seq_len, dim]
|
|
// RMSNorm: x * weight / sqrt(mean(x^2) + eps)
|
|
let x_sq = x.clone().powf_scalar(2.0);
|
|
// mean over last dim, keeping dims for broadcast
|
|
let [b, s, d] = x_sq.dims();
|
|
let variance = x_sq.sum_dim(2).div_scalar(d as f32);
|
|
let norm = x.div(variance.add_scalar(self.epsilon).sqrt());
|
|
|
|
let w = self.weight.val().unsqueeze::<2>().unsqueeze::<3>().reshape([1, 1, d]);
|
|
norm * w
|
|
}
|
|
}
|
|
|
|
#[derive(Module, Debug)]
|
|
pub struct Mlp<B: Backend> {
|
|
pub gate_proj: Param<Tensor<B, 2>>, // [in, intermediate]
|
|
pub up_proj: Param<Tensor<B, 2>>, // [in, intermediate]
|
|
pub down_proj: Param<Tensor<B, 2>>, // [intermediate, out]
|
|
}
|
|
|
|
impl<B: Backend> Mlp<B> {
|
|
pub fn new(hidden_size: usize, intermediate_size: usize, device: &B::Device) -> Self {
|
|
Self {
|
|
gate_proj: Param::from_tensor(Tensor::zeros([hidden_size, intermediate_size], device)),
|
|
up_proj: Param::from_tensor(Tensor::zeros([hidden_size, intermediate_size], device)),
|
|
down_proj: Param::from_tensor(Tensor::zeros([intermediate_size, hidden_size], device)),
|
|
}
|
|
}
|
|
|
|
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
|
|
// x: [batch, seq, hidden]
|
|
// gate = x @ gate_proj -> [batch, seq, intermediate]
|
|
let gate = x.clone().matmul(self.gate_proj.val().unsqueeze());
|
|
let up = x.matmul(self.up_proj.val().unsqueeze());
|
|
|
|
// SiLU(gate) * up
|
|
let silu = gate.clone() * burn::tensor::activation::sigmoid(gate);
|
|
let intermediate = silu * up;
|
|
|
|
// intermediate @ down_proj -> [batch, seq, hidden]
|
|
intermediate.matmul(self.down_proj.val().unsqueeze())
|
|
}
|
|
}
|