Assistantin vastauksen alkuun syötetään valmiiksi backtick-koodiblokki, jolloin malli jatkaa suoraan koodilla eikä tuota "Sure! Here is..." -johdantotekstejä. Säästää tokeneita ja vastausaikaa. strip_markdown_wrapper poistaa ``` -merkit jälkikäteen. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
375 lines
17 KiB
Rust
375 lines
17 KiB
Rust
use candle_core::{Device, Tensor, DType};
|
|
use candle_nn::VarBuilder;
|
|
use candle_transformers::models::qwen2::{Config as QwenConfig, ModelForCausalLM as QwenModel};
|
|
use wasm_bindgen::JsCast;
|
|
use std::cell::RefCell;
|
|
use std::rc::Rc;
|
|
use web_sys::WebSocket;
|
|
|
|
use crate::storage;
|
|
|
|
macro_rules! console_log {
|
|
($($t:tt)*) => (web_sys::console::log_1(&format_args!($($t)*).to_string().into()))
|
|
}
|
|
|
|
// 0.5B — nopea, sopii kaikille laitteille
|
|
const MODEL_05B_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-Coder-0.5B-Instruct/resolve/main/model.safetensors";
|
|
const TOKENIZER_05B_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-Coder-0.5B-Instruct/resolve/main/tokenizer.json";
|
|
|
|
// 3B — parempi laatu, vaatii enemmän muistia (~6 GB lataus, ~12 GB RAM)
|
|
const MODEL_3B_PART1_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-Coder-3B-Instruct/resolve/main/model-00001-of-00002.safetensors";
|
|
const MODEL_3B_PART2_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-Coder-3B-Instruct/resolve/main/model-00002-of-00002.safetensors";
|
|
const TOKENIZER_3B_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-Coder-3B-Instruct/resolve/main/tokenizer.json";
|
|
|
|
struct CachedModel {
|
|
model: QwenModel,
|
|
tokenizer: tokenizers::Tokenizer,
|
|
is_3b: bool,
|
|
}
|
|
|
|
/// Poistaa mallin tuottaman markdown-wrapperin ja johdantotekstin.
|
|
/// "Sure! Here is...\n```python\nprint('hi')\n```" → "print('hi')"
|
|
fn strip_markdown_wrapper(text: &str) -> String {
|
|
let text = text.trim();
|
|
// Jos vastaus sisältää ```-koodiblokin, ota vain sen sisältö
|
|
if let Some(start) = text.find("```") {
|
|
let after_backticks = &text[start + 3..];
|
|
// Ohita mahdollinen kielitunniste (```python, ```rust jne.)
|
|
let code_start = after_backticks.find('\n').map(|i| i + 1).unwrap_or(0);
|
|
let code = &after_backticks[code_start..];
|
|
// Etsi sulkeva ```
|
|
if let Some(end) = code.find("```") {
|
|
return code[..end].trim().to_string();
|
|
}
|
|
// Ei sulkevaa ``` — ota kaikki loput
|
|
return code.trim().to_string();
|
|
}
|
|
// Ei koodiblokkia — poista yleiset johdantolauseet ja selityskommentit alusta
|
|
let mut result = text.to_string();
|
|
let lower = result.to_lowercase();
|
|
for prefix in &["sure!", "here is", "here's", "certainly!", "below is"] {
|
|
if lower.starts_with(prefix) {
|
|
if let Some(newline) = result.find('\n') {
|
|
result = result[newline + 1..].to_string();
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
// Poistetaan alun selityskommentit: "# This is a simple..." -tyyppiset rivit
|
|
// jotka eivät ole osa varsinaista koodia (esim. shebangia #! pidetään)
|
|
let mut lines: Vec<&str> = result.trim().lines().collect();
|
|
while !lines.is_empty() {
|
|
let first = lines[0].trim();
|
|
let is_preamble_comment = first.starts_with("# ")
|
|
&& !first.starts_with("#!")
|
|
&& (first.to_lowercase().contains("this is")
|
|
|| first.to_lowercase().contains("simple")
|
|
|| first.to_lowercase().contains("program that")
|
|
|| first.to_lowercase().contains("here is")
|
|
|| first.to_lowercase().contains("the following")
|
|
|| first.to_lowercase().contains("below"));
|
|
if is_preamble_comment {
|
|
lines.remove(0);
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
lines.join("\n").trim().to_string()
|
|
}
|
|
|
|
thread_local! {
|
|
static RAM_CACHE: RefCell<std::collections::HashMap<String, Rc<Vec<u8>>>> = RefCell::new(std::collections::HashMap::new());
|
|
static MODEL_CACHE: RefCell<Option<CachedModel>> = RefCell::new(None);
|
|
}
|
|
|
|
async fn ensure_cached(key: &str, url: &str, ws: &Rc<RefCell<WebSocket>>) -> Result<Rc<Vec<u8>>, String> {
|
|
// 1. Tarkistetaan RAM välimuisti (estää OOM ja levy-I/O pullonkaulat)
|
|
let ram_hit = RAM_CACHE.with(|cache| {
|
|
cache.borrow().get(key).cloned()
|
|
});
|
|
if let Some(bytes) = ram_hit {
|
|
console_log!("[Coder] {} löytyi nopeasta RAM-välimuistista!", key);
|
|
return Ok(bytes);
|
|
}
|
|
|
|
// 2. Tarkistetaan IndexedDB (jos selain on suljettu aikaisemmin)
|
|
if let Ok(Some(bytes)) = storage::load_from_idb(key).await {
|
|
console_log!("[Coder] {} löytyi IndexedDB-välimuistista ({} MB)", key, bytes.len() / 1024 / 1024);
|
|
let rc_bytes = Rc::new(bytes);
|
|
RAM_CACHE.with(|cache| cache.borrow_mut().insert(key.to_string(), rc_bytes.clone()));
|
|
return Ok(rc_bytes);
|
|
}
|
|
|
|
console_log!("[Coder] Ladataan {}...", key);
|
|
|
|
let window = web_sys::window().unwrap();
|
|
let resp_val = wasm_bindgen_futures::JsFuture::from(window.fetch_with_str(url))
|
|
.await.map_err(|e| format!("Fetch: {:?}", e))?;
|
|
let resp: web_sys::Response = resp_val.dyn_into().map_err(|_| "Ei Response".to_string())?;
|
|
if !resp.ok() { return Err(format!("HTTP {}", resp.status())); }
|
|
|
|
let total_size: usize = resp.headers()
|
|
.get("content-length").ok().flatten()
|
|
.and_then(|s| s.parse().ok())
|
|
.unwrap_or(0);
|
|
|
|
let body = resp.body().ok_or("Ei bodyä")?;
|
|
let reader: web_sys::ReadableStreamDefaultReader = body.get_reader().dyn_into().map_err(|_| "Ei reader".to_string())?;
|
|
|
|
let mut data: Vec<u8> = Vec::with_capacity(total_size);
|
|
let mut last_pct: u32 = 0;
|
|
|
|
loop {
|
|
let chunk = wasm_bindgen_futures::JsFuture::from(reader.read())
|
|
.await.map_err(|e| format!("Read: {:?}", e))?;
|
|
let done = js_sys::Reflect::get(&chunk, &"done".into()).ok().and_then(|v| v.as_bool()).unwrap_or(true);
|
|
if done { break; }
|
|
let value = js_sys::Reflect::get(&chunk, &"value".into()).map_err(|_| "value puuttuu".to_string())?;
|
|
let array = js_sys::Uint8Array::new(&value);
|
|
let mut buf = vec![0u8; array.length() as usize];
|
|
array.copy_to(&mut buf);
|
|
data.extend_from_slice(&buf);
|
|
|
|
if total_size > 0 {
|
|
let pct = ((data.len() as f64 / total_size as f64) * 100.0) as u32;
|
|
if pct >= last_pct + 5 || pct == 100 {
|
|
last_pct = pct;
|
|
console_log!("[Coder] {} lataus: {}%", key, pct);
|
|
let msg = serde_json::json!({ "type": "download_progress", "file": key, "pct": pct, "loaded_mb": data.len()/1024/1024, "total_mb": total_size/1024/1024 });
|
|
let _ = ws.borrow().send_with_str(&msg.to_string());
|
|
}
|
|
}
|
|
}
|
|
|
|
console_log!("[Coder] Tallennetaan {} ({} MB) IndexedDB:hen...", key, data.len() / 1024 / 1024);
|
|
let _ = storage::save_to_idb(key, &data).await;
|
|
console_log!("[Coder] {} tallennettu!", key);
|
|
|
|
let rc_data = Rc::new(data);
|
|
RAM_CACHE.with(|cache| cache.borrow_mut().insert(key.to_string(), rc_data.clone()));
|
|
|
|
Ok(rc_data)
|
|
}
|
|
|
|
/// Lataa tai palauttaa välimuistista valmiin mallin + tokenizerin
|
|
async fn get_or_build_model(use_3b: bool, ws: &Rc<RefCell<WebSocket>>) -> Result<(), String> {
|
|
// Tarkistetaan onko oikea malli jo muistissa
|
|
let cache_hit = MODEL_CACHE.with(|c| {
|
|
c.borrow().as_ref().map(|m| m.is_3b == use_3b).unwrap_or(false)
|
|
});
|
|
if cache_hit {
|
|
// Logitetaan kaikki välivaiheet valmiiksi, jotta pipeline-UI päivittyy
|
|
console_log!("[Coder] tokenizer löytyi (cache)");
|
|
console_log!("[Coder] model löytyi (cache)");
|
|
console_log!("[Coder] Malli ladattu (välimuistista)");
|
|
return Ok(());
|
|
}
|
|
|
|
let device = Device::Cpu;
|
|
let dtype = DType::F32;
|
|
|
|
// Tokenizer
|
|
let tok_url = if use_3b { TOKENIZER_3B_URL } else { TOKENIZER_05B_URL };
|
|
let tok_key = if use_3b { "coder3b-tokenizer.json" } else { "coder05b-tokenizer.json" };
|
|
let tok_bytes = ensure_cached(tok_key, tok_url, ws).await?;
|
|
let tokenizer = tokenizers::Tokenizer::from_bytes(&tok_bytes[..])
|
|
.map_err(|e| format!("Tokenizer: {}", e))?;
|
|
|
|
// Painot
|
|
let tensors = if use_3b {
|
|
let part1 = ensure_cached("coder3b-model-part1.safetensors", MODEL_3B_PART1_URL, ws).await?;
|
|
let part2 = ensure_cached("coder3b-model-part2.safetensors", MODEL_3B_PART2_URL, ws).await?;
|
|
console_log!("[Coder] Rakennetaan 3B-mallia...");
|
|
let mut all_tensors = candle_core::safetensors::load_buffer(&part1[..], &device)
|
|
.map_err(|e| format!("Part1: {}", e))?;
|
|
let tensors2 = candle_core::safetensors::load_buffer(&part2[..], &device)
|
|
.map_err(|e| format!("Part2: {}", e))?;
|
|
all_tensors.extend(tensors2);
|
|
all_tensors
|
|
} else {
|
|
let model_bytes = ensure_cached("coder05b-model.safetensors", MODEL_05B_URL, ws).await?;
|
|
console_log!("[Coder] Rakennetaan 0.5B-mallia...");
|
|
candle_core::safetensors::load_buffer(&model_bytes[..], &device)
|
|
.map_err(|e| format!("Safetensors: {}", e))?
|
|
};
|
|
|
|
let vb = VarBuilder::from_tensors(tensors, dtype, &device);
|
|
let config = if use_3b {
|
|
QwenConfig {
|
|
vocab_size: 151936, hidden_size: 2048, intermediate_size: 11008,
|
|
num_hidden_layers: 36, num_attention_heads: 16, num_key_value_heads: 2,
|
|
max_position_embeddings: 32768, sliding_window: 32768, max_window_layers: 36,
|
|
tie_word_embeddings: true, rope_theta: 1000000.0, rms_norm_eps: 1e-6,
|
|
use_sliding_window: false, hidden_act: candle_nn::Activation::Silu,
|
|
}
|
|
} else {
|
|
QwenConfig {
|
|
vocab_size: 151936, hidden_size: 896, intermediate_size: 4864,
|
|
num_hidden_layers: 24, num_attention_heads: 14, num_key_value_heads: 2,
|
|
max_position_embeddings: 32768, sliding_window: 32768, max_window_layers: 21,
|
|
tie_word_embeddings: true, rope_theta: 1000000.0, rms_norm_eps: 1e-6,
|
|
use_sliding_window: false, hidden_act: candle_nn::Activation::Silu,
|
|
}
|
|
};
|
|
|
|
let model = QwenModel::new(&config, vb).map_err(|e| format!("Malli: {}", e))?;
|
|
console_log!("[Coder] Malli ladattu ja välimuistitettu");
|
|
|
|
MODEL_CACHE.with(|c| {
|
|
*c.borrow_mut() = Some(CachedModel { model, tokenizer, is_3b: use_3b });
|
|
});
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// use_3b: false = 0.5B (nopea), true = 3B (laadukas)
|
|
pub async fn run_coder_inference(prompt: String, ws: Rc<RefCell<WebSocket>>, use_3b: bool, task_id: Option<String>) {
|
|
let perf = web_sys::window().unwrap().performance().unwrap();
|
|
let size_label = if use_3b { "3B" } else { "0.5B" };
|
|
|
|
let start_load = perf.now();
|
|
|
|
if let Err(e) = get_or_build_model(use_3b, &ws).await {
|
|
console_log!("[Coder] Mallin lataus: {}", e);
|
|
return;
|
|
}
|
|
|
|
let load_time = perf.now() - start_load;
|
|
if load_time > 100.0 {
|
|
console_log!("[Coder] Malli ladattu ({:.0}ms). Generoidaan...", load_time);
|
|
}
|
|
|
|
// Parsitaan JSON-prompti tai käytetään teksti sellaisenaan
|
|
let default_system = "You are a coding assistant. Respond with ONLY code. No explanations, no markdown, no comments unless asked.";
|
|
let (actual_prompt, system_msg, max_new_tokens) = if prompt.starts_with('{') {
|
|
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&prompt) {
|
|
let p = json.get("prompt").and_then(|v| v.as_str()).unwrap_or(&prompt).to_string();
|
|
let s = json.get("system").and_then(|v| v.as_str()).unwrap_or(default_system).to_string();
|
|
let m = json.get("max_tokens").and_then(|v| v.as_u64()).unwrap_or(128) as usize;
|
|
(p, s, m)
|
|
} else {
|
|
(prompt.clone(), default_system.to_string(), 128)
|
|
}
|
|
} else {
|
|
(prompt.clone(), default_system.to_string(), 128)
|
|
};
|
|
|
|
// Prefill: aloitetaan vastaus ```-koodiblokkilla, jolloin malli jatkaa suoraan koodilla
|
|
// eikä tuota "Sure! Here is..." -johdantoa. strip_markdown_wrapper poistaa ``` jälkikäteen.
|
|
let formatted = format!("<|im_start|>system\n{}<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n```\n", system_msg, actual_prompt);
|
|
|
|
// Inferenssi: käytetään välimuistissa olevaa mallia
|
|
let (generated_text, tokens_generated, gen_time) = MODEL_CACHE.with(|cache| {
|
|
let mut cache = cache.borrow_mut();
|
|
let cached = cache.as_mut().expect("Malli pitää olla ladattu");
|
|
|
|
let encoding = cached.tokenizer.encode(formatted.as_str(), true)
|
|
.map_err(|e| format!("Encode: {}", e)).unwrap();
|
|
let input_ids: Vec<u32> = encoding.get_ids().to_vec();
|
|
let input_len = input_ids.len();
|
|
console_log!("[Coder] Syöte: {} tokenia", input_len);
|
|
|
|
let device = Device::Cpu;
|
|
let start_gen = perf.now();
|
|
let eos_token = 151645u32;
|
|
let temperature: f32 = 0.7;
|
|
let top_k: usize = 40;
|
|
let repetition_penalty: f32 = 1.15;
|
|
|
|
// Nollataan KV-cache edellisestä promptista
|
|
cached.model.clear_kv_cache();
|
|
|
|
let mut generated_text = String::new();
|
|
let mut tokens_generated: usize = 0;
|
|
let mut all_generated: Vec<u32> = Vec::new();
|
|
|
|
// Prefill
|
|
let input = Tensor::new(input_ids.as_slice(), &device).and_then(|t| t.unsqueeze(0)).unwrap();
|
|
let logits = cached.model.forward(&input, 0).unwrap();
|
|
let logits = logits.squeeze(0).unwrap();
|
|
let logits = if logits.dims().len() == 2 {
|
|
logits.get(logits.dim(0).unwrap() - 1).unwrap()
|
|
} else { logits };
|
|
|
|
let mut next_token = crate::sampling::sample_top_k_with_penalty(&logits, top_k, temperature, &all_generated, repetition_penalty);
|
|
|
|
if next_token != eos_token {
|
|
if let Ok(text) = cached.tokenizer.decode(&[next_token], true) {
|
|
generated_text.push_str(&text);
|
|
let mut chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "Qwen2.5-Coder" });
|
|
if let Some(ref tid) = task_id { chunk.as_object_mut().unwrap().insert("task_id".to_string(), serde_json::json!(tid)); }
|
|
let _ = ws.borrow().send_with_str(&chunk.to_string());
|
|
}
|
|
all_generated.push(next_token);
|
|
tokens_generated += 1;
|
|
}
|
|
|
|
// Autoregressive
|
|
let mut pos = input_len;
|
|
for _ in 1..max_new_tokens {
|
|
if next_token == eos_token { break; }
|
|
|
|
let input = Tensor::new(&[next_token], &device).and_then(|t| t.unsqueeze(0)).unwrap();
|
|
let logits = match cached.model.forward(&input, pos) {
|
|
Ok(l) => l,
|
|
Err(e) => { console_log!("[Coder] Forward pos {}: {}", pos, e); break; }
|
|
};
|
|
|
|
let logits = logits.squeeze(0).unwrap();
|
|
let logits = if logits.dims().len() == 2 {
|
|
logits.get(logits.dim(0).unwrap() - 1).unwrap()
|
|
} else { logits };
|
|
next_token = crate::sampling::sample_top_k_with_penalty(&logits, top_k, temperature, &all_generated, repetition_penalty);
|
|
pos += 1;
|
|
|
|
if next_token == eos_token { break; }
|
|
|
|
if let Ok(text) = cached.tokenizer.decode(&[next_token], true) {
|
|
generated_text.push_str(&text);
|
|
|
|
// Stop-sekvenssit: katkaistaan kun malli alkaa selittää
|
|
let lower = generated_text.to_lowercase();
|
|
if lower.contains("\n###") || lower.contains("\nexplanation") || lower.contains("\nnote:") || lower.contains("\noutput:") || lower.contains("\n```\n\n") {
|
|
for stop in &["\n###", "\nExplanation", "\nNote:", "\nOutput:", "\n```\n\n"] {
|
|
if let Some(pos) = generated_text.find(stop) {
|
|
generated_text.truncate(pos);
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
|
|
let mut chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "Qwen2.5-Coder" });
|
|
if let Some(ref tid) = task_id { chunk.as_object_mut().unwrap().insert("task_id".to_string(), serde_json::json!(tid)); }
|
|
let _ = ws.borrow().send_with_str(&chunk.to_string());
|
|
}
|
|
all_generated.push(next_token);
|
|
tokens_generated += 1;
|
|
}
|
|
|
|
let gen_time = perf.now() - start_gen;
|
|
|
|
// Siivotaan vastaus: poista markdown-koodiblokit ja johdantotekstit
|
|
let cleaned = strip_markdown_wrapper(&generated_text);
|
|
|
|
(cleaned, tokens_generated, gen_time)
|
|
});
|
|
|
|
let tokens_per_sec = if gen_time > 0.0 { (tokens_generated as f64 / gen_time) * 1000.0 } else { 0.0 };
|
|
console_log!("[Coder] {} tokenia | {:.0}ms | {:.1} tok/s", tokens_generated, gen_time, tokens_per_sec);
|
|
|
|
let mut done = serde_json::json!({
|
|
"type": "llm_done",
|
|
"prompt": prompt,
|
|
"model": format!("Qwen2.5-Coder-{}-Instruct", size_label),
|
|
"response": generated_text,
|
|
"tokens_generated": tokens_generated,
|
|
"duration_ms": (gen_time * 100.0).round() / 100.0,
|
|
"tokens_per_sec": (tokens_per_sec * 10.0).round() / 10.0,
|
|
"load_time_ms": (load_time * 100.0).round() / 100.0,
|
|
});
|
|
if let Some(tid) = task_id {
|
|
done.as_object_mut().unwrap().insert("task_id".to_string(), serde_json::json!(tid));
|
|
}
|
|
let _ = ws.borrow().send_with_str(&done.to_string());
|
|
}
|