From 0ef1a3f7cdd1d3f12112a9554a33ccbc78c63497 Mon Sep 17 00:00:00 2001 From: Jaakko Vanhala Date: Sat, 4 Apr 2026 21:59:15 +0300 Subject: [PATCH] =?UTF-8?q?S=C3=A4=C3=A4t=C3=B6=20Hommia?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- network-poc/native-node/src/inference.rs | 29 +- network-poc/node/src/qwen_coder.rs | 344 +++++++++++------------ 2 files changed, 171 insertions(+), 202 deletions(-) diff --git a/network-poc/native-node/src/inference.rs b/network-poc/native-node/src/inference.rs index e0e023f..57fe516 100644 --- a/network-poc/native-node/src/inference.rs +++ b/network-poc/native-node/src/inference.rs @@ -2,7 +2,6 @@ use candle_core::{Device, Tensor, DType}; use candle_nn::VarBuilder; use candle_transformers::models::qwen2::{Config as QwenConfig, ModelForCausalLM as QwenModel}; use hf_hub::{api::sync::Api, Repo, RepoType}; -use std::path::PathBuf; use std::time::Instant; /// Top-k sampling with temperature and repetition penalty @@ -63,10 +62,8 @@ fn sample_top_k(logits: &Tensor, k: usize, temperature: f64, generated_tokens: & pub struct LlmEngine { tokenizer: tokenizers::Tokenizer, - model_path: PathBuf, + model: QwenModel, device: Device, - dtype: DType, - config: QwenConfig, eos_token: u32, } @@ -110,34 +107,22 @@ impl LlmEngine { hidden_act: candle_nn::Activation::Silu, }; - // Testi-lataus varmistaa, että painot toimivat let start = Instant::now(); let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_path.clone()], dtype, &device) .map_err(|e| format!("VarBuilder: {}", e))? }; - let _model = QwenModel::new(&config, vb).map_err(|e| format!("Malli: {}", e))?; + let model = QwenModel::new(&config, vb).map_err(|e| format!("Malli: {}", e))?; tracing::info!("Malli ladattu ({:.1}s) — {}", start.elapsed().as_secs_f64(), device_name); Ok(LlmEngine { tokenizer, - model_path, + model, device, - dtype, - config, eos_token: 151645, }) } - /// Luo tuore malliinstanssi (nollaa KV-cachen) - fn fresh_model(&self) -> Result { - let vb = unsafe { - VarBuilder::from_mmaped_safetensors(&[self.model_path.clone()], self.dtype, &self.device) - .map_err(|e| format!("VarBuilder: {}", e))? - }; - QwenModel::new(&self.config, vb).map_err(|e| format!("Malli: {}", e)) - } - pub fn generate(&mut self, prompt: &str, max_tokens: usize) -> Result { let formatted = format!("<|im_start|>system\nYou are a coding assistant. Respond with ONLY code. No explanations, no markdown, no comments unless asked.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", prompt); @@ -146,8 +131,8 @@ impl LlmEngine { let input_ids: Vec = encoding.get_ids().to_vec(); let input_len = input_ids.len(); - // Tuore malli joka promptille (nollaa KV-cachen) - let mut model = self.fresh_model()?; + // Nollataan KV-cache edellisestä promptista + self.model.clear_kv_cache(); // Sampling-parametrit let temperature = 0.7; @@ -165,7 +150,7 @@ impl LlmEngine { .and_then(|t| t.unsqueeze(0)) .map_err(|e| format!("Tensor: {}", e))?; - let logits = model.forward(&input, 0) + let logits = self.model.forward(&input, 0) .map_err(|e| format!("Forward prefill: {}", e))?; let logits = logits.squeeze(0).map_err(|e| format!("Squeeze: {}", e))?; @@ -200,7 +185,7 @@ impl LlmEngine { .and_then(|t| t.unsqueeze(0)) .map_err(|e| format!("Tensor: {}", e))?; - let logits = model.forward(&input, pos) + let logits = self.model.forward(&input, pos) .map_err(|e| format!("Forward pos {}: {}", pos, e))?; let logits = logits.squeeze(0).map_err(|e| format!("Squeeze: {}", e))?; diff --git a/network-poc/node/src/qwen_coder.rs b/network-poc/node/src/qwen_coder.rs index 7c43762..a1b95aa 100644 --- a/network-poc/node/src/qwen_coder.rs +++ b/network-poc/node/src/qwen_coder.rs @@ -21,8 +21,15 @@ const MODEL_3B_PART1_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-Coder-3B-I const MODEL_3B_PART2_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-Coder-3B-Instruct/resolve/main/model-00002-of-00002.safetensors"; const TOKENIZER_3B_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-Coder-3B-Instruct/resolve/main/tokenizer.json"; +struct CachedModel { + model: QwenModel, + tokenizer: tokenizers::Tokenizer, + is_3b: bool, +} + thread_local! { static RAM_CACHE: RefCell>>> = RefCell::new(std::collections::HashMap::new()); + static MODEL_CACHE: RefCell> = RefCell::new(None); } async fn ensure_cached(key: &str, url: &str, ws: &Rc>) -> Result>, String> { @@ -94,223 +101,200 @@ async fn ensure_cached(key: &str, url: &str, ws: &Rc>) -> Res Ok(rc_data) } +/// Lataa tai palauttaa välimuistista valmiin mallin + tokenizerin +async fn get_or_build_model(use_3b: bool, ws: &Rc>) -> Result<(), String> { + // Tarkistetaan onko oikea malli jo muistissa + let cache_hit = MODEL_CACHE.with(|c| { + c.borrow().as_ref().map(|m| m.is_3b == use_3b).unwrap_or(false) + }); + if cache_hit { + console_log!("[Coder] Malli löytyi muistista — ohitetaan lataus"); + return Ok(()); + } + + let device = Device::Cpu; + let dtype = DType::F32; + + // Tokenizer + let tok_url = if use_3b { TOKENIZER_3B_URL } else { TOKENIZER_05B_URL }; + let tok_key = if use_3b { "coder3b-tokenizer.json" } else { "coder05b-tokenizer.json" }; + let tok_bytes = ensure_cached(tok_key, tok_url, ws).await?; + let tokenizer = tokenizers::Tokenizer::from_bytes(&tok_bytes[..]) + .map_err(|e| format!("Tokenizer: {}", e))?; + + // Painot + let tensors = if use_3b { + let part1 = ensure_cached("coder3b-model-part1.safetensors", MODEL_3B_PART1_URL, ws).await?; + let part2 = ensure_cached("coder3b-model-part2.safetensors", MODEL_3B_PART2_URL, ws).await?; + console_log!("[Coder] Rakennetaan 3B-mallia..."); + let mut all_tensors = candle_core::safetensors::load_buffer(&part1[..], &device) + .map_err(|e| format!("Part1: {}", e))?; + let tensors2 = candle_core::safetensors::load_buffer(&part2[..], &device) + .map_err(|e| format!("Part2: {}", e))?; + all_tensors.extend(tensors2); + all_tensors + } else { + let model_bytes = ensure_cached("coder05b-model.safetensors", MODEL_05B_URL, ws).await?; + console_log!("[Coder] Rakennetaan 0.5B-mallia..."); + candle_core::safetensors::load_buffer(&model_bytes[..], &device) + .map_err(|e| format!("Safetensors: {}", e))? + }; + + let vb = VarBuilder::from_tensors(tensors, dtype, &device); + let config = if use_3b { + QwenConfig { + vocab_size: 151936, hidden_size: 2048, intermediate_size: 11008, + num_hidden_layers: 36, num_attention_heads: 16, num_key_value_heads: 2, + max_position_embeddings: 32768, sliding_window: 32768, max_window_layers: 36, + tie_word_embeddings: true, rope_theta: 1000000.0, rms_norm_eps: 1e-6, + use_sliding_window: false, hidden_act: candle_nn::Activation::Silu, + } + } else { + QwenConfig { + vocab_size: 151936, hidden_size: 896, intermediate_size: 4864, + num_hidden_layers: 24, num_attention_heads: 14, num_key_value_heads: 2, + max_position_embeddings: 32768, sliding_window: 32768, max_window_layers: 21, + tie_word_embeddings: true, rope_theta: 1000000.0, rms_norm_eps: 1e-6, + use_sliding_window: false, hidden_act: candle_nn::Activation::Silu, + } + }; + + let model = QwenModel::new(&config, vb).map_err(|e| format!("Malli: {}", e))?; + console_log!("[Coder] Malli rakennettu ja välimuistitettu"); + + MODEL_CACHE.with(|c| { + *c.borrow_mut() = Some(CachedModel { model, tokenizer, is_3b: use_3b }); + }); + + Ok(()) +} + /// use_3b: false = 0.5B (nopea), true = 3B (laadukas) pub async fn run_coder_inference(prompt: String, ws: Rc>, use_3b: bool, task_id: Option) { let perf = web_sys::window().unwrap().performance().unwrap(); let size_label = if use_3b { "3B" } else { "0.5B" }; - // Tokenizer (sama molemmille) - let tok_url = if use_3b { TOKENIZER_3B_URL } else { TOKENIZER_05B_URL }; - let tok_key = if use_3b { "coder3b-tokenizer.json" } else { "coder05b-tokenizer.json" }; - let tok_bytes = match ensure_cached(tok_key, tok_url, &ws).await { - Ok(b) => b, - Err(e) => { console_log!("[Coder] Tokenizer-virhe: {}", e); return; } - }; - let tokenizer = match tokenizers::Tokenizer::from_bytes(&tok_bytes[..]) { - Ok(t) => t, - Err(e) => { console_log!("[Coder] Tokenizer-parsinta: {}", e); return; } - }; - - // Mallin painot - let device = Device::Cpu; - let dtype = DType::F32; - - let tensors = if use_3b { - // 3B: kaksi osaa - let part1 = match ensure_cached("coder3b-model-part1.safetensors", MODEL_3B_PART1_URL, &ws).await { - Ok(b) => b, - Err(e) => { console_log!("[Coder] Malli osa 1 virhe: {}", e); return; } - }; - let part2 = match ensure_cached("coder3b-model-part2.safetensors", MODEL_3B_PART2_URL, &ws).await { - Ok(b) => b, - Err(e) => { console_log!("[Coder] Malli osa 2 virhe: {}", e); return; } - }; - console_log!("[Coder] Rakennetaan 3B-mallia..."); - let mut all_tensors = candle_core::safetensors::load_buffer(&part1[..], &device) - .map_err(|e| format!("Part1: {}", e)).unwrap(); - let tensors2 = candle_core::safetensors::load_buffer(&part2[..], &device) - .map_err(|e| format!("Part2: {}", e)).unwrap(); - all_tensors.extend(tensors2); - all_tensors - } else { - // 0.5B: yksi osa - let model_bytes = match ensure_cached("coder05b-model.safetensors", MODEL_05B_URL, &ws).await { - Ok(b) => b, - Err(e) => { console_log!("[Coder] Malli-virhe: {}", e); return; } - }; - console_log!("[Coder] Rakennetaan 0.5B-mallia..."); - match candle_core::safetensors::load_buffer(&model_bytes[..], &device) { - Ok(t) => t, - Err(e) => { console_log!("[Coder] Safetensors: {}", e); return; } - } - }; - let start_load = perf.now(); - let vb = VarBuilder::from_tensors(tensors, dtype, &device); - let config = if use_3b { - QwenConfig { - vocab_size: 151936, - hidden_size: 2048, - intermediate_size: 11008, - num_hidden_layers: 36, - num_attention_heads: 16, - num_key_value_heads: 2, - max_position_embeddings: 32768, - sliding_window: 32768, - max_window_layers: 36, - tie_word_embeddings: true, - rope_theta: 1000000.0, - rms_norm_eps: 1e-6, - use_sliding_window: false, - hidden_act: candle_nn::Activation::Silu, - } - } else { - QwenConfig { - vocab_size: 151936, - hidden_size: 896, - intermediate_size: 4864, - num_hidden_layers: 24, - num_attention_heads: 14, - num_key_value_heads: 2, - max_position_embeddings: 32768, - sliding_window: 32768, - max_window_layers: 21, - tie_word_embeddings: true, - rope_theta: 1000000.0, - rms_norm_eps: 1e-6, - use_sliding_window: false, - hidden_act: candle_nn::Activation::Silu, - } - }; - - let mut model = match QwenModel::new(&config, vb) { - Ok(m) => m, - Err(e) => { console_log!("[Coder] Mallin lataus: {}", e); return; } - }; + if let Err(e) = get_or_build_model(use_3b, &ws).await { + console_log!("[Coder] Mallin lataus: {}", e); + return; + } let load_time = perf.now() - start_load; - console_log!("[Coder] Malli ladattu ({:.0}ms). Generoidaan...", load_time); + if load_time > 100.0 { + console_log!("[Coder] Malli ladattu ({:.0}ms). Generoidaan...", load_time); + } // Parsitaan JSON-prompti tai käytetään teksti sellaisenaan + let default_system = "You are a coding assistant. Respond with ONLY code. No explanations, no markdown, no comments unless asked."; let (actual_prompt, system_msg, max_new_tokens) = if prompt.starts_with('{') { if let Ok(json) = serde_json::from_str::(&prompt) { let p = json.get("prompt").and_then(|v| v.as_str()).unwrap_or(&prompt).to_string(); - let s = json.get("system").and_then(|v| v.as_str()) - .unwrap_or("You are a coding assistant. Respond with ONLY code. No explanations, no markdown, no comments unless asked.").to_string(); + let s = json.get("system").and_then(|v| v.as_str()).unwrap_or(default_system).to_string(); let m = json.get("max_tokens").and_then(|v| v.as_u64()).unwrap_or(128) as usize; (p, s, m) } else { - (prompt.clone(), "You are a coding assistant. Respond with ONLY code. No explanations, no markdown, no comments unless asked.".to_string(), 128) + (prompt.clone(), default_system.to_string(), 128) } } else { - (prompt.clone(), "You are a coding assistant. Respond with ONLY code. No explanations, no markdown, no comments unless asked.".to_string(), 128) + (prompt.clone(), default_system.to_string(), 128) }; let formatted = format!("<|im_start|>system\n{}<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", system_msg, actual_prompt); - let encoding = match tokenizer.encode(formatted.as_str(), true) { - Ok(e) => e, - Err(e) => { console_log!("[Coder] Tokenisointivirhe: {}", e); return; } - }; - let input_ids: Vec = encoding.get_ids().to_vec(); - let input_len = input_ids.len(); - console_log!("[Coder] Syöte: {} tokenia", input_len); + // Inferenssi: käytetään välimuistissa olevaa mallia + let (generated_text, tokens_generated, gen_time) = MODEL_CACHE.with(|cache| { + let mut cache = cache.borrow_mut(); + let cached = cache.as_mut().expect("Malli pitää olla ladattu"); - let start_gen = perf.now(); - // max_new_tokens tulee JSON-promptista tai oletuksena 128 - let mut generated_text = String::new(); - let mut tokens_generated: usize = 0; - let eos_token = 151645u32; + let encoding = cached.tokenizer.encode(formatted.as_str(), true) + .map_err(|e| format!("Encode: {}", e)).unwrap(); + let input_ids: Vec = encoding.get_ids().to_vec(); + let input_len = input_ids.len(); + console_log!("[Coder] Syöte: {} tokenia", input_len); - // Prefill - let input = match Tensor::new(input_ids.as_slice(), &device).and_then(|t| t.unsqueeze(0)) { - Ok(t) => t, - Err(e) => { console_log!("[Coder] Tensor: {}", e); return; } - }; - let logits = match model.forward(&input, 0) { - Ok(l) => l, - Err(e) => { console_log!("[Coder] Forward (prefill): {}", e); return; } - }; + let device = Device::Cpu; + let start_gen = perf.now(); + let eos_token = 151645u32; + let temperature: f32 = 0.7; + let top_k: usize = 40; + let repetition_penalty: f32 = 1.15; - let logits = logits.squeeze(0).unwrap(); - let logits = if logits.dims().len() == 2 { - logits.get(logits.dim(0).unwrap() - 1).unwrap() - } else { - logits - }; + // Nollataan KV-cache edellisestä promptista + cached.model.clear_kv_cache(); - // Sampling-parametrit - let temperature: f32 = 0.7; - let top_k: usize = 40; - let repetition_penalty: f32 = 1.15; - let mut all_generated: Vec = Vec::new(); - - let mut next_token = crate::sampling::sample_top_k_with_penalty(&logits, top_k, temperature, &all_generated, repetition_penalty); - - if next_token != eos_token { - if let Ok(text) = tokenizer.decode(&[next_token], true) { - generated_text.push_str(&text); - let mut chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "Qwen2.5-Coder" }); - if let Some(ref tid) = task_id { chunk.as_object_mut().unwrap().insert("task_id".to_string(), serde_json::json!(tid)); } - let _ = ws.borrow().send_with_str(&chunk.to_string()); - } - all_generated.push(next_token); - tokens_generated += 1; - } - - // Autoregressive - let mut pos = input_len; - for _ in 1..max_new_tokens { - if next_token == eos_token { break; } - - let input = match Tensor::new(&[next_token], &device).and_then(|t| t.unsqueeze(0)) { - Ok(t) => t, - Err(e) => { console_log!("[Coder] Tensor: {}", e); break; } - }; - let logits = match model.forward(&input, pos) { - Ok(l) => l, - Err(e) => { console_log!("[Coder] Forward pos {}: {}", pos, e); break; } - }; + let mut generated_text = String::new(); + let mut tokens_generated: usize = 0; + let mut all_generated: Vec = Vec::new(); + // Prefill + let input = Tensor::new(input_ids.as_slice(), &device).and_then(|t| t.unsqueeze(0)).unwrap(); + let logits = cached.model.forward(&input, 0).unwrap(); let logits = logits.squeeze(0).unwrap(); let logits = if logits.dims().len() == 2 { logits.get(logits.dim(0).unwrap() - 1).unwrap() - } else { - logits - }; - next_token = crate::sampling::sample_top_k_with_penalty(&logits, top_k, temperature, &all_generated, repetition_penalty); - pos += 1; + } else { logits }; - if next_token == eos_token { break; } + let mut next_token = crate::sampling::sample_top_k_with_penalty(&logits, top_k, temperature, &all_generated, repetition_penalty); - if let Ok(text) = tokenizer.decode(&[next_token], true) { - generated_text.push_str(&text); - - // Stop-sekvenssit: katkaistaan kun malli alkaa selittää - let lower = generated_text.to_lowercase(); - if lower.contains("\n###") || lower.contains("\nexplanation") || lower.contains("\nnote:") || lower.contains("\noutput:") || lower.contains("\n```\n\n") { - // Trimmataan selitysosuus pois - for stop in &["\n###", "\nExplanation", "\nNote:", "\nOutput:", "\n```\n\n"] { - if let Some(pos) = generated_text.find(stop) { - generated_text.truncate(pos); - } - } - break; - } - - let mut chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "Qwen2.5-Coder" }); + if next_token != eos_token { + if let Ok(text) = cached.tokenizer.decode(&[next_token], true) { + generated_text.push_str(&text); + let mut chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "Qwen2.5-Coder" }); if let Some(ref tid) = task_id { chunk.as_object_mut().unwrap().insert("task_id".to_string(), serde_json::json!(tid)); } - let _ = ws.borrow().send_with_str(&chunk.to_string()); + let _ = ws.borrow().send_with_str(&chunk.to_string()); + } + all_generated.push(next_token); + tokens_generated += 1; } - all_generated.push(next_token); - tokens_generated += 1; - // Yield — vapautetaan selaimen event loop joka tokenin jälkeen - crate::sleep_ms(0).await; - } + // Autoregressive + let mut pos = input_len; + for _ in 1..max_new_tokens { + if next_token == eos_token { break; } + + let input = Tensor::new(&[next_token], &device).and_then(|t| t.unsqueeze(0)).unwrap(); + let logits = match cached.model.forward(&input, pos) { + Ok(l) => l, + Err(e) => { console_log!("[Coder] Forward pos {}: {}", pos, e); break; } + }; + + let logits = logits.squeeze(0).unwrap(); + let logits = if logits.dims().len() == 2 { + logits.get(logits.dim(0).unwrap() - 1).unwrap() + } else { logits }; + next_token = crate::sampling::sample_top_k_with_penalty(&logits, top_k, temperature, &all_generated, repetition_penalty); + pos += 1; + + if next_token == eos_token { break; } + + if let Ok(text) = cached.tokenizer.decode(&[next_token], true) { + generated_text.push_str(&text); + + // Stop-sekvenssit: katkaistaan kun malli alkaa selittää + let lower = generated_text.to_lowercase(); + if lower.contains("\n###") || lower.contains("\nexplanation") || lower.contains("\nnote:") || lower.contains("\noutput:") || lower.contains("\n```\n\n") { + for stop in &["\n###", "\nExplanation", "\nNote:", "\nOutput:", "\n```\n\n"] { + if let Some(pos) = generated_text.find(stop) { + generated_text.truncate(pos); + } + } + break; + } + + let mut chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "Qwen2.5-Coder" }); + if let Some(ref tid) = task_id { chunk.as_object_mut().unwrap().insert("task_id".to_string(), serde_json::json!(tid)); } + let _ = ws.borrow().send_with_str(&chunk.to_string()); + } + all_generated.push(next_token); + tokens_generated += 1; + } + + let gen_time = perf.now() - start_gen; + (generated_text, tokens_generated, gen_time) + }); - let gen_time = perf.now() - start_gen; let tokens_per_sec = if gen_time > 0.0 { (tokens_generated as f64 / gen_time) * 1000.0 } else { 0.0 }; console_log!("[Coder] {} tokenia | {:.0}ms | {:.1} tok/s", tokens_generated, gen_time, tokens_per_sec);