koodilabran v0.1
This commit is contained in:
268
network-poc/node/src/qwen_coder.rs
Normal file
268
network-poc/node/src/qwen_coder.rs
Normal file
@@ -0,0 +1,268 @@
|
||||
use candle_core::{Device, Tensor, DType};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::qwen2::{Config as QwenConfig, ModelForCausalLM as QwenModel};
|
||||
use wasm_bindgen::JsCast;
|
||||
use std::cell::RefCell;
|
||||
use std::rc::Rc;
|
||||
use web_sys::WebSocket;
|
||||
|
||||
use crate::storage;
|
||||
|
||||
macro_rules! console_log {
|
||||
($($t:tt)*) => (web_sys::console::log_1(&format_args!($($t)*).to_string().into()))
|
||||
}
|
||||
|
||||
// 0.5B — nopea, sopii kaikille laitteille
|
||||
const MODEL_05B_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-Coder-0.5B-Instruct/resolve/main/model.safetensors";
|
||||
const TOKENIZER_05B_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-Coder-0.5B-Instruct/resolve/main/tokenizer.json";
|
||||
|
||||
// 3B — parempi laatu, vaatii enemmän muistia (~6 GB lataus, ~12 GB RAM)
|
||||
const MODEL_3B_PART1_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-Coder-3B-Instruct/resolve/main/model-00001-of-00002.safetensors";
|
||||
const MODEL_3B_PART2_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-Coder-3B-Instruct/resolve/main/model-00002-of-00002.safetensors";
|
||||
const TOKENIZER_3B_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-Coder-3B-Instruct/resolve/main/tokenizer.json";
|
||||
|
||||
async fn ensure_cached(key: &str, url: &str, ws: &Rc<RefCell<WebSocket>>) -> Result<Vec<u8>, String> {
|
||||
if let Ok(Some(bytes)) = storage::load_from_idb(key).await {
|
||||
console_log!("[Coder] {} löytyi välimuistista ({} MB)", key, bytes.len() / 1024 / 1024);
|
||||
return Ok(bytes);
|
||||
}
|
||||
|
||||
console_log!("[Coder] Ladataan {}...", key);
|
||||
|
||||
let window = web_sys::window().unwrap();
|
||||
let resp_val = wasm_bindgen_futures::JsFuture::from(window.fetch_with_str(url))
|
||||
.await.map_err(|e| format!("Fetch: {:?}", e))?;
|
||||
let resp: web_sys::Response = resp_val.dyn_into().map_err(|_| "Ei Response".to_string())?;
|
||||
if !resp.ok() { return Err(format!("HTTP {}", resp.status())); }
|
||||
|
||||
let total_size: usize = resp.headers()
|
||||
.get("content-length").ok().flatten()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(0);
|
||||
|
||||
let body = resp.body().ok_or("Ei bodyä")?;
|
||||
let reader: web_sys::ReadableStreamDefaultReader = body.get_reader().dyn_into().map_err(|_| "Ei reader".to_string())?;
|
||||
|
||||
let mut data: Vec<u8> = Vec::with_capacity(total_size);
|
||||
let mut last_pct: u32 = 0;
|
||||
|
||||
loop {
|
||||
let chunk = wasm_bindgen_futures::JsFuture::from(reader.read())
|
||||
.await.map_err(|e| format!("Read: {:?}", e))?;
|
||||
let done = js_sys::Reflect::get(&chunk, &"done".into()).ok().and_then(|v| v.as_bool()).unwrap_or(true);
|
||||
if done { break; }
|
||||
let value = js_sys::Reflect::get(&chunk, &"value".into()).map_err(|_| "value puuttuu".to_string())?;
|
||||
let array = js_sys::Uint8Array::new(&value);
|
||||
let mut buf = vec![0u8; array.length() as usize];
|
||||
array.copy_to(&mut buf);
|
||||
data.extend_from_slice(&buf);
|
||||
|
||||
if total_size > 0 {
|
||||
let pct = ((data.len() as f64 / total_size as f64) * 100.0) as u32;
|
||||
if pct >= last_pct + 5 || pct == 100 {
|
||||
last_pct = pct;
|
||||
console_log!("[Coder] {} lataus: {}%", key, pct);
|
||||
let msg = serde_json::json!({ "type": "download_progress", "file": key, "pct": pct, "loaded_mb": data.len()/1024/1024, "total_mb": total_size/1024/1024 });
|
||||
let _ = ws.borrow().send_with_str(&msg.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
console_log!("[Coder] Tallennetaan {} ({} MB)...", key, data.len() / 1024 / 1024);
|
||||
let _ = storage::save_to_idb(key, &data).await;
|
||||
console_log!("[Coder] {} tallennettu!", key);
|
||||
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
/// use_3b: false = 0.5B (nopea), true = 3B (laadukas)
|
||||
pub async fn run_coder_inference(prompt: String, ws: Rc<RefCell<WebSocket>>, use_3b: bool) {
|
||||
let perf = web_sys::window().unwrap().performance().unwrap();
|
||||
let size_label = if use_3b { "3B" } else { "0.5B" };
|
||||
|
||||
// Tokenizer (sama molemmille)
|
||||
let tok_url = if use_3b { TOKENIZER_3B_URL } else { TOKENIZER_05B_URL };
|
||||
let tok_key = if use_3b { "coder3b-tokenizer.json" } else { "coder05b-tokenizer.json" };
|
||||
let tok_bytes = match ensure_cached(tok_key, tok_url, &ws).await {
|
||||
Ok(b) => b,
|
||||
Err(e) => { console_log!("[Coder] Tokenizer-virhe: {}", e); return; }
|
||||
};
|
||||
let tokenizer = match tokenizers::Tokenizer::from_bytes(&tok_bytes) {
|
||||
Ok(t) => t,
|
||||
Err(e) => { console_log!("[Coder] Tokenizer-parsinta: {}", e); return; }
|
||||
};
|
||||
|
||||
// Mallin painot
|
||||
let device = Device::Cpu;
|
||||
let dtype = DType::F32;
|
||||
|
||||
let tensors = if use_3b {
|
||||
// 3B: kaksi osaa
|
||||
let part1 = match ensure_cached("coder3b-model-part1.safetensors", MODEL_3B_PART1_URL, &ws).await {
|
||||
Ok(b) => b,
|
||||
Err(e) => { console_log!("[Coder] Malli osa 1 virhe: {}", e); return; }
|
||||
};
|
||||
let part2 = match ensure_cached("coder3b-model-part2.safetensors", MODEL_3B_PART2_URL, &ws).await {
|
||||
Ok(b) => b,
|
||||
Err(e) => { console_log!("[Coder] Malli osa 2 virhe: {}", e); return; }
|
||||
};
|
||||
console_log!("[Coder] Rakennetaan 3B-mallia...");
|
||||
let mut all_tensors = candle_core::safetensors::load_buffer(&part1, &device)
|
||||
.map_err(|e| format!("Part1: {}", e)).unwrap();
|
||||
let tensors2 = candle_core::safetensors::load_buffer(&part2, &device)
|
||||
.map_err(|e| format!("Part2: {}", e)).unwrap();
|
||||
all_tensors.extend(tensors2);
|
||||
all_tensors
|
||||
} else {
|
||||
// 0.5B: yksi osa
|
||||
let model_bytes = match ensure_cached("coder05b-model.safetensors", MODEL_05B_URL, &ws).await {
|
||||
Ok(b) => b,
|
||||
Err(e) => { console_log!("[Coder] Malli-virhe: {}", e); return; }
|
||||
};
|
||||
console_log!("[Coder] Rakennetaan 0.5B-mallia...");
|
||||
match candle_core::safetensors::load_buffer(&model_bytes, &device) {
|
||||
Ok(t) => t,
|
||||
Err(e) => { console_log!("[Coder] Safetensors: {}", e); return; }
|
||||
}
|
||||
};
|
||||
|
||||
let start_load = perf.now();
|
||||
let vb = VarBuilder::from_tensors(tensors, dtype, &device);
|
||||
|
||||
let config = if use_3b {
|
||||
QwenConfig {
|
||||
vocab_size: 151936,
|
||||
hidden_size: 2048,
|
||||
intermediate_size: 11008,
|
||||
num_hidden_layers: 36,
|
||||
num_attention_heads: 16,
|
||||
num_key_value_heads: 2,
|
||||
max_position_embeddings: 32768,
|
||||
sliding_window: 32768,
|
||||
max_window_layers: 36,
|
||||
tie_word_embeddings: true,
|
||||
rope_theta: 1000000.0,
|
||||
rms_norm_eps: 1e-6,
|
||||
use_sliding_window: false,
|
||||
hidden_act: candle_nn::Activation::Silu,
|
||||
}
|
||||
} else {
|
||||
QwenConfig {
|
||||
vocab_size: 151936,
|
||||
hidden_size: 896,
|
||||
intermediate_size: 4864,
|
||||
num_hidden_layers: 24,
|
||||
num_attention_heads: 14,
|
||||
num_key_value_heads: 2,
|
||||
max_position_embeddings: 32768,
|
||||
sliding_window: 32768,
|
||||
max_window_layers: 21,
|
||||
tie_word_embeddings: true,
|
||||
rope_theta: 1000000.0,
|
||||
rms_norm_eps: 1e-6,
|
||||
use_sliding_window: false,
|
||||
hidden_act: candle_nn::Activation::Silu,
|
||||
}
|
||||
};
|
||||
|
||||
let mut model = match QwenModel::new(&config, vb) {
|
||||
Ok(m) => m,
|
||||
Err(e) => { console_log!("[Coder] Mallin lataus: {}", e); return; }
|
||||
};
|
||||
|
||||
let load_time = perf.now() - start_load;
|
||||
console_log!("[Coder] Malli ladattu ({:.0}ms). Generoidaan...", load_time);
|
||||
|
||||
// Muotoillaan chat-template
|
||||
let formatted = format!("<|im_start|>system\nYou are a Python coding assistant. Write only code, no explanations.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", prompt);
|
||||
|
||||
let encoding = match tokenizer.encode(formatted.as_str(), true) {
|
||||
Ok(e) => e,
|
||||
Err(e) => { console_log!("[Coder] Tokenisointivirhe: {}", e); return; }
|
||||
};
|
||||
let input_ids: Vec<u32> = encoding.get_ids().to_vec();
|
||||
let input_len = input_ids.len();
|
||||
console_log!("[Coder] Syöte: {} tokenia", input_len);
|
||||
|
||||
let start_gen = perf.now();
|
||||
let max_new_tokens = 128; // Koodille enemmän tokeneita
|
||||
let mut generated_text = String::new();
|
||||
let mut tokens_generated: usize = 0;
|
||||
let eos_token = 151645u32;
|
||||
|
||||
// Prefill
|
||||
let input = match Tensor::new(input_ids.as_slice(), &device).and_then(|t| t.unsqueeze(0)) {
|
||||
Ok(t) => t,
|
||||
Err(e) => { console_log!("[Coder] Tensor: {}", e); return; }
|
||||
};
|
||||
let logits = match model.forward(&input, 0) {
|
||||
Ok(l) => l,
|
||||
Err(e) => { console_log!("[Coder] Forward (prefill): {}", e); return; }
|
||||
};
|
||||
|
||||
let logits = logits.squeeze(0).unwrap();
|
||||
let logits = if logits.dims().len() == 2 {
|
||||
logits.get(logits.dim(0).unwrap() - 1).unwrap()
|
||||
} else {
|
||||
logits
|
||||
};
|
||||
let mut next_token = logits.argmax(0).unwrap().to_vec0::<u32>().unwrap();
|
||||
|
||||
if next_token != eos_token {
|
||||
if let Ok(text) = tokenizer.decode(&[next_token], true) {
|
||||
generated_text.push_str(&text);
|
||||
let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "Qwen2.5-Coder" });
|
||||
let _ = ws.borrow().send_with_str(&chunk.to_string());
|
||||
}
|
||||
tokens_generated += 1;
|
||||
}
|
||||
|
||||
// Autoregressive
|
||||
let mut pos = input_len;
|
||||
for _ in 1..max_new_tokens {
|
||||
if next_token == eos_token { break; }
|
||||
|
||||
let input = match Tensor::new(&[next_token], &device).and_then(|t| t.unsqueeze(0)) {
|
||||
Ok(t) => t,
|
||||
Err(e) => { console_log!("[Coder] Tensor: {}", e); break; }
|
||||
};
|
||||
let logits = match model.forward(&input, pos) {
|
||||
Ok(l) => l,
|
||||
Err(e) => { console_log!("[Coder] Forward pos {}: {}", pos, e); break; }
|
||||
};
|
||||
|
||||
let logits = logits.squeeze(0).unwrap();
|
||||
let logits = if logits.dims().len() == 2 {
|
||||
logits.get(logits.dim(0).unwrap() - 1).unwrap()
|
||||
} else {
|
||||
logits
|
||||
};
|
||||
next_token = logits.argmax(0).unwrap().to_vec0::<u32>().unwrap();
|
||||
pos += 1;
|
||||
|
||||
if next_token == eos_token { break; }
|
||||
|
||||
if let Ok(text) = tokenizer.decode(&[next_token], true) {
|
||||
generated_text.push_str(&text);
|
||||
let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "Qwen2.5-Coder" });
|
||||
let _ = ws.borrow().send_with_str(&chunk.to_string());
|
||||
}
|
||||
tokens_generated += 1;
|
||||
}
|
||||
|
||||
let gen_time = perf.now() - start_gen;
|
||||
let tokens_per_sec = if gen_time > 0.0 { (tokens_generated as f64 / gen_time) * 1000.0 } else { 0.0 };
|
||||
console_log!("[Coder] {} tokenia | {:.0}ms | {:.1} tok/s", tokens_generated, gen_time, tokens_per_sec);
|
||||
|
||||
let done = serde_json::json!({
|
||||
"type": "llm_done",
|
||||
"prompt": prompt,
|
||||
"model": format!("Qwen2.5-Coder-{}-Instruct", size_label),
|
||||
"response": generated_text,
|
||||
"tokens_generated": tokens_generated,
|
||||
"duration_ms": (gen_time * 100.0).round() / 100.0,
|
||||
"tokens_per_sec": (tokens_per_sec * 10.0).round() / 10.0,
|
||||
"load_time_ms": (load_time * 100.0).round() / 100.0,
|
||||
});
|
||||
let _ = ws.borrow().send_with_str(&done.to_string());
|
||||
}
|
||||
Reference in New Issue
Block a user