toka toimiva vedos
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "node"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
edition = "2024"
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib"]
|
||||
@@ -17,6 +17,12 @@ web-sys = { version = "0.3.68", features = [
|
||||
"MessageEvent",
|
||||
"Performance",
|
||||
"console",
|
||||
"Request",
|
||||
"RequestInit",
|
||||
"Response",
|
||||
"Headers",
|
||||
"ReadableStream",
|
||||
"ReadableStreamDefaultReader",
|
||||
] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
@@ -29,4 +35,8 @@ reqwest = { version = "0.12", default-features = false, features = ["json"] }
|
||||
tokenizers = { version = "0.19.1", default-features = false, features = ["unstable_wasm"] }
|
||||
rexie = "0.6"
|
||||
log = "0.4"
|
||||
candle-core = { version = "0.8" }
|
||||
candle-nn = "0.8"
|
||||
candle-transformers = "0.8"
|
||||
getrandom = { version = "0.3", features = ["wasm_js"] }
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use wasm_bindgen::prelude::*;
|
||||
use web_sys::{console, WebSocket, MessageEvent};
|
||||
use web_sys::{WebSocket, MessageEvent};
|
||||
use std::cell::RefCell;
|
||||
use std::rc::Rc;
|
||||
use std::sync::atomic::{AtomicU32, AtomicBool, Ordering};
|
||||
@@ -7,15 +7,17 @@ use burn::tensor::Tensor;
|
||||
use burn::backend::{Wgpu, NdArray};
|
||||
|
||||
pub mod storage;
|
||||
pub mod smollm;
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! console_log {
|
||||
($($t:tt)*) => (console::log_1(&format_args!($($t)*).to_string().into()))
|
||||
($($t:tt)*) => (web_sys::console::log_1(&format_args!($($t)*).to_string().into()))
|
||||
}
|
||||
|
||||
// Globaali muuttuja GPU Load Sliderille (25-100%)
|
||||
static GPU_LOAD_PERCENT: AtomicU32 = AtomicU32::new(50);
|
||||
// Onko WebGPU käytettävissä — asetetaan JS-puolelta käynnistyksessä
|
||||
static HAS_WEBGPU: AtomicBool = AtomicBool::new(true);
|
||||
// Valittu tehtävä: 0=tokenize, 1=smollm-135m, 2=qwen-05b, 3=phi3-mini
|
||||
static SELECTED_TASK: AtomicU32 = AtomicU32::new(0);
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub fn set_gpu_load(load: u32) {
|
||||
@@ -148,12 +150,15 @@ async fn run_pair_comparison(en_text: String, fi_text: String, ws: Rc<RefCell<We
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub async fn start_agent_node(hub_url: String, has_webgpu: bool, device_info_json: String) -> Result<(), JsValue> {
|
||||
pub async fn start_agent_node(hub_url: String, has_webgpu: bool, device_info_json: String, task_id: u32) -> Result<(), JsValue> {
|
||||
console_error_panic_hook::set_once();
|
||||
|
||||
HAS_WEBGPU.store(has_webgpu, Ordering::SeqCst);
|
||||
SELECTED_TASK.store(task_id, Ordering::SeqCst);
|
||||
let backend_name = if has_webgpu { "WebGPU" } else { "CPU (NdArray)" };
|
||||
console_log!("Kipinä Agent Node käynnistyy — backend: {}", backend_name);
|
||||
let task_names = ["tokenize", "smollm-135m", "qwen-05b", "phi3-mini"];
|
||||
let task_name = task_names.get(task_id as usize).unwrap_or(&"tokenize");
|
||||
console_log!("Kipinä Agent Node käynnistyy — backend: {} | tehtävä: {}", backend_name, task_name);
|
||||
|
||||
let device_info = device_info_json.clone();
|
||||
|
||||
@@ -182,7 +187,10 @@ pub async fn start_agent_node(hub_url: String, has_webgpu: bool, device_info_jso
|
||||
if let Ok(txt) = e.data().dyn_into::<js_sys::JsString>() {
|
||||
let msg: String = txt.into();
|
||||
|
||||
if msg.contains("pair_task") {
|
||||
let current_task = SELECTED_TASK.load(Ordering::SeqCst);
|
||||
|
||||
if msg.contains("pair_task") && current_task == 0 {
|
||||
// Vain tokenisaatiosolmut käsittelevät pair_task-viestejä
|
||||
if let Ok(task) = serde_json::from_str::<serde_json::Value>(&msg) {
|
||||
let en = task.get("en").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
let fi = task.get("fi").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
@@ -193,6 +201,18 @@ pub async fn start_agent_node(hub_url: String, has_webgpu: bool, device_info_jso
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if msg.contains("llm_prompt") && current_task == 1 {
|
||||
// Vain SmolLM-solmut käsittelevät llm_prompt-viestejä
|
||||
if let Ok(task) = serde_json::from_str::<serde_json::Value>(&msg) {
|
||||
let prompt = task.get("prompt").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
let model = task.get("model").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
if !prompt.is_empty() && model == "smollm-135m" {
|
||||
let ws_for_async = ws_clone.clone();
|
||||
wasm_bindgen_futures::spawn_local(async move {
|
||||
smollm::run_smollm_inference(prompt, ws_for_async).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if msg.contains("ai_task") {
|
||||
console_log!("Hub task vastaanotettu, ajetaan GPU:lla...");
|
||||
let ws_for_async = ws_clone.clone();
|
||||
|
||||
246
network-poc/node/src/smollm.rs
Normal file
246
network-poc/node/src/smollm.rs
Normal file
@@ -0,0 +1,246 @@
|
||||
use candle_core::{Device, Tensor, DType};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::llama::{Llama, LlamaConfig, LlamaEosToks, Cache};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use wasm_bindgen::JsCast;
|
||||
use std::cell::RefCell;
|
||||
use std::rc::Rc;
|
||||
use web_sys::WebSocket;
|
||||
|
||||
use crate::storage;
|
||||
|
||||
macro_rules! console_log {
|
||||
($($t:tt)*) => (web_sys::console::log_1(&format_args!($($t)*).to_string().into()))
|
||||
}
|
||||
|
||||
const MODEL_URL: &str = "https://huggingface.co/HuggingFaceTB/SmolLM-135M-Instruct/resolve/main/model.safetensors";
|
||||
const TOKENIZER_URL: &str = "https://huggingface.co/HuggingFaceTB/SmolLM-135M-Instruct/resolve/main/tokenizer.json";
|
||||
|
||||
/// Lataa tiedosto HuggingFacesta streaming-latauksella (progress-ilmoitukset) ja tallentaa IndexedDB:hen
|
||||
async fn ensure_cached(key: &str, url: &str, ws: &Rc<RefCell<WebSocket>>) -> Result<Vec<u8>, String> {
|
||||
if let Ok(Some(bytes)) = storage::load_from_idb(key).await {
|
||||
console_log!("[SmolLM] {} löytyi välimuistista ({} MB)", key, bytes.len() / 1024 / 1024);
|
||||
send_progress(ws, key, 100, bytes.len(), bytes.len());
|
||||
return Ok(bytes);
|
||||
}
|
||||
|
||||
console_log!("[SmolLM] Ladataan {}...", key);
|
||||
send_progress(ws, key, 0, 0, 0);
|
||||
|
||||
// Fetch API:lla saadaan Content-Length ja streaming-luku
|
||||
let window = web_sys::window().unwrap();
|
||||
let resp_val = wasm_bindgen_futures::JsFuture::from(window.fetch_with_str(url))
|
||||
.await.map_err(|e| format!("Fetch epäonnistui: {:?}", e))?;
|
||||
let resp: web_sys::Response = resp_val.dyn_into().map_err(|_| "Ei Response-objekti".to_string())?;
|
||||
|
||||
if !resp.ok() {
|
||||
return Err(format!("HTTP {}", resp.status()));
|
||||
}
|
||||
|
||||
// Kokonaiskoko Content-Length-headerista
|
||||
let total_size: usize = resp.headers()
|
||||
.get("content-length").ok().flatten()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(0);
|
||||
|
||||
let body = resp.body().ok_or("Ei bodyä")?;
|
||||
let reader = body.get_reader();
|
||||
let reader: web_sys::ReadableStreamDefaultReader = reader.dyn_into().map_err(|_| "Ei ReadableStreamDefaultReader".to_string())?;
|
||||
|
||||
let mut data: Vec<u8> = Vec::with_capacity(total_size);
|
||||
let mut last_pct: u32 = 0;
|
||||
|
||||
loop {
|
||||
let chunk = wasm_bindgen_futures::JsFuture::from(reader.read())
|
||||
.await.map_err(|e| format!("Luku epäonnistui: {:?}", e))?;
|
||||
|
||||
let done = js_sys::Reflect::get(&chunk, &"done".into())
|
||||
.map_err(|_| "done-kenttä puuttuu".to_string())?
|
||||
.as_bool().unwrap_or(true);
|
||||
|
||||
if done { break; }
|
||||
|
||||
let value = js_sys::Reflect::get(&chunk, &"value".into())
|
||||
.map_err(|_| "value-kenttä puuttuu".to_string())?;
|
||||
let array = js_sys::Uint8Array::new(&value);
|
||||
let mut buf = vec![0u8; array.length() as usize];
|
||||
array.copy_to(&mut buf);
|
||||
data.extend_from_slice(&buf);
|
||||
|
||||
// Progress-päivitys (joka 5%)
|
||||
if total_size > 0 {
|
||||
let pct = ((data.len() as f64 / total_size as f64) * 100.0) as u32;
|
||||
if pct >= last_pct + 5 || pct == 100 {
|
||||
last_pct = pct;
|
||||
console_log!("[SmolLM] {} lataus: {}% ({}/{} MB)", key, pct, data.len() / 1024 / 1024, total_size / 1024 / 1024);
|
||||
send_progress(ws, key, pct, data.len(), total_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
console_log!("[SmolLM] Tallennetaan {} ({} MB) IndexedDB:hen...", key, data.len() / 1024 / 1024);
|
||||
let _ = storage::save_to_idb(key, &data).await;
|
||||
console_log!("[SmolLM] {} tallennettu!", key);
|
||||
send_progress(ws, key, 100, data.len(), data.len());
|
||||
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
fn send_progress(ws: &Rc<RefCell<WebSocket>>, file: &str, pct: u32, loaded: usize, total: usize) {
|
||||
let msg = serde_json::json!({
|
||||
"type": "download_progress",
|
||||
"file": file,
|
||||
"pct": pct,
|
||||
"loaded_mb": loaded / 1024 / 1024,
|
||||
"total_mb": total / 1024 / 1024,
|
||||
});
|
||||
let _ = ws.borrow().send_with_str(&msg.to_string());
|
||||
}
|
||||
|
||||
/// Lataa malli ja tokenizer, suorita inferenssi ja streamaa tokenit hubille
|
||||
pub async fn run_smollm_inference(prompt: String, ws: Rc<RefCell<WebSocket>>) {
|
||||
let perf = web_sys::window().unwrap().performance().unwrap();
|
||||
|
||||
// 1. Lataa tokenizer
|
||||
let tok_bytes = match ensure_cached("smollm-tokenizer.json", TOKENIZER_URL, &ws).await {
|
||||
Ok(b) => b,
|
||||
Err(e) => { console_log!("[SmolLM] Tokenizer-virhe: {}", e); return; }
|
||||
};
|
||||
|
||||
let tokenizer = match tokenizers::Tokenizer::from_bytes(&tok_bytes) {
|
||||
Ok(t) => t,
|
||||
Err(e) => { console_log!("[SmolLM] Tokenizer-parsinta epäonnistui: {}", e); return; }
|
||||
};
|
||||
|
||||
// 2. Lataa mallin painot
|
||||
let model_bytes = match ensure_cached("smollm-model.safetensors", MODEL_URL, &ws).await {
|
||||
Ok(b) => b,
|
||||
Err(e) => { console_log!("[SmolLM] Malli-virhe: {}", e); return; }
|
||||
};
|
||||
|
||||
console_log!("[SmolLM] Rakennetaan mallia...");
|
||||
let start_load = perf.now();
|
||||
|
||||
let device = Device::Cpu;
|
||||
let dtype = DType::F32;
|
||||
|
||||
// Parsitaan safetensors
|
||||
let tensors = match candle_core::safetensors::load_buffer(&model_bytes, &device) {
|
||||
Ok(t) => t,
|
||||
Err(e) => { console_log!("[SmolLM] Safetensors-parsinta epäonnistui: {}", e); return; }
|
||||
};
|
||||
|
||||
let vb = VarBuilder::from_tensors(tensors, dtype, &device);
|
||||
|
||||
// SmolLM-135M config (Llama-arkkitehtuuri)
|
||||
let config = LlamaConfig {
|
||||
hidden_size: 576,
|
||||
intermediate_size: 1536,
|
||||
vocab_size: 49152,
|
||||
num_hidden_layers: 30,
|
||||
num_attention_heads: 9,
|
||||
num_key_value_heads: Some(3),
|
||||
rms_norm_eps: 1e-5,
|
||||
rope_theta: 10000.0,
|
||||
max_position_embeddings: 2048,
|
||||
tie_word_embeddings: Some(true),
|
||||
bos_token_id: Some(1u32),
|
||||
eos_token_id: Some(LlamaEosToks::Single(2)),
|
||||
rope_scaling: None,
|
||||
};
|
||||
|
||||
let llama_config = config.into_config(false); // false = ei flash attention
|
||||
let mut cache = Cache::new(true, dtype, &llama_config, &device).unwrap();
|
||||
|
||||
let model = match Llama::load(vb, &llama_config) {
|
||||
Ok(m) => m,
|
||||
Err(e) => { console_log!("[SmolLM] Mallin lataus epäonnistui: {}", e); return; }
|
||||
};
|
||||
|
||||
let load_time = perf.now() - start_load;
|
||||
console_log!("[SmolLM] Malli ladattu ({:.0}ms). Generoidaan...", load_time);
|
||||
|
||||
// 3. Tokenisoi syöte
|
||||
let encoding = match tokenizer.encode(prompt.as_str(), true) {
|
||||
Ok(e) => e,
|
||||
Err(e) => { console_log!("[SmolLM] Tokenisointivirhe: {}", e); return; }
|
||||
};
|
||||
|
||||
let input_ids: Vec<u32> = encoding.get_ids().to_vec();
|
||||
let input_len = input_ids.len();
|
||||
console_log!("[SmolLM] Syöte: {} tokenia", input_len);
|
||||
|
||||
// 4. Generoi tokeneita
|
||||
let start_gen = perf.now();
|
||||
let mut logits_processor = LogitsProcessor::new(42, Some(0.8), Some(0.95));
|
||||
let mut all_tokens = input_ids.clone();
|
||||
let max_new_tokens = 64;
|
||||
let mut generated_text = String::new();
|
||||
|
||||
for i in 0..max_new_tokens {
|
||||
let context_tokens = if i == 0 {
|
||||
all_tokens.as_slice()
|
||||
} else {
|
||||
std::slice::from_ref(all_tokens.last().unwrap())
|
||||
};
|
||||
|
||||
let input = Tensor::new(context_tokens, &device).unwrap().unsqueeze(0).unwrap();
|
||||
let seq_len = input.dim(1).unwrap();
|
||||
|
||||
let logits = match model.forward(&input, input_len + i - seq_len, &mut cache) {
|
||||
Ok(l) => l,
|
||||
Err(e) => { console_log!("[SmolLM] Forward-virhe stepissä {}: {}", i, e); break; }
|
||||
};
|
||||
|
||||
// Viimeisen tokenin logitit
|
||||
let logits = logits.squeeze(0).unwrap();
|
||||
let last_dim = logits.dim(0).unwrap();
|
||||
let logits = if last_dim > 1 {
|
||||
logits.get(last_dim - 1).unwrap()
|
||||
} else {
|
||||
logits.get(0).unwrap()
|
||||
};
|
||||
|
||||
let next_token = logits_processor.sample(&logits).unwrap();
|
||||
|
||||
// EOS-tarkistus
|
||||
if next_token == 2 {
|
||||
break;
|
||||
}
|
||||
|
||||
all_tokens.push(next_token);
|
||||
|
||||
// Dekoodaa token tekstiksi
|
||||
if let Ok(text) = tokenizer.decode(&[next_token], true) {
|
||||
generated_text.push_str(&text);
|
||||
|
||||
// Streamaa token hubille
|
||||
let chunk = serde_json::json!({
|
||||
"type": "llm_chunk",
|
||||
"token": text,
|
||||
"is_last": false,
|
||||
"prompt": prompt,
|
||||
"model": "SmolLM-135M"
|
||||
});
|
||||
let _ = ws.borrow().send_with_str(&chunk.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
let gen_time = perf.now() - start_gen;
|
||||
let tokens_generated = all_tokens.len() - input_len;
|
||||
let tokens_per_sec = if gen_time > 0.0 { (tokens_generated as f64 / gen_time) * 1000.0 } else { 0.0 };
|
||||
|
||||
console_log!("[SmolLM] Generoitu {} tokenia | {:.0}ms | {:.1} tok/s", tokens_generated, gen_time, tokens_per_sec);
|
||||
|
||||
let done = serde_json::json!({
|
||||
"type": "llm_done",
|
||||
"prompt": prompt,
|
||||
"model": "SmolLM-135M-Instruct",
|
||||
"response": generated_text,
|
||||
"tokens_generated": tokens_generated,
|
||||
"duration_ms": (gen_time * 100.0).round() / 100.0,
|
||||
"tokens_per_sec": (tokens_per_sec * 10.0).round() / 10.0,
|
||||
"load_time_ms": (load_time * 100.0).round() / 100.0,
|
||||
});
|
||||
let _ = ws.borrow().send_with_str(&done.to_string());
|
||||
}
|
||||
Reference in New Issue
Block a user