use burn::module::{Module, Param}; use burn::tensor::{backend::Backend, Tensor, Int}; use super::modules::{RmsNorm, Mlp}; use super::attention::{Attention, KVCache}; use super::config::SmolLMConfig; #[derive(Module, Debug)] pub struct LlamaBlock { pub self_attn: Attention, pub mlp: Mlp, pub input_layernorm: RmsNorm, pub post_attention_layernorm: RmsNorm, } impl LlamaBlock { pub fn new(config: &SmolLMConfig, device: &B::Device) -> Self { Self { self_attn: Attention::new(config, device), mlp: Mlp::new(config.hidden_size, config.intermediate_size, device), input_layernorm: RmsNorm::new(config.hidden_size, config.rms_norm_eps, device), post_attention_layernorm: RmsNorm::new(config.hidden_size, config.rms_norm_eps, device), } } pub fn forward( &self, x: Tensor, offset: usize, cache: Option> ) -> (Tensor, KVCache) { let residual = x.clone(); let x_norm = self.input_layernorm.forward(x); let (attn_out, new_cache) = self.self_attn.forward(x_norm, offset, cache); let x = residual + attn_out; let residual = x.clone(); let x_norm = self.post_attention_layernorm.forward(x); let mlp_out = self.mlp.forward(x_norm); let x = residual + mlp_out; (x, new_cache) } } #[derive(Module, Debug)] pub struct LlamaModel { pub embed_tokens: Param>, pub layers: Vec>, pub norm: RmsNorm, pub lm_head: Param>, // For tie_word_embeddings this can point to embed_tokens } impl LlamaModel { pub fn new(config: &SmolLMConfig, device: &B::Device) -> Self { let embed = Tensor::zeros([config.vocab_size, config.hidden_size], device); let lm_head = Tensor::zeros([config.vocab_size, config.hidden_size], device); let mut layers = Vec::new(); for _ in 0..config.num_hidden_layers { layers.push(LlamaBlock::new(config, device)); } Self { embed_tokens: Param::from_tensor(embed), layers, norm: RmsNorm::new(config.hidden_size, config.rms_norm_eps, device), lm_head: Param::from_tensor(lm_head), } } pub fn forward( &self, input_ids: Tensor, offset: usize, caches: &mut Vec>> ) -> Tensor { let [_batch, _seq_len] = input_ids.dims(); let mut x = burn::tensor::module::embedding(self.embed_tokens.val(), input_ids); for (i, layer) in self.layers.iter().enumerate() { let cache = caches[i].take(); let (out, new_cache) = layer.forward(x, offset, cache); x = out; caches[i] = Some(new_cache); } x = self.norm.forward(x); // Matmul with lm_head (or embed_tokens if tied) to get logits // Notice: lm_head is typically [vocab_size, hidden_size] in HF, so we swap dims x.matmul(self.lm_head.val().swap_dims(0, 1).unsqueeze()) } }