on tämä työmaa

This commit is contained in:
2026-04-02 00:50:29 +03:00
parent e55ff64565
commit b693542116
11 changed files with 648 additions and 70 deletions

View File

@@ -8,6 +8,8 @@ use burn::backend::{Wgpu, NdArray};
pub mod storage;
pub mod smollm;
pub mod qwen;
pub mod phi3;
#[macro_export]
macro_rules! console_log {
@@ -16,8 +18,9 @@ macro_rules! console_log {
static GPU_LOAD_PERCENT: AtomicU32 = AtomicU32::new(50);
static HAS_WEBGPU: AtomicBool = AtomicBool::new(true);
// Valittu tehtävä: 0=tokenize, 1=smollm-135m, 2=qwen-05b, 3=phi3-mini
static SELECTED_TASK: AtomicU32 = AtomicU32::new(0);
// Estää rinnakkaiset LLM-inferenssit (vain yksi kerrallaan)
static LLM_BUSY: AtomicBool = AtomicBool::new(false);
#[wasm_bindgen]
pub fn set_gpu_load(load: u32) {
@@ -202,14 +205,46 @@ pub async fn start_agent_node(hub_url: String, has_webgpu: bool, device_info_jso
}
}
} else if msg.contains("llm_prompt") && current_task == 1 {
// Vain SmolLM-solmut käsittelevät llm_prompt-viestejä
if let Ok(task) = serde_json::from_str::<serde_json::Value>(&msg) {
// Vain SmolLM-solmut, ja vain yksi inferenssi kerrallaan
if LLM_BUSY.load(Ordering::SeqCst) {
// Ohitetaan — edellinen inferenssi vielä käynnissä
} else if let Ok(task) = serde_json::from_str::<serde_json::Value>(&msg) {
let prompt = task.get("prompt").and_then(|v| v.as_str()).unwrap_or("").to_string();
let model = task.get("model").and_then(|v| v.as_str()).unwrap_or("").to_string();
if !prompt.is_empty() && model == "smollm-135m" {
LLM_BUSY.store(true, Ordering::SeqCst);
let ws_for_async = ws_clone.clone();
wasm_bindgen_futures::spawn_local(async move {
smollm::run_smollm_inference(prompt, ws_for_async).await;
LLM_BUSY.store(false, Ordering::SeqCst);
});
}
}
} else if msg.contains("llm_prompt") && current_task == 2 {
// Qwen2.5-0.5B
if LLM_BUSY.load(Ordering::SeqCst) {
} else if let Ok(task) = serde_json::from_str::<serde_json::Value>(&msg) {
let prompt = task.get("prompt").and_then(|v| v.as_str()).unwrap_or("").to_string();
if !prompt.is_empty() {
LLM_BUSY.store(true, Ordering::SeqCst);
let ws_for_async = ws_clone.clone();
wasm_bindgen_futures::spawn_local(async move {
qwen::run_qwen_inference(prompt, ws_for_async).await;
LLM_BUSY.store(false, Ordering::SeqCst);
});
}
}
} else if msg.contains("llm_prompt") && current_task == 3 {
// Phi-3 Mini
if LLM_BUSY.load(Ordering::SeqCst) {
} else if let Ok(task) = serde_json::from_str::<serde_json::Value>(&msg) {
let prompt = task.get("prompt").and_then(|v| v.as_str()).unwrap_or("").to_string();
if !prompt.is_empty() {
LLM_BUSY.store(true, Ordering::SeqCst);
let ws_for_async = ws_clone.clone();
wasm_bindgen_futures::spawn_local(async move {
phi3::run_phi3_inference(prompt, ws_for_async).await;
LLM_BUSY.store(false, Ordering::SeqCst);
});
}
}

View File

@@ -0,0 +1,36 @@
use candle_core::{Device, Tensor, DType};
use candle_nn::VarBuilder;
use candle_transformers::models::phi3::{Config as Phi3Config, Model as Phi3Model};
use wasm_bindgen::JsCast;
use std::cell::RefCell;
use std::rc::Rc;
use web_sys::WebSocket;
use crate::storage;
macro_rules! console_log {
($($t:tt)*) => (web_sys::console::log_1(&format_args!($($t)*).to_string().into()))
}
const MODEL_URL: &str = "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/model.safetensors.index.json";
const TOKENIZER_URL: &str = "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/tokenizer.json";
// Phi-3 Mini on iso (7.6 GB) — käytetään kvantisoidumpaa versiota myöhemmin
// Tällä hetkellä: placeholder joka raportoi koon ja jättää inferenssin väliin
pub async fn run_phi3_inference(prompt: String, ws: Rc<RefCell<WebSocket>>) {
console_log!("[Phi-3] Phi-3 Mini 3.8B on liian suuri selaimessa ajettavaksi (~7.6 GB).");
console_log!("[Phi-3] Käytä SmolLM 135M tai Qwen2.5 0.5B selaininferenssiin.");
console_log!("[Phi-3] Phi-3 tuetaan native-node:lla (Docker + GPU).");
let done = serde_json::json!({
"type": "llm_done",
"prompt": prompt,
"model": "Phi-3-Mini (ei tuettu selaimessa)",
"response": "Phi-3 Mini 3.8B on liian suuri selaimessa ajettavaksi. Käytä SmolLM 135M tai Qwen2.5 0.5B.",
"tokens_generated": 0,
"duration_ms": 0,
"tokens_per_sec": 0,
"load_time_ms": 0,
});
let _ = ws.borrow().send_with_str(&done.to_string());
}

View File

@@ -0,0 +1,219 @@
use candle_core::{Device, Tensor, DType};
use candle_nn::VarBuilder;
use candle_transformers::models::qwen2::{Config as QwenConfig, ModelForCausalLM as QwenModel};
use wasm_bindgen::JsCast;
use std::cell::RefCell;
use std::rc::Rc;
use web_sys::WebSocket;
use crate::storage;
macro_rules! console_log {
($($t:tt)*) => (web_sys::console::log_1(&format_args!($($t)*).to_string().into()))
}
const MODEL_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct/resolve/main/model.safetensors";
const TOKENIZER_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct/resolve/main/tokenizer.json";
/// Streaming-lataus HuggingFacesta IndexedDB-cacheen
async fn ensure_cached(key: &str, url: &str, ws: &Rc<RefCell<WebSocket>>) -> Result<Vec<u8>, String> {
if let Ok(Some(bytes)) = storage::load_from_idb(key).await {
console_log!("[Qwen] {} löytyi välimuistista ({} MB)", key, bytes.len() / 1024 / 1024);
return Ok(bytes);
}
console_log!("[Qwen] Ladataan {}...", key);
let window = web_sys::window().unwrap();
let resp_val = wasm_bindgen_futures::JsFuture::from(window.fetch_with_str(url))
.await.map_err(|e| format!("Fetch epäonnistui: {:?}", e))?;
let resp: web_sys::Response = resp_val.dyn_into().map_err(|_| "Ei Response".to_string())?;
if !resp.ok() { return Err(format!("HTTP {}", resp.status())); }
let total_size: usize = resp.headers()
.get("content-length").ok().flatten()
.and_then(|s| s.parse().ok())
.unwrap_or(0);
let body = resp.body().ok_or("Ei bodyä")?;
let reader: web_sys::ReadableStreamDefaultReader = body.get_reader().dyn_into().map_err(|_| "Ei reader".to_string())?;
let mut data: Vec<u8> = Vec::with_capacity(total_size);
let mut last_pct: u32 = 0;
loop {
let chunk = wasm_bindgen_futures::JsFuture::from(reader.read())
.await.map_err(|e| format!("Read: {:?}", e))?;
let done = js_sys::Reflect::get(&chunk, &"done".into()).ok().and_then(|v| v.as_bool()).unwrap_or(true);
if done { break; }
let value = js_sys::Reflect::get(&chunk, &"value".into()).map_err(|_| "value puuttuu".to_string())?;
let array = js_sys::Uint8Array::new(&value);
let mut buf = vec![0u8; array.length() as usize];
array.copy_to(&mut buf);
data.extend_from_slice(&buf);
if total_size > 0 {
let pct = ((data.len() as f64 / total_size as f64) * 100.0) as u32;
if pct >= last_pct + 5 || pct == 100 {
last_pct = pct;
console_log!("[Qwen] {} lataus: {}%", key, pct);
let msg = serde_json::json!({ "type": "download_progress", "file": key, "pct": pct, "loaded_mb": data.len()/1024/1024, "total_mb": total_size/1024/1024 });
let _ = ws.borrow().send_with_str(&msg.to_string());
}
}
}
console_log!("[Qwen] Tallennetaan {} ({} MB)...", key, data.len() / 1024 / 1024);
let _ = storage::save_to_idb(key, &data).await;
console_log!("[Qwen] {} tallennettu!", key);
Ok(data)
}
pub async fn run_qwen_inference(prompt: String, ws: Rc<RefCell<WebSocket>>) {
let perf = web_sys::window().unwrap().performance().unwrap();
let tok_bytes = match ensure_cached("qwen05b-tokenizer.json", TOKENIZER_URL, &ws).await {
Ok(b) => b,
Err(e) => { console_log!("[Qwen] Tokenizer-virhe: {}", e); return; }
};
let tokenizer = match tokenizers::Tokenizer::from_bytes(&tok_bytes) {
Ok(t) => t,
Err(e) => { console_log!("[Qwen] Tokenizer-parsinta: {}", e); return; }
};
let model_bytes = match ensure_cached("qwen05b-model.safetensors", MODEL_URL, &ws).await {
Ok(b) => b,
Err(e) => { console_log!("[Qwen] Malli-virhe: {}", e); return; }
};
console_log!("[Qwen] Rakennetaan mallia...");
let start_load = perf.now();
let device = Device::Cpu;
let dtype = DType::F32;
let tensors = match candle_core::safetensors::load_buffer(&model_bytes, &device) {
Ok(t) => t,
Err(e) => { console_log!("[Qwen] Safetensors: {}", e); return; }
};
let vb = VarBuilder::from_tensors(tensors, dtype, &device);
let config = QwenConfig {
vocab_size: 151936,
hidden_size: 896,
intermediate_size: 4864,
num_hidden_layers: 24,
num_attention_heads: 14,
num_key_value_heads: 2,
max_position_embeddings: 32768,
sliding_window: 32768,
max_window_layers: 21,
tie_word_embeddings: true,
rope_theta: 1000000.0,
rms_norm_eps: 1e-6,
use_sliding_window: false,
hidden_act: candle_nn::Activation::Silu,
};
let mut model = match QwenModel::new(&config, vb) {
Ok(m) => m,
Err(e) => { console_log!("[Qwen] Mallin lataus: {}", e); return; }
};
let load_time = perf.now() - start_load;
console_log!("[Qwen] Malli ladattu ({:.0}ms). Generoidaan...", load_time);
let encoding = match tokenizer.encode(prompt.as_str(), true) {
Ok(e) => e,
Err(e) => { console_log!("[Qwen] Tokenisointivirhe: {}", e); return; }
};
let input_ids: Vec<u32> = encoding.get_ids().to_vec();
let input_len = input_ids.len();
console_log!("[Qwen] Syöte: {} tokenia", input_len);
let start_gen = perf.now();
let max_new_tokens = 32;
let mut generated_text = String::new();
let mut tokens_generated: usize = 0;
// Prefill
let input = match Tensor::new(input_ids.as_slice(), &device).and_then(|t| t.unsqueeze(0)) {
Ok(t) => t,
Err(e) => { console_log!("[Qwen] Tensor: {}", e); return; }
};
let logits = match model.forward(&input, 0) {
Ok(l) => l,
Err(e) => { console_log!("[Qwen] Forward (prefill): {}", e); return; }
};
// Forward palauttaa [batch, vocab_size] tai [batch, seq_len, vocab_size]
let logits = logits.squeeze(0).unwrap();
let logits = if logits.dims().len() == 2 {
// [seq_len, vocab_size] — ota viimeinen
logits.get(logits.dim(0).unwrap() - 1).unwrap()
} else {
logits // jo [vocab_size]
};
let mut next_token = logits.argmax(0).unwrap().to_vec0::<u32>().unwrap();
console_log!("[Qwen] Ensimmäinen token: {}", next_token);
let eos_token = 151645u32; // <|endoftext|> for Qwen2.5
if next_token != eos_token {
if let Ok(text) = tokenizer.decode(&[next_token], true) {
generated_text.push_str(&text);
let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "Qwen2.5-0.5B" });
let _ = ws.borrow().send_with_str(&chunk.to_string());
}
tokens_generated += 1;
}
// Autoregressive
let mut pos = input_len;
for _ in 1..max_new_tokens {
if next_token == eos_token { break; }
let input = match Tensor::new(&[next_token], &device).and_then(|t| t.unsqueeze(0)) {
Ok(t) => t,
Err(e) => { console_log!("[Qwen] Tensor: {}", e); break; }
};
let logits = match model.forward(&input, pos) {
Ok(l) => l,
Err(e) => { console_log!("[Qwen] Forward pos {}: {}", pos, e); break; }
};
let logits = logits.squeeze(0).unwrap();
let logits = if logits.dims().len() == 2 {
logits.get(logits.dim(0).unwrap() - 1).unwrap()
} else {
logits
};
next_token = logits.argmax(0).unwrap().to_vec0::<u32>().unwrap();
pos += 1;
if next_token == eos_token { break; }
if let Ok(text) = tokenizer.decode(&[next_token], true) {
generated_text.push_str(&text);
let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "Qwen2.5-0.5B" });
let _ = ws.borrow().send_with_str(&chunk.to_string());
}
tokens_generated += 1;
}
let gen_time = perf.now() - start_gen;
let tokens_per_sec = if gen_time > 0.0 { (tokens_generated as f64 / gen_time) * 1000.0 } else { 0.0 };
console_log!("[Qwen] {} tokenia | {:.0}ms | {:.1} tok/s", tokens_generated, gen_time, tokens_per_sec);
let done = serde_json::json!({
"type": "llm_done",
"prompt": prompt,
"model": "Qwen2.5-0.5B-Instruct",
"response": generated_text,
"tokens_generated": tokens_generated,
"duration_ms": (gen_time * 100.0).round() / 100.0,
"tokens_per_sec": (tokens_per_sec * 10.0).round() / 10.0,
"load_time_ms": (load_time * 100.0).round() / 100.0,
});
let _ = ws.borrow().send_with_str(&done.to_string());
}

View File

@@ -1,7 +1,7 @@
use candle_core::{Device, Tensor, DType};
use candle_nn::VarBuilder;
use candle_transformers::models::llama::{Llama, LlamaConfig, LlamaEosToks, Cache};
use candle_transformers::generation::LogitsProcessor;
// LogitsProcessor poistettu — käytetään greedy samplingia (argmax) Wasm-yhteensopivuuden vuoksi
use wasm_bindgen::JsCast;
use std::cell::RefCell;
use std::rc::Rc;
@@ -160,8 +160,9 @@ pub async fn run_smollm_inference(prompt: String, ws: Rc<RefCell<WebSocket>>) {
let load_time = perf.now() - start_load;
console_log!("[SmolLM] Malli ladattu ({:.0}ms). Generoidaan...", load_time);
// 3. Tokenisoi syöte
let encoding = match tokenizer.encode(prompt.as_str(), true) {
// 3. Tokenisoi syöte (Käytetään ChatML-formaattia SmolLM-Instructille)
let formatted_prompt = format!("<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", prompt);
let encoding = match tokenizer.encode(formatted_prompt.as_str(), true) {
Ok(e) => e,
Err(e) => { console_log!("[SmolLM] Tokenisointivirhe: {}", e); return; }
};
@@ -172,62 +173,76 @@ pub async fn run_smollm_inference(prompt: String, ws: Rc<RefCell<WebSocket>>) {
// 4. Generoi tokeneita
let start_gen = perf.now();
let mut logits_processor = LogitsProcessor::new(42, Some(0.8), Some(0.95));
let mut all_tokens = input_ids.clone();
let max_new_tokens = 64;
let max_new_tokens = 32;
let mut generated_text = String::new();
let mut tokens_generated: usize = 0;
let mut pos: usize = 0;
for i in 0..max_new_tokens {
let context_tokens = if i == 0 {
all_tokens.as_slice()
} else {
std::slice::from_ref(all_tokens.last().unwrap())
};
// Ensimmäinen forward: koko syöte kerralla
let input = match Tensor::new(input_ids.as_slice(), &device).and_then(|t| t.unsqueeze(0)) {
Ok(t) => t,
Err(e) => { console_log!("[SmolLM] Tensor-virhe: {}", e); return; }
};
let input = Tensor::new(context_tokens, &device).unwrap().unsqueeze(0).unwrap();
let seq_len = input.dim(1).unwrap();
let logits = match model.forward(&input, 0, &mut cache) {
Ok(l) => l,
Err(e) => { console_log!("[SmolLM] Forward-virhe (prefill): {}", e); return; }
};
let logits = match model.forward(&input, input_len + i - seq_len, &mut cache) {
Ok(l) => l,
Err(e) => { console_log!("[SmolLM] Forward-virhe stepissä {}: {}", i, e); break; }
};
// Llama forward voi palauttaa [batch, vocab] tai [batch, seq_len, vocab]
let logits = logits.squeeze(0).unwrap();
let logits = if logits.dims().len() == 2 {
logits.get(logits.dim(0).unwrap() - 1).unwrap()
} else {
logits
};
let mut next_token = logits.argmax(0).unwrap().to_vec0::<u32>().unwrap();
console_log!("[SmolLM] Ensimmäinen generoitu token: {}", next_token);
pos = input_len;
// Viimeisen tokenin logitit
let logits = logits.squeeze(0).unwrap();
let last_dim = logits.dim(0).unwrap();
let logits = if last_dim > 1 {
logits.get(last_dim - 1).unwrap()
} else {
logits.get(0).unwrap()
};
let next_token = logits_processor.sample(&logits).unwrap();
// EOS-tarkistus
if next_token == 2 {
break;
}
all_tokens.push(next_token);
// Dekoodaa token tekstiksi
if next_token != 2 {
if let Ok(text) = tokenizer.decode(&[next_token], true) {
generated_text.push_str(&text);
// Streamaa token hubille
let chunk = serde_json::json!({
"type": "llm_chunk",
"token": text,
"is_last": false,
"prompt": prompt,
"model": "SmolLM-135M"
});
let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "SmolLM-135M" });
let _ = ws.borrow().send_with_str(&chunk.to_string());
}
tokens_generated += 1;
}
// Autoregressiivinen generointi: yksi token kerrallaan
for _ in 1..max_new_tokens {
if next_token == 2 { break; }
let input = match Tensor::new(&[next_token], &device).and_then(|t| t.unsqueeze(0)) {
Ok(t) => t,
Err(e) => { console_log!("[SmolLM] Tensor-virhe: {}", e); break; }
};
let logits = match model.forward(&input, pos, &mut cache) {
Ok(l) => l,
Err(e) => { console_log!("[SmolLM] Forward-virhe pos {}: {}", pos, e); break; }
};
let logits = logits.squeeze(0).unwrap();
let logits = if logits.dims().len() == 2 {
logits.get(logits.dim(0).unwrap() - 1).unwrap()
} else {
logits
};
next_token = logits.argmax(0).unwrap().to_vec0::<u32>().unwrap();
pos += 1;
if next_token == 2 { break; }
if let Ok(text) = tokenizer.decode(&[next_token], true) {
generated_text.push_str(&text);
let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "SmolLM-135M" });
let _ = ws.borrow().send_with_str(&chunk.to_string());
}
tokens_generated += 1;
}
let gen_time = perf.now() - start_gen;
let tokens_generated = all_tokens.len() - input_len;
let tokens_per_sec = if gen_time > 0.0 { (tokens_generated as f64 / gen_time) * 1000.0 } else { 0.0 };
console_log!("[SmolLM] Generoitu {} tokenia | {:.0}ms | {:.1} tok/s", tokens_generated, gen_time, tokens_per_sec);