Säätö Hommia
This commit is contained in:
@@ -2,7 +2,6 @@ use candle_core::{Device, Tensor, DType};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::qwen2::{Config as QwenConfig, ModelForCausalLM as QwenModel};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::path::PathBuf;
|
||||
use std::time::Instant;
|
||||
|
||||
/// Top-k sampling with temperature and repetition penalty
|
||||
@@ -63,10 +62,8 @@ fn sample_top_k(logits: &Tensor, k: usize, temperature: f64, generated_tokens: &
|
||||
|
||||
pub struct LlmEngine {
|
||||
tokenizer: tokenizers::Tokenizer,
|
||||
model_path: PathBuf,
|
||||
model: QwenModel,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
config: QwenConfig,
|
||||
eos_token: u32,
|
||||
}
|
||||
|
||||
@@ -110,34 +107,22 @@ impl LlmEngine {
|
||||
hidden_act: candle_nn::Activation::Silu,
|
||||
};
|
||||
|
||||
// Testi-lataus varmistaa, että painot toimivat
|
||||
let start = Instant::now();
|
||||
let vb = unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&[model_path.clone()], dtype, &device)
|
||||
.map_err(|e| format!("VarBuilder: {}", e))?
|
||||
};
|
||||
let _model = QwenModel::new(&config, vb).map_err(|e| format!("Malli: {}", e))?;
|
||||
let model = QwenModel::new(&config, vb).map_err(|e| format!("Malli: {}", e))?;
|
||||
tracing::info!("Malli ladattu ({:.1}s) — {}", start.elapsed().as_secs_f64(), device_name);
|
||||
|
||||
Ok(LlmEngine {
|
||||
tokenizer,
|
||||
model_path,
|
||||
model,
|
||||
device,
|
||||
dtype,
|
||||
config,
|
||||
eos_token: 151645,
|
||||
})
|
||||
}
|
||||
|
||||
/// Luo tuore malliinstanssi (nollaa KV-cachen)
|
||||
fn fresh_model(&self) -> Result<QwenModel, String> {
|
||||
let vb = unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&[self.model_path.clone()], self.dtype, &self.device)
|
||||
.map_err(|e| format!("VarBuilder: {}", e))?
|
||||
};
|
||||
QwenModel::new(&self.config, vb).map_err(|e| format!("Malli: {}", e))
|
||||
}
|
||||
|
||||
pub fn generate(&mut self, prompt: &str, max_tokens: usize) -> Result<GenerateResult, String> {
|
||||
let formatted = format!("<|im_start|>system\nYou are a coding assistant. Respond with ONLY code. No explanations, no markdown, no comments unless asked.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", prompt);
|
||||
|
||||
@@ -146,8 +131,8 @@ impl LlmEngine {
|
||||
let input_ids: Vec<u32> = encoding.get_ids().to_vec();
|
||||
let input_len = input_ids.len();
|
||||
|
||||
// Tuore malli joka promptille (nollaa KV-cachen)
|
||||
let mut model = self.fresh_model()?;
|
||||
// Nollataan KV-cache edellisestä promptista
|
||||
self.model.clear_kv_cache();
|
||||
|
||||
// Sampling-parametrit
|
||||
let temperature = 0.7;
|
||||
@@ -165,7 +150,7 @@ impl LlmEngine {
|
||||
.and_then(|t| t.unsqueeze(0))
|
||||
.map_err(|e| format!("Tensor: {}", e))?;
|
||||
|
||||
let logits = model.forward(&input, 0)
|
||||
let logits = self.model.forward(&input, 0)
|
||||
.map_err(|e| format!("Forward prefill: {}", e))?;
|
||||
|
||||
let logits = logits.squeeze(0).map_err(|e| format!("Squeeze: {}", e))?;
|
||||
@@ -200,7 +185,7 @@ impl LlmEngine {
|
||||
.and_then(|t| t.unsqueeze(0))
|
||||
.map_err(|e| format!("Tensor: {}", e))?;
|
||||
|
||||
let logits = model.forward(&input, pos)
|
||||
let logits = self.model.forward(&input, pos)
|
||||
.map_err(|e| format!("Forward pos {}: {}", pos, e))?;
|
||||
|
||||
let logits = logits.squeeze(0).map_err(|e| format!("Squeeze: {}", e))?;
|
||||
|
||||
@@ -21,8 +21,15 @@ const MODEL_3B_PART1_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-Coder-3B-I
|
||||
const MODEL_3B_PART2_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-Coder-3B-Instruct/resolve/main/model-00002-of-00002.safetensors";
|
||||
const TOKENIZER_3B_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-Coder-3B-Instruct/resolve/main/tokenizer.json";
|
||||
|
||||
struct CachedModel {
|
||||
model: QwenModel,
|
||||
tokenizer: tokenizers::Tokenizer,
|
||||
is_3b: bool,
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static RAM_CACHE: RefCell<std::collections::HashMap<String, Rc<Vec<u8>>>> = RefCell::new(std::collections::HashMap::new());
|
||||
static MODEL_CACHE: RefCell<Option<CachedModel>> = RefCell::new(None);
|
||||
}
|
||||
|
||||
async fn ensure_cached(key: &str, url: &str, ws: &Rc<RefCell<WebSocket>>) -> Result<Rc<Vec<u8>>, String> {
|
||||
@@ -94,223 +101,200 @@ async fn ensure_cached(key: &str, url: &str, ws: &Rc<RefCell<WebSocket>>) -> Res
|
||||
Ok(rc_data)
|
||||
}
|
||||
|
||||
/// Lataa tai palauttaa välimuistista valmiin mallin + tokenizerin
|
||||
async fn get_or_build_model(use_3b: bool, ws: &Rc<RefCell<WebSocket>>) -> Result<(), String> {
|
||||
// Tarkistetaan onko oikea malli jo muistissa
|
||||
let cache_hit = MODEL_CACHE.with(|c| {
|
||||
c.borrow().as_ref().map(|m| m.is_3b == use_3b).unwrap_or(false)
|
||||
});
|
||||
if cache_hit {
|
||||
console_log!("[Coder] Malli löytyi muistista — ohitetaan lataus");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let device = Device::Cpu;
|
||||
let dtype = DType::F32;
|
||||
|
||||
// Tokenizer
|
||||
let tok_url = if use_3b { TOKENIZER_3B_URL } else { TOKENIZER_05B_URL };
|
||||
let tok_key = if use_3b { "coder3b-tokenizer.json" } else { "coder05b-tokenizer.json" };
|
||||
let tok_bytes = ensure_cached(tok_key, tok_url, ws).await?;
|
||||
let tokenizer = tokenizers::Tokenizer::from_bytes(&tok_bytes[..])
|
||||
.map_err(|e| format!("Tokenizer: {}", e))?;
|
||||
|
||||
// Painot
|
||||
let tensors = if use_3b {
|
||||
let part1 = ensure_cached("coder3b-model-part1.safetensors", MODEL_3B_PART1_URL, ws).await?;
|
||||
let part2 = ensure_cached("coder3b-model-part2.safetensors", MODEL_3B_PART2_URL, ws).await?;
|
||||
console_log!("[Coder] Rakennetaan 3B-mallia...");
|
||||
let mut all_tensors = candle_core::safetensors::load_buffer(&part1[..], &device)
|
||||
.map_err(|e| format!("Part1: {}", e))?;
|
||||
let tensors2 = candle_core::safetensors::load_buffer(&part2[..], &device)
|
||||
.map_err(|e| format!("Part2: {}", e))?;
|
||||
all_tensors.extend(tensors2);
|
||||
all_tensors
|
||||
} else {
|
||||
let model_bytes = ensure_cached("coder05b-model.safetensors", MODEL_05B_URL, ws).await?;
|
||||
console_log!("[Coder] Rakennetaan 0.5B-mallia...");
|
||||
candle_core::safetensors::load_buffer(&model_bytes[..], &device)
|
||||
.map_err(|e| format!("Safetensors: {}", e))?
|
||||
};
|
||||
|
||||
let vb = VarBuilder::from_tensors(tensors, dtype, &device);
|
||||
let config = if use_3b {
|
||||
QwenConfig {
|
||||
vocab_size: 151936, hidden_size: 2048, intermediate_size: 11008,
|
||||
num_hidden_layers: 36, num_attention_heads: 16, num_key_value_heads: 2,
|
||||
max_position_embeddings: 32768, sliding_window: 32768, max_window_layers: 36,
|
||||
tie_word_embeddings: true, rope_theta: 1000000.0, rms_norm_eps: 1e-6,
|
||||
use_sliding_window: false, hidden_act: candle_nn::Activation::Silu,
|
||||
}
|
||||
} else {
|
||||
QwenConfig {
|
||||
vocab_size: 151936, hidden_size: 896, intermediate_size: 4864,
|
||||
num_hidden_layers: 24, num_attention_heads: 14, num_key_value_heads: 2,
|
||||
max_position_embeddings: 32768, sliding_window: 32768, max_window_layers: 21,
|
||||
tie_word_embeddings: true, rope_theta: 1000000.0, rms_norm_eps: 1e-6,
|
||||
use_sliding_window: false, hidden_act: candle_nn::Activation::Silu,
|
||||
}
|
||||
};
|
||||
|
||||
let model = QwenModel::new(&config, vb).map_err(|e| format!("Malli: {}", e))?;
|
||||
console_log!("[Coder] Malli rakennettu ja välimuistitettu");
|
||||
|
||||
MODEL_CACHE.with(|c| {
|
||||
*c.borrow_mut() = Some(CachedModel { model, tokenizer, is_3b: use_3b });
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// use_3b: false = 0.5B (nopea), true = 3B (laadukas)
|
||||
pub async fn run_coder_inference(prompt: String, ws: Rc<RefCell<WebSocket>>, use_3b: bool, task_id: Option<String>) {
|
||||
let perf = web_sys::window().unwrap().performance().unwrap();
|
||||
let size_label = if use_3b { "3B" } else { "0.5B" };
|
||||
|
||||
// Tokenizer (sama molemmille)
|
||||
let tok_url = if use_3b { TOKENIZER_3B_URL } else { TOKENIZER_05B_URL };
|
||||
let tok_key = if use_3b { "coder3b-tokenizer.json" } else { "coder05b-tokenizer.json" };
|
||||
let tok_bytes = match ensure_cached(tok_key, tok_url, &ws).await {
|
||||
Ok(b) => b,
|
||||
Err(e) => { console_log!("[Coder] Tokenizer-virhe: {}", e); return; }
|
||||
};
|
||||
let tokenizer = match tokenizers::Tokenizer::from_bytes(&tok_bytes[..]) {
|
||||
Ok(t) => t,
|
||||
Err(e) => { console_log!("[Coder] Tokenizer-parsinta: {}", e); return; }
|
||||
};
|
||||
|
||||
// Mallin painot
|
||||
let device = Device::Cpu;
|
||||
let dtype = DType::F32;
|
||||
|
||||
let tensors = if use_3b {
|
||||
// 3B: kaksi osaa
|
||||
let part1 = match ensure_cached("coder3b-model-part1.safetensors", MODEL_3B_PART1_URL, &ws).await {
|
||||
Ok(b) => b,
|
||||
Err(e) => { console_log!("[Coder] Malli osa 1 virhe: {}", e); return; }
|
||||
};
|
||||
let part2 = match ensure_cached("coder3b-model-part2.safetensors", MODEL_3B_PART2_URL, &ws).await {
|
||||
Ok(b) => b,
|
||||
Err(e) => { console_log!("[Coder] Malli osa 2 virhe: {}", e); return; }
|
||||
};
|
||||
console_log!("[Coder] Rakennetaan 3B-mallia...");
|
||||
let mut all_tensors = candle_core::safetensors::load_buffer(&part1[..], &device)
|
||||
.map_err(|e| format!("Part1: {}", e)).unwrap();
|
||||
let tensors2 = candle_core::safetensors::load_buffer(&part2[..], &device)
|
||||
.map_err(|e| format!("Part2: {}", e)).unwrap();
|
||||
all_tensors.extend(tensors2);
|
||||
all_tensors
|
||||
} else {
|
||||
// 0.5B: yksi osa
|
||||
let model_bytes = match ensure_cached("coder05b-model.safetensors", MODEL_05B_URL, &ws).await {
|
||||
Ok(b) => b,
|
||||
Err(e) => { console_log!("[Coder] Malli-virhe: {}", e); return; }
|
||||
};
|
||||
console_log!("[Coder] Rakennetaan 0.5B-mallia...");
|
||||
match candle_core::safetensors::load_buffer(&model_bytes[..], &device) {
|
||||
Ok(t) => t,
|
||||
Err(e) => { console_log!("[Coder] Safetensors: {}", e); return; }
|
||||
}
|
||||
};
|
||||
|
||||
let start_load = perf.now();
|
||||
let vb = VarBuilder::from_tensors(tensors, dtype, &device);
|
||||
|
||||
let config = if use_3b {
|
||||
QwenConfig {
|
||||
vocab_size: 151936,
|
||||
hidden_size: 2048,
|
||||
intermediate_size: 11008,
|
||||
num_hidden_layers: 36,
|
||||
num_attention_heads: 16,
|
||||
num_key_value_heads: 2,
|
||||
max_position_embeddings: 32768,
|
||||
sliding_window: 32768,
|
||||
max_window_layers: 36,
|
||||
tie_word_embeddings: true,
|
||||
rope_theta: 1000000.0,
|
||||
rms_norm_eps: 1e-6,
|
||||
use_sliding_window: false,
|
||||
hidden_act: candle_nn::Activation::Silu,
|
||||
}
|
||||
} else {
|
||||
QwenConfig {
|
||||
vocab_size: 151936,
|
||||
hidden_size: 896,
|
||||
intermediate_size: 4864,
|
||||
num_hidden_layers: 24,
|
||||
num_attention_heads: 14,
|
||||
num_key_value_heads: 2,
|
||||
max_position_embeddings: 32768,
|
||||
sliding_window: 32768,
|
||||
max_window_layers: 21,
|
||||
tie_word_embeddings: true,
|
||||
rope_theta: 1000000.0,
|
||||
rms_norm_eps: 1e-6,
|
||||
use_sliding_window: false,
|
||||
hidden_act: candle_nn::Activation::Silu,
|
||||
}
|
||||
};
|
||||
|
||||
let mut model = match QwenModel::new(&config, vb) {
|
||||
Ok(m) => m,
|
||||
Err(e) => { console_log!("[Coder] Mallin lataus: {}", e); return; }
|
||||
};
|
||||
if let Err(e) = get_or_build_model(use_3b, &ws).await {
|
||||
console_log!("[Coder] Mallin lataus: {}", e);
|
||||
return;
|
||||
}
|
||||
|
||||
let load_time = perf.now() - start_load;
|
||||
console_log!("[Coder] Malli ladattu ({:.0}ms). Generoidaan...", load_time);
|
||||
if load_time > 100.0 {
|
||||
console_log!("[Coder] Malli ladattu ({:.0}ms). Generoidaan...", load_time);
|
||||
}
|
||||
|
||||
// Parsitaan JSON-prompti tai käytetään teksti sellaisenaan
|
||||
let default_system = "You are a coding assistant. Respond with ONLY code. No explanations, no markdown, no comments unless asked.";
|
||||
let (actual_prompt, system_msg, max_new_tokens) = if prompt.starts_with('{') {
|
||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&prompt) {
|
||||
let p = json.get("prompt").and_then(|v| v.as_str()).unwrap_or(&prompt).to_string();
|
||||
let s = json.get("system").and_then(|v| v.as_str())
|
||||
.unwrap_or("You are a coding assistant. Respond with ONLY code. No explanations, no markdown, no comments unless asked.").to_string();
|
||||
let s = json.get("system").and_then(|v| v.as_str()).unwrap_or(default_system).to_string();
|
||||
let m = json.get("max_tokens").and_then(|v| v.as_u64()).unwrap_or(128) as usize;
|
||||
(p, s, m)
|
||||
} else {
|
||||
(prompt.clone(), "You are a coding assistant. Respond with ONLY code. No explanations, no markdown, no comments unless asked.".to_string(), 128)
|
||||
(prompt.clone(), default_system.to_string(), 128)
|
||||
}
|
||||
} else {
|
||||
(prompt.clone(), "You are a coding assistant. Respond with ONLY code. No explanations, no markdown, no comments unless asked.".to_string(), 128)
|
||||
(prompt.clone(), default_system.to_string(), 128)
|
||||
};
|
||||
|
||||
let formatted = format!("<|im_start|>system\n{}<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", system_msg, actual_prompt);
|
||||
|
||||
let encoding = match tokenizer.encode(formatted.as_str(), true) {
|
||||
Ok(e) => e,
|
||||
Err(e) => { console_log!("[Coder] Tokenisointivirhe: {}", e); return; }
|
||||
};
|
||||
let input_ids: Vec<u32> = encoding.get_ids().to_vec();
|
||||
let input_len = input_ids.len();
|
||||
console_log!("[Coder] Syöte: {} tokenia", input_len);
|
||||
// Inferenssi: käytetään välimuistissa olevaa mallia
|
||||
let (generated_text, tokens_generated, gen_time) = MODEL_CACHE.with(|cache| {
|
||||
let mut cache = cache.borrow_mut();
|
||||
let cached = cache.as_mut().expect("Malli pitää olla ladattu");
|
||||
|
||||
let start_gen = perf.now();
|
||||
// max_new_tokens tulee JSON-promptista tai oletuksena 128
|
||||
let mut generated_text = String::new();
|
||||
let mut tokens_generated: usize = 0;
|
||||
let eos_token = 151645u32;
|
||||
let encoding = cached.tokenizer.encode(formatted.as_str(), true)
|
||||
.map_err(|e| format!("Encode: {}", e)).unwrap();
|
||||
let input_ids: Vec<u32> = encoding.get_ids().to_vec();
|
||||
let input_len = input_ids.len();
|
||||
console_log!("[Coder] Syöte: {} tokenia", input_len);
|
||||
|
||||
// Prefill
|
||||
let input = match Tensor::new(input_ids.as_slice(), &device).and_then(|t| t.unsqueeze(0)) {
|
||||
Ok(t) => t,
|
||||
Err(e) => { console_log!("[Coder] Tensor: {}", e); return; }
|
||||
};
|
||||
let logits = match model.forward(&input, 0) {
|
||||
Ok(l) => l,
|
||||
Err(e) => { console_log!("[Coder] Forward (prefill): {}", e); return; }
|
||||
};
|
||||
let device = Device::Cpu;
|
||||
let start_gen = perf.now();
|
||||
let eos_token = 151645u32;
|
||||
let temperature: f32 = 0.7;
|
||||
let top_k: usize = 40;
|
||||
let repetition_penalty: f32 = 1.15;
|
||||
|
||||
let logits = logits.squeeze(0).unwrap();
|
||||
let logits = if logits.dims().len() == 2 {
|
||||
logits.get(logits.dim(0).unwrap() - 1).unwrap()
|
||||
} else {
|
||||
logits
|
||||
};
|
||||
// Nollataan KV-cache edellisestä promptista
|
||||
cached.model.clear_kv_cache();
|
||||
|
||||
// Sampling-parametrit
|
||||
let temperature: f32 = 0.7;
|
||||
let top_k: usize = 40;
|
||||
let repetition_penalty: f32 = 1.15;
|
||||
let mut all_generated: Vec<u32> = Vec::new();
|
||||
|
||||
let mut next_token = crate::sampling::sample_top_k_with_penalty(&logits, top_k, temperature, &all_generated, repetition_penalty);
|
||||
|
||||
if next_token != eos_token {
|
||||
if let Ok(text) = tokenizer.decode(&[next_token], true) {
|
||||
generated_text.push_str(&text);
|
||||
let mut chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "Qwen2.5-Coder" });
|
||||
if let Some(ref tid) = task_id { chunk.as_object_mut().unwrap().insert("task_id".to_string(), serde_json::json!(tid)); }
|
||||
let _ = ws.borrow().send_with_str(&chunk.to_string());
|
||||
}
|
||||
all_generated.push(next_token);
|
||||
tokens_generated += 1;
|
||||
}
|
||||
|
||||
// Autoregressive
|
||||
let mut pos = input_len;
|
||||
for _ in 1..max_new_tokens {
|
||||
if next_token == eos_token { break; }
|
||||
|
||||
let input = match Tensor::new(&[next_token], &device).and_then(|t| t.unsqueeze(0)) {
|
||||
Ok(t) => t,
|
||||
Err(e) => { console_log!("[Coder] Tensor: {}", e); break; }
|
||||
};
|
||||
let logits = match model.forward(&input, pos) {
|
||||
Ok(l) => l,
|
||||
Err(e) => { console_log!("[Coder] Forward pos {}: {}", pos, e); break; }
|
||||
};
|
||||
let mut generated_text = String::new();
|
||||
let mut tokens_generated: usize = 0;
|
||||
let mut all_generated: Vec<u32> = Vec::new();
|
||||
|
||||
// Prefill
|
||||
let input = Tensor::new(input_ids.as_slice(), &device).and_then(|t| t.unsqueeze(0)).unwrap();
|
||||
let logits = cached.model.forward(&input, 0).unwrap();
|
||||
let logits = logits.squeeze(0).unwrap();
|
||||
let logits = if logits.dims().len() == 2 {
|
||||
logits.get(logits.dim(0).unwrap() - 1).unwrap()
|
||||
} else {
|
||||
logits
|
||||
};
|
||||
next_token = crate::sampling::sample_top_k_with_penalty(&logits, top_k, temperature, &all_generated, repetition_penalty);
|
||||
pos += 1;
|
||||
} else { logits };
|
||||
|
||||
if next_token == eos_token { break; }
|
||||
let mut next_token = crate::sampling::sample_top_k_with_penalty(&logits, top_k, temperature, &all_generated, repetition_penalty);
|
||||
|
||||
if let Ok(text) = tokenizer.decode(&[next_token], true) {
|
||||
generated_text.push_str(&text);
|
||||
|
||||
// Stop-sekvenssit: katkaistaan kun malli alkaa selittää
|
||||
let lower = generated_text.to_lowercase();
|
||||
if lower.contains("\n###") || lower.contains("\nexplanation") || lower.contains("\nnote:") || lower.contains("\noutput:") || lower.contains("\n```\n\n") {
|
||||
// Trimmataan selitysosuus pois
|
||||
for stop in &["\n###", "\nExplanation", "\nNote:", "\nOutput:", "\n```\n\n"] {
|
||||
if let Some(pos) = generated_text.find(stop) {
|
||||
generated_text.truncate(pos);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
let mut chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "Qwen2.5-Coder" });
|
||||
if next_token != eos_token {
|
||||
if let Ok(text) = cached.tokenizer.decode(&[next_token], true) {
|
||||
generated_text.push_str(&text);
|
||||
let mut chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "Qwen2.5-Coder" });
|
||||
if let Some(ref tid) = task_id { chunk.as_object_mut().unwrap().insert("task_id".to_string(), serde_json::json!(tid)); }
|
||||
let _ = ws.borrow().send_with_str(&chunk.to_string());
|
||||
let _ = ws.borrow().send_with_str(&chunk.to_string());
|
||||
}
|
||||
all_generated.push(next_token);
|
||||
tokens_generated += 1;
|
||||
}
|
||||
all_generated.push(next_token);
|
||||
tokens_generated += 1;
|
||||
|
||||
// Yield — vapautetaan selaimen event loop joka tokenin jälkeen
|
||||
crate::sleep_ms(0).await;
|
||||
}
|
||||
// Autoregressive
|
||||
let mut pos = input_len;
|
||||
for _ in 1..max_new_tokens {
|
||||
if next_token == eos_token { break; }
|
||||
|
||||
let input = Tensor::new(&[next_token], &device).and_then(|t| t.unsqueeze(0)).unwrap();
|
||||
let logits = match cached.model.forward(&input, pos) {
|
||||
Ok(l) => l,
|
||||
Err(e) => { console_log!("[Coder] Forward pos {}: {}", pos, e); break; }
|
||||
};
|
||||
|
||||
let logits = logits.squeeze(0).unwrap();
|
||||
let logits = if logits.dims().len() == 2 {
|
||||
logits.get(logits.dim(0).unwrap() - 1).unwrap()
|
||||
} else { logits };
|
||||
next_token = crate::sampling::sample_top_k_with_penalty(&logits, top_k, temperature, &all_generated, repetition_penalty);
|
||||
pos += 1;
|
||||
|
||||
if next_token == eos_token { break; }
|
||||
|
||||
if let Ok(text) = cached.tokenizer.decode(&[next_token], true) {
|
||||
generated_text.push_str(&text);
|
||||
|
||||
// Stop-sekvenssit: katkaistaan kun malli alkaa selittää
|
||||
let lower = generated_text.to_lowercase();
|
||||
if lower.contains("\n###") || lower.contains("\nexplanation") || lower.contains("\nnote:") || lower.contains("\noutput:") || lower.contains("\n```\n\n") {
|
||||
for stop in &["\n###", "\nExplanation", "\nNote:", "\nOutput:", "\n```\n\n"] {
|
||||
if let Some(pos) = generated_text.find(stop) {
|
||||
generated_text.truncate(pos);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
let mut chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "Qwen2.5-Coder" });
|
||||
if let Some(ref tid) = task_id { chunk.as_object_mut().unwrap().insert("task_id".to_string(), serde_json::json!(tid)); }
|
||||
let _ = ws.borrow().send_with_str(&chunk.to_string());
|
||||
}
|
||||
all_generated.push(next_token);
|
||||
tokens_generated += 1;
|
||||
}
|
||||
|
||||
let gen_time = perf.now() - start_gen;
|
||||
(generated_text, tokens_generated, gen_time)
|
||||
});
|
||||
|
||||
let gen_time = perf.now() - start_gen;
|
||||
let tokens_per_sec = if gen_time > 0.0 { (tokens_generated as f64 / gen_time) * 1000.0 } else { 0.0 };
|
||||
console_log!("[Coder] {} tokenia | {:.0}ms | {:.1} tok/s", tokens_generated, gen_time, tokens_per_sec);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user