on tämä työmaa

This commit is contained in:
2026-04-02 00:50:29 +03:00
parent e55ff64565
commit b693542116
11 changed files with 648 additions and 70 deletions

View File

@@ -0,0 +1,181 @@
use candle_core::{Device, Tensor, DType};
use candle_nn::VarBuilder;
use candle_transformers::models::qwen2::{Config as QwenConfig, ModelForCausalLM as QwenModel};
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::path::PathBuf;
use std::time::Instant;
pub struct LlmEngine {
tokenizer: tokenizers::Tokenizer,
model_path: PathBuf,
device: Device,
dtype: DType,
config: QwenConfig,
eos_token: u32,
}
impl LlmEngine {
pub fn load() -> Result<Self, String> {
let device = Device::cuda_if_available(0).map_err(|e| format!("Device: {}", e))?;
let device_name = if device.is_cuda() { "CUDA" } else { "CPU" };
tracing::info!("LLM device: {}", device_name);
let dtype = if device.is_cuda() { DType::F16 } else { DType::F32 };
tracing::info!("Ladataan Qwen2.5-0.5B-Instruct...");
let api = Api::new().map_err(|e| format!("HF API: {}", e))?;
let repo = api.repo(Repo::with_revision(
"Qwen/Qwen2.5-0.5B-Instruct".to_string(),
RepoType::Model,
"main".to_string(),
));
let tokenizer_path = repo.get("tokenizer.json").map_err(|e| format!("Tokenizer lataus: {}", e))?;
let model_path = repo.get("model.safetensors").map_err(|e| format!("Malli lataus: {}", e))?;
tracing::info!("Ladataan tokenizer: {:?}", tokenizer_path);
let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
.map_err(|e| format!("Tokenizer: {}", e))?;
let config = QwenConfig {
vocab_size: 151936,
hidden_size: 896,
intermediate_size: 4864,
num_hidden_layers: 24,
num_attention_heads: 14,
num_key_value_heads: 2,
max_position_embeddings: 32768,
sliding_window: 32768,
max_window_layers: 21,
tie_word_embeddings: true,
rope_theta: 1000000.0,
rms_norm_eps: 1e-6,
use_sliding_window: false,
hidden_act: candle_nn::Activation::Silu,
};
// Testi-lataus varmistaa, että painot toimivat
let start = Instant::now();
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[model_path.clone()], dtype, &device)
.map_err(|e| format!("VarBuilder: {}", e))?
};
let _model = QwenModel::new(&config, vb).map_err(|e| format!("Malli: {}", e))?;
tracing::info!("Malli ladattu ({:.1}s) — {}", start.elapsed().as_secs_f64(), device_name);
Ok(LlmEngine {
tokenizer,
model_path,
device,
dtype,
config,
eos_token: 151645,
})
}
/// Luo tuore malliinstanssi (nollaa KV-cachen)
fn fresh_model(&self) -> Result<QwenModel, String> {
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[self.model_path.clone()], self.dtype, &self.device)
.map_err(|e| format!("VarBuilder: {}", e))?
};
QwenModel::new(&self.config, vb).map_err(|e| format!("Malli: {}", e))
}
pub fn generate(&mut self, prompt: &str, max_tokens: usize) -> Result<GenerateResult, String> {
let formatted = format!("<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", prompt);
let encoding = self.tokenizer.encode(formatted.as_str(), true)
.map_err(|e| format!("Encode: {}", e))?;
let input_ids: Vec<u32> = encoding.get_ids().to_vec();
let input_len = input_ids.len();
// Tuore malli joka promptille (nollaa KV-cachen)
let mut model = self.fresh_model()?;
let start = Instant::now();
// Prefill
let input = Tensor::new(input_ids.as_slice(), &self.device)
.and_then(|t| t.unsqueeze(0))
.map_err(|e| format!("Tensor: {}", e))?;
let logits = model.forward(&input, 0)
.map_err(|e| format!("Forward prefill: {}", e))?;
let logits = logits.squeeze(0).map_err(|e| format!("Squeeze: {}", e))?;
let logits = if logits.dims().len() == 2 {
logits.get(logits.dim(0).unwrap() - 1).map_err(|e| format!("Get: {}", e))?
} else {
logits
};
let mut next_token = logits.argmax(0)
.map_err(|e| format!("Argmax: {}", e))?
.to_vec0::<u32>()
.map_err(|e| format!("to_vec0: {}", e))?;
let mut generated_text = String::new();
let mut tokens_generated: usize = 0;
let mut all_tokens: Vec<u32> = Vec::new();
if next_token != self.eos_token {
if let Ok(text) = self.tokenizer.decode(&[next_token], true) {
generated_text.push_str(&text);
}
all_tokens.push(next_token);
tokens_generated += 1;
}
// Autoregressive
let mut pos = input_len;
for _ in 1..max_tokens {
if next_token == self.eos_token { break; }
let input = Tensor::new(&[next_token], &self.device)
.and_then(|t| t.unsqueeze(0))
.map_err(|e| format!("Tensor: {}", e))?;
let logits = model.forward(&input, pos)
.map_err(|e| format!("Forward pos {}: {}", pos, e))?;
let logits = logits.squeeze(0).map_err(|e| format!("Squeeze: {}", e))?;
let logits = if logits.dims().len() == 2 {
logits.get(logits.dim(0).unwrap() - 1).map_err(|e| format!("Get: {}", e))?
} else {
logits
};
next_token = logits.argmax(0)
.map_err(|e| format!("Argmax: {}", e))?
.to_vec0::<u32>()
.map_err(|e| format!("to_vec0: {}", e))?;
pos += 1;
if next_token == self.eos_token { break; }
if let Ok(text) = self.tokenizer.decode(&[next_token], true) {
generated_text.push_str(&text);
}
all_tokens.push(next_token);
tokens_generated += 1;
}
let gen_time = start.elapsed();
let tokens_per_sec = if gen_time.as_secs_f64() > 0.0 {
tokens_generated as f64 / gen_time.as_secs_f64()
} else { 0.0 };
Ok(GenerateResult {
text: generated_text,
tokens_generated,
duration_ms: gen_time.as_millis() as f64,
tokens_per_sec,
})
}
}
pub struct GenerateResult {
pub text: String,
pub tokens_generated: usize,
pub duration_ms: f64,
pub tokens_per_sec: f64,
}