Päivitetty juttuja
This commit is contained in:
@@ -5,6 +5,62 @@ use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::path::PathBuf;
|
||||
use std::time::Instant;
|
||||
|
||||
/// Top-k sampling with temperature and repetition penalty
|
||||
fn sample_top_k(logits: &Tensor, k: usize, temperature: f64, generated_tokens: &[u32], repetition_penalty: f64, rng_state: &mut u64) -> Result<u32, String> {
|
||||
let mut logits_vec: Vec<f32> = logits.to_vec1::<f32>().map_err(|e| format!("to_vec1: {}", e))?;
|
||||
if logits_vec.is_empty() { return Err("Tyhjä logits".to_string()); }
|
||||
|
||||
// Repetition penalty: rankaisee jo generoituja tokeneita
|
||||
for &token_id in generated_tokens {
|
||||
if (token_id as usize) < logits_vec.len() {
|
||||
let logit = &mut logits_vec[token_id as usize];
|
||||
if *logit > 0.0 {
|
||||
*logit /= repetition_penalty as f32;
|
||||
} else {
|
||||
*logit *= repetition_penalty as f32;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Temperature scaling
|
||||
if temperature > 0.0 && temperature != 1.0 {
|
||||
for logit in logits_vec.iter_mut() {
|
||||
*logit /= temperature as f32;
|
||||
}
|
||||
}
|
||||
|
||||
// Top-k: etsitään k suurinta
|
||||
let mut indexed: Vec<(usize, f32)> = logits_vec.iter().enumerate().map(|(i, &v)| (i, v)).collect();
|
||||
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
indexed.truncate(k);
|
||||
|
||||
if k == 1 || temperature == 0.0 {
|
||||
return Ok(indexed[0].0 as u32);
|
||||
}
|
||||
|
||||
// Softmax top-k:lle
|
||||
let max_logit = indexed[0].1;
|
||||
let exps: Vec<f32> = indexed.iter().map(|x| (x.1 - max_logit).exp()).collect();
|
||||
let sum: f32 = exps.iter().sum();
|
||||
let probs: Vec<f32> = exps.iter().map(|e| e / sum).collect();
|
||||
|
||||
// XorShift64 RNG
|
||||
*rng_state ^= *rng_state << 13;
|
||||
*rng_state ^= *rng_state >> 7;
|
||||
*rng_state ^= *rng_state << 17;
|
||||
let rand_val = (*rng_state % 10000) as f32 / 10000.0;
|
||||
|
||||
let mut cumulative = 0.0;
|
||||
for (i, p) in probs.iter().enumerate() {
|
||||
cumulative += p;
|
||||
if rand_val < cumulative {
|
||||
return Ok(indexed[i].0 as u32);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(indexed[0].0 as u32)
|
||||
}
|
||||
|
||||
pub struct LlmEngine {
|
||||
tokenizer: tokenizers::Tokenizer,
|
||||
model_path: PathBuf,
|
||||
@@ -22,10 +78,10 @@ impl LlmEngine {
|
||||
|
||||
let dtype = if device.is_cuda() { DType::F16 } else { DType::F32 };
|
||||
|
||||
tracing::info!("Ladataan Qwen2.5-0.5B-Instruct...");
|
||||
tracing::info!("Ladataan Qwen2.5-Coder-0.5B-Instruct...");
|
||||
let api = Api::new().map_err(|e| format!("HF API: {}", e))?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
"Qwen/Qwen2.5-0.5B-Instruct".to_string(),
|
||||
"Qwen/Qwen2.5-Coder-0.5B-Instruct".to_string(),
|
||||
RepoType::Model,
|
||||
"main".to_string(),
|
||||
));
|
||||
@@ -93,6 +149,15 @@ impl LlmEngine {
|
||||
// Tuore malli joka promptille (nollaa KV-cachen)
|
||||
let mut model = self.fresh_model()?;
|
||||
|
||||
// Sampling-parametrit
|
||||
let temperature = 0.7;
|
||||
let top_k = 40;
|
||||
let repetition_penalty = 1.15;
|
||||
let mut rng_state: u64 = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos() as u64;
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
// Prefill
|
||||
@@ -105,19 +170,19 @@ impl LlmEngine {
|
||||
|
||||
let logits = logits.squeeze(0).map_err(|e| format!("Squeeze: {}", e))?;
|
||||
let logits = if logits.dims().len() == 2 {
|
||||
logits.get(logits.dim(0).unwrap() - 1).map_err(|e| format!("Get: {}", e))?
|
||||
let seq_len = logits.dim(0).map_err(|e| format!("Dim: {}", e))?;
|
||||
if seq_len == 0 { return Err("Tyhjä tensori".to_string()); }
|
||||
logits.get(seq_len - 1).map_err(|e| format!("Get: {}", e))?
|
||||
} else {
|
||||
logits
|
||||
};
|
||||
let mut next_token = logits.argmax(0)
|
||||
.map_err(|e| format!("Argmax: {}", e))?
|
||||
.to_vec0::<u32>()
|
||||
.map_err(|e| format!("to_vec0: {}", e))?;
|
||||
|
||||
let mut generated_text = String::new();
|
||||
let mut tokens_generated: usize = 0;
|
||||
let mut all_tokens: Vec<u32> = Vec::new();
|
||||
|
||||
let mut next_token = sample_top_k(&logits, top_k, temperature, &all_tokens, repetition_penalty, &mut rng_state)?;
|
||||
|
||||
if next_token != self.eos_token {
|
||||
if let Ok(text) = self.tokenizer.decode(&[next_token], true) {
|
||||
generated_text.push_str(&text);
|
||||
@@ -140,14 +205,13 @@ impl LlmEngine {
|
||||
|
||||
let logits = logits.squeeze(0).map_err(|e| format!("Squeeze: {}", e))?;
|
||||
let logits = if logits.dims().len() == 2 {
|
||||
logits.get(logits.dim(0).unwrap() - 1).map_err(|e| format!("Get: {}", e))?
|
||||
let seq_len = logits.dim(0).map_err(|e| format!("Dim: {}", e))?;
|
||||
if seq_len == 0 { break; }
|
||||
logits.get(seq_len - 1).map_err(|e| format!("Get: {}", e))?
|
||||
} else {
|
||||
logits
|
||||
};
|
||||
next_token = logits.argmax(0)
|
||||
.map_err(|e| format!("Argmax: {}", e))?
|
||||
.to_vec0::<u32>()
|
||||
.map_err(|e| format!("to_vec0: {}", e))?;
|
||||
next_token = sample_top_k(&logits, top_k, temperature, &all_tokens, repetition_penalty, &mut rng_state)?;
|
||||
pos += 1;
|
||||
|
||||
if next_token == self.eos_token { break; }
|
||||
|
||||
@@ -227,6 +227,7 @@ fn build_auth_message(allocated_gb: u32) -> String {
|
||||
"status": "agent_ready",
|
||||
"node_type": "native",
|
||||
"allocated_gb": allocated_gb,
|
||||
"selected_task": "qwen-coder-05b",
|
||||
"system": sys,
|
||||
});
|
||||
|
||||
@@ -318,10 +319,14 @@ async fn main() {
|
||||
if text.contains("llm_prompt") && !busy {
|
||||
if let Ok(task) = serde_json::from_str::<serde_json::Value>(&text) {
|
||||
let prompt = task.get("prompt").and_then(|v| v.as_str()).unwrap_or("");
|
||||
if !prompt.is_empty() {
|
||||
let task_id = task.get("task_id").and_then(|v| v.as_str()).unwrap_or("?");
|
||||
let msg_model = task.get("model").and_then(|v| v.as_str()).unwrap_or("");
|
||||
|
||||
if !prompt.is_empty() && msg_model.starts_with("qwen-coder") {
|
||||
|
||||
if let Some(ref mut engine) = llm {
|
||||
busy = true;
|
||||
tracing::info!("Generoidaan: \"{}\"", prompt);
|
||||
tracing::info!("Generoidaan (task_id: {}): \"{}\"", task_id, prompt);
|
||||
|
||||
match engine.generate(prompt, 64) {
|
||||
Ok(result) => {
|
||||
@@ -336,12 +341,13 @@ async fn main() {
|
||||
let done = json!({
|
||||
"type": "llm_done",
|
||||
"prompt": prompt,
|
||||
"model": "Qwen2.5-0.5B-Instruct (native/GPU)",
|
||||
"model": "Qwen2.5-Coder-0.5B (native/GPU)",
|
||||
"response": result.text,
|
||||
"tokens_generated": result.tokens_generated,
|
||||
"duration_ms": result.duration_ms,
|
||||
"tokens_per_sec": (result.tokens_per_sec * 10.0).round() / 10.0,
|
||||
"load_time_ms": 0,
|
||||
"task_id": task_id,
|
||||
});
|
||||
let _ = write.send(Message::Text(done.to_string())).await;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user