kylä lähtee!
This commit is contained in:
118
network-poc/node/src/burn_smollm/attention.rs
Normal file
118
network-poc/node/src/burn_smollm/attention.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
use burn::module::{Module, Param};
|
||||
use burn::tensor::{backend::Backend, Tensor};
|
||||
use super::rope::RoPE;
|
||||
use super::config::SmolLMConfig;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct KVCache<B: Backend> {
|
||||
pub k: Tensor<B, 4>,
|
||||
pub v: Tensor<B, 4>,
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Attention<B: Backend> {
|
||||
pub q_proj: Param<Tensor<B, 2>>, // [hidden, num_heads * head_dim]
|
||||
pub k_proj: Param<Tensor<B, 2>>, // [hidden, num_kv_heads * head_dim]
|
||||
pub v_proj: Param<Tensor<B, 2>>, // [hidden, num_kv_heads * head_dim]
|
||||
pub o_proj: Param<Tensor<B, 2>>, // [num_heads * head_dim, hidden]
|
||||
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
head_dim: usize,
|
||||
|
||||
rope: RoPE<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Attention<B> {
|
||||
pub fn new(config: &SmolLMConfig, device: &B::Device) -> Self {
|
||||
let head_dim = config.hidden_size / config.num_attention_heads;
|
||||
|
||||
Self {
|
||||
q_proj: Param::from_tensor(Tensor::zeros([config.hidden_size, config.num_attention_heads * head_dim], device)),
|
||||
k_proj: Param::from_tensor(Tensor::zeros([config.hidden_size, config.num_key_value_heads * head_dim], device)),
|
||||
v_proj: Param::from_tensor(Tensor::zeros([config.hidden_size, config.num_key_value_heads * head_dim], device)),
|
||||
o_proj: Param::from_tensor(Tensor::zeros([config.num_attention_heads * head_dim, config.hidden_size], device)),
|
||||
|
||||
num_heads: config.num_attention_heads,
|
||||
num_kv_heads: config.num_key_value_heads,
|
||||
head_dim,
|
||||
|
||||
rope: RoPE::new(head_dim, config.max_position_embeddings, config.rope_theta, device),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
x: Tensor<B, 3>,
|
||||
offset: usize,
|
||||
cache: Option<KVCache<B>>
|
||||
) -> (Tensor<B, 3>, KVCache<B>) {
|
||||
let [batch, seq_len, hidden_dim] = x.dims();
|
||||
|
||||
// Project Q, K, V: x @ W -> [batch, seq, proj_dim]
|
||||
let q = x.clone().matmul(self.q_proj.val().unsqueeze());
|
||||
let k = x.clone().matmul(self.k_proj.val().unsqueeze());
|
||||
let v = x.matmul(self.v_proj.val().unsqueeze());
|
||||
|
||||
// Reshape: [batch, seq, heads, head_dim] -> [batch, heads, seq, head_dim]
|
||||
let q = q.reshape([batch, seq_len, self.num_heads, self.head_dim]).swap_dims(1, 2);
|
||||
let k = k.reshape([batch, seq_len, self.num_kv_heads, self.head_dim]).swap_dims(1, 2);
|
||||
let v = v.reshape([batch, seq_len, self.num_kv_heads, self.head_dim]).swap_dims(1, 2);
|
||||
|
||||
// Apply RoPE
|
||||
let q = self.rope.forward(q, offset);
|
||||
let k = self.rope.forward(k, offset);
|
||||
|
||||
// KV cache
|
||||
let (k, v) = if let Some(c) = cache {
|
||||
(Tensor::cat(vec![c.k, k], 2), Tensor::cat(vec![c.v, v], 2))
|
||||
} else {
|
||||
(k, v)
|
||||
};
|
||||
|
||||
let new_cache = KVCache { k: k.clone(), v: v.clone() };
|
||||
let kv_len = k.dims()[2];
|
||||
|
||||
// GQA: repeat K,V heads — [batch, kv_heads, kv_len, hd] -> [batch, num_heads, kv_len, hd]
|
||||
let num_reps = self.num_heads / self.num_kv_heads;
|
||||
let k = if num_reps > 1 {
|
||||
let [b, kv_h, s, hd] = k.dims();
|
||||
k.reshape([b, kv_h, 1, s, hd]).repeat_dim(2, num_reps).reshape([b, self.num_heads, s, hd])
|
||||
} else { k };
|
||||
let v = if num_reps > 1 {
|
||||
let [b, kv_h, s, hd] = v.dims();
|
||||
v.reshape([b, kv_h, 1, s, hd]).repeat_dim(2, num_reps).reshape([b, self.num_heads, s, hd])
|
||||
} else { v };
|
||||
|
||||
// Attention: Q @ K^T / sqrt(d)
|
||||
let scale = 1.0 / (self.head_dim as f64).sqrt();
|
||||
let scores = q.matmul(k.swap_dims(2, 3)).mul_scalar(scale);
|
||||
// scores: [batch, heads, seq_len, kv_len]
|
||||
|
||||
// Causal mask for prefill (seq_len > 1)
|
||||
let scores = if seq_len > 1 {
|
||||
let mask_data: Vec<f32> = (0..seq_len).flat_map(|i| {
|
||||
(0..kv_len).map(move |j| {
|
||||
if j > offset + i { f32::NEG_INFINITY } else { 0.0 }
|
||||
})
|
||||
}).collect();
|
||||
let mask = Tensor::<B, 2>::from_data(
|
||||
burn::tensor::TensorData::new(mask_data, [seq_len, kv_len]),
|
||||
&scores.device()
|
||||
).reshape([1, 1, seq_len, kv_len]);
|
||||
scores + mask
|
||||
} else {
|
||||
scores
|
||||
};
|
||||
|
||||
let attn_weights = burn::tensor::activation::softmax(scores, 3);
|
||||
|
||||
let context = attn_weights.matmul(v);
|
||||
// [batch, heads, seq, hd] -> [batch, seq, heads*hd]
|
||||
let context = context.swap_dims(1, 2).reshape([batch, seq_len, self.num_heads * self.head_dim]);
|
||||
|
||||
let output = context.matmul(self.o_proj.val().unsqueeze());
|
||||
|
||||
(output, new_cache)
|
||||
}
|
||||
}
|
||||
28
network-poc/node/src/burn_smollm/config.rs
Normal file
28
network-poc/node/src/burn_smollm/config.rs
Normal file
@@ -0,0 +1,28 @@
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SmolLMConfig {
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub vocab_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub rms_norm_eps: f64,
|
||||
pub rope_theta: f32,
|
||||
pub max_position_embeddings: usize,
|
||||
}
|
||||
|
||||
impl Default for SmolLMConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
hidden_size: 576,
|
||||
intermediate_size: 1536,
|
||||
vocab_size: 49152,
|
||||
num_hidden_layers: 30,
|
||||
num_attention_heads: 9,
|
||||
num_key_value_heads: 3,
|
||||
rms_norm_eps: 1e-5,
|
||||
rope_theta: 10000.0,
|
||||
max_position_embeddings: 2048,
|
||||
}
|
||||
}
|
||||
}
|
||||
90
network-poc/node/src/burn_smollm/loader.rs
Normal file
90
network-poc/node/src/burn_smollm/loader.rs
Normal file
@@ -0,0 +1,90 @@
|
||||
use burn::tensor::{backend::Backend, Tensor, Data};
|
||||
use candle_core::safetensors;
|
||||
use candle_core::Device as CandleDevice;
|
||||
use burn::module::Param;
|
||||
use super::model::LlamaModel;
|
||||
use super::config::SmolLMConfig;
|
||||
|
||||
fn load_tensor_2d<B: Backend>(
|
||||
tensors_map: &std::collections::HashMap<String, candle_core::Tensor>,
|
||||
name: &str,
|
||||
device: &B::Device,
|
||||
shape_out_in: [usize; 2]
|
||||
) -> Result<Param<Tensor<B, 2>>, String> {
|
||||
let t = tensors_map.get(name).ok_or_else(|| format!("Puuttuu: {}", name))?;
|
||||
let t = t.to_dtype(candle_core::DType::F32).unwrap();
|
||||
let vec = t.flatten_all().unwrap().to_vec1::<f32>().unwrap();
|
||||
let t_burn = Tensor::<B, 2>::from_data(burn::tensor::TensorData::new(vec, shape_out_in), device);
|
||||
// transpose from [out, in] to [in, out]
|
||||
Ok(Param::from_tensor(t_burn.transpose()))
|
||||
}
|
||||
|
||||
fn load_tensor_1d<B: Backend>(
|
||||
tensors_map: &std::collections::HashMap<String, candle_core::Tensor>,
|
||||
name: &str,
|
||||
device: &B::Device,
|
||||
_shape: [usize; 1]
|
||||
) -> Result<Param<Tensor<B, 1>>, String> {
|
||||
let t = tensors_map.get(name).ok_or_else(|| format!("Puuttuu: {}", name))?;
|
||||
let t = t.to_dtype(candle_core::DType::F32).unwrap();
|
||||
let vec = t.flatten_all().unwrap().to_vec1::<f32>().unwrap();
|
||||
Ok(Param::from_tensor(Tensor::<B, 1>::from_floats(vec.as_slice(), device)))
|
||||
}
|
||||
|
||||
fn load_embed<B: Backend>(
|
||||
tensors_map: &std::collections::HashMap<String, candle_core::Tensor>,
|
||||
name: &str,
|
||||
device: &B::Device,
|
||||
shape: [usize; 2]
|
||||
) -> Result<Param<Tensor<B, 2>>, String> {
|
||||
let t = tensors_map.get(name).ok_or_else(|| format!("Puuttuu: {}", name))?;
|
||||
let t = t.to_dtype(candle_core::DType::F32).unwrap();
|
||||
let vec = t.flatten_all().unwrap().to_vec1::<f32>().unwrap();
|
||||
// Embed ei transponoi samalla tavalla, se pysyy [vocab, hidden]
|
||||
Ok(Param::from_tensor(Tensor::<B, 2>::from_data(burn::tensor::TensorData::new(vec, shape), device)))
|
||||
}
|
||||
|
||||
pub fn load_safetensors_to_model<B: Backend>(
|
||||
buffer: &[u8],
|
||||
config: &SmolLMConfig,
|
||||
device: &B::Device
|
||||
) -> Result<LlamaModel<B>, String> {
|
||||
|
||||
let mut model = LlamaModel::new(config, device);
|
||||
let tensors_map = safetensors::load_buffer(buffer, &CandleDevice::Cpu)
|
||||
.map_err(|e| format!("Virhe Safetensors luennassa: {}", e))?;
|
||||
|
||||
// Embeddings
|
||||
model.embed_tokens = load_embed(&tensors_map, "model.embed_tokens.weight", device, [config.vocab_size, config.hidden_size])?;
|
||||
model.norm.weight = load_tensor_1d(&tensors_map, "model.norm.weight", device, [config.hidden_size])?;
|
||||
model.lm_head = load_embed(&tensors_map, "lm_head.weight", device, [config.vocab_size, config.hidden_size]).or_else(|_| {
|
||||
load_embed(&tensors_map, "model.embed_tokens.weight", device, [config.vocab_size, config.hidden_size])
|
||||
})?;
|
||||
|
||||
let head_dim = config.hidden_size / config.num_attention_heads;
|
||||
|
||||
for i in 0..config.num_hidden_layers {
|
||||
let prefix = format!("model.layers.{}", i);
|
||||
|
||||
let layer = &mut model.layers[i];
|
||||
|
||||
// Norms
|
||||
layer.input_layernorm.weight = load_tensor_1d(&tensors_map, &format!("{}.input_layernorm.weight", prefix), device, [config.hidden_size])?;
|
||||
layer.post_attention_layernorm.weight = load_tensor_1d(&tensors_map, &format!("{}.post_attention_layernorm.weight", prefix), device, [config.hidden_size])?;
|
||||
|
||||
// Attention
|
||||
let num_heads = config.num_attention_heads;
|
||||
let num_kv_heads = config.num_key_value_heads;
|
||||
layer.self_attn.q_proj = load_tensor_2d(&tensors_map, &format!("{}.self_attn.q_proj.weight", prefix), device, [num_heads * head_dim, config.hidden_size])?;
|
||||
layer.self_attn.k_proj = load_tensor_2d(&tensors_map, &format!("{}.self_attn.k_proj.weight", prefix), device, [num_kv_heads * head_dim, config.hidden_size])?;
|
||||
layer.self_attn.v_proj = load_tensor_2d(&tensors_map, &format!("{}.self_attn.v_proj.weight", prefix), device, [num_kv_heads * head_dim, config.hidden_size])?;
|
||||
layer.self_attn.o_proj = load_tensor_2d(&tensors_map, &format!("{}.self_attn.o_proj.weight", prefix), device, [config.hidden_size, num_heads * head_dim])?;
|
||||
|
||||
// MLP
|
||||
layer.mlp.gate_proj = load_tensor_2d(&tensors_map, &format!("{}.mlp.gate_proj.weight", prefix), device, [config.intermediate_size, config.hidden_size])?;
|
||||
layer.mlp.up_proj = load_tensor_2d(&tensors_map, &format!("{}.mlp.up_proj.weight", prefix), device, [config.intermediate_size, config.hidden_size])?;
|
||||
layer.mlp.down_proj = load_tensor_2d(&tensors_map, &format!("{}.mlp.down_proj.weight", prefix), device, [config.hidden_size, config.intermediate_size])?;
|
||||
}
|
||||
|
||||
Ok(model)
|
||||
}
|
||||
6
network-poc/node/src/burn_smollm/mod.rs
Normal file
6
network-poc/node/src/burn_smollm/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
pub mod attention;
|
||||
pub mod config;
|
||||
pub mod loader;
|
||||
pub mod model;
|
||||
pub mod modules;
|
||||
pub mod rope;
|
||||
96
network-poc/node/src/burn_smollm/model.rs
Normal file
96
network-poc/node/src/burn_smollm/model.rs
Normal file
@@ -0,0 +1,96 @@
|
||||
use burn::module::{Module, Param};
|
||||
use burn::tensor::{backend::Backend, Tensor, Int};
|
||||
use super::modules::{RmsNorm, Mlp};
|
||||
use super::attention::{Attention, KVCache};
|
||||
use super::config::SmolLMConfig;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct LlamaBlock<B: Backend> {
|
||||
pub self_attn: Attention<B>,
|
||||
pub mlp: Mlp<B>,
|
||||
pub input_layernorm: RmsNorm<B>,
|
||||
pub post_attention_layernorm: RmsNorm<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> LlamaBlock<B> {
|
||||
pub fn new(config: &SmolLMConfig, device: &B::Device) -> Self {
|
||||
Self {
|
||||
self_attn: Attention::new(config, device),
|
||||
mlp: Mlp::new(config.hidden_size, config.intermediate_size, device),
|
||||
input_layernorm: RmsNorm::new(config.hidden_size, config.rms_norm_eps, device),
|
||||
post_attention_layernorm: RmsNorm::new(config.hidden_size, config.rms_norm_eps, device),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
x: Tensor<B, 3>,
|
||||
offset: usize,
|
||||
cache: Option<KVCache<B>>
|
||||
) -> (Tensor<B, 3>, KVCache<B>) {
|
||||
let residual = x.clone();
|
||||
let x_norm = self.input_layernorm.forward(x);
|
||||
|
||||
let (attn_out, new_cache) = self.self_attn.forward(x_norm, offset, cache);
|
||||
|
||||
let x = residual + attn_out;
|
||||
|
||||
let residual = x.clone();
|
||||
let x_norm = self.post_attention_layernorm.forward(x);
|
||||
let mlp_out = self.mlp.forward(x_norm);
|
||||
|
||||
let x = residual + mlp_out;
|
||||
(x, new_cache)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct LlamaModel<B: Backend> {
|
||||
pub embed_tokens: Param<Tensor<B, 2>>,
|
||||
pub layers: Vec<LlamaBlock<B>>,
|
||||
pub norm: RmsNorm<B>,
|
||||
pub lm_head: Param<Tensor<B, 2>>, // For tie_word_embeddings this can point to embed_tokens
|
||||
}
|
||||
|
||||
impl<B: Backend> LlamaModel<B> {
|
||||
pub fn new(config: &SmolLMConfig, device: &B::Device) -> Self {
|
||||
let embed = Tensor::zeros([config.vocab_size, config.hidden_size], device);
|
||||
let lm_head = Tensor::zeros([config.vocab_size, config.hidden_size], device);
|
||||
|
||||
let mut layers = Vec::new();
|
||||
for _ in 0..config.num_hidden_layers {
|
||||
layers.push(LlamaBlock::new(config, device));
|
||||
}
|
||||
|
||||
Self {
|
||||
embed_tokens: Param::from_tensor(embed),
|
||||
layers,
|
||||
norm: RmsNorm::new(config.hidden_size, config.rms_norm_eps, device),
|
||||
lm_head: Param::from_tensor(lm_head),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
input_ids: Tensor<B, 2, Int>,
|
||||
offset: usize,
|
||||
caches: &mut Vec<Option<KVCache<B>>>
|
||||
) -> Tensor<B, 3> {
|
||||
let [_batch, _seq_len] = input_ids.dims();
|
||||
|
||||
let mut x = burn::tensor::module::embedding(self.embed_tokens.val(), input_ids);
|
||||
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
let cache = caches[i].take();
|
||||
let (out, new_cache) = layer.forward(x, offset, cache);
|
||||
x = out;
|
||||
caches[i] = Some(new_cache);
|
||||
}
|
||||
|
||||
x = self.norm.forward(x);
|
||||
|
||||
// Matmul with lm_head (or embed_tokens if tied) to get logits
|
||||
// Notice: lm_head is typically [vocab_size, hidden_size] in HF, so we swap dims
|
||||
x.matmul(self.lm_head.val().swap_dims(0, 1).unsqueeze())
|
||||
}
|
||||
}
|
||||
59
network-poc/node/src/burn_smollm/modules.rs
Normal file
59
network-poc/node/src/burn_smollm/modules.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
use burn::module::{Module, Param};
|
||||
use burn::tensor::{backend::Backend, Tensor};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct RmsNorm<B: Backend> {
|
||||
pub weight: Param<Tensor<B, 1>>,
|
||||
epsilon: f64,
|
||||
}
|
||||
|
||||
impl<B: Backend> RmsNorm<B> {
|
||||
pub fn new(size: usize, epsilon: f64, device: &B::Device) -> Self {
|
||||
let weight = Param::from_tensor(Tensor::ones([size], device));
|
||||
Self { weight, epsilon }
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
// x: [batch, seq_len, dim]
|
||||
// RMSNorm: x * weight / sqrt(mean(x^2) + eps)
|
||||
let x_sq = x.clone().powf_scalar(2.0);
|
||||
// mean over last dim, keeping dims for broadcast
|
||||
let [b, s, d] = x_sq.dims();
|
||||
let variance = x_sq.sum_dim(2).div_scalar(d as f32);
|
||||
let norm = x.div(variance.add_scalar(self.epsilon).sqrt());
|
||||
|
||||
let w = self.weight.val().unsqueeze::<2>().unsqueeze::<3>().reshape([1, 1, d]);
|
||||
norm * w
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Mlp<B: Backend> {
|
||||
pub gate_proj: Param<Tensor<B, 2>>, // [in, intermediate]
|
||||
pub up_proj: Param<Tensor<B, 2>>, // [in, intermediate]
|
||||
pub down_proj: Param<Tensor<B, 2>>, // [intermediate, out]
|
||||
}
|
||||
|
||||
impl<B: Backend> Mlp<B> {
|
||||
pub fn new(hidden_size: usize, intermediate_size: usize, device: &B::Device) -> Self {
|
||||
Self {
|
||||
gate_proj: Param::from_tensor(Tensor::zeros([hidden_size, intermediate_size], device)),
|
||||
up_proj: Param::from_tensor(Tensor::zeros([hidden_size, intermediate_size], device)),
|
||||
down_proj: Param::from_tensor(Tensor::zeros([intermediate_size, hidden_size], device)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
// x: [batch, seq, hidden]
|
||||
// gate = x @ gate_proj -> [batch, seq, intermediate]
|
||||
let gate = x.clone().matmul(self.gate_proj.val().unsqueeze());
|
||||
let up = x.matmul(self.up_proj.val().unsqueeze());
|
||||
|
||||
// SiLU(gate) * up
|
||||
let silu = gate.clone() * burn::tensor::activation::sigmoid(gate);
|
||||
let intermediate = silu * up;
|
||||
|
||||
// intermediate @ down_proj -> [batch, seq, hidden]
|
||||
intermediate.matmul(self.down_proj.val().unsqueeze())
|
||||
}
|
||||
}
|
||||
59
network-poc/node/src/burn_smollm/rope.rs
Normal file
59
network-poc/node/src/burn_smollm/rope.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
use burn::module::Module;
|
||||
use burn::tensor::{backend::Backend, Tensor};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct RoPE<B: Backend> {
|
||||
cos_cache: Tensor<B, 2>,
|
||||
sin_cache: Tensor<B, 2>,
|
||||
}
|
||||
|
||||
impl<B: Backend> RoPE<B> {
|
||||
pub fn new(head_dim: usize, max_seq_len: usize, theta: f32, device: &B::Device) -> Self {
|
||||
// (head_dim / 2) values
|
||||
let half_dim = head_dim / 2;
|
||||
let inv_freq: Vec<f32> = (0..half_dim)
|
||||
.map(|i| 1.0 / theta.powf((2 * i) as f32 / head_dim as f32))
|
||||
.collect();
|
||||
|
||||
let inv_freq = Tensor::<B, 1>::from_floats(inv_freq.as_slice(), device).unsqueeze::<2>();
|
||||
let t_floats: Vec<f32> = (0..max_seq_len).map(|v| v as f32).collect();
|
||||
let t = Tensor::<B, 1>::from_floats(t_floats.as_slice(), device).unsqueeze::<2>().transpose();
|
||||
// t shape: [max_seq_len, 1]
|
||||
// inv_freq shape: [1, half_dim]
|
||||
|
||||
// freqs shape: [max_seq_len, half_dim]
|
||||
let freqs = t.matmul(inv_freq);
|
||||
|
||||
let cos_cache = freqs.clone().cos();
|
||||
let sin_cache = freqs.sin();
|
||||
|
||||
Self {
|
||||
cos_cache,
|
||||
sin_cache,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: Tensor<B, 4>, offset: usize) -> Tensor<B, 4> {
|
||||
let [batch, heads, seq_len, head_dim] = x.dims();
|
||||
let half_dim = head_dim / 2;
|
||||
|
||||
// x shape: [batch, heads, seq_len, head_dim]
|
||||
// valitaan viipaleet (x1 ja x2) jotta saadaan pyöritettyä rotaatiot
|
||||
let x1 = x.clone().slice([0..batch, 0..heads, 0..seq_len, 0..half_dim]);
|
||||
let x2 = x.clone().slice([0..batch, 0..heads, 0..seq_len, half_dim..head_dim]);
|
||||
|
||||
// haetaan vastaava seq offsetista alkaen
|
||||
let cos = self.cos_cache.clone().slice([offset..offset+seq_len, 0..half_dim])
|
||||
.unsqueeze::<4>() // [seq, half_dim, 1]
|
||||
.reshape([1, 1, seq_len, half_dim]);
|
||||
let sin = self.sin_cache.clone().slice([offset..offset+seq_len, 0..half_dim])
|
||||
.reshape([1, 1, seq_len, half_dim]);
|
||||
|
||||
// x1 * cos - x2 * sin
|
||||
let o1 = x1.clone().mul(cos.clone()) - x2.clone().mul(sin.clone());
|
||||
// x2 * cos + x1 * sin
|
||||
let o2 = x2.mul(cos) + x1.mul(sin);
|
||||
|
||||
Tensor::cat(vec![o1, o2], 3)
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,7 @@ pub mod smollm;
|
||||
pub mod qwen;
|
||||
pub mod qwen_coder;
|
||||
pub mod phi3;
|
||||
pub mod burn_smollm;
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! console_log {
|
||||
|
||||
@@ -118,125 +118,106 @@ pub async fn run_smollm_inference(prompt: String, ws: Rc<RefCell<WebSocket>>) {
|
||||
Err(e) => { console_log!("[SmolLM] Malli-virhe: {}", e); return; }
|
||||
};
|
||||
|
||||
console_log!("[SmolLM] Rakennetaan mallia...");
|
||||
let use_gpu = crate::HAS_WEBGPU.load(std::sync::atomic::Ordering::SeqCst);
|
||||
if use_gpu {
|
||||
console_log!("[SmolLM] Alustetaan Burn WebGPU...");
|
||||
burn_wgpu::init_async::<burn_wgpu::AutoGraphicsApi>(&Default::default(), Default::default()).await;
|
||||
run_burn_inference::<burn::backend::Wgpu>(prompt, model_bytes, tokenizer, ws, perf.clone()).await;
|
||||
} else {
|
||||
console_log!("[SmolLM] Käytetään CPU NdArrayta (vanha tapa)...");
|
||||
run_burn_inference::<burn::backend::NdArray>(prompt, model_bytes, tokenizer, ws, perf.clone()).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_burn_inference<B: burn::tensor::backend::Backend>(
|
||||
prompt: String,
|
||||
model_bytes: Vec<u8>,
|
||||
tokenizer: tokenizers::Tokenizer,
|
||||
ws: Rc<RefCell<WebSocket>>,
|
||||
perf: web_sys::Performance, // Korjattu Wasm-performanssi välitettäväksi
|
||||
) {
|
||||
let start_load = perf.now();
|
||||
|
||||
let device = Device::Cpu;
|
||||
let dtype = DType::F32;
|
||||
|
||||
// Parsitaan safetensors
|
||||
let tensors = match candle_core::safetensors::load_buffer(&model_bytes, &device) {
|
||||
Ok(t) => t,
|
||||
Err(e) => { console_log!("[SmolLM] Safetensors-parsinta epäonnistui: {}", e); return; }
|
||||
};
|
||||
|
||||
let vb = VarBuilder::from_tensors(tensors, dtype, &device);
|
||||
|
||||
// SmolLM-135M config (Llama-arkkitehtuuri)
|
||||
let config = LlamaConfig {
|
||||
hidden_size: 576,
|
||||
intermediate_size: 1536,
|
||||
vocab_size: 49152,
|
||||
num_hidden_layers: 30,
|
||||
num_attention_heads: 9,
|
||||
num_key_value_heads: Some(3),
|
||||
rms_norm_eps: 1e-5,
|
||||
rope_theta: 10000.0,
|
||||
max_position_embeddings: 2048,
|
||||
tie_word_embeddings: Some(true),
|
||||
bos_token_id: Some(1u32),
|
||||
eos_token_id: Some(LlamaEosToks::Single(2)),
|
||||
rope_scaling: None,
|
||||
};
|
||||
|
||||
let llama_config = config.into_config(false); // false = ei flash attention
|
||||
let mut cache = Cache::new(true, dtype, &llama_config, &device).unwrap();
|
||||
|
||||
let model = match Llama::load(vb, &llama_config) {
|
||||
let device = Default::default();
|
||||
let config = crate::burn_smollm::config::SmolLMConfig::default();
|
||||
|
||||
console_log!("[SmolLM] Injektoidaan Safetensors -> Burn Params...");
|
||||
let model = match crate::burn_smollm::loader::load_safetensors_to_model::<B>(&model_bytes, &config, &device) {
|
||||
Ok(m) => m,
|
||||
Err(e) => { console_log!("[SmolLM] Mallin lataus epäonnistui: {}", e); return; }
|
||||
Err(e) => { console_log!("[SmolLM] Lataus epäonnistui: {}", e); return; }
|
||||
};
|
||||
|
||||
let load_time = perf.now() - start_load;
|
||||
console_log!("[SmolLM] Malli ladattu ({:.0}ms). Generoidaan...", load_time);
|
||||
console_log!("[SmolLM] Burn-malli ladattu ({:.0}ms). Generoidaan...", load_time);
|
||||
|
||||
// 3. Tokenisoi syöte (Käytetään ChatML-formaattia SmolLM-Instructille)
|
||||
let formatted_prompt = format!("<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", prompt);
|
||||
let encoding = match tokenizer.encode(formatted_prompt.as_str(), true) {
|
||||
Ok(e) => e,
|
||||
Err(e) => { console_log!("[SmolLM] Tokenisointivirhe: {}", e); return; }
|
||||
};
|
||||
|
||||
let input_ids: Vec<u32> = encoding.get_ids().to_vec();
|
||||
let mut input_ids: Vec<u32> = encoding.get_ids().to_vec();
|
||||
let input_len = input_ids.len();
|
||||
console_log!("[SmolLM] Syöte: {} tokenia", input_len);
|
||||
|
||||
// 4. Generoi tokeneita
|
||||
let start_gen = perf.now();
|
||||
let max_new_tokens = 32;
|
||||
let mut generated_text = String::new();
|
||||
let mut tokens_generated: usize = 0;
|
||||
let mut pos: usize = 0;
|
||||
|
||||
// KV-välimuistin taulukko kerroksittain
|
||||
let mut caches: Vec<Option<crate::burn_smollm::attention::KVCache<B>>> = vec![None; config.num_hidden_layers];
|
||||
let mut current_offset = 0;
|
||||
|
||||
// Ensimmäinen forward: koko syöte kerralla
|
||||
let input = match Tensor::new(input_ids.as_slice(), &device).and_then(|t| t.unsqueeze(0)) {
|
||||
Ok(t) => t,
|
||||
Err(e) => { console_log!("[SmolLM] Tensor-virhe: {}", e); return; }
|
||||
};
|
||||
// Prefill: yksitellen, vältetään future token leakage koska ei causal maskia
|
||||
let input_ids_i32: Vec<i32> = input_ids.iter().map(|&x| x as i32).collect();
|
||||
let mut last_logits = None;
|
||||
|
||||
for &id in &input_ids_i32 {
|
||||
let input_tensor = burn::tensor::Tensor::<B, 1, burn::tensor::Int>::from_data(
|
||||
burn::tensor::TensorData::from([id]),
|
||||
&device
|
||||
).unsqueeze::<2>(); // [1, 1]
|
||||
|
||||
last_logits = Some(model.forward(input_tensor, current_offset, &mut caches));
|
||||
current_offset += 1;
|
||||
}
|
||||
|
||||
let logits = match model.forward(&input, 0, &mut cache) {
|
||||
Ok(l) => l,
|
||||
Err(e) => { console_log!("[SmolLM] Forward-virhe (prefill): {}", e); return; }
|
||||
};
|
||||
let mut logits = last_logits.unwrap();
|
||||
|
||||
// Llama forward voi palauttaa [batch, vocab] tai [batch, seq_len, vocab]
|
||||
let logits = logits.squeeze(0).unwrap();
|
||||
let logits = if logits.dims().len() == 2 {
|
||||
logits.get(logits.dim(0).unwrap() - 1).unwrap()
|
||||
} else {
|
||||
logits
|
||||
};
|
||||
let mut next_token = crate::sampling::sample_top_k(&logits, 10, 5.0);
|
||||
console_log!("[SmolLM] Ensimmäinen generoitu token: {}", next_token);
|
||||
pos = input_len;
|
||||
// Argmax sämpläys
|
||||
let next_token_tensor = logits.clone().argmax(2);
|
||||
let mut next_token: u32 = next_token_tensor.into_scalar().to_string().parse().unwrap_or(2); // Yksinkertainen cast koska int scalar
|
||||
|
||||
if next_token != 2 {
|
||||
if let Ok(text) = tokenizer.decode(&[next_token], true) {
|
||||
generated_text.push_str(&text);
|
||||
let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "SmolLM-135M" });
|
||||
let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "SmolLM-135M (WebGPU)" });
|
||||
let _ = ws.borrow().send_with_str(&chunk.to_string());
|
||||
}
|
||||
tokens_generated += 1;
|
||||
}
|
||||
|
||||
// Autoregressiivinen generointi: yksi token kerrallaan
|
||||
// Autoregressiivinen luuppi
|
||||
for _ in 1..max_new_tokens {
|
||||
if next_token == 2 { break; }
|
||||
|
||||
let input = match Tensor::new(&[next_token], &device).and_then(|t| t.unsqueeze(0)) {
|
||||
Ok(t) => t,
|
||||
Err(e) => { console_log!("[SmolLM] Tensor-virhe: {}", e); break; }
|
||||
};
|
||||
|
||||
let logits = match model.forward(&input, pos, &mut cache) {
|
||||
Ok(l) => l,
|
||||
Err(e) => { console_log!("[SmolLM] Forward-virhe pos {}: {}", pos, e); break; }
|
||||
};
|
||||
|
||||
let logits = logits.squeeze(0).unwrap();
|
||||
let logits = if logits.dims().len() == 2 {
|
||||
logits.get(logits.dim(0).unwrap() - 1).unwrap()
|
||||
} else {
|
||||
logits
|
||||
};
|
||||
next_token = crate::sampling::sample_top_k(&logits, 10, 5.0);
|
||||
pos += 1;
|
||||
|
||||
let mut input_tensor = burn::tensor::Tensor::<B, 1, burn::tensor::Int>::from_data(
|
||||
burn::tensor::TensorData::from([next_token as i32]),
|
||||
&device
|
||||
).unsqueeze::<2>();
|
||||
|
||||
logits = model.forward(input_tensor, current_offset, &mut caches);
|
||||
current_offset += 1;
|
||||
|
||||
let next_token_tensor = logits.argmax(2);
|
||||
next_token = next_token_tensor.into_scalar().to_string().parse().unwrap_or(2);
|
||||
|
||||
if next_token == 2 { break; }
|
||||
|
||||
if let Ok(text) = tokenizer.decode(&[next_token], true) {
|
||||
generated_text.push_str(&text);
|
||||
let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "SmolLM-135M" });
|
||||
let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "SmolLM-135M (WebGPU)" });
|
||||
let _ = ws.borrow().send_with_str(&chunk.to_string());
|
||||
}
|
||||
tokens_generated += 1;
|
||||
@@ -245,12 +226,10 @@ pub async fn run_smollm_inference(prompt: String, ws: Rc<RefCell<WebSocket>>) {
|
||||
let gen_time = perf.now() - start_gen;
|
||||
let tokens_per_sec = if gen_time > 0.0 { (tokens_generated as f64 / gen_time) * 1000.0 } else { 0.0 };
|
||||
|
||||
console_log!("[SmolLM] Generoitu {} tokenia | {:.0}ms | {:.1} tok/s", tokens_generated, gen_time, tokens_per_sec);
|
||||
|
||||
let done = serde_json::json!({
|
||||
"type": "llm_done",
|
||||
"prompt": prompt,
|
||||
"model": "SmolLM-135M-Instruct",
|
||||
"model": "SmolLM-135M-Instruct (WebGPU)",
|
||||
"response": generated_text,
|
||||
"tokens_generated": tokens_generated,
|
||||
"duration_ms": (gen_time * 100.0).round() / 100.0,
|
||||
|
||||
Reference in New Issue
Block a user