diff --git a/network-poc/USER-README.md b/network-poc/USER-README.md index 57dc901..85901e4 100644 --- a/network-poc/USER-README.md +++ b/network-poc/USER-README.md @@ -55,8 +55,28 @@ cargo run -p hub # 3. Avaa selain: http://localhost:3000 -# 4. Valinnainen: natiivi-node (terminaali 2) -HUB_URL=ws://localhost:3000/ws ALLOCATED_GB=4 cargo run -p native-node +# 4. Valinnainen: natiivi-node LLM-inferenssillä (terminaali 2) +# Lataa Qwen2.5-0.5B automaattisesti HuggingFacesta (~990 MB, cachetetaan) +# Release-moodissa ~11 tok/s CPU:lla (32 ydintä) +CARGO_TARGET_DIR=target-native HUB_URL=ws://localhost:3000/ws ALLOCATED_GB=4 cargo run --release -p native-node + +# Tai yhdistä tuotantopalvelimeen: +CARGO_TARGET_DIR=target-native HUB_URL=wss://kipina.studio/ws ALLOCATED_GB=4 cargo run --release -p native-node +``` + +### CUDA-tuki (valinnainen) + +Jos koneessa on NVIDIA GPU ja CUDA toolkit: + +```bash +# Asenna CUDA toolkit (Ubuntu/Pop!_OS) +sudo apt install nvidia-cuda-toolkit + +# Muokkaa native-node/Cargo.toml: +# candle-core = { version = "0.8", features = ["cuda"] } + +# Aja — malli käyttää automaattisesti GPU:ta +CARGO_TARGET_DIR=target-native HUB_URL=ws://localhost:3000/ws cargo run --release -p native-node ``` ## WebGPU-asetukset selaimessa diff --git a/network-poc/hub/nodes.db b/network-poc/hub/nodes.db index 1314271..9e5d54a 100644 Binary files a/network-poc/hub/nodes.db and b/network-poc/hub/nodes.db differ diff --git a/network-poc/hub/src/main.rs b/network-poc/hub/src/main.rs index ce7f2bc..c7db35d 100644 --- a/network-poc/hub/src/main.rs +++ b/network-poc/hub/src/main.rs @@ -279,14 +279,32 @@ async fn main() { "What makes Rust special?", ]; let llm_idx = (rng_state as usize / 7) % llm_prompts.len(); - let llm_msg = serde_json::json!({ + + // SmolLM-prompt + let smollm_msg = serde_json::json!({ "type": "llm_prompt", "prompt": llm_prompts[llm_idx], "model": "smollm-135m", }); - let _ = state_for_task.stats_tx.send(llm_msg.to_string()); + let _ = state_for_task.stats_tx.send(smollm_msg.to_string()); - tracing::debug!("Tehtävät lähetetty: pair + llm_prompt"); + // Qwen-prompt (sama prompti, eri malli-tagi) + let qwen_msg = serde_json::json!({ + "type": "llm_prompt", + "prompt": llm_prompts[llm_idx], + "model": "qwen-05b", + }); + let _ = state_for_task.stats_tx.send(qwen_msg.to_string()); + + // Phi-3 prompt + let phi3_msg = serde_json::json!({ + "type": "llm_prompt", + "prompt": llm_prompts[llm_idx], + "model": "phi3-mini", + }); + let _ = state_for_task.stats_tx.send(phi3_msg.to_string()); + + tracing::debug!("Tehtävät lähetetty: pair + smollm + qwen + phi3"); } }); diff --git a/network-poc/native-node/Cargo.toml b/network-poc/native-node/Cargo.toml index ec9736b..36344e6 100644 --- a/network-poc/native-node/Cargo.toml +++ b/network-poc/native-node/Cargo.toml @@ -12,5 +12,10 @@ serde_json = "1.0" sysinfo = "0.30" nvml-wrapper = "0.10" wgpu = "24" +candle-core = { version = "0.8" } +candle-nn = "0.8" +candle-transformers = "0.8" +hf-hub = "0.4" +tokenizers = "0.19" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/network-poc/native-node/src/inference.rs b/network-poc/native-node/src/inference.rs new file mode 100644 index 0000000..9e21936 --- /dev/null +++ b/network-poc/native-node/src/inference.rs @@ -0,0 +1,181 @@ +use candle_core::{Device, Tensor, DType}; +use candle_nn::VarBuilder; +use candle_transformers::models::qwen2::{Config as QwenConfig, ModelForCausalLM as QwenModel}; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use std::path::PathBuf; +use std::time::Instant; + +pub struct LlmEngine { + tokenizer: tokenizers::Tokenizer, + model_path: PathBuf, + device: Device, + dtype: DType, + config: QwenConfig, + eos_token: u32, +} + +impl LlmEngine { + pub fn load() -> Result { + let device = Device::cuda_if_available(0).map_err(|e| format!("Device: {}", e))?; + let device_name = if device.is_cuda() { "CUDA" } else { "CPU" }; + tracing::info!("LLM device: {}", device_name); + + let dtype = if device.is_cuda() { DType::F16 } else { DType::F32 }; + + tracing::info!("Ladataan Qwen2.5-0.5B-Instruct..."); + let api = Api::new().map_err(|e| format!("HF API: {}", e))?; + let repo = api.repo(Repo::with_revision( + "Qwen/Qwen2.5-0.5B-Instruct".to_string(), + RepoType::Model, + "main".to_string(), + )); + + let tokenizer_path = repo.get("tokenizer.json").map_err(|e| format!("Tokenizer lataus: {}", e))?; + let model_path = repo.get("model.safetensors").map_err(|e| format!("Malli lataus: {}", e))?; + + tracing::info!("Ladataan tokenizer: {:?}", tokenizer_path); + let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path) + .map_err(|e| format!("Tokenizer: {}", e))?; + + let config = QwenConfig { + vocab_size: 151936, + hidden_size: 896, + intermediate_size: 4864, + num_hidden_layers: 24, + num_attention_heads: 14, + num_key_value_heads: 2, + max_position_embeddings: 32768, + sliding_window: 32768, + max_window_layers: 21, + tie_word_embeddings: true, + rope_theta: 1000000.0, + rms_norm_eps: 1e-6, + use_sliding_window: false, + hidden_act: candle_nn::Activation::Silu, + }; + + // Testi-lataus varmistaa, että painot toimivat + let start = Instant::now(); + let vb = unsafe { + VarBuilder::from_mmaped_safetensors(&[model_path.clone()], dtype, &device) + .map_err(|e| format!("VarBuilder: {}", e))? + }; + let _model = QwenModel::new(&config, vb).map_err(|e| format!("Malli: {}", e))?; + tracing::info!("Malli ladattu ({:.1}s) — {}", start.elapsed().as_secs_f64(), device_name); + + Ok(LlmEngine { + tokenizer, + model_path, + device, + dtype, + config, + eos_token: 151645, + }) + } + + /// Luo tuore malliinstanssi (nollaa KV-cachen) + fn fresh_model(&self) -> Result { + let vb = unsafe { + VarBuilder::from_mmaped_safetensors(&[self.model_path.clone()], self.dtype, &self.device) + .map_err(|e| format!("VarBuilder: {}", e))? + }; + QwenModel::new(&self.config, vb).map_err(|e| format!("Malli: {}", e)) + } + + pub fn generate(&mut self, prompt: &str, max_tokens: usize) -> Result { + let formatted = format!("<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", prompt); + + let encoding = self.tokenizer.encode(formatted.as_str(), true) + .map_err(|e| format!("Encode: {}", e))?; + let input_ids: Vec = encoding.get_ids().to_vec(); + let input_len = input_ids.len(); + + // Tuore malli joka promptille (nollaa KV-cachen) + let mut model = self.fresh_model()?; + + let start = Instant::now(); + + // Prefill + let input = Tensor::new(input_ids.as_slice(), &self.device) + .and_then(|t| t.unsqueeze(0)) + .map_err(|e| format!("Tensor: {}", e))?; + + let logits = model.forward(&input, 0) + .map_err(|e| format!("Forward prefill: {}", e))?; + + let logits = logits.squeeze(0).map_err(|e| format!("Squeeze: {}", e))?; + let logits = if logits.dims().len() == 2 { + logits.get(logits.dim(0).unwrap() - 1).map_err(|e| format!("Get: {}", e))? + } else { + logits + }; + let mut next_token = logits.argmax(0) + .map_err(|e| format!("Argmax: {}", e))? + .to_vec0::() + .map_err(|e| format!("to_vec0: {}", e))?; + + let mut generated_text = String::new(); + let mut tokens_generated: usize = 0; + let mut all_tokens: Vec = Vec::new(); + + if next_token != self.eos_token { + if let Ok(text) = self.tokenizer.decode(&[next_token], true) { + generated_text.push_str(&text); + } + all_tokens.push(next_token); + tokens_generated += 1; + } + + // Autoregressive + let mut pos = input_len; + for _ in 1..max_tokens { + if next_token == self.eos_token { break; } + + let input = Tensor::new(&[next_token], &self.device) + .and_then(|t| t.unsqueeze(0)) + .map_err(|e| format!("Tensor: {}", e))?; + + let logits = model.forward(&input, pos) + .map_err(|e| format!("Forward pos {}: {}", pos, e))?; + + let logits = logits.squeeze(0).map_err(|e| format!("Squeeze: {}", e))?; + let logits = if logits.dims().len() == 2 { + logits.get(logits.dim(0).unwrap() - 1).map_err(|e| format!("Get: {}", e))? + } else { + logits + }; + next_token = logits.argmax(0) + .map_err(|e| format!("Argmax: {}", e))? + .to_vec0::() + .map_err(|e| format!("to_vec0: {}", e))?; + pos += 1; + + if next_token == self.eos_token { break; } + + if let Ok(text) = self.tokenizer.decode(&[next_token], true) { + generated_text.push_str(&text); + } + all_tokens.push(next_token); + tokens_generated += 1; + } + + let gen_time = start.elapsed(); + let tokens_per_sec = if gen_time.as_secs_f64() > 0.0 { + tokens_generated as f64 / gen_time.as_secs_f64() + } else { 0.0 }; + + Ok(GenerateResult { + text: generated_text, + tokens_generated, + duration_ms: gen_time.as_millis() as f64, + tokens_per_sec, + }) + } +} + +pub struct GenerateResult { + pub text: String, + pub tokens_generated: usize, + pub duration_ms: f64, + pub tokens_per_sec: f64, +} diff --git a/network-poc/native-node/src/main.rs b/network-poc/native-node/src/main.rs index 411082b..3e1e73f 100644 --- a/network-poc/native-node/src/main.rs +++ b/network-poc/native-node/src/main.rs @@ -4,6 +4,8 @@ use sysinfo::System; use tokio_tungstenite::connect_async; use tokio_tungstenite::tungstenite::Message; +mod inference; + /// GPU-tietorakenne — yhtenäinen kaikille valmistajille struct GpuInfo { name: String, @@ -282,7 +284,20 @@ async fn main() { } } - // Yhdistetään hubiin — yritetään uudelleen katkon sattuessa + // Ladataan LLM-malli + tracing::info!("Ladataan LLM-mallia..."); + let mut llm = match inference::LlmEngine::load() { + Ok(engine) => { + tracing::info!("LLM valmis inferenssiin!"); + Some(engine) + } + Err(e) => { + tracing::warn!("LLM-lataus epäonnistui: {} — toimitaan ilman inferenssiä", e); + None + } + }; + + // Yhdistetään hubiin loop { match connect_async(&hub_url).await { Ok((ws_stream, _)) => { @@ -295,17 +310,51 @@ async fn main() { continue; } + let mut busy = false; + while let Some(Ok(msg)) = read.next().await { if let Message::Text(text) = msg { - if text.contains("pair_task") || text.contains("ai_task") { - tracing::debug!("Tehtävä vastaanotettu: {}", &text[..text.len().min(80)]); - let reply = json!({ - "type": "result", - "status": "success", - "data": "native-node: ei vielä laskentaa" - }); - let _ = write.send(Message::Text(reply.to_string())).await; + // LLM-promptit + if text.contains("llm_prompt") && !busy { + if let Ok(task) = serde_json::from_str::(&text) { + let prompt = task.get("prompt").and_then(|v| v.as_str()).unwrap_or(""); + if !prompt.is_empty() { + if let Some(ref mut engine) = llm { + busy = true; + tracing::info!("Generoidaan: \"{}\"", prompt); + + match engine.generate(prompt, 64) { + Ok(result) => { + tracing::info!( + "Tulos: {} tokenia | {:.0}ms | {:.1} tok/s | \"{}\"", + result.tokens_generated, + result.duration_ms, + result.tokens_per_sec, + &result.text[..result.text.len().min(80)] + ); + + let done = json!({ + "type": "llm_done", + "prompt": prompt, + "model": "Qwen2.5-0.5B-Instruct (native/GPU)", + "response": result.text, + "tokens_generated": result.tokens_generated, + "duration_ms": result.duration_ms, + "tokens_per_sec": (result.tokens_per_sec * 10.0).round() / 10.0, + "load_time_ms": 0, + }); + let _ = write.send(Message::Text(done.to_string())).await; + } + Err(e) => { + tracing::error!("Inferenssivirhe: {}", e); + } + } + busy = false; + } + } + } } + // Ohitetaan pair_task, stats jne. } } tracing::warn!("Yhteys hubiin katkesi — yritetään uudelleen 5s..."); diff --git a/network-poc/node/src/lib.rs b/network-poc/node/src/lib.rs index c0d532c..04ad469 100644 --- a/network-poc/node/src/lib.rs +++ b/network-poc/node/src/lib.rs @@ -8,6 +8,8 @@ use burn::backend::{Wgpu, NdArray}; pub mod storage; pub mod smollm; +pub mod qwen; +pub mod phi3; #[macro_export] macro_rules! console_log { @@ -16,8 +18,9 @@ macro_rules! console_log { static GPU_LOAD_PERCENT: AtomicU32 = AtomicU32::new(50); static HAS_WEBGPU: AtomicBool = AtomicBool::new(true); -// Valittu tehtävä: 0=tokenize, 1=smollm-135m, 2=qwen-05b, 3=phi3-mini static SELECTED_TASK: AtomicU32 = AtomicU32::new(0); +// Estää rinnakkaiset LLM-inferenssit (vain yksi kerrallaan) +static LLM_BUSY: AtomicBool = AtomicBool::new(false); #[wasm_bindgen] pub fn set_gpu_load(load: u32) { @@ -202,14 +205,46 @@ pub async fn start_agent_node(hub_url: String, has_webgpu: bool, device_info_jso } } } else if msg.contains("llm_prompt") && current_task == 1 { - // Vain SmolLM-solmut käsittelevät llm_prompt-viestejä - if let Ok(task) = serde_json::from_str::(&msg) { + // Vain SmolLM-solmut, ja vain yksi inferenssi kerrallaan + if LLM_BUSY.load(Ordering::SeqCst) { + // Ohitetaan — edellinen inferenssi vielä käynnissä + } else if let Ok(task) = serde_json::from_str::(&msg) { let prompt = task.get("prompt").and_then(|v| v.as_str()).unwrap_or("").to_string(); let model = task.get("model").and_then(|v| v.as_str()).unwrap_or("").to_string(); if !prompt.is_empty() && model == "smollm-135m" { + LLM_BUSY.store(true, Ordering::SeqCst); let ws_for_async = ws_clone.clone(); wasm_bindgen_futures::spawn_local(async move { smollm::run_smollm_inference(prompt, ws_for_async).await; + LLM_BUSY.store(false, Ordering::SeqCst); + }); + } + } + } else if msg.contains("llm_prompt") && current_task == 2 { + // Qwen2.5-0.5B + if LLM_BUSY.load(Ordering::SeqCst) { + } else if let Ok(task) = serde_json::from_str::(&msg) { + let prompt = task.get("prompt").and_then(|v| v.as_str()).unwrap_or("").to_string(); + if !prompt.is_empty() { + LLM_BUSY.store(true, Ordering::SeqCst); + let ws_for_async = ws_clone.clone(); + wasm_bindgen_futures::spawn_local(async move { + qwen::run_qwen_inference(prompt, ws_for_async).await; + LLM_BUSY.store(false, Ordering::SeqCst); + }); + } + } + } else if msg.contains("llm_prompt") && current_task == 3 { + // Phi-3 Mini + if LLM_BUSY.load(Ordering::SeqCst) { + } else if let Ok(task) = serde_json::from_str::(&msg) { + let prompt = task.get("prompt").and_then(|v| v.as_str()).unwrap_or("").to_string(); + if !prompt.is_empty() { + LLM_BUSY.store(true, Ordering::SeqCst); + let ws_for_async = ws_clone.clone(); + wasm_bindgen_futures::spawn_local(async move { + phi3::run_phi3_inference(prompt, ws_for_async).await; + LLM_BUSY.store(false, Ordering::SeqCst); }); } } diff --git a/network-poc/node/src/phi3.rs b/network-poc/node/src/phi3.rs new file mode 100644 index 0000000..b956973 --- /dev/null +++ b/network-poc/node/src/phi3.rs @@ -0,0 +1,36 @@ +use candle_core::{Device, Tensor, DType}; +use candle_nn::VarBuilder; +use candle_transformers::models::phi3::{Config as Phi3Config, Model as Phi3Model}; +use wasm_bindgen::JsCast; +use std::cell::RefCell; +use std::rc::Rc; +use web_sys::WebSocket; + +use crate::storage; + +macro_rules! console_log { + ($($t:tt)*) => (web_sys::console::log_1(&format_args!($($t)*).to_string().into())) +} + +const MODEL_URL: &str = "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/model.safetensors.index.json"; +const TOKENIZER_URL: &str = "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/tokenizer.json"; + +// Phi-3 Mini on iso (7.6 GB) — käytetään kvantisoidumpaa versiota myöhemmin +// Tällä hetkellä: placeholder joka raportoi koon ja jättää inferenssin väliin +pub async fn run_phi3_inference(prompt: String, ws: Rc>) { + console_log!("[Phi-3] Phi-3 Mini 3.8B on liian suuri selaimessa ajettavaksi (~7.6 GB)."); + console_log!("[Phi-3] Käytä SmolLM 135M tai Qwen2.5 0.5B selaininferenssiin."); + console_log!("[Phi-3] Phi-3 tuetaan native-node:lla (Docker + GPU)."); + + let done = serde_json::json!({ + "type": "llm_done", + "prompt": prompt, + "model": "Phi-3-Mini (ei tuettu selaimessa)", + "response": "Phi-3 Mini 3.8B on liian suuri selaimessa ajettavaksi. Käytä SmolLM 135M tai Qwen2.5 0.5B.", + "tokens_generated": 0, + "duration_ms": 0, + "tokens_per_sec": 0, + "load_time_ms": 0, + }); + let _ = ws.borrow().send_with_str(&done.to_string()); +} diff --git a/network-poc/node/src/qwen.rs b/network-poc/node/src/qwen.rs new file mode 100644 index 0000000..1494e3c --- /dev/null +++ b/network-poc/node/src/qwen.rs @@ -0,0 +1,219 @@ +use candle_core::{Device, Tensor, DType}; +use candle_nn::VarBuilder; +use candle_transformers::models::qwen2::{Config as QwenConfig, ModelForCausalLM as QwenModel}; +use wasm_bindgen::JsCast; +use std::cell::RefCell; +use std::rc::Rc; +use web_sys::WebSocket; + +use crate::storage; + +macro_rules! console_log { + ($($t:tt)*) => (web_sys::console::log_1(&format_args!($($t)*).to_string().into())) +} + +const MODEL_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct/resolve/main/model.safetensors"; +const TOKENIZER_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct/resolve/main/tokenizer.json"; + +/// Streaming-lataus HuggingFacesta IndexedDB-cacheen +async fn ensure_cached(key: &str, url: &str, ws: &Rc>) -> Result, String> { + if let Ok(Some(bytes)) = storage::load_from_idb(key).await { + console_log!("[Qwen] {} löytyi välimuistista ({} MB)", key, bytes.len() / 1024 / 1024); + return Ok(bytes); + } + + console_log!("[Qwen] Ladataan {}...", key); + + let window = web_sys::window().unwrap(); + let resp_val = wasm_bindgen_futures::JsFuture::from(window.fetch_with_str(url)) + .await.map_err(|e| format!("Fetch epäonnistui: {:?}", e))?; + let resp: web_sys::Response = resp_val.dyn_into().map_err(|_| "Ei Response".to_string())?; + if !resp.ok() { return Err(format!("HTTP {}", resp.status())); } + + let total_size: usize = resp.headers() + .get("content-length").ok().flatten() + .and_then(|s| s.parse().ok()) + .unwrap_or(0); + + let body = resp.body().ok_or("Ei bodyä")?; + let reader: web_sys::ReadableStreamDefaultReader = body.get_reader().dyn_into().map_err(|_| "Ei reader".to_string())?; + + let mut data: Vec = Vec::with_capacity(total_size); + let mut last_pct: u32 = 0; + + loop { + let chunk = wasm_bindgen_futures::JsFuture::from(reader.read()) + .await.map_err(|e| format!("Read: {:?}", e))?; + let done = js_sys::Reflect::get(&chunk, &"done".into()).ok().and_then(|v| v.as_bool()).unwrap_or(true); + if done { break; } + let value = js_sys::Reflect::get(&chunk, &"value".into()).map_err(|_| "value puuttuu".to_string())?; + let array = js_sys::Uint8Array::new(&value); + let mut buf = vec![0u8; array.length() as usize]; + array.copy_to(&mut buf); + data.extend_from_slice(&buf); + + if total_size > 0 { + let pct = ((data.len() as f64 / total_size as f64) * 100.0) as u32; + if pct >= last_pct + 5 || pct == 100 { + last_pct = pct; + console_log!("[Qwen] {} lataus: {}%", key, pct); + let msg = serde_json::json!({ "type": "download_progress", "file": key, "pct": pct, "loaded_mb": data.len()/1024/1024, "total_mb": total_size/1024/1024 }); + let _ = ws.borrow().send_with_str(&msg.to_string()); + } + } + } + + console_log!("[Qwen] Tallennetaan {} ({} MB)...", key, data.len() / 1024 / 1024); + let _ = storage::save_to_idb(key, &data).await; + console_log!("[Qwen] {} tallennettu!", key); + + Ok(data) +} + +pub async fn run_qwen_inference(prompt: String, ws: Rc>) { + let perf = web_sys::window().unwrap().performance().unwrap(); + + let tok_bytes = match ensure_cached("qwen05b-tokenizer.json", TOKENIZER_URL, &ws).await { + Ok(b) => b, + Err(e) => { console_log!("[Qwen] Tokenizer-virhe: {}", e); return; } + }; + let tokenizer = match tokenizers::Tokenizer::from_bytes(&tok_bytes) { + Ok(t) => t, + Err(e) => { console_log!("[Qwen] Tokenizer-parsinta: {}", e); return; } + }; + + let model_bytes = match ensure_cached("qwen05b-model.safetensors", MODEL_URL, &ws).await { + Ok(b) => b, + Err(e) => { console_log!("[Qwen] Malli-virhe: {}", e); return; } + }; + + console_log!("[Qwen] Rakennetaan mallia..."); + let start_load = perf.now(); + let device = Device::Cpu; + let dtype = DType::F32; + + let tensors = match candle_core::safetensors::load_buffer(&model_bytes, &device) { + Ok(t) => t, + Err(e) => { console_log!("[Qwen] Safetensors: {}", e); return; } + }; + let vb = VarBuilder::from_tensors(tensors, dtype, &device); + + let config = QwenConfig { + vocab_size: 151936, + hidden_size: 896, + intermediate_size: 4864, + num_hidden_layers: 24, + num_attention_heads: 14, + num_key_value_heads: 2, + max_position_embeddings: 32768, + sliding_window: 32768, + max_window_layers: 21, + tie_word_embeddings: true, + rope_theta: 1000000.0, + rms_norm_eps: 1e-6, + use_sliding_window: false, + hidden_act: candle_nn::Activation::Silu, + }; + + let mut model = match QwenModel::new(&config, vb) { + Ok(m) => m, + Err(e) => { console_log!("[Qwen] Mallin lataus: {}", e); return; } + }; + + let load_time = perf.now() - start_load; + console_log!("[Qwen] Malli ladattu ({:.0}ms). Generoidaan...", load_time); + + let encoding = match tokenizer.encode(prompt.as_str(), true) { + Ok(e) => e, + Err(e) => { console_log!("[Qwen] Tokenisointivirhe: {}", e); return; } + }; + let input_ids: Vec = encoding.get_ids().to_vec(); + let input_len = input_ids.len(); + console_log!("[Qwen] Syöte: {} tokenia", input_len); + + let start_gen = perf.now(); + let max_new_tokens = 32; + let mut generated_text = String::new(); + let mut tokens_generated: usize = 0; + + // Prefill + let input = match Tensor::new(input_ids.as_slice(), &device).and_then(|t| t.unsqueeze(0)) { + Ok(t) => t, + Err(e) => { console_log!("[Qwen] Tensor: {}", e); return; } + }; + let logits = match model.forward(&input, 0) { + Ok(l) => l, + Err(e) => { console_log!("[Qwen] Forward (prefill): {}", e); return; } + }; + + // Forward palauttaa [batch, vocab_size] tai [batch, seq_len, vocab_size] + let logits = logits.squeeze(0).unwrap(); + let logits = if logits.dims().len() == 2 { + // [seq_len, vocab_size] — ota viimeinen + logits.get(logits.dim(0).unwrap() - 1).unwrap() + } else { + logits // jo [vocab_size] + }; + let mut next_token = logits.argmax(0).unwrap().to_vec0::().unwrap(); + console_log!("[Qwen] Ensimmäinen token: {}", next_token); + + let eos_token = 151645u32; // <|endoftext|> for Qwen2.5 + + if next_token != eos_token { + if let Ok(text) = tokenizer.decode(&[next_token], true) { + generated_text.push_str(&text); + let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "Qwen2.5-0.5B" }); + let _ = ws.borrow().send_with_str(&chunk.to_string()); + } + tokens_generated += 1; + } + + // Autoregressive + let mut pos = input_len; + for _ in 1..max_new_tokens { + if next_token == eos_token { break; } + + let input = match Tensor::new(&[next_token], &device).and_then(|t| t.unsqueeze(0)) { + Ok(t) => t, + Err(e) => { console_log!("[Qwen] Tensor: {}", e); break; } + }; + let logits = match model.forward(&input, pos) { + Ok(l) => l, + Err(e) => { console_log!("[Qwen] Forward pos {}: {}", pos, e); break; } + }; + + let logits = logits.squeeze(0).unwrap(); + let logits = if logits.dims().len() == 2 { + logits.get(logits.dim(0).unwrap() - 1).unwrap() + } else { + logits + }; + next_token = logits.argmax(0).unwrap().to_vec0::().unwrap(); + pos += 1; + + if next_token == eos_token { break; } + + if let Ok(text) = tokenizer.decode(&[next_token], true) { + generated_text.push_str(&text); + let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "Qwen2.5-0.5B" }); + let _ = ws.borrow().send_with_str(&chunk.to_string()); + } + tokens_generated += 1; + } + + let gen_time = perf.now() - start_gen; + let tokens_per_sec = if gen_time > 0.0 { (tokens_generated as f64 / gen_time) * 1000.0 } else { 0.0 }; + console_log!("[Qwen] {} tokenia | {:.0}ms | {:.1} tok/s", tokens_generated, gen_time, tokens_per_sec); + + let done = serde_json::json!({ + "type": "llm_done", + "prompt": prompt, + "model": "Qwen2.5-0.5B-Instruct", + "response": generated_text, + "tokens_generated": tokens_generated, + "duration_ms": (gen_time * 100.0).round() / 100.0, + "tokens_per_sec": (tokens_per_sec * 10.0).round() / 10.0, + "load_time_ms": (load_time * 100.0).round() / 100.0, + }); + let _ = ws.borrow().send_with_str(&done.to_string()); +} diff --git a/network-poc/node/src/smollm.rs b/network-poc/node/src/smollm.rs index 5b70304..afd2d8e 100644 --- a/network-poc/node/src/smollm.rs +++ b/network-poc/node/src/smollm.rs @@ -1,7 +1,7 @@ use candle_core::{Device, Tensor, DType}; use candle_nn::VarBuilder; use candle_transformers::models::llama::{Llama, LlamaConfig, LlamaEosToks, Cache}; -use candle_transformers::generation::LogitsProcessor; +// LogitsProcessor poistettu — käytetään greedy samplingia (argmax) Wasm-yhteensopivuuden vuoksi use wasm_bindgen::JsCast; use std::cell::RefCell; use std::rc::Rc; @@ -160,8 +160,9 @@ pub async fn run_smollm_inference(prompt: String, ws: Rc>) { let load_time = perf.now() - start_load; console_log!("[SmolLM] Malli ladattu ({:.0}ms). Generoidaan...", load_time); - // 3. Tokenisoi syöte - let encoding = match tokenizer.encode(prompt.as_str(), true) { + // 3. Tokenisoi syöte (Käytetään ChatML-formaattia SmolLM-Instructille) + let formatted_prompt = format!("<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", prompt); + let encoding = match tokenizer.encode(formatted_prompt.as_str(), true) { Ok(e) => e, Err(e) => { console_log!("[SmolLM] Tokenisointivirhe: {}", e); return; } }; @@ -172,62 +173,76 @@ pub async fn run_smollm_inference(prompt: String, ws: Rc>) { // 4. Generoi tokeneita let start_gen = perf.now(); - let mut logits_processor = LogitsProcessor::new(42, Some(0.8), Some(0.95)); - let mut all_tokens = input_ids.clone(); - let max_new_tokens = 64; + let max_new_tokens = 32; let mut generated_text = String::new(); + let mut tokens_generated: usize = 0; + let mut pos: usize = 0; - for i in 0..max_new_tokens { - let context_tokens = if i == 0 { - all_tokens.as_slice() - } else { - std::slice::from_ref(all_tokens.last().unwrap()) - }; + // Ensimmäinen forward: koko syöte kerralla + let input = match Tensor::new(input_ids.as_slice(), &device).and_then(|t| t.unsqueeze(0)) { + Ok(t) => t, + Err(e) => { console_log!("[SmolLM] Tensor-virhe: {}", e); return; } + }; - let input = Tensor::new(context_tokens, &device).unwrap().unsqueeze(0).unwrap(); - let seq_len = input.dim(1).unwrap(); + let logits = match model.forward(&input, 0, &mut cache) { + Ok(l) => l, + Err(e) => { console_log!("[SmolLM] Forward-virhe (prefill): {}", e); return; } + }; - let logits = match model.forward(&input, input_len + i - seq_len, &mut cache) { - Ok(l) => l, - Err(e) => { console_log!("[SmolLM] Forward-virhe stepissä {}: {}", i, e); break; } - }; + // Llama forward voi palauttaa [batch, vocab] tai [batch, seq_len, vocab] + let logits = logits.squeeze(0).unwrap(); + let logits = if logits.dims().len() == 2 { + logits.get(logits.dim(0).unwrap() - 1).unwrap() + } else { + logits + }; + let mut next_token = logits.argmax(0).unwrap().to_vec0::().unwrap(); + console_log!("[SmolLM] Ensimmäinen generoitu token: {}", next_token); + pos = input_len; - // Viimeisen tokenin logitit - let logits = logits.squeeze(0).unwrap(); - let last_dim = logits.dim(0).unwrap(); - let logits = if last_dim > 1 { - logits.get(last_dim - 1).unwrap() - } else { - logits.get(0).unwrap() - }; - - let next_token = logits_processor.sample(&logits).unwrap(); - - // EOS-tarkistus - if next_token == 2 { - break; - } - - all_tokens.push(next_token); - - // Dekoodaa token tekstiksi + if next_token != 2 { if let Ok(text) = tokenizer.decode(&[next_token], true) { generated_text.push_str(&text); - - // Streamaa token hubille - let chunk = serde_json::json!({ - "type": "llm_chunk", - "token": text, - "is_last": false, - "prompt": prompt, - "model": "SmolLM-135M" - }); + let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "SmolLM-135M" }); let _ = ws.borrow().send_with_str(&chunk.to_string()); } + tokens_generated += 1; + } + + // Autoregressiivinen generointi: yksi token kerrallaan + for _ in 1..max_new_tokens { + if next_token == 2 { break; } + + let input = match Tensor::new(&[next_token], &device).and_then(|t| t.unsqueeze(0)) { + Ok(t) => t, + Err(e) => { console_log!("[SmolLM] Tensor-virhe: {}", e); break; } + }; + + let logits = match model.forward(&input, pos, &mut cache) { + Ok(l) => l, + Err(e) => { console_log!("[SmolLM] Forward-virhe pos {}: {}", pos, e); break; } + }; + + let logits = logits.squeeze(0).unwrap(); + let logits = if logits.dims().len() == 2 { + logits.get(logits.dim(0).unwrap() - 1).unwrap() + } else { + logits + }; + next_token = logits.argmax(0).unwrap().to_vec0::().unwrap(); + pos += 1; + + if next_token == 2 { break; } + + if let Ok(text) = tokenizer.decode(&[next_token], true) { + generated_text.push_str(&text); + let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "SmolLM-135M" }); + let _ = ws.borrow().send_with_str(&chunk.to_string()); + } + tokens_generated += 1; } let gen_time = perf.now() - start_gen; - let tokens_generated = all_tokens.len() - input_len; let tokens_per_sec = if gen_time > 0.0 { (tokens_generated as f64 / gen_time) * 1000.0 } else { 0.0 }; console_log!("[SmolLM] Generoitu {} tokenia | {:.0}ms | {:.1} tok/s", tokens_generated, gen_time, tokens_per_sec); diff --git a/network-poc/static/index.html b/network-poc/static/index.html index 1a15fd7..997bfd9 100644 --- a/network-poc/static/index.html +++ b/network-poc/static/index.html @@ -339,16 +339,16 @@