use candle_core::{Device, Tensor, DType}; use candle_nn::VarBuilder; use candle_transformers::models::qwen2::{Config as QwenConfig, ModelForCausalLM as QwenModel}; use hf_hub::{api::sync::Api, Repo, RepoType}; use std::path::PathBuf; use std::time::Instant; pub struct LlmEngine { tokenizer: tokenizers::Tokenizer, model_path: PathBuf, device: Device, dtype: DType, config: QwenConfig, eos_token: u32, } impl LlmEngine { pub fn load() -> Result { let device = Device::cuda_if_available(0).map_err(|e| format!("Device: {}", e))?; let device_name = if device.is_cuda() { "CUDA" } else { "CPU" }; tracing::info!("LLM device: {}", device_name); let dtype = if device.is_cuda() { DType::F16 } else { DType::F32 }; tracing::info!("Ladataan Qwen2.5-0.5B-Instruct..."); let api = Api::new().map_err(|e| format!("HF API: {}", e))?; let repo = api.repo(Repo::with_revision( "Qwen/Qwen2.5-0.5B-Instruct".to_string(), RepoType::Model, "main".to_string(), )); let tokenizer_path = repo.get("tokenizer.json").map_err(|e| format!("Tokenizer lataus: {}", e))?; let model_path = repo.get("model.safetensors").map_err(|e| format!("Malli lataus: {}", e))?; tracing::info!("Ladataan tokenizer: {:?}", tokenizer_path); let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path) .map_err(|e| format!("Tokenizer: {}", e))?; let config = QwenConfig { vocab_size: 151936, hidden_size: 896, intermediate_size: 4864, num_hidden_layers: 24, num_attention_heads: 14, num_key_value_heads: 2, max_position_embeddings: 32768, sliding_window: 32768, max_window_layers: 21, tie_word_embeddings: true, rope_theta: 1000000.0, rms_norm_eps: 1e-6, use_sliding_window: false, hidden_act: candle_nn::Activation::Silu, }; // Testi-lataus varmistaa, että painot toimivat let start = Instant::now(); let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_path.clone()], dtype, &device) .map_err(|e| format!("VarBuilder: {}", e))? }; let _model = QwenModel::new(&config, vb).map_err(|e| format!("Malli: {}", e))?; tracing::info!("Malli ladattu ({:.1}s) — {}", start.elapsed().as_secs_f64(), device_name); Ok(LlmEngine { tokenizer, model_path, device, dtype, config, eos_token: 151645, }) } /// Luo tuore malliinstanssi (nollaa KV-cachen) fn fresh_model(&self) -> Result { let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[self.model_path.clone()], self.dtype, &self.device) .map_err(|e| format!("VarBuilder: {}", e))? }; QwenModel::new(&self.config, vb).map_err(|e| format!("Malli: {}", e)) } pub fn generate(&mut self, prompt: &str, max_tokens: usize) -> Result { let formatted = format!("<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", prompt); let encoding = self.tokenizer.encode(formatted.as_str(), true) .map_err(|e| format!("Encode: {}", e))?; let input_ids: Vec = encoding.get_ids().to_vec(); let input_len = input_ids.len(); // Tuore malli joka promptille (nollaa KV-cachen) let mut model = self.fresh_model()?; let start = Instant::now(); // Prefill let input = Tensor::new(input_ids.as_slice(), &self.device) .and_then(|t| t.unsqueeze(0)) .map_err(|e| format!("Tensor: {}", e))?; let logits = model.forward(&input, 0) .map_err(|e| format!("Forward prefill: {}", e))?; let logits = logits.squeeze(0).map_err(|e| format!("Squeeze: {}", e))?; let logits = if logits.dims().len() == 2 { logits.get(logits.dim(0).unwrap() - 1).map_err(|e| format!("Get: {}", e))? } else { logits }; let mut next_token = logits.argmax(0) .map_err(|e| format!("Argmax: {}", e))? .to_vec0::() .map_err(|e| format!("to_vec0: {}", e))?; let mut generated_text = String::new(); let mut tokens_generated: usize = 0; let mut all_tokens: Vec = Vec::new(); if next_token != self.eos_token { if let Ok(text) = self.tokenizer.decode(&[next_token], true) { generated_text.push_str(&text); } all_tokens.push(next_token); tokens_generated += 1; } // Autoregressive let mut pos = input_len; for _ in 1..max_tokens { if next_token == self.eos_token { break; } let input = Tensor::new(&[next_token], &self.device) .and_then(|t| t.unsqueeze(0)) .map_err(|e| format!("Tensor: {}", e))?; let logits = model.forward(&input, pos) .map_err(|e| format!("Forward pos {}: {}", pos, e))?; let logits = logits.squeeze(0).map_err(|e| format!("Squeeze: {}", e))?; let logits = if logits.dims().len() == 2 { logits.get(logits.dim(0).unwrap() - 1).map_err(|e| format!("Get: {}", e))? } else { logits }; next_token = logits.argmax(0) .map_err(|e| format!("Argmax: {}", e))? .to_vec0::() .map_err(|e| format!("to_vec0: {}", e))?; pos += 1; if next_token == self.eos_token { break; } if let Ok(text) = self.tokenizer.decode(&[next_token], true) { generated_text.push_str(&text); } all_tokens.push(next_token); tokens_generated += 1; } let gen_time = start.elapsed(); let tokens_per_sec = if gen_time.as_secs_f64() > 0.0 { tokens_generated as f64 / gen_time.as_secs_f64() } else { 0.0 }; Ok(GenerateResult { text: generated_text, tokens_generated, duration_ms: gen_time.as_millis() as f64, tokens_per_sec, }) } } pub struct GenerateResult { pub text: String, pub tokens_generated: usize, pub duration_ms: f64, pub tokens_per_sec: f64, }