182 lines
6.4 KiB
Rust
182 lines
6.4 KiB
Rust
use candle_core::{Device, Tensor, DType};
|
|
use candle_nn::VarBuilder;
|
|
use candle_transformers::models::qwen2::{Config as QwenConfig, ModelForCausalLM as QwenModel};
|
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
use std::path::PathBuf;
|
|
use std::time::Instant;
|
|
|
|
pub struct LlmEngine {
|
|
tokenizer: tokenizers::Tokenizer,
|
|
model_path: PathBuf,
|
|
device: Device,
|
|
dtype: DType,
|
|
config: QwenConfig,
|
|
eos_token: u32,
|
|
}
|
|
|
|
impl LlmEngine {
|
|
pub fn load() -> Result<Self, String> {
|
|
let device = Device::cuda_if_available(0).map_err(|e| format!("Device: {}", e))?;
|
|
let device_name = if device.is_cuda() { "CUDA" } else { "CPU" };
|
|
tracing::info!("LLM device: {}", device_name);
|
|
|
|
let dtype = if device.is_cuda() { DType::F16 } else { DType::F32 };
|
|
|
|
tracing::info!("Ladataan Qwen2.5-0.5B-Instruct...");
|
|
let api = Api::new().map_err(|e| format!("HF API: {}", e))?;
|
|
let repo = api.repo(Repo::with_revision(
|
|
"Qwen/Qwen2.5-0.5B-Instruct".to_string(),
|
|
RepoType::Model,
|
|
"main".to_string(),
|
|
));
|
|
|
|
let tokenizer_path = repo.get("tokenizer.json").map_err(|e| format!("Tokenizer lataus: {}", e))?;
|
|
let model_path = repo.get("model.safetensors").map_err(|e| format!("Malli lataus: {}", e))?;
|
|
|
|
tracing::info!("Ladataan tokenizer: {:?}", tokenizer_path);
|
|
let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
|
|
.map_err(|e| format!("Tokenizer: {}", e))?;
|
|
|
|
let config = QwenConfig {
|
|
vocab_size: 151936,
|
|
hidden_size: 896,
|
|
intermediate_size: 4864,
|
|
num_hidden_layers: 24,
|
|
num_attention_heads: 14,
|
|
num_key_value_heads: 2,
|
|
max_position_embeddings: 32768,
|
|
sliding_window: 32768,
|
|
max_window_layers: 21,
|
|
tie_word_embeddings: true,
|
|
rope_theta: 1000000.0,
|
|
rms_norm_eps: 1e-6,
|
|
use_sliding_window: false,
|
|
hidden_act: candle_nn::Activation::Silu,
|
|
};
|
|
|
|
// Testi-lataus varmistaa, että painot toimivat
|
|
let start = Instant::now();
|
|
let vb = unsafe {
|
|
VarBuilder::from_mmaped_safetensors(&[model_path.clone()], dtype, &device)
|
|
.map_err(|e| format!("VarBuilder: {}", e))?
|
|
};
|
|
let _model = QwenModel::new(&config, vb).map_err(|e| format!("Malli: {}", e))?;
|
|
tracing::info!("Malli ladattu ({:.1}s) — {}", start.elapsed().as_secs_f64(), device_name);
|
|
|
|
Ok(LlmEngine {
|
|
tokenizer,
|
|
model_path,
|
|
device,
|
|
dtype,
|
|
config,
|
|
eos_token: 151645,
|
|
})
|
|
}
|
|
|
|
/// Luo tuore malliinstanssi (nollaa KV-cachen)
|
|
fn fresh_model(&self) -> Result<QwenModel, String> {
|
|
let vb = unsafe {
|
|
VarBuilder::from_mmaped_safetensors(&[self.model_path.clone()], self.dtype, &self.device)
|
|
.map_err(|e| format!("VarBuilder: {}", e))?
|
|
};
|
|
QwenModel::new(&self.config, vb).map_err(|e| format!("Malli: {}", e))
|
|
}
|
|
|
|
pub fn generate(&mut self, prompt: &str, max_tokens: usize) -> Result<GenerateResult, String> {
|
|
let formatted = format!("<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", prompt);
|
|
|
|
let encoding = self.tokenizer.encode(formatted.as_str(), true)
|
|
.map_err(|e| format!("Encode: {}", e))?;
|
|
let input_ids: Vec<u32> = encoding.get_ids().to_vec();
|
|
let input_len = input_ids.len();
|
|
|
|
// Tuore malli joka promptille (nollaa KV-cachen)
|
|
let mut model = self.fresh_model()?;
|
|
|
|
let start = Instant::now();
|
|
|
|
// Prefill
|
|
let input = Tensor::new(input_ids.as_slice(), &self.device)
|
|
.and_then(|t| t.unsqueeze(0))
|
|
.map_err(|e| format!("Tensor: {}", e))?;
|
|
|
|
let logits = model.forward(&input, 0)
|
|
.map_err(|e| format!("Forward prefill: {}", e))?;
|
|
|
|
let logits = logits.squeeze(0).map_err(|e| format!("Squeeze: {}", e))?;
|
|
let logits = if logits.dims().len() == 2 {
|
|
logits.get(logits.dim(0).unwrap() - 1).map_err(|e| format!("Get: {}", e))?
|
|
} else {
|
|
logits
|
|
};
|
|
let mut next_token = logits.argmax(0)
|
|
.map_err(|e| format!("Argmax: {}", e))?
|
|
.to_vec0::<u32>()
|
|
.map_err(|e| format!("to_vec0: {}", e))?;
|
|
|
|
let mut generated_text = String::new();
|
|
let mut tokens_generated: usize = 0;
|
|
let mut all_tokens: Vec<u32> = Vec::new();
|
|
|
|
if next_token != self.eos_token {
|
|
if let Ok(text) = self.tokenizer.decode(&[next_token], true) {
|
|
generated_text.push_str(&text);
|
|
}
|
|
all_tokens.push(next_token);
|
|
tokens_generated += 1;
|
|
}
|
|
|
|
// Autoregressive
|
|
let mut pos = input_len;
|
|
for _ in 1..max_tokens {
|
|
if next_token == self.eos_token { break; }
|
|
|
|
let input = Tensor::new(&[next_token], &self.device)
|
|
.and_then(|t| t.unsqueeze(0))
|
|
.map_err(|e| format!("Tensor: {}", e))?;
|
|
|
|
let logits = model.forward(&input, pos)
|
|
.map_err(|e| format!("Forward pos {}: {}", pos, e))?;
|
|
|
|
let logits = logits.squeeze(0).map_err(|e| format!("Squeeze: {}", e))?;
|
|
let logits = if logits.dims().len() == 2 {
|
|
logits.get(logits.dim(0).unwrap() - 1).map_err(|e| format!("Get: {}", e))?
|
|
} else {
|
|
logits
|
|
};
|
|
next_token = logits.argmax(0)
|
|
.map_err(|e| format!("Argmax: {}", e))?
|
|
.to_vec0::<u32>()
|
|
.map_err(|e| format!("to_vec0: {}", e))?;
|
|
pos += 1;
|
|
|
|
if next_token == self.eos_token { break; }
|
|
|
|
if let Ok(text) = self.tokenizer.decode(&[next_token], true) {
|
|
generated_text.push_str(&text);
|
|
}
|
|
all_tokens.push(next_token);
|
|
tokens_generated += 1;
|
|
}
|
|
|
|
let gen_time = start.elapsed();
|
|
let tokens_per_sec = if gen_time.as_secs_f64() > 0.0 {
|
|
tokens_generated as f64 / gen_time.as_secs_f64()
|
|
} else { 0.0 };
|
|
|
|
Ok(GenerateResult {
|
|
text: generated_text,
|
|
tokens_generated,
|
|
duration_ms: gen_time.as_millis() as f64,
|
|
tokens_per_sec,
|
|
})
|
|
}
|
|
}
|
|
|
|
pub struct GenerateResult {
|
|
pub text: String,
|
|
pub tokens_generated: usize,
|
|
pub duration_ms: f64,
|
|
pub tokens_per_sec: f64,
|
|
}
|