hetki ennen webgpu inferenssiä

This commit is contained in:
2026-04-02 12:49:40 +03:00
parent d2920e5ab4
commit e1326b145e
10 changed files with 375 additions and 80 deletions

View File

@@ -173,8 +173,22 @@ pub async fn run_coder_inference(prompt: String, ws: Rc<RefCell<WebSocket>>, use
let load_time = perf.now() - start_load;
console_log!("[Coder] Malli ladattu ({:.0}ms). Generoidaan...", load_time);
// Muotoillaan chat-template
let formatted = format!("<|im_start|>system\nYou are a Python coding assistant. Write only code, no explanations.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", prompt);
// Parsitaan JSON-prompti tai käytetään teksti sellaisenaan
let (actual_prompt, system_msg, max_new_tokens) = if prompt.starts_with('{') {
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&prompt) {
let p = json.get("prompt").and_then(|v| v.as_str()).unwrap_or(&prompt).to_string();
let s = json.get("system").and_then(|v| v.as_str())
.unwrap_or("You are a Python coding assistant. Write only code, no explanations.").to_string();
let m = json.get("max_tokens").and_then(|v| v.as_u64()).unwrap_or(128) as usize;
(p, s, m)
} else {
(prompt.clone(), "You are a Python coding assistant. Write only code, no explanations.".to_string(), 128)
}
} else {
(prompt.clone(), "You are a Python coding assistant. Write only code, no explanations.".to_string(), 128)
};
let formatted = format!("<|im_start|>system\n{}<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", system_msg, actual_prompt);
let encoding = match tokenizer.encode(formatted.as_str(), true) {
Ok(e) => e,
@@ -185,7 +199,7 @@ pub async fn run_coder_inference(prompt: String, ws: Rc<RefCell<WebSocket>>, use
console_log!("[Coder] Syöte: {} tokenia", input_len);
let start_gen = perf.now();
let max_new_tokens = 128; // Koodille enemmän tokeneita
// max_new_tokens tulee JSON-promptista tai oletuksena 128
let mut generated_text = String::new();
let mut tokens_generated: usize = 0;
let eos_token = 151645u32;
@@ -206,7 +220,7 @@ pub async fn run_coder_inference(prompt: String, ws: Rc<RefCell<WebSocket>>, use
} else {
logits
};
let mut next_token = logits.argmax(0).unwrap().to_vec0::<u32>().unwrap();
let mut next_token = crate::sampling::sample_top_k(&logits, 10, 5.0);
if next_token != eos_token {
if let Ok(text) = tokenizer.decode(&[next_token], true) {
@@ -237,7 +251,7 @@ pub async fn run_coder_inference(prompt: String, ws: Rc<RefCell<WebSocket>>, use
} else {
logits
};
next_token = logits.argmax(0).unwrap().to_vec0::<u32>().unwrap();
next_token = crate::sampling::sample_top_k(&logits, 10, 5.0);
pos += 1;
if next_token == eos_token { break; }