use candle_core::{Device, Tensor, DType}; use candle_nn::VarBuilder; use candle_transformers::models::llama::{Llama, LlamaConfig, LlamaEosToks, Cache}; // LogitsProcessor poistettu — käytetään greedy samplingia (argmax) Wasm-yhteensopivuuden vuoksi use wasm_bindgen::JsCast; use std::cell::RefCell; use std::rc::Rc; use web_sys::WebSocket; use crate::storage; macro_rules! console_log { ($($t:tt)*) => (web_sys::console::log_1(&format_args!($($t)*).to_string().into())) } const MODEL_URL: &str = "https://huggingface.co/HuggingFaceTB/SmolLM-135M-Instruct/resolve/main/model.safetensors"; const TOKENIZER_URL: &str = "https://huggingface.co/HuggingFaceTB/SmolLM-135M-Instruct/resolve/main/tokenizer.json"; /// Lataa tiedosto HuggingFacesta streaming-latauksella (progress-ilmoitukset) ja tallentaa IndexedDB:hen async fn ensure_cached(key: &str, url: &str, ws: &Rc>) -> Result, String> { if let Ok(Some(bytes)) = storage::load_from_idb(key).await { console_log!("[SmolLM] {} löytyi välimuistista ({} MB)", key, bytes.len() / 1024 / 1024); send_progress(ws, key, 100, bytes.len(), bytes.len()); return Ok(bytes); } console_log!("[SmolLM] Ladataan {}...", key); send_progress(ws, key, 0, 0, 0); // Fetch API:lla saadaan Content-Length ja streaming-luku let resp = crate::worker_fetch(url).await?; if !resp.ok() { return Err(format!("HTTP {}", resp.status())); } // Kokonaiskoko Content-Length-headerista let total_size: usize = resp.headers() .get("content-length").ok().flatten() .and_then(|s| s.parse().ok()) .unwrap_or(0); let body = resp.body().ok_or("Ei bodyä")?; let reader = body.get_reader(); let reader: web_sys::ReadableStreamDefaultReader = reader.dyn_into().map_err(|_| "Ei ReadableStreamDefaultReader".to_string())?; let mut data: Vec = Vec::with_capacity(total_size); let mut last_pct: u32 = 0; loop { let chunk = wasm_bindgen_futures::JsFuture::from(reader.read()) .await.map_err(|e| format!("Luku epäonnistui: {:?}", e))?; let done = js_sys::Reflect::get(&chunk, &"done".into()) .map_err(|_| "done-kenttä puuttuu".to_string())? .as_bool().unwrap_or(true); if done { break; } let value = js_sys::Reflect::get(&chunk, &"value".into()) .map_err(|_| "value-kenttä puuttuu".to_string())?; let array = js_sys::Uint8Array::new(&value); let mut buf = vec![0u8; array.length() as usize]; array.copy_to(&mut buf); data.extend_from_slice(&buf); // Progress-päivitys (joka 5%) if total_size > 0 { let pct = ((data.len() as f64 / total_size as f64) * 100.0) as u32; if pct >= last_pct + 5 || pct == 100 { last_pct = pct; console_log!("[SmolLM] {} lataus: {}% ({}/{} MB)", key, pct, data.len() / 1024 / 1024, total_size / 1024 / 1024); send_progress(ws, key, pct, data.len(), total_size); } } } console_log!("[SmolLM] Tallennetaan {} ({} MB) IndexedDB:hen...", key, data.len() / 1024 / 1024); let _ = storage::save_to_idb(key, &data).await; console_log!("[SmolLM] {} tallennettu!", key); send_progress(ws, key, 100, data.len(), data.len()); Ok(data) } fn send_progress(ws: &Rc>, file: &str, pct: u32, loaded: usize, total: usize) { let msg = serde_json::json!({ "type": "download_progress", "file": file, "pct": pct, "loaded_mb": loaded / 1024 / 1024, "total_mb": total / 1024 / 1024, }); let _ = ws.borrow().send_with_str(&msg.to_string()); } /// Lataa malli ja tokenizer, suorita inferenssi ja streamaa tokenit hubille pub async fn run_smollm_inference(prompt: String, ws: Rc>) { // performance via crate::perf_now() // 1. Lataa tokenizer let tok_bytes = match ensure_cached("smollm-tokenizer.json", TOKENIZER_URL, &ws).await { Ok(b) => b, Err(e) => { console_log!("[SmolLM] Tokenizer-virhe: {}", e); return; } }; let tokenizer = match tokenizers::Tokenizer::from_bytes(&tok_bytes) { Ok(t) => t, Err(e) => { console_log!("[SmolLM] Tokenizer-parsinta epäonnistui: {}", e); return; } }; // 2. Lataa mallin painot let model_bytes = match ensure_cached("smollm-model.safetensors", MODEL_URL, &ws).await { Ok(b) => b, Err(e) => { console_log!("[SmolLM] Malli-virhe: {}", e); return; } }; // Burn 0.14 wgpu ei yhteensopiva nykyisten selainten kanssa (maxInterStageShaderComponents) // Burn 0.21-pre.2 cubecl-runtime ei käänny Wasmille (println! puuttuu) // → NdArray kunnes Burn 0.21 stable + Wasm-tuki console_log!("[SmolLM] Burn NdArray (CPU) inferenssi..."); run_burn_inference::(prompt, model_bytes, tokenizer, ws).await; } async fn run_burn_inference( prompt: String, model_bytes: Vec, tokenizer: tokenizers::Tokenizer, ws: Rc>, ) { let start_load = crate::perf_now(); let device = Default::default(); let config = crate::burn_smollm::config::SmolLMConfig::default(); console_log!("[SmolLM] Injektoidaan Safetensors -> Burn Params..."); let model = match crate::burn_smollm::loader::load_safetensors_to_model::(&model_bytes, &config, &device) { Ok(m) => m, Err(e) => { console_log!("[SmolLM] Lataus epäonnistui: {}", e); return; } }; let load_time = crate::perf_now() - start_load; console_log!("[SmolLM] Burn-malli ladattu ({:.0}ms). Generoidaan...", load_time); let formatted_prompt = format!("<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", prompt); let encoding = match tokenizer.encode(formatted_prompt.as_str(), true) { Ok(e) => e, Err(e) => { console_log!("[SmolLM] Tokenisointivirhe: {}", e); return; } }; let mut input_ids: Vec = encoding.get_ids().to_vec(); let input_len = input_ids.len(); console_log!("[SmolLM] Syöte: {} tokenia", input_len); let start_gen = crate::perf_now(); let max_new_tokens = 32; let mut generated_text = String::new(); let mut tokens_generated: usize = 0; // KV-välimuistin taulukko kerroksittain let mut caches: Vec>> = vec![None; config.num_hidden_layers]; let mut current_offset = 0; // Prefill: yksitellen, vältetään future token leakage koska ei causal maskia let input_ids_i32: Vec = input_ids.iter().map(|&x| x as i32).collect(); let mut last_logits = None; for &id in &input_ids_i32 { let input_tensor = burn::tensor::Tensor::::from_data( burn::tensor::TensorData::from([id]), &device ).unsqueeze::<2>(); // [1, 1] last_logits = Some(model.forward(input_tensor, current_offset, &mut caches)); current_offset += 1; } let mut logits = last_logits.unwrap(); // Argmax sämpläys let next_token_tensor = logits.clone().argmax(2); let mut next_token: u32 = next_token_tensor.into_scalar().to_string().parse().unwrap_or(2); // Yksinkertainen cast koska int scalar if next_token != 2 { if let Ok(text) = tokenizer.decode(&[next_token], true) { generated_text.push_str(&text); let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "SmolLM-135M (WebGPU)" }); let _ = ws.borrow().send_with_str(&chunk.to_string()); } tokens_generated += 1; } // Autoregressiivinen luuppi for _ in 1..max_new_tokens { if next_token == 2 { break; } let mut input_tensor = burn::tensor::Tensor::::from_data( burn::tensor::TensorData::from([next_token as i32]), &device ).unsqueeze::<2>(); logits = model.forward(input_tensor, current_offset, &mut caches); current_offset += 1; let next_token_tensor = logits.argmax(2); next_token = next_token_tensor.into_scalar().to_string().parse().unwrap_or(2); if next_token == 2 { break; } if let Ok(text) = tokenizer.decode(&[next_token], true) { generated_text.push_str(&text); let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "SmolLM-135M (WebGPU)" }); let _ = ws.borrow().send_with_str(&chunk.to_string()); } tokens_generated += 1; } let gen_time = crate::perf_now() - start_gen; let tokens_per_sec = if gen_time > 0.0 { (tokens_generated as f64 / gen_time) * 1000.0 } else { 0.0 }; let done = serde_json::json!({ "type": "llm_done", "prompt": prompt, "model": "SmolLM-135M-Instruct (WebGPU)", "response": generated_text, "tokens_generated": tokens_generated, "duration_ms": (gen_time * 100.0).round() / 100.0, "tokens_per_sec": (tokens_per_sec * 10.0).round() / 10.0, "load_time_ms": (load_time * 100.0).round() / 100.0, }); let _ = ws.borrow().send_with_str(&done.to_string()); }