97 lines
3.2 KiB
Rust
97 lines
3.2 KiB
Rust
use burn::module::{Module, Param};
|
|
use burn::tensor::{backend::Backend, Tensor, Int};
|
|
use super::modules::{RmsNorm, Mlp};
|
|
use super::attention::{Attention, KVCache};
|
|
use super::config::SmolLMConfig;
|
|
|
|
#[derive(Module, Debug)]
|
|
pub struct LlamaBlock<B: Backend> {
|
|
pub self_attn: Attention<B>,
|
|
pub mlp: Mlp<B>,
|
|
pub input_layernorm: RmsNorm<B>,
|
|
pub post_attention_layernorm: RmsNorm<B>,
|
|
}
|
|
|
|
impl<B: Backend> LlamaBlock<B> {
|
|
pub fn new(config: &SmolLMConfig, device: &B::Device) -> Self {
|
|
Self {
|
|
self_attn: Attention::new(config, device),
|
|
mlp: Mlp::new(config.hidden_size, config.intermediate_size, device),
|
|
input_layernorm: RmsNorm::new(config.hidden_size, config.rms_norm_eps, device),
|
|
post_attention_layernorm: RmsNorm::new(config.hidden_size, config.rms_norm_eps, device),
|
|
}
|
|
}
|
|
|
|
pub fn forward(
|
|
&self,
|
|
x: Tensor<B, 3>,
|
|
offset: usize,
|
|
cache: Option<KVCache<B>>
|
|
) -> (Tensor<B, 3>, KVCache<B>) {
|
|
let residual = x.clone();
|
|
let x_norm = self.input_layernorm.forward(x);
|
|
|
|
let (attn_out, new_cache) = self.self_attn.forward(x_norm, offset, cache);
|
|
|
|
let x = residual + attn_out;
|
|
|
|
let residual = x.clone();
|
|
let x_norm = self.post_attention_layernorm.forward(x);
|
|
let mlp_out = self.mlp.forward(x_norm);
|
|
|
|
let x = residual + mlp_out;
|
|
(x, new_cache)
|
|
}
|
|
}
|
|
|
|
#[derive(Module, Debug)]
|
|
pub struct LlamaModel<B: Backend> {
|
|
pub embed_tokens: Param<Tensor<B, 2>>,
|
|
pub layers: Vec<LlamaBlock<B>>,
|
|
pub norm: RmsNorm<B>,
|
|
pub lm_head: Param<Tensor<B, 2>>, // For tie_word_embeddings this can point to embed_tokens
|
|
}
|
|
|
|
impl<B: Backend> LlamaModel<B> {
|
|
pub fn new(config: &SmolLMConfig, device: &B::Device) -> Self {
|
|
let embed = Tensor::zeros([config.vocab_size, config.hidden_size], device);
|
|
let lm_head = Tensor::zeros([config.vocab_size, config.hidden_size], device);
|
|
|
|
let mut layers = Vec::new();
|
|
for _ in 0..config.num_hidden_layers {
|
|
layers.push(LlamaBlock::new(config, device));
|
|
}
|
|
|
|
Self {
|
|
embed_tokens: Param::from_tensor(embed),
|
|
layers,
|
|
norm: RmsNorm::new(config.hidden_size, config.rms_norm_eps, device),
|
|
lm_head: Param::from_tensor(lm_head),
|
|
}
|
|
}
|
|
|
|
pub fn forward(
|
|
&self,
|
|
input_ids: Tensor<B, 2, Int>,
|
|
offset: usize,
|
|
caches: &mut Vec<Option<KVCache<B>>>
|
|
) -> Tensor<B, 3> {
|
|
let [_batch, _seq_len] = input_ids.dims();
|
|
|
|
let mut x = burn::tensor::module::embedding(self.embed_tokens.val(), input_ids);
|
|
|
|
for (i, layer) in self.layers.iter().enumerate() {
|
|
let cache = caches[i].take();
|
|
let (out, new_cache) = layer.forward(x, offset, cache);
|
|
x = out;
|
|
caches[i] = Some(new_cache);
|
|
}
|
|
|
|
x = self.norm.forward(x);
|
|
|
|
// Matmul with lm_head (or embed_tokens if tied) to get logits
|
|
// Notice: lm_head is typically [vocab_size, hidden_size] in HF, so we swap dims
|
|
x.matmul(self.lm_head.val().swap_dims(0, 1).unsqueeze())
|
|
}
|
|
}
|