hetki ennen webgpu inferenssiä

This commit is contained in:
2026-04-02 12:49:40 +03:00
parent d2920e5ab4
commit e1326b145e
10 changed files with 375 additions and 80 deletions

View File

@@ -0,0 +1,47 @@
use candle_core::Tensor;
/// Top-k sampling ilman softmaxia — kiertää Candlen SoftmaxLastDim Wasm-bugin.
/// Valitsee top-k logiteista ja poimii satunnaisen (painotettu).
/// Jos k=1, toimii kuten argmax (greedy).
pub fn sample_top_k(logits: &Tensor, k: usize, eos_penalty: f32) -> u32 {
// Muunnetaan Vec<f32>:ksi
let logits_vec: Vec<f32> = logits.to_vec1::<f32>().unwrap_or_default();
if logits_vec.is_empty() { return 0; }
// Rangotaan ja otetaan top-k indeksit
let mut indexed: Vec<(usize, f32)> = logits_vec.iter().enumerate().map(|(i, &v)| (i, v)).collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed.truncate(k);
// EOS-penaltti: vähennetään EOS-tokenin logitia
for item in indexed.iter_mut() {
if item.0 == 2 || item.0 == 151645 { // SmolLM EOS=2, Qwen EOS=151645
item.1 -= eos_penalty;
}
}
if k == 1 {
return indexed[0].0 as u32;
}
// Yksinkertainen "softmax" top-k:lle CPU:lla
let max_logit = indexed.iter().map(|x| x.1).fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = indexed.iter().map(|x| (x.1 - max_logit).exp()).collect();
let sum: f32 = exps.iter().sum();
let probs: Vec<f32> = exps.iter().map(|e| e / sum).collect();
// Satunnainen valinta kumulatiivisella todennäköisyydellä
// Käytetään yksinkertaista XorShift-satunnaislukugeneraattoria (ei tarvita getrandom)
let seed = (js_sys::Date::now() * 1000.0) as u64;
let rand_val = ((seed ^ (seed >> 13) ^ (seed << 7)) % 10000) as f32 / 10000.0;
let mut cumulative = 0.0;
for (i, p) in probs.iter().enumerate() {
cumulative += p;
if rand_val < cumulative {
return indexed[i].0 as u32;
}
}
indexed[0].0 as u32
}