Päivitetty juttuja

This commit is contained in:
Jaakko Vanhala
2026-04-04 21:13:20 +03:00
parent 2e7ddf6f1e
commit 3ada8949d0
11 changed files with 457 additions and 105 deletions

View File

@@ -5,6 +5,62 @@ use hf_hub::{api::sync::Api, Repo, RepoType};
use std::path::PathBuf;
use std::time::Instant;
/// Top-k sampling with temperature and repetition penalty
fn sample_top_k(logits: &Tensor, k: usize, temperature: f64, generated_tokens: &[u32], repetition_penalty: f64, rng_state: &mut u64) -> Result<u32, String> {
let mut logits_vec: Vec<f32> = logits.to_vec1::<f32>().map_err(|e| format!("to_vec1: {}", e))?;
if logits_vec.is_empty() { return Err("Tyhjä logits".to_string()); }
// Repetition penalty: rankaisee jo generoituja tokeneita
for &token_id in generated_tokens {
if (token_id as usize) < logits_vec.len() {
let logit = &mut logits_vec[token_id as usize];
if *logit > 0.0 {
*logit /= repetition_penalty as f32;
} else {
*logit *= repetition_penalty as f32;
}
}
}
// Temperature scaling
if temperature > 0.0 && temperature != 1.0 {
for logit in logits_vec.iter_mut() {
*logit /= temperature as f32;
}
}
// Top-k: etsitään k suurinta
let mut indexed: Vec<(usize, f32)> = logits_vec.iter().enumerate().map(|(i, &v)| (i, v)).collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed.truncate(k);
if k == 1 || temperature == 0.0 {
return Ok(indexed[0].0 as u32);
}
// Softmax top-k:lle
let max_logit = indexed[0].1;
let exps: Vec<f32> = indexed.iter().map(|x| (x.1 - max_logit).exp()).collect();
let sum: f32 = exps.iter().sum();
let probs: Vec<f32> = exps.iter().map(|e| e / sum).collect();
// XorShift64 RNG
*rng_state ^= *rng_state << 13;
*rng_state ^= *rng_state >> 7;
*rng_state ^= *rng_state << 17;
let rand_val = (*rng_state % 10000) as f32 / 10000.0;
let mut cumulative = 0.0;
for (i, p) in probs.iter().enumerate() {
cumulative += p;
if rand_val < cumulative {
return Ok(indexed[i].0 as u32);
}
}
Ok(indexed[0].0 as u32)
}
pub struct LlmEngine {
tokenizer: tokenizers::Tokenizer,
model_path: PathBuf,
@@ -22,10 +78,10 @@ impl LlmEngine {
let dtype = if device.is_cuda() { DType::F16 } else { DType::F32 };
tracing::info!("Ladataan Qwen2.5-0.5B-Instruct...");
tracing::info!("Ladataan Qwen2.5-Coder-0.5B-Instruct...");
let api = Api::new().map_err(|e| format!("HF API: {}", e))?;
let repo = api.repo(Repo::with_revision(
"Qwen/Qwen2.5-0.5B-Instruct".to_string(),
"Qwen/Qwen2.5-Coder-0.5B-Instruct".to_string(),
RepoType::Model,
"main".to_string(),
));
@@ -93,6 +149,15 @@ impl LlmEngine {
// Tuore malli joka promptille (nollaa KV-cachen)
let mut model = self.fresh_model()?;
// Sampling-parametrit
let temperature = 0.7;
let top_k = 40;
let repetition_penalty = 1.15;
let mut rng_state: u64 = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos() as u64;
let start = Instant::now();
// Prefill
@@ -105,19 +170,19 @@ impl LlmEngine {
let logits = logits.squeeze(0).map_err(|e| format!("Squeeze: {}", e))?;
let logits = if logits.dims().len() == 2 {
logits.get(logits.dim(0).unwrap() - 1).map_err(|e| format!("Get: {}", e))?
let seq_len = logits.dim(0).map_err(|e| format!("Dim: {}", e))?;
if seq_len == 0 { return Err("Tyhjä tensori".to_string()); }
logits.get(seq_len - 1).map_err(|e| format!("Get: {}", e))?
} else {
logits
};
let mut next_token = logits.argmax(0)
.map_err(|e| format!("Argmax: {}", e))?
.to_vec0::<u32>()
.map_err(|e| format!("to_vec0: {}", e))?;
let mut generated_text = String::new();
let mut tokens_generated: usize = 0;
let mut all_tokens: Vec<u32> = Vec::new();
let mut next_token = sample_top_k(&logits, top_k, temperature, &all_tokens, repetition_penalty, &mut rng_state)?;
if next_token != self.eos_token {
if let Ok(text) = self.tokenizer.decode(&[next_token], true) {
generated_text.push_str(&text);
@@ -140,14 +205,13 @@ impl LlmEngine {
let logits = logits.squeeze(0).map_err(|e| format!("Squeeze: {}", e))?;
let logits = if logits.dims().len() == 2 {
logits.get(logits.dim(0).unwrap() - 1).map_err(|e| format!("Get: {}", e))?
let seq_len = logits.dim(0).map_err(|e| format!("Dim: {}", e))?;
if seq_len == 0 { break; }
logits.get(seq_len - 1).map_err(|e| format!("Get: {}", e))?
} else {
logits
};
next_token = logits.argmax(0)
.map_err(|e| format!("Argmax: {}", e))?
.to_vec0::<u32>()
.map_err(|e| format!("to_vec0: {}", e))?;
next_token = sample_top_k(&logits, top_k, temperature, &all_tokens, repetition_penalty, &mut rng_state)?;
pos += 1;
if next_token == self.eos_token { break; }

View File

@@ -227,6 +227,7 @@ fn build_auth_message(allocated_gb: u32) -> String {
"status": "agent_ready",
"node_type": "native",
"allocated_gb": allocated_gb,
"selected_task": "qwen-coder-05b",
"system": sys,
});
@@ -318,10 +319,14 @@ async fn main() {
if text.contains("llm_prompt") && !busy {
if let Ok(task) = serde_json::from_str::<serde_json::Value>(&text) {
let prompt = task.get("prompt").and_then(|v| v.as_str()).unwrap_or("");
if !prompt.is_empty() {
let task_id = task.get("task_id").and_then(|v| v.as_str()).unwrap_or("?");
let msg_model = task.get("model").and_then(|v| v.as_str()).unwrap_or("");
if !prompt.is_empty() && msg_model.starts_with("qwen-coder") {
if let Some(ref mut engine) = llm {
busy = true;
tracing::info!("Generoidaan: \"{}\"", prompt);
tracing::info!("Generoidaan (task_id: {}): \"{}\"", task_id, prompt);
match engine.generate(prompt, 64) {
Ok(result) => {
@@ -336,12 +341,13 @@ async fn main() {
let done = json!({
"type": "llm_done",
"prompt": prompt,
"model": "Qwen2.5-0.5B-Instruct (native/GPU)",
"model": "Qwen2.5-Coder-0.5B (native/GPU)",
"response": result.text,
"tokens_generated": result.tokens_generated,
"duration_ms": result.duration_ms,
"tokens_per_sec": (result.tokens_per_sec * 10.0).round() / 10.0,
"load_time_ms": 0,
"task_id": task_id,
});
let _ = write.send(Message::Text(done.to_string())).await;
}