Päivitetty juttuja
This commit is contained in:
@@ -1,39 +1,105 @@
|
||||
use candle_core::Tensor;
|
||||
use std::cell::Cell;
|
||||
|
||||
/// Top-k sampling ilman softmaxia — kiertää Candlen SoftmaxLastDim Wasm-bugin.
|
||||
/// Valitsee top-k logiteista ja poimii satunnaisen (painotettu).
|
||||
/// Jos k=1, toimii kuten argmax (greedy).
|
||||
pub fn sample_top_k(logits: &Tensor, k: usize, eos_penalty: f32) -> u32 {
|
||||
// Muunnetaan Vec<f32>:ksi
|
||||
let logits_vec: Vec<f32> = logits.to_vec1::<f32>().unwrap_or_default();
|
||||
thread_local! {
|
||||
static RNG_STATE: Cell<u64> = Cell::new(0);
|
||||
}
|
||||
|
||||
fn next_rand() -> f32 {
|
||||
RNG_STATE.with(|state| {
|
||||
let mut s = state.get();
|
||||
if s == 0 {
|
||||
s = (js_sys::Date::now() * 1000.0) as u64 | 1;
|
||||
}
|
||||
s ^= s << 13;
|
||||
s ^= s >> 7;
|
||||
s ^= s << 17;
|
||||
state.set(s);
|
||||
(s % 10000) as f32 / 10000.0
|
||||
})
|
||||
}
|
||||
|
||||
/// Top-k sampling with temperature and repetition penalty.
|
||||
/// `generated_tokens` sisältää aiemmin generoidut token-id:t toiston estämiseksi.
|
||||
pub fn sample_top_k_with_penalty(logits: &Tensor, k: usize, temperature: f32, generated_tokens: &[u32], repetition_penalty: f32) -> u32 {
|
||||
let mut logits_vec: Vec<f32> = logits.to_vec1::<f32>().unwrap_or_default();
|
||||
if logits_vec.is_empty() { return 0; }
|
||||
|
||||
// Rangotaan ja otetaan top-k indeksit
|
||||
// Repetition penalty
|
||||
if repetition_penalty != 1.0 {
|
||||
for &token_id in generated_tokens {
|
||||
if (token_id as usize) < logits_vec.len() {
|
||||
let logit = &mut logits_vec[token_id as usize];
|
||||
if *logit > 0.0 {
|
||||
*logit /= repetition_penalty;
|
||||
} else {
|
||||
*logit *= repetition_penalty;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Temperature scaling
|
||||
if temperature > 0.0 && temperature != 1.0 {
|
||||
for logit in logits_vec.iter_mut() {
|
||||
*logit /= temperature;
|
||||
}
|
||||
}
|
||||
|
||||
// Top-k
|
||||
let mut indexed: Vec<(usize, f32)> = logits_vec.iter().enumerate().map(|(i, &v)| (i, v)).collect();
|
||||
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
indexed.truncate(k);
|
||||
|
||||
// EOS-penaltti: vähennetään EOS-tokenin logitia
|
||||
for item in indexed.iter_mut() {
|
||||
if item.0 == 2 || item.0 == 151645 { // SmolLM EOS=2, Qwen EOS=151645
|
||||
item.1 -= eos_penalty;
|
||||
}
|
||||
}
|
||||
|
||||
if k == 1 {
|
||||
if k == 1 || temperature == 0.0 {
|
||||
return indexed[0].0 as u32;
|
||||
}
|
||||
|
||||
// Yksinkertainen "softmax" top-k:lle CPU:lla
|
||||
let max_logit = indexed.iter().map(|x| x.1).fold(f32::NEG_INFINITY, f32::max);
|
||||
// Softmax top-k:lle
|
||||
let max_logit = indexed[0].1;
|
||||
let exps: Vec<f32> = indexed.iter().map(|x| (x.1 - max_logit).exp()).collect();
|
||||
let sum: f32 = exps.iter().sum();
|
||||
let probs: Vec<f32> = exps.iter().map(|e| e / sum).collect();
|
||||
|
||||
// Satunnainen valinta kumulatiivisella todennäköisyydellä
|
||||
// Käytetään yksinkertaista XorShift-satunnaislukugeneraattoria (ei tarvita getrandom)
|
||||
let seed = (js_sys::Date::now() * 1000.0) as u64;
|
||||
let rand_val = ((seed ^ (seed >> 13) ^ (seed << 7)) % 10000) as f32 / 10000.0;
|
||||
let rand_val = next_rand();
|
||||
|
||||
let mut cumulative = 0.0;
|
||||
for (i, p) in probs.iter().enumerate() {
|
||||
cumulative += p;
|
||||
if rand_val < cumulative {
|
||||
return indexed[i].0 as u32;
|
||||
}
|
||||
}
|
||||
|
||||
indexed[0].0 as u32
|
||||
}
|
||||
|
||||
/// Alkuperäinen API yhteensopivuudeksi SmolLM/Qwen-moduulien kanssa
|
||||
pub fn sample_top_k(logits: &Tensor, k: usize, eos_penalty: f32) -> u32 {
|
||||
let mut logits_vec: Vec<f32> = logits.to_vec1::<f32>().unwrap_or_default();
|
||||
if logits_vec.is_empty() { return 0; }
|
||||
|
||||
// EOS-penaltti
|
||||
for &eos_id in &[2u32, 151645] {
|
||||
if (eos_id as usize) < logits_vec.len() {
|
||||
logits_vec[eos_id as usize] -= eos_penalty;
|
||||
}
|
||||
}
|
||||
|
||||
let mut indexed: Vec<(usize, f32)> = logits_vec.iter().enumerate().map(|(i, &v)| (i, v)).collect();
|
||||
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
indexed.truncate(k);
|
||||
|
||||
if k == 1 {
|
||||
return indexed[0].0 as u32;
|
||||
}
|
||||
|
||||
let max_logit = indexed[0].1;
|
||||
let exps: Vec<f32> = indexed.iter().map(|x| (x.1 - max_logit).exp()).collect();
|
||||
let sum: f32 = exps.iter().sum();
|
||||
let probs: Vec<f32> = exps.iter().map(|e| e / sum).collect();
|
||||
|
||||
let rand_val = next_rand();
|
||||
|
||||
let mut cumulative = 0.0;
|
||||
for (i, p) in probs.iter().enumerate() {
|
||||
|
||||
Reference in New Issue
Block a user