hetki ennen webgpu inferenssiä
This commit is contained in:
@@ -7,6 +7,7 @@ use burn::tensor::Tensor;
|
||||
use burn::backend::{Wgpu, NdArray};
|
||||
|
||||
pub mod storage;
|
||||
pub mod sampling;
|
||||
pub mod smollm;
|
||||
pub mod qwen;
|
||||
pub mod qwen_coder;
|
||||
|
||||
@@ -154,7 +154,7 @@ pub async fn run_qwen_inference(prompt: String, ws: Rc<RefCell<WebSocket>>) {
|
||||
} else {
|
||||
logits // jo [vocab_size]
|
||||
};
|
||||
let mut next_token = logits.argmax(0).unwrap().to_vec0::<u32>().unwrap();
|
||||
let mut next_token = crate::sampling::sample_top_k(&logits, 10, 5.0);
|
||||
console_log!("[Qwen] Ensimmäinen token: {}", next_token);
|
||||
|
||||
let eos_token = 151645u32; // <|endoftext|> for Qwen2.5
|
||||
@@ -188,7 +188,7 @@ pub async fn run_qwen_inference(prompt: String, ws: Rc<RefCell<WebSocket>>) {
|
||||
} else {
|
||||
logits
|
||||
};
|
||||
next_token = logits.argmax(0).unwrap().to_vec0::<u32>().unwrap();
|
||||
next_token = crate::sampling::sample_top_k(&logits, 10, 5.0);
|
||||
pos += 1;
|
||||
|
||||
if next_token == eos_token { break; }
|
||||
|
||||
@@ -173,8 +173,22 @@ pub async fn run_coder_inference(prompt: String, ws: Rc<RefCell<WebSocket>>, use
|
||||
let load_time = perf.now() - start_load;
|
||||
console_log!("[Coder] Malli ladattu ({:.0}ms). Generoidaan...", load_time);
|
||||
|
||||
// Muotoillaan chat-template
|
||||
let formatted = format!("<|im_start|>system\nYou are a Python coding assistant. Write only code, no explanations.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", prompt);
|
||||
// Parsitaan JSON-prompti tai käytetään teksti sellaisenaan
|
||||
let (actual_prompt, system_msg, max_new_tokens) = if prompt.starts_with('{') {
|
||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&prompt) {
|
||||
let p = json.get("prompt").and_then(|v| v.as_str()).unwrap_or(&prompt).to_string();
|
||||
let s = json.get("system").and_then(|v| v.as_str())
|
||||
.unwrap_or("You are a Python coding assistant. Write only code, no explanations.").to_string();
|
||||
let m = json.get("max_tokens").and_then(|v| v.as_u64()).unwrap_or(128) as usize;
|
||||
(p, s, m)
|
||||
} else {
|
||||
(prompt.clone(), "You are a Python coding assistant. Write only code, no explanations.".to_string(), 128)
|
||||
}
|
||||
} else {
|
||||
(prompt.clone(), "You are a Python coding assistant. Write only code, no explanations.".to_string(), 128)
|
||||
};
|
||||
|
||||
let formatted = format!("<|im_start|>system\n{}<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", system_msg, actual_prompt);
|
||||
|
||||
let encoding = match tokenizer.encode(formatted.as_str(), true) {
|
||||
Ok(e) => e,
|
||||
@@ -185,7 +199,7 @@ pub async fn run_coder_inference(prompt: String, ws: Rc<RefCell<WebSocket>>, use
|
||||
console_log!("[Coder] Syöte: {} tokenia", input_len);
|
||||
|
||||
let start_gen = perf.now();
|
||||
let max_new_tokens = 128; // Koodille enemmän tokeneita
|
||||
// max_new_tokens tulee JSON-promptista tai oletuksena 128
|
||||
let mut generated_text = String::new();
|
||||
let mut tokens_generated: usize = 0;
|
||||
let eos_token = 151645u32;
|
||||
@@ -206,7 +220,7 @@ pub async fn run_coder_inference(prompt: String, ws: Rc<RefCell<WebSocket>>, use
|
||||
} else {
|
||||
logits
|
||||
};
|
||||
let mut next_token = logits.argmax(0).unwrap().to_vec0::<u32>().unwrap();
|
||||
let mut next_token = crate::sampling::sample_top_k(&logits, 10, 5.0);
|
||||
|
||||
if next_token != eos_token {
|
||||
if let Ok(text) = tokenizer.decode(&[next_token], true) {
|
||||
@@ -237,7 +251,7 @@ pub async fn run_coder_inference(prompt: String, ws: Rc<RefCell<WebSocket>>, use
|
||||
} else {
|
||||
logits
|
||||
};
|
||||
next_token = logits.argmax(0).unwrap().to_vec0::<u32>().unwrap();
|
||||
next_token = crate::sampling::sample_top_k(&logits, 10, 5.0);
|
||||
pos += 1;
|
||||
|
||||
if next_token == eos_token { break; }
|
||||
|
||||
47
network-poc/node/src/sampling.rs
Normal file
47
network-poc/node/src/sampling.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
use candle_core::Tensor;
|
||||
|
||||
/// Top-k sampling ilman softmaxia — kiertää Candlen SoftmaxLastDim Wasm-bugin.
|
||||
/// Valitsee top-k logiteista ja poimii satunnaisen (painotettu).
|
||||
/// Jos k=1, toimii kuten argmax (greedy).
|
||||
pub fn sample_top_k(logits: &Tensor, k: usize, eos_penalty: f32) -> u32 {
|
||||
// Muunnetaan Vec<f32>:ksi
|
||||
let logits_vec: Vec<f32> = logits.to_vec1::<f32>().unwrap_or_default();
|
||||
if logits_vec.is_empty() { return 0; }
|
||||
|
||||
// Rangotaan ja otetaan top-k indeksit
|
||||
let mut indexed: Vec<(usize, f32)> = logits_vec.iter().enumerate().map(|(i, &v)| (i, v)).collect();
|
||||
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
indexed.truncate(k);
|
||||
|
||||
// EOS-penaltti: vähennetään EOS-tokenin logitia
|
||||
for item in indexed.iter_mut() {
|
||||
if item.0 == 2 || item.0 == 151645 { // SmolLM EOS=2, Qwen EOS=151645
|
||||
item.1 -= eos_penalty;
|
||||
}
|
||||
}
|
||||
|
||||
if k == 1 {
|
||||
return indexed[0].0 as u32;
|
||||
}
|
||||
|
||||
// Yksinkertainen "softmax" top-k:lle CPU:lla
|
||||
let max_logit = indexed.iter().map(|x| x.1).fold(f32::NEG_INFINITY, f32::max);
|
||||
let exps: Vec<f32> = indexed.iter().map(|x| (x.1 - max_logit).exp()).collect();
|
||||
let sum: f32 = exps.iter().sum();
|
||||
let probs: Vec<f32> = exps.iter().map(|e| e / sum).collect();
|
||||
|
||||
// Satunnainen valinta kumulatiivisella todennäköisyydellä
|
||||
// Käytetään yksinkertaista XorShift-satunnaislukugeneraattoria (ei tarvita getrandom)
|
||||
let seed = (js_sys::Date::now() * 1000.0) as u64;
|
||||
let rand_val = ((seed ^ (seed >> 13) ^ (seed << 7)) % 10000) as f32 / 10000.0;
|
||||
|
||||
let mut cumulative = 0.0;
|
||||
for (i, p) in probs.iter().enumerate() {
|
||||
cumulative += p;
|
||||
if rand_val < cumulative {
|
||||
return indexed[i].0 as u32;
|
||||
}
|
||||
}
|
||||
|
||||
indexed[0].0 as u32
|
||||
}
|
||||
@@ -196,7 +196,7 @@ pub async fn run_smollm_inference(prompt: String, ws: Rc<RefCell<WebSocket>>) {
|
||||
} else {
|
||||
logits
|
||||
};
|
||||
let mut next_token = logits.argmax(0).unwrap().to_vec0::<u32>().unwrap();
|
||||
let mut next_token = crate::sampling::sample_top_k(&logits, 10, 5.0);
|
||||
console_log!("[SmolLM] Ensimmäinen generoitu token: {}", next_token);
|
||||
pos = input_len;
|
||||
|
||||
@@ -229,7 +229,7 @@ pub async fn run_smollm_inference(prompt: String, ws: Rc<RefCell<WebSocket>>) {
|
||||
} else {
|
||||
logits
|
||||
};
|
||||
next_token = logits.argmax(0).unwrap().to_vec0::<u32>().unwrap();
|
||||
next_token = crate::sampling::sample_top_k(&logits, 10, 5.0);
|
||||
pos += 1;
|
||||
|
||||
if next_token == 2 { break; }
|
||||
|
||||
Reference in New Issue
Block a user