use candle_core::Tensor; use std::cell::Cell; thread_local! { static RNG_STATE: Cell = Cell::new(0); } fn next_rand() -> f32 { RNG_STATE.with(|state| { let mut s = state.get(); if s == 0 { s = (js_sys::Date::now() * 1000.0) as u64 | 1; } s ^= s << 13; s ^= s >> 7; s ^= s << 17; state.set(s); (s % 10000) as f32 / 10000.0 }) } /// Top-k sampling with temperature and repetition penalty. /// `generated_tokens` sisältää aiemmin generoidut token-id:t toiston estämiseksi. pub fn sample_top_k_with_penalty(logits: &Tensor, k: usize, temperature: f32, generated_tokens: &[u32], repetition_penalty: f32) -> u32 { let mut logits_vec: Vec = logits.to_vec1::().unwrap_or_default(); if logits_vec.is_empty() { return 0; } // Repetition penalty if repetition_penalty != 1.0 { for &token_id in generated_tokens { if (token_id as usize) < logits_vec.len() { let logit = &mut logits_vec[token_id as usize]; if *logit > 0.0 { *logit /= repetition_penalty; } else { *logit *= repetition_penalty; } } } } // Temperature scaling if temperature > 0.0 && temperature != 1.0 { for logit in logits_vec.iter_mut() { *logit /= temperature; } } // Top-k let mut indexed: Vec<(usize, f32)> = logits_vec.iter().enumerate().map(|(i, &v)| (i, v)).collect(); indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); indexed.truncate(k); if k == 1 || temperature == 0.0 { return indexed[0].0 as u32; } // Softmax top-k:lle let max_logit = indexed[0].1; let exps: Vec = indexed.iter().map(|x| (x.1 - max_logit).exp()).collect(); let sum: f32 = exps.iter().sum(); let probs: Vec = exps.iter().map(|e| e / sum).collect(); let rand_val = next_rand(); let mut cumulative = 0.0; for (i, p) in probs.iter().enumerate() { cumulative += p; if rand_val < cumulative { return indexed[i].0 as u32; } } indexed[0].0 as u32 } /// Alkuperäinen API yhteensopivuudeksi SmolLM/Qwen-moduulien kanssa pub fn sample_top_k(logits: &Tensor, k: usize, eos_penalty: f32) -> u32 { let mut logits_vec: Vec = logits.to_vec1::().unwrap_or_default(); if logits_vec.is_empty() { return 0; } // EOS-penaltti for &eos_id in &[2u32, 151645] { if (eos_id as usize) < logits_vec.len() { logits_vec[eos_id as usize] -= eos_penalty; } } let mut indexed: Vec<(usize, f32)> = logits_vec.iter().enumerate().map(|(i, &v)| (i, v)).collect(); indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); indexed.truncate(k); if k == 1 { return indexed[0].0 as u32; } let max_logit = indexed[0].1; let exps: Vec = indexed.iter().map(|x| (x.1 - max_logit).exp()).collect(); let sum: f32 = exps.iter().sum(); let probs: Vec = exps.iter().map(|e| e / sum).collect(); let rand_val = next_rand(); let mut cumulative = 0.0; for (i, p) in probs.iter().enumerate() { cumulative += p; if rand_val < cumulative { return indexed[i].0 as u32; } } indexed[0].0 as u32 }