diff --git a/network-poc/cargo-errors.log b/network-poc/cargo-errors.log new file mode 100644 index 0000000..d4edf5f --- /dev/null +++ b/network-poc/cargo-errors.log @@ -0,0 +1,4 @@ +error: failed to write `/home/jaakko/code/kipinä/digikipinae/agentic-office/network-poc/target/wasm32-unknown-unknown/debug/.fingerprint/num-traits-0a015ef9fd3732e0/run-build-script-build-script-build` + +Caused by: + Permission denied (os error 13) diff --git a/network-poc/docker-compose.yml b/network-poc/docker-compose.yml index f85ae0b..46d4c4b 100644 --- a/network-poc/docker-compose.yml +++ b/network-poc/docker-compose.yml @@ -9,7 +9,7 @@ services: volumes: - .:/app # Käännetään aina käynnistyksen yhteydessä varmuuden vuoksi Wasm uusimmista koodeista, ja päälle pyöräytetään Hub! - command: bash -c "cd node && wasm-pack build --dev --target web --out-dir ../static/pkg && cd ../hub && cargo run" + command: bash -c "cd node && wasm-pack build --target web --out-dir ../static/pkg && cd ../hub && cargo run" # Valinnainen natiivi-solmu — kerää oikeat laitteistotiedot (nvidia-smi-taso) native-node: diff --git a/network-poc/hub/nodes.db b/network-poc/hub/nodes.db index 9e5d54a..b30b85c 100644 Binary files a/network-poc/hub/nodes.db and b/network-poc/hub/nodes.db differ diff --git a/network-poc/node/src/burn_smollm/attention.rs b/network-poc/node/src/burn_smollm/attention.rs new file mode 100644 index 0000000..b69acfb --- /dev/null +++ b/network-poc/node/src/burn_smollm/attention.rs @@ -0,0 +1,118 @@ +use burn::module::{Module, Param}; +use burn::tensor::{backend::Backend, Tensor}; +use super::rope::RoPE; +use super::config::SmolLMConfig; + +#[derive(Clone, Debug)] +pub struct KVCache { + pub k: Tensor, + pub v: Tensor, +} + +#[derive(Module, Debug)] +pub struct Attention { + pub q_proj: Param>, // [hidden, num_heads * head_dim] + pub k_proj: Param>, // [hidden, num_kv_heads * head_dim] + pub v_proj: Param>, // [hidden, num_kv_heads * head_dim] + pub o_proj: Param>, // [num_heads * head_dim, hidden] + + num_heads: usize, + num_kv_heads: usize, + head_dim: usize, + + rope: RoPE, +} + +impl Attention { + pub fn new(config: &SmolLMConfig, device: &B::Device) -> Self { + let head_dim = config.hidden_size / config.num_attention_heads; + + Self { + q_proj: Param::from_tensor(Tensor::zeros([config.hidden_size, config.num_attention_heads * head_dim], device)), + k_proj: Param::from_tensor(Tensor::zeros([config.hidden_size, config.num_key_value_heads * head_dim], device)), + v_proj: Param::from_tensor(Tensor::zeros([config.hidden_size, config.num_key_value_heads * head_dim], device)), + o_proj: Param::from_tensor(Tensor::zeros([config.num_attention_heads * head_dim, config.hidden_size], device)), + + num_heads: config.num_attention_heads, + num_kv_heads: config.num_key_value_heads, + head_dim, + + rope: RoPE::new(head_dim, config.max_position_embeddings, config.rope_theta, device), + } + } + + pub fn forward( + &self, + x: Tensor, + offset: usize, + cache: Option> + ) -> (Tensor, KVCache) { + let [batch, seq_len, hidden_dim] = x.dims(); + + // Project Q, K, V: x @ W -> [batch, seq, proj_dim] + let q = x.clone().matmul(self.q_proj.val().unsqueeze()); + let k = x.clone().matmul(self.k_proj.val().unsqueeze()); + let v = x.matmul(self.v_proj.val().unsqueeze()); + + // Reshape: [batch, seq, heads, head_dim] -> [batch, heads, seq, head_dim] + let q = q.reshape([batch, seq_len, self.num_heads, self.head_dim]).swap_dims(1, 2); + let k = k.reshape([batch, seq_len, self.num_kv_heads, self.head_dim]).swap_dims(1, 2); + let v = v.reshape([batch, seq_len, self.num_kv_heads, self.head_dim]).swap_dims(1, 2); + + // Apply RoPE + let q = self.rope.forward(q, offset); + let k = self.rope.forward(k, offset); + + // KV cache + let (k, v) = if let Some(c) = cache { + (Tensor::cat(vec![c.k, k], 2), Tensor::cat(vec![c.v, v], 2)) + } else { + (k, v) + }; + + let new_cache = KVCache { k: k.clone(), v: v.clone() }; + let kv_len = k.dims()[2]; + + // GQA: repeat K,V heads — [batch, kv_heads, kv_len, hd] -> [batch, num_heads, kv_len, hd] + let num_reps = self.num_heads / self.num_kv_heads; + let k = if num_reps > 1 { + let [b, kv_h, s, hd] = k.dims(); + k.reshape([b, kv_h, 1, s, hd]).repeat_dim(2, num_reps).reshape([b, self.num_heads, s, hd]) + } else { k }; + let v = if num_reps > 1 { + let [b, kv_h, s, hd] = v.dims(); + v.reshape([b, kv_h, 1, s, hd]).repeat_dim(2, num_reps).reshape([b, self.num_heads, s, hd]) + } else { v }; + + // Attention: Q @ K^T / sqrt(d) + let scale = 1.0 / (self.head_dim as f64).sqrt(); + let scores = q.matmul(k.swap_dims(2, 3)).mul_scalar(scale); + // scores: [batch, heads, seq_len, kv_len] + + // Causal mask for prefill (seq_len > 1) + let scores = if seq_len > 1 { + let mask_data: Vec = (0..seq_len).flat_map(|i| { + (0..kv_len).map(move |j| { + if j > offset + i { f32::NEG_INFINITY } else { 0.0 } + }) + }).collect(); + let mask = Tensor::::from_data( + burn::tensor::TensorData::new(mask_data, [seq_len, kv_len]), + &scores.device() + ).reshape([1, 1, seq_len, kv_len]); + scores + mask + } else { + scores + }; + + let attn_weights = burn::tensor::activation::softmax(scores, 3); + + let context = attn_weights.matmul(v); + // [batch, heads, seq, hd] -> [batch, seq, heads*hd] + let context = context.swap_dims(1, 2).reshape([batch, seq_len, self.num_heads * self.head_dim]); + + let output = context.matmul(self.o_proj.val().unsqueeze()); + + (output, new_cache) + } +} diff --git a/network-poc/node/src/burn_smollm/config.rs b/network-poc/node/src/burn_smollm/config.rs new file mode 100644 index 0000000..ac0b263 --- /dev/null +++ b/network-poc/node/src/burn_smollm/config.rs @@ -0,0 +1,28 @@ +#[derive(Clone, Debug)] +pub struct SmolLMConfig { + pub hidden_size: usize, + pub intermediate_size: usize, + pub vocab_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: usize, + pub rms_norm_eps: f64, + pub rope_theta: f32, + pub max_position_embeddings: usize, +} + +impl Default for SmolLMConfig { + fn default() -> Self { + Self { + hidden_size: 576, + intermediate_size: 1536, + vocab_size: 49152, + num_hidden_layers: 30, + num_attention_heads: 9, + num_key_value_heads: 3, + rms_norm_eps: 1e-5, + rope_theta: 10000.0, + max_position_embeddings: 2048, + } + } +} diff --git a/network-poc/node/src/burn_smollm/loader.rs b/network-poc/node/src/burn_smollm/loader.rs new file mode 100644 index 0000000..d5cf29b --- /dev/null +++ b/network-poc/node/src/burn_smollm/loader.rs @@ -0,0 +1,90 @@ +use burn::tensor::{backend::Backend, Tensor, Data}; +use candle_core::safetensors; +use candle_core::Device as CandleDevice; +use burn::module::Param; +use super::model::LlamaModel; +use super::config::SmolLMConfig; + +fn load_tensor_2d( + tensors_map: &std::collections::HashMap, + name: &str, + device: &B::Device, + shape_out_in: [usize; 2] +) -> Result>, String> { + let t = tensors_map.get(name).ok_or_else(|| format!("Puuttuu: {}", name))?; + let t = t.to_dtype(candle_core::DType::F32).unwrap(); + let vec = t.flatten_all().unwrap().to_vec1::().unwrap(); + let t_burn = Tensor::::from_data(burn::tensor::TensorData::new(vec, shape_out_in), device); + // transpose from [out, in] to [in, out] + Ok(Param::from_tensor(t_burn.transpose())) +} + +fn load_tensor_1d( + tensors_map: &std::collections::HashMap, + name: &str, + device: &B::Device, + _shape: [usize; 1] +) -> Result>, String> { + let t = tensors_map.get(name).ok_or_else(|| format!("Puuttuu: {}", name))?; + let t = t.to_dtype(candle_core::DType::F32).unwrap(); + let vec = t.flatten_all().unwrap().to_vec1::().unwrap(); + Ok(Param::from_tensor(Tensor::::from_floats(vec.as_slice(), device))) +} + +fn load_embed( + tensors_map: &std::collections::HashMap, + name: &str, + device: &B::Device, + shape: [usize; 2] +) -> Result>, String> { + let t = tensors_map.get(name).ok_or_else(|| format!("Puuttuu: {}", name))?; + let t = t.to_dtype(candle_core::DType::F32).unwrap(); + let vec = t.flatten_all().unwrap().to_vec1::().unwrap(); + // Embed ei transponoi samalla tavalla, se pysyy [vocab, hidden] + Ok(Param::from_tensor(Tensor::::from_data(burn::tensor::TensorData::new(vec, shape), device))) +} + +pub fn load_safetensors_to_model( + buffer: &[u8], + config: &SmolLMConfig, + device: &B::Device +) -> Result, String> { + + let mut model = LlamaModel::new(config, device); + let tensors_map = safetensors::load_buffer(buffer, &CandleDevice::Cpu) + .map_err(|e| format!("Virhe Safetensors luennassa: {}", e))?; + + // Embeddings + model.embed_tokens = load_embed(&tensors_map, "model.embed_tokens.weight", device, [config.vocab_size, config.hidden_size])?; + model.norm.weight = load_tensor_1d(&tensors_map, "model.norm.weight", device, [config.hidden_size])?; + model.lm_head = load_embed(&tensors_map, "lm_head.weight", device, [config.vocab_size, config.hidden_size]).or_else(|_| { + load_embed(&tensors_map, "model.embed_tokens.weight", device, [config.vocab_size, config.hidden_size]) + })?; + + let head_dim = config.hidden_size / config.num_attention_heads; + + for i in 0..config.num_hidden_layers { + let prefix = format!("model.layers.{}", i); + + let layer = &mut model.layers[i]; + + // Norms + layer.input_layernorm.weight = load_tensor_1d(&tensors_map, &format!("{}.input_layernorm.weight", prefix), device, [config.hidden_size])?; + layer.post_attention_layernorm.weight = load_tensor_1d(&tensors_map, &format!("{}.post_attention_layernorm.weight", prefix), device, [config.hidden_size])?; + + // Attention + let num_heads = config.num_attention_heads; + let num_kv_heads = config.num_key_value_heads; + layer.self_attn.q_proj = load_tensor_2d(&tensors_map, &format!("{}.self_attn.q_proj.weight", prefix), device, [num_heads * head_dim, config.hidden_size])?; + layer.self_attn.k_proj = load_tensor_2d(&tensors_map, &format!("{}.self_attn.k_proj.weight", prefix), device, [num_kv_heads * head_dim, config.hidden_size])?; + layer.self_attn.v_proj = load_tensor_2d(&tensors_map, &format!("{}.self_attn.v_proj.weight", prefix), device, [num_kv_heads * head_dim, config.hidden_size])?; + layer.self_attn.o_proj = load_tensor_2d(&tensors_map, &format!("{}.self_attn.o_proj.weight", prefix), device, [config.hidden_size, num_heads * head_dim])?; + + // MLP + layer.mlp.gate_proj = load_tensor_2d(&tensors_map, &format!("{}.mlp.gate_proj.weight", prefix), device, [config.intermediate_size, config.hidden_size])?; + layer.mlp.up_proj = load_tensor_2d(&tensors_map, &format!("{}.mlp.up_proj.weight", prefix), device, [config.intermediate_size, config.hidden_size])?; + layer.mlp.down_proj = load_tensor_2d(&tensors_map, &format!("{}.mlp.down_proj.weight", prefix), device, [config.hidden_size, config.intermediate_size])?; + } + + Ok(model) +} diff --git a/network-poc/node/src/burn_smollm/mod.rs b/network-poc/node/src/burn_smollm/mod.rs new file mode 100644 index 0000000..3664e61 --- /dev/null +++ b/network-poc/node/src/burn_smollm/mod.rs @@ -0,0 +1,6 @@ +pub mod attention; +pub mod config; +pub mod loader; +pub mod model; +pub mod modules; +pub mod rope; diff --git a/network-poc/node/src/burn_smollm/model.rs b/network-poc/node/src/burn_smollm/model.rs new file mode 100644 index 0000000..9a4f485 --- /dev/null +++ b/network-poc/node/src/burn_smollm/model.rs @@ -0,0 +1,96 @@ +use burn::module::{Module, Param}; +use burn::tensor::{backend::Backend, Tensor, Int}; +use super::modules::{RmsNorm, Mlp}; +use super::attention::{Attention, KVCache}; +use super::config::SmolLMConfig; + +#[derive(Module, Debug)] +pub struct LlamaBlock { + pub self_attn: Attention, + pub mlp: Mlp, + pub input_layernorm: RmsNorm, + pub post_attention_layernorm: RmsNorm, +} + +impl LlamaBlock { + pub fn new(config: &SmolLMConfig, device: &B::Device) -> Self { + Self { + self_attn: Attention::new(config, device), + mlp: Mlp::new(config.hidden_size, config.intermediate_size, device), + input_layernorm: RmsNorm::new(config.hidden_size, config.rms_norm_eps, device), + post_attention_layernorm: RmsNorm::new(config.hidden_size, config.rms_norm_eps, device), + } + } + + pub fn forward( + &self, + x: Tensor, + offset: usize, + cache: Option> + ) -> (Tensor, KVCache) { + let residual = x.clone(); + let x_norm = self.input_layernorm.forward(x); + + let (attn_out, new_cache) = self.self_attn.forward(x_norm, offset, cache); + + let x = residual + attn_out; + + let residual = x.clone(); + let x_norm = self.post_attention_layernorm.forward(x); + let mlp_out = self.mlp.forward(x_norm); + + let x = residual + mlp_out; + (x, new_cache) + } +} + +#[derive(Module, Debug)] +pub struct LlamaModel { + pub embed_tokens: Param>, + pub layers: Vec>, + pub norm: RmsNorm, + pub lm_head: Param>, // For tie_word_embeddings this can point to embed_tokens +} + +impl LlamaModel { + pub fn new(config: &SmolLMConfig, device: &B::Device) -> Self { + let embed = Tensor::zeros([config.vocab_size, config.hidden_size], device); + let lm_head = Tensor::zeros([config.vocab_size, config.hidden_size], device); + + let mut layers = Vec::new(); + for _ in 0..config.num_hidden_layers { + layers.push(LlamaBlock::new(config, device)); + } + + Self { + embed_tokens: Param::from_tensor(embed), + layers, + norm: RmsNorm::new(config.hidden_size, config.rms_norm_eps, device), + lm_head: Param::from_tensor(lm_head), + } + } + + pub fn forward( + &self, + input_ids: Tensor, + offset: usize, + caches: &mut Vec>> + ) -> Tensor { + let [_batch, _seq_len] = input_ids.dims(); + + let mut x = burn::tensor::module::embedding(self.embed_tokens.val(), input_ids); + + for (i, layer) in self.layers.iter().enumerate() { + let cache = caches[i].take(); + let (out, new_cache) = layer.forward(x, offset, cache); + x = out; + caches[i] = Some(new_cache); + } + + x = self.norm.forward(x); + + // Matmul with lm_head (or embed_tokens if tied) to get logits + // Notice: lm_head is typically [vocab_size, hidden_size] in HF, so we swap dims + x.matmul(self.lm_head.val().swap_dims(0, 1).unsqueeze()) + } +} diff --git a/network-poc/node/src/burn_smollm/modules.rs b/network-poc/node/src/burn_smollm/modules.rs new file mode 100644 index 0000000..b1dc9cb --- /dev/null +++ b/network-poc/node/src/burn_smollm/modules.rs @@ -0,0 +1,59 @@ +use burn::module::{Module, Param}; +use burn::tensor::{backend::Backend, Tensor}; + +#[derive(Module, Debug)] +pub struct RmsNorm { + pub weight: Param>, + epsilon: f64, +} + +impl RmsNorm { + pub fn new(size: usize, epsilon: f64, device: &B::Device) -> Self { + let weight = Param::from_tensor(Tensor::ones([size], device)); + Self { weight, epsilon } + } + + pub fn forward(&self, x: Tensor) -> Tensor { + // x: [batch, seq_len, dim] + // RMSNorm: x * weight / sqrt(mean(x^2) + eps) + let x_sq = x.clone().powf_scalar(2.0); + // mean over last dim, keeping dims for broadcast + let [b, s, d] = x_sq.dims(); + let variance = x_sq.sum_dim(2).div_scalar(d as f32); + let norm = x.div(variance.add_scalar(self.epsilon).sqrt()); + + let w = self.weight.val().unsqueeze::<2>().unsqueeze::<3>().reshape([1, 1, d]); + norm * w + } +} + +#[derive(Module, Debug)] +pub struct Mlp { + pub gate_proj: Param>, // [in, intermediate] + pub up_proj: Param>, // [in, intermediate] + pub down_proj: Param>, // [intermediate, out] +} + +impl Mlp { + pub fn new(hidden_size: usize, intermediate_size: usize, device: &B::Device) -> Self { + Self { + gate_proj: Param::from_tensor(Tensor::zeros([hidden_size, intermediate_size], device)), + up_proj: Param::from_tensor(Tensor::zeros([hidden_size, intermediate_size], device)), + down_proj: Param::from_tensor(Tensor::zeros([intermediate_size, hidden_size], device)), + } + } + + pub fn forward(&self, x: Tensor) -> Tensor { + // x: [batch, seq, hidden] + // gate = x @ gate_proj -> [batch, seq, intermediate] + let gate = x.clone().matmul(self.gate_proj.val().unsqueeze()); + let up = x.matmul(self.up_proj.val().unsqueeze()); + + // SiLU(gate) * up + let silu = gate.clone() * burn::tensor::activation::sigmoid(gate); + let intermediate = silu * up; + + // intermediate @ down_proj -> [batch, seq, hidden] + intermediate.matmul(self.down_proj.val().unsqueeze()) + } +} diff --git a/network-poc/node/src/burn_smollm/rope.rs b/network-poc/node/src/burn_smollm/rope.rs new file mode 100644 index 0000000..2ed2993 --- /dev/null +++ b/network-poc/node/src/burn_smollm/rope.rs @@ -0,0 +1,59 @@ +use burn::module::Module; +use burn::tensor::{backend::Backend, Tensor}; + +#[derive(Module, Debug)] +pub struct RoPE { + cos_cache: Tensor, + sin_cache: Tensor, +} + +impl RoPE { + pub fn new(head_dim: usize, max_seq_len: usize, theta: f32, device: &B::Device) -> Self { + // (head_dim / 2) values + let half_dim = head_dim / 2; + let inv_freq: Vec = (0..half_dim) + .map(|i| 1.0 / theta.powf((2 * i) as f32 / head_dim as f32)) + .collect(); + + let inv_freq = Tensor::::from_floats(inv_freq.as_slice(), device).unsqueeze::<2>(); + let t_floats: Vec = (0..max_seq_len).map(|v| v as f32).collect(); + let t = Tensor::::from_floats(t_floats.as_slice(), device).unsqueeze::<2>().transpose(); + // t shape: [max_seq_len, 1] + // inv_freq shape: [1, half_dim] + + // freqs shape: [max_seq_len, half_dim] + let freqs = t.matmul(inv_freq); + + let cos_cache = freqs.clone().cos(); + let sin_cache = freqs.sin(); + + Self { + cos_cache, + sin_cache, + } + } + + pub fn forward(&self, x: Tensor, offset: usize) -> Tensor { + let [batch, heads, seq_len, head_dim] = x.dims(); + let half_dim = head_dim / 2; + + // x shape: [batch, heads, seq_len, head_dim] + // valitaan viipaleet (x1 ja x2) jotta saadaan pyöritettyä rotaatiot + let x1 = x.clone().slice([0..batch, 0..heads, 0..seq_len, 0..half_dim]); + let x2 = x.clone().slice([0..batch, 0..heads, 0..seq_len, half_dim..head_dim]); + + // haetaan vastaava seq offsetista alkaen + let cos = self.cos_cache.clone().slice([offset..offset+seq_len, 0..half_dim]) + .unsqueeze::<4>() // [seq, half_dim, 1] + .reshape([1, 1, seq_len, half_dim]); + let sin = self.sin_cache.clone().slice([offset..offset+seq_len, 0..half_dim]) + .reshape([1, 1, seq_len, half_dim]); + + // x1 * cos - x2 * sin + let o1 = x1.clone().mul(cos.clone()) - x2.clone().mul(sin.clone()); + // x2 * cos + x1 * sin + let o2 = x2.mul(cos) + x1.mul(sin); + + Tensor::cat(vec![o1, o2], 3) + } +} diff --git a/network-poc/node/src/lib.rs b/network-poc/node/src/lib.rs index 03df2db..7df3a50 100644 --- a/network-poc/node/src/lib.rs +++ b/network-poc/node/src/lib.rs @@ -12,6 +12,7 @@ pub mod smollm; pub mod qwen; pub mod qwen_coder; pub mod phi3; +pub mod burn_smollm; #[macro_export] macro_rules! console_log { diff --git a/network-poc/node/src/smollm.rs b/network-poc/node/src/smollm.rs index 2176467..0a622d4 100644 --- a/network-poc/node/src/smollm.rs +++ b/network-poc/node/src/smollm.rs @@ -118,125 +118,106 @@ pub async fn run_smollm_inference(prompt: String, ws: Rc>) { Err(e) => { console_log!("[SmolLM] Malli-virhe: {}", e); return; } }; - console_log!("[SmolLM] Rakennetaan mallia..."); + let use_gpu = crate::HAS_WEBGPU.load(std::sync::atomic::Ordering::SeqCst); + if use_gpu { + console_log!("[SmolLM] Alustetaan Burn WebGPU..."); + burn_wgpu::init_async::(&Default::default(), Default::default()).await; + run_burn_inference::(prompt, model_bytes, tokenizer, ws, perf.clone()).await; + } else { + console_log!("[SmolLM] Käytetään CPU NdArrayta (vanha tapa)..."); + run_burn_inference::(prompt, model_bytes, tokenizer, ws, perf.clone()).await; + } +} + +async fn run_burn_inference( + prompt: String, + model_bytes: Vec, + tokenizer: tokenizers::Tokenizer, + ws: Rc>, + perf: web_sys::Performance, // Korjattu Wasm-performanssi välitettäväksi +) { let start_load = perf.now(); - let device = Device::Cpu; - let dtype = DType::F32; - - // Parsitaan safetensors - let tensors = match candle_core::safetensors::load_buffer(&model_bytes, &device) { - Ok(t) => t, - Err(e) => { console_log!("[SmolLM] Safetensors-parsinta epäonnistui: {}", e); return; } - }; - - let vb = VarBuilder::from_tensors(tensors, dtype, &device); - - // SmolLM-135M config (Llama-arkkitehtuuri) - let config = LlamaConfig { - hidden_size: 576, - intermediate_size: 1536, - vocab_size: 49152, - num_hidden_layers: 30, - num_attention_heads: 9, - num_key_value_heads: Some(3), - rms_norm_eps: 1e-5, - rope_theta: 10000.0, - max_position_embeddings: 2048, - tie_word_embeddings: Some(true), - bos_token_id: Some(1u32), - eos_token_id: Some(LlamaEosToks::Single(2)), - rope_scaling: None, - }; - - let llama_config = config.into_config(false); // false = ei flash attention - let mut cache = Cache::new(true, dtype, &llama_config, &device).unwrap(); - - let model = match Llama::load(vb, &llama_config) { + let device = Default::default(); + let config = crate::burn_smollm::config::SmolLMConfig::default(); + + console_log!("[SmolLM] Injektoidaan Safetensors -> Burn Params..."); + let model = match crate::burn_smollm::loader::load_safetensors_to_model::(&model_bytes, &config, &device) { Ok(m) => m, - Err(e) => { console_log!("[SmolLM] Mallin lataus epäonnistui: {}", e); return; } + Err(e) => { console_log!("[SmolLM] Lataus epäonnistui: {}", e); return; } }; let load_time = perf.now() - start_load; - console_log!("[SmolLM] Malli ladattu ({:.0}ms). Generoidaan...", load_time); + console_log!("[SmolLM] Burn-malli ladattu ({:.0}ms). Generoidaan...", load_time); - // 3. Tokenisoi syöte (Käytetään ChatML-formaattia SmolLM-Instructille) let formatted_prompt = format!("<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", prompt); let encoding = match tokenizer.encode(formatted_prompt.as_str(), true) { Ok(e) => e, Err(e) => { console_log!("[SmolLM] Tokenisointivirhe: {}", e); return; } }; - let input_ids: Vec = encoding.get_ids().to_vec(); + let mut input_ids: Vec = encoding.get_ids().to_vec(); let input_len = input_ids.len(); console_log!("[SmolLM] Syöte: {} tokenia", input_len); - // 4. Generoi tokeneita let start_gen = perf.now(); let max_new_tokens = 32; let mut generated_text = String::new(); let mut tokens_generated: usize = 0; - let mut pos: usize = 0; + + // KV-välimuistin taulukko kerroksittain + let mut caches: Vec>> = vec![None; config.num_hidden_layers]; + let mut current_offset = 0; - // Ensimmäinen forward: koko syöte kerralla - let input = match Tensor::new(input_ids.as_slice(), &device).and_then(|t| t.unsqueeze(0)) { - Ok(t) => t, - Err(e) => { console_log!("[SmolLM] Tensor-virhe: {}", e); return; } - }; + // Prefill: yksitellen, vältetään future token leakage koska ei causal maskia + let input_ids_i32: Vec = input_ids.iter().map(|&x| x as i32).collect(); + let mut last_logits = None; + + for &id in &input_ids_i32 { + let input_tensor = burn::tensor::Tensor::::from_data( + burn::tensor::TensorData::from([id]), + &device + ).unsqueeze::<2>(); // [1, 1] + + last_logits = Some(model.forward(input_tensor, current_offset, &mut caches)); + current_offset += 1; + } - let logits = match model.forward(&input, 0, &mut cache) { - Ok(l) => l, - Err(e) => { console_log!("[SmolLM] Forward-virhe (prefill): {}", e); return; } - }; + let mut logits = last_logits.unwrap(); - // Llama forward voi palauttaa [batch, vocab] tai [batch, seq_len, vocab] - let logits = logits.squeeze(0).unwrap(); - let logits = if logits.dims().len() == 2 { - logits.get(logits.dim(0).unwrap() - 1).unwrap() - } else { - logits - }; - let mut next_token = crate::sampling::sample_top_k(&logits, 10, 5.0); - console_log!("[SmolLM] Ensimmäinen generoitu token: {}", next_token); - pos = input_len; + // Argmax sämpläys + let next_token_tensor = logits.clone().argmax(2); + let mut next_token: u32 = next_token_tensor.into_scalar().to_string().parse().unwrap_or(2); // Yksinkertainen cast koska int scalar if next_token != 2 { if let Ok(text) = tokenizer.decode(&[next_token], true) { generated_text.push_str(&text); - let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "SmolLM-135M" }); + let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "SmolLM-135M (WebGPU)" }); let _ = ws.borrow().send_with_str(&chunk.to_string()); } tokens_generated += 1; } - // Autoregressiivinen generointi: yksi token kerrallaan + // Autoregressiivinen luuppi for _ in 1..max_new_tokens { if next_token == 2 { break; } - - let input = match Tensor::new(&[next_token], &device).and_then(|t| t.unsqueeze(0)) { - Ok(t) => t, - Err(e) => { console_log!("[SmolLM] Tensor-virhe: {}", e); break; } - }; - - let logits = match model.forward(&input, pos, &mut cache) { - Ok(l) => l, - Err(e) => { console_log!("[SmolLM] Forward-virhe pos {}: {}", pos, e); break; } - }; - - let logits = logits.squeeze(0).unwrap(); - let logits = if logits.dims().len() == 2 { - logits.get(logits.dim(0).unwrap() - 1).unwrap() - } else { - logits - }; - next_token = crate::sampling::sample_top_k(&logits, 10, 5.0); - pos += 1; + + let mut input_tensor = burn::tensor::Tensor::::from_data( + burn::tensor::TensorData::from([next_token as i32]), + &device + ).unsqueeze::<2>(); + + logits = model.forward(input_tensor, current_offset, &mut caches); + current_offset += 1; + + let next_token_tensor = logits.argmax(2); + next_token = next_token_tensor.into_scalar().to_string().parse().unwrap_or(2); if next_token == 2 { break; } if let Ok(text) = tokenizer.decode(&[next_token], true) { generated_text.push_str(&text); - let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "SmolLM-135M" }); + let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "SmolLM-135M (WebGPU)" }); let _ = ws.borrow().send_with_str(&chunk.to_string()); } tokens_generated += 1; @@ -245,12 +226,10 @@ pub async fn run_smollm_inference(prompt: String, ws: Rc>) { let gen_time = perf.now() - start_gen; let tokens_per_sec = if gen_time > 0.0 { (tokens_generated as f64 / gen_time) * 1000.0 } else { 0.0 }; - console_log!("[SmolLM] Generoitu {} tokenia | {:.0}ms | {:.1} tok/s", tokens_generated, gen_time, tokens_per_sec); - let done = serde_json::json!({ "type": "llm_done", "prompt": prompt, - "model": "SmolLM-135M-Instruct", + "model": "SmolLM-135M-Instruct (WebGPU)", "response": generated_text, "tokens_generated": tokens_generated, "duration_ms": (gen_time * 100.0).round() / 100.0,