use candle_core::{Device, Tensor, DType}; use candle_nn::VarBuilder; use candle_transformers::models::qwen2::{Config as QwenConfig, ModelForCausalLM as QwenModel}; use wasm_bindgen::JsCast; use std::cell::RefCell; use std::rc::Rc; use web_sys::WebSocket; use crate::storage; macro_rules! console_log { ($($t:tt)*) => (web_sys::console::log_1(&format_args!($($t)*).to_string().into())) } const MODEL_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct/resolve/main/model.safetensors"; const TOKENIZER_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct/resolve/main/tokenizer.json"; /// Streaming-lataus HuggingFacesta IndexedDB-cacheen async fn ensure_cached(key: &str, url: &str, ws: &Rc>) -> Result, String> { if let Ok(Some(bytes)) = storage::load_from_idb(key).await { console_log!("[Qwen] {} löytyi välimuistista ({} MB)", key, bytes.len() / 1024 / 1024); return Ok(bytes); } console_log!("[Qwen] Ladataan {}...", key); let resp = crate::worker_fetch(url).await?; if !resp.ok() { return Err(format!("HTTP {}", resp.status())); } let total_size: usize = resp.headers() .get("content-length").ok().flatten() .and_then(|s| s.parse().ok()) .unwrap_or(0); let body = resp.body().ok_or("Ei bodyä")?; let reader: web_sys::ReadableStreamDefaultReader = body.get_reader().dyn_into().map_err(|_| "Ei reader".to_string())?; let mut data: Vec = Vec::with_capacity(total_size); let mut last_pct: u32 = 0; loop { let chunk = wasm_bindgen_futures::JsFuture::from(reader.read()) .await.map_err(|e| format!("Read: {:?}", e))?; let done = js_sys::Reflect::get(&chunk, &"done".into()).ok().and_then(|v| v.as_bool()).unwrap_or(true); if done { break; } let value = js_sys::Reflect::get(&chunk, &"value".into()).map_err(|_| "value puuttuu".to_string())?; let array = js_sys::Uint8Array::new(&value); let mut buf = vec![0u8; array.length() as usize]; array.copy_to(&mut buf); data.extend_from_slice(&buf); if total_size > 0 { let pct = ((data.len() as f64 / total_size as f64) * 100.0) as u32; if pct >= last_pct + 5 || pct == 100 { last_pct = pct; console_log!("[Qwen] {} lataus: {}%", key, pct); let msg = serde_json::json!({ "type": "download_progress", "file": key, "pct": pct, "loaded_mb": data.len()/1024/1024, "total_mb": total_size/1024/1024 }); let _ = ws.borrow().send_with_str(&msg.to_string()); } } } console_log!("[Qwen] Tallennetaan {} ({} MB)...", key, data.len() / 1024 / 1024); let _ = storage::save_to_idb(key, &data).await; console_log!("[Qwen] {} tallennettu!", key); Ok(data) } pub async fn run_qwen_inference(prompt: String, ws: Rc>) { // performance via crate::perf_now() let tok_bytes = match ensure_cached("qwen05b-tokenizer.json", TOKENIZER_URL, &ws).await { Ok(b) => b, Err(e) => { console_log!("[Qwen] Tokenizer-virhe: {}", e); return; } }; let tokenizer = match tokenizers::Tokenizer::from_bytes(&tok_bytes) { Ok(t) => t, Err(e) => { console_log!("[Qwen] Tokenizer-parsinta: {}", e); return; } }; let model_bytes = match ensure_cached("qwen05b-model.safetensors", MODEL_URL, &ws).await { Ok(b) => b, Err(e) => { console_log!("[Qwen] Malli-virhe: {}", e); return; } }; console_log!("[Qwen] Rakennetaan mallia..."); let start_load = crate::perf_now(); let device = Device::Cpu; let dtype = DType::F32; let tensors = match candle_core::safetensors::load_buffer(&model_bytes, &device) { Ok(t) => t, Err(e) => { console_log!("[Qwen] Safetensors: {}", e); return; } }; let vb = VarBuilder::from_tensors(tensors, dtype, &device); let config = QwenConfig { vocab_size: 151936, hidden_size: 896, intermediate_size: 4864, num_hidden_layers: 24, num_attention_heads: 14, num_key_value_heads: 2, max_position_embeddings: 32768, sliding_window: 32768, max_window_layers: 21, tie_word_embeddings: true, rope_theta: 1000000.0, rms_norm_eps: 1e-6, use_sliding_window: false, hidden_act: candle_nn::Activation::Silu, }; let mut model = match QwenModel::new(&config, vb) { Ok(m) => m, Err(e) => { console_log!("[Qwen] Mallin lataus: {}", e); return; } }; let load_time = crate::perf_now() - start_load; console_log!("[Qwen] Malli ladattu ({:.0}ms). Generoidaan...", load_time); let encoding = match tokenizer.encode(prompt.as_str(), true) { Ok(e) => e, Err(e) => { console_log!("[Qwen] Tokenisointivirhe: {}", e); return; } }; let input_ids: Vec = encoding.get_ids().to_vec(); let input_len = input_ids.len(); console_log!("[Qwen] Syöte: {} tokenia", input_len); let start_gen = crate::perf_now(); let max_new_tokens = 32; let mut generated_text = String::new(); let mut tokens_generated: usize = 0; // Prefill let input = match Tensor::new(input_ids.as_slice(), &device).and_then(|t| t.unsqueeze(0)) { Ok(t) => t, Err(e) => { console_log!("[Qwen] Tensor: {}", e); return; } }; let logits = match model.forward(&input, 0) { Ok(l) => l, Err(e) => { console_log!("[Qwen] Forward (prefill): {}", e); return; } }; // Forward palauttaa [batch, vocab_size] tai [batch, seq_len, vocab_size] let logits = logits.squeeze(0).unwrap(); let logits = if logits.dims().len() == 2 { // [seq_len, vocab_size] — ota viimeinen logits.get(logits.dim(0).unwrap() - 1).unwrap() } else { logits // jo [vocab_size] }; let mut next_token = crate::sampling::sample_top_k(&logits, 10, 5.0); console_log!("[Qwen] Ensimmäinen token: {}", next_token); let eos_token = 151645u32; // <|endoftext|> for Qwen2.5 if next_token != eos_token { if let Ok(text) = tokenizer.decode(&[next_token], true) { generated_text.push_str(&text); let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "Qwen2.5-0.5B" }); let _ = ws.borrow().send_with_str(&chunk.to_string()); } tokens_generated += 1; } // Autoregressive let mut pos = input_len; for _ in 1..max_new_tokens { if next_token == eos_token { break; } let input = match Tensor::new(&[next_token], &device).and_then(|t| t.unsqueeze(0)) { Ok(t) => t, Err(e) => { console_log!("[Qwen] Tensor: {}", e); break; } }; let logits = match model.forward(&input, pos) { Ok(l) => l, Err(e) => { console_log!("[Qwen] Forward pos {}: {}", pos, e); break; } }; let logits = logits.squeeze(0).unwrap(); let logits = if logits.dims().len() == 2 { logits.get(logits.dim(0).unwrap() - 1).unwrap() } else { logits }; next_token = crate::sampling::sample_top_k(&logits, 10, 5.0); pos += 1; if next_token == eos_token { break; } if let Ok(text) = tokenizer.decode(&[next_token], true) { generated_text.push_str(&text); let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "Qwen2.5-0.5B" }); let _ = ws.borrow().send_with_str(&chunk.to_string()); } tokens_generated += 1; crate::sleep_ms(0).await; } let gen_time = crate::perf_now() - start_gen; let tokens_per_sec = if gen_time > 0.0 { (tokens_generated as f64 / gen_time) * 1000.0 } else { 0.0 }; console_log!("[Qwen] {} tokenia | {:.0}ms | {:.1} tok/s", tokens_generated, gen_time, tokens_per_sec); let done = serde_json::json!({ "type": "llm_done", "prompt": prompt, "model": "Qwen2.5-0.5B-Instruct", "response": generated_text, "tokens_generated": tokens_generated, "duration_ms": (gen_time * 100.0).round() / 100.0, "tokens_per_sec": (tokens_per_sec * 10.0).round() / 10.0, "load_time_ms": (load_time * 100.0).round() / 100.0, }); let _ = ws.borrow().send_with_str(&done.to_string()); }