- Poistettu kaikki web_sys::window() -kutsut Rust WASM:sta - Uudet Worker-yhteensopivat apufunktiot: perf_now(), worker_fetch(), sleep_ms() - worker.js lataa ja ajaa WASM-moduulin erillisessä säikeessä - ensureCoderNode käynnistää Workerin pääsäikeen sijaan - Selaimen UI pysyy responsiivisena inferenssin aikana Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
233 lines
9.2 KiB
Rust
233 lines
9.2 KiB
Rust
use candle_core::{Device, Tensor, DType};
|
|
use candle_nn::VarBuilder;
|
|
use candle_transformers::models::llama::{Llama, LlamaConfig, LlamaEosToks, Cache};
|
|
// LogitsProcessor poistettu — käytetään greedy samplingia (argmax) Wasm-yhteensopivuuden vuoksi
|
|
use wasm_bindgen::JsCast;
|
|
use std::cell::RefCell;
|
|
use std::rc::Rc;
|
|
use web_sys::WebSocket;
|
|
|
|
use crate::storage;
|
|
|
|
macro_rules! console_log {
|
|
($($t:tt)*) => (web_sys::console::log_1(&format_args!($($t)*).to_string().into()))
|
|
}
|
|
|
|
const MODEL_URL: &str = "https://huggingface.co/HuggingFaceTB/SmolLM-135M-Instruct/resolve/main/model.safetensors";
|
|
const TOKENIZER_URL: &str = "https://huggingface.co/HuggingFaceTB/SmolLM-135M-Instruct/resolve/main/tokenizer.json";
|
|
|
|
/// Lataa tiedosto HuggingFacesta streaming-latauksella (progress-ilmoitukset) ja tallentaa IndexedDB:hen
|
|
async fn ensure_cached(key: &str, url: &str, ws: &Rc<RefCell<WebSocket>>) -> Result<Vec<u8>, String> {
|
|
if let Ok(Some(bytes)) = storage::load_from_idb(key).await {
|
|
console_log!("[SmolLM] {} löytyi välimuistista ({} MB)", key, bytes.len() / 1024 / 1024);
|
|
send_progress(ws, key, 100, bytes.len(), bytes.len());
|
|
return Ok(bytes);
|
|
}
|
|
|
|
console_log!("[SmolLM] Ladataan {}...", key);
|
|
send_progress(ws, key, 0, 0, 0);
|
|
|
|
// Fetch API:lla saadaan Content-Length ja streaming-luku
|
|
let resp = crate::worker_fetch(url).await?;
|
|
|
|
if !resp.ok() {
|
|
return Err(format!("HTTP {}", resp.status()));
|
|
}
|
|
|
|
// Kokonaiskoko Content-Length-headerista
|
|
let total_size: usize = resp.headers()
|
|
.get("content-length").ok().flatten()
|
|
.and_then(|s| s.parse().ok())
|
|
.unwrap_or(0);
|
|
|
|
let body = resp.body().ok_or("Ei bodyä")?;
|
|
let reader = body.get_reader();
|
|
let reader: web_sys::ReadableStreamDefaultReader = reader.dyn_into().map_err(|_| "Ei ReadableStreamDefaultReader".to_string())?;
|
|
|
|
let mut data: Vec<u8> = Vec::with_capacity(total_size);
|
|
let mut last_pct: u32 = 0;
|
|
|
|
loop {
|
|
let chunk = wasm_bindgen_futures::JsFuture::from(reader.read())
|
|
.await.map_err(|e| format!("Luku epäonnistui: {:?}", e))?;
|
|
|
|
let done = js_sys::Reflect::get(&chunk, &"done".into())
|
|
.map_err(|_| "done-kenttä puuttuu".to_string())?
|
|
.as_bool().unwrap_or(true);
|
|
|
|
if done { break; }
|
|
|
|
let value = js_sys::Reflect::get(&chunk, &"value".into())
|
|
.map_err(|_| "value-kenttä puuttuu".to_string())?;
|
|
let array = js_sys::Uint8Array::new(&value);
|
|
let mut buf = vec![0u8; array.length() as usize];
|
|
array.copy_to(&mut buf);
|
|
data.extend_from_slice(&buf);
|
|
|
|
// Progress-päivitys (joka 5%)
|
|
if total_size > 0 {
|
|
let pct = ((data.len() as f64 / total_size as f64) * 100.0) as u32;
|
|
if pct >= last_pct + 5 || pct == 100 {
|
|
last_pct = pct;
|
|
console_log!("[SmolLM] {} lataus: {}% ({}/{} MB)", key, pct, data.len() / 1024 / 1024, total_size / 1024 / 1024);
|
|
send_progress(ws, key, pct, data.len(), total_size);
|
|
}
|
|
}
|
|
}
|
|
|
|
console_log!("[SmolLM] Tallennetaan {} ({} MB) IndexedDB:hen...", key, data.len() / 1024 / 1024);
|
|
let _ = storage::save_to_idb(key, &data).await;
|
|
console_log!("[SmolLM] {} tallennettu!", key);
|
|
send_progress(ws, key, 100, data.len(), data.len());
|
|
|
|
Ok(data)
|
|
}
|
|
|
|
fn send_progress(ws: &Rc<RefCell<WebSocket>>, file: &str, pct: u32, loaded: usize, total: usize) {
|
|
let msg = serde_json::json!({
|
|
"type": "download_progress",
|
|
"file": file,
|
|
"pct": pct,
|
|
"loaded_mb": loaded / 1024 / 1024,
|
|
"total_mb": total / 1024 / 1024,
|
|
});
|
|
let _ = ws.borrow().send_with_str(&msg.to_string());
|
|
}
|
|
|
|
/// Lataa malli ja tokenizer, suorita inferenssi ja streamaa tokenit hubille
|
|
pub async fn run_smollm_inference(prompt: String, ws: Rc<RefCell<WebSocket>>) {
|
|
// performance via crate::perf_now()
|
|
|
|
// 1. Lataa tokenizer
|
|
let tok_bytes = match ensure_cached("smollm-tokenizer.json", TOKENIZER_URL, &ws).await {
|
|
Ok(b) => b,
|
|
Err(e) => { console_log!("[SmolLM] Tokenizer-virhe: {}", e); return; }
|
|
};
|
|
|
|
let tokenizer = match tokenizers::Tokenizer::from_bytes(&tok_bytes) {
|
|
Ok(t) => t,
|
|
Err(e) => { console_log!("[SmolLM] Tokenizer-parsinta epäonnistui: {}", e); return; }
|
|
};
|
|
|
|
// 2. Lataa mallin painot
|
|
let model_bytes = match ensure_cached("smollm-model.safetensors", MODEL_URL, &ws).await {
|
|
Ok(b) => b,
|
|
Err(e) => { console_log!("[SmolLM] Malli-virhe: {}", e); return; }
|
|
};
|
|
|
|
// Burn 0.14 wgpu ei yhteensopiva nykyisten selainten kanssa (maxInterStageShaderComponents)
|
|
// Burn 0.21-pre.2 cubecl-runtime ei käänny Wasmille (println! puuttuu)
|
|
// → NdArray kunnes Burn 0.21 stable + Wasm-tuki
|
|
console_log!("[SmolLM] Burn NdArray (CPU) inferenssi...");
|
|
run_burn_inference::<burn::backend::NdArray>(prompt, model_bytes, tokenizer, ws).await;
|
|
}
|
|
|
|
async fn run_burn_inference<B: burn::tensor::backend::Backend>(
|
|
prompt: String,
|
|
model_bytes: Vec<u8>,
|
|
tokenizer: tokenizers::Tokenizer,
|
|
ws: Rc<RefCell<WebSocket>>,
|
|
) {
|
|
let start_load = crate::perf_now();
|
|
|
|
let device = Default::default();
|
|
let config = crate::burn_smollm::config::SmolLMConfig::default();
|
|
|
|
console_log!("[SmolLM] Injektoidaan Safetensors -> Burn Params...");
|
|
let model = match crate::burn_smollm::loader::load_safetensors_to_model::<B>(&model_bytes, &config, &device) {
|
|
Ok(m) => m,
|
|
Err(e) => { console_log!("[SmolLM] Lataus epäonnistui: {}", e); return; }
|
|
};
|
|
|
|
let load_time = crate::perf_now() - start_load;
|
|
console_log!("[SmolLM] Burn-malli ladattu ({:.0}ms). Generoidaan...", load_time);
|
|
|
|
let formatted_prompt = format!("<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", prompt);
|
|
let encoding = match tokenizer.encode(formatted_prompt.as_str(), true) {
|
|
Ok(e) => e,
|
|
Err(e) => { console_log!("[SmolLM] Tokenisointivirhe: {}", e); return; }
|
|
};
|
|
|
|
let mut input_ids: Vec<u32> = encoding.get_ids().to_vec();
|
|
let input_len = input_ids.len();
|
|
console_log!("[SmolLM] Syöte: {} tokenia", input_len);
|
|
|
|
let start_gen = crate::perf_now();
|
|
let max_new_tokens = 32;
|
|
let mut generated_text = String::new();
|
|
let mut tokens_generated: usize = 0;
|
|
|
|
// KV-välimuistin taulukko kerroksittain
|
|
let mut caches: Vec<Option<crate::burn_smollm::attention::KVCache<B>>> = vec![None; config.num_hidden_layers];
|
|
let mut current_offset = 0;
|
|
|
|
// Prefill: yksitellen, vältetään future token leakage koska ei causal maskia
|
|
let input_ids_i32: Vec<i32> = input_ids.iter().map(|&x| x as i32).collect();
|
|
let mut last_logits = None;
|
|
|
|
for &id in &input_ids_i32 {
|
|
let input_tensor = burn::tensor::Tensor::<B, 1, burn::tensor::Int>::from_data(
|
|
burn::tensor::TensorData::from([id]),
|
|
&device
|
|
).unsqueeze::<2>(); // [1, 1]
|
|
|
|
last_logits = Some(model.forward(input_tensor, current_offset, &mut caches));
|
|
current_offset += 1;
|
|
}
|
|
|
|
let mut logits = last_logits.unwrap();
|
|
|
|
// Argmax sämpläys
|
|
let next_token_tensor = logits.clone().argmax(2);
|
|
let mut next_token: u32 = next_token_tensor.into_scalar().to_string().parse().unwrap_or(2); // Yksinkertainen cast koska int scalar
|
|
|
|
if next_token != 2 {
|
|
if let Ok(text) = tokenizer.decode(&[next_token], true) {
|
|
generated_text.push_str(&text);
|
|
let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "SmolLM-135M (WebGPU)" });
|
|
let _ = ws.borrow().send_with_str(&chunk.to_string());
|
|
}
|
|
tokens_generated += 1;
|
|
}
|
|
|
|
// Autoregressiivinen luuppi
|
|
for _ in 1..max_new_tokens {
|
|
if next_token == 2 { break; }
|
|
|
|
let mut input_tensor = burn::tensor::Tensor::<B, 1, burn::tensor::Int>::from_data(
|
|
burn::tensor::TensorData::from([next_token as i32]),
|
|
&device
|
|
).unsqueeze::<2>();
|
|
|
|
logits = model.forward(input_tensor, current_offset, &mut caches);
|
|
current_offset += 1;
|
|
|
|
let next_token_tensor = logits.argmax(2);
|
|
next_token = next_token_tensor.into_scalar().to_string().parse().unwrap_or(2);
|
|
|
|
if next_token == 2 { break; }
|
|
|
|
if let Ok(text) = tokenizer.decode(&[next_token], true) {
|
|
generated_text.push_str(&text);
|
|
let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "SmolLM-135M (WebGPU)" });
|
|
let _ = ws.borrow().send_with_str(&chunk.to_string());
|
|
}
|
|
tokens_generated += 1;
|
|
}
|
|
|
|
let gen_time = crate::perf_now() - start_gen;
|
|
let tokens_per_sec = if gen_time > 0.0 { (tokens_generated as f64 / gen_time) * 1000.0 } else { 0.0 };
|
|
|
|
let done = serde_json::json!({
|
|
"type": "llm_done",
|
|
"prompt": prompt,
|
|
"model": "SmolLM-135M-Instruct (WebGPU)",
|
|
"response": generated_text,
|
|
"tokens_generated": tokens_generated,
|
|
"duration_ms": (gen_time * 100.0).round() / 100.0,
|
|
"tokens_per_sec": (tokens_per_sec * 10.0).round() / 10.0,
|
|
"load_time_ms": (load_time * 100.0).round() / 100.0,
|
|
});
|
|
let _ = ws.borrow().send_with_str(&done.to_string());
|
|
}
|