Säätö Hommia
This commit is contained in:
@@ -2,7 +2,6 @@ use candle_core::{Device, Tensor, DType};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::qwen2::{Config as QwenConfig, ModelForCausalLM as QwenModel};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::path::PathBuf;
|
||||
use std::time::Instant;
|
||||
|
||||
/// Top-k sampling with temperature and repetition penalty
|
||||
@@ -63,10 +62,8 @@ fn sample_top_k(logits: &Tensor, k: usize, temperature: f64, generated_tokens: &
|
||||
|
||||
pub struct LlmEngine {
|
||||
tokenizer: tokenizers::Tokenizer,
|
||||
model_path: PathBuf,
|
||||
model: QwenModel,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
config: QwenConfig,
|
||||
eos_token: u32,
|
||||
}
|
||||
|
||||
@@ -110,34 +107,22 @@ impl LlmEngine {
|
||||
hidden_act: candle_nn::Activation::Silu,
|
||||
};
|
||||
|
||||
// Testi-lataus varmistaa, että painot toimivat
|
||||
let start = Instant::now();
|
||||
let vb = unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&[model_path.clone()], dtype, &device)
|
||||
.map_err(|e| format!("VarBuilder: {}", e))?
|
||||
};
|
||||
let _model = QwenModel::new(&config, vb).map_err(|e| format!("Malli: {}", e))?;
|
||||
let model = QwenModel::new(&config, vb).map_err(|e| format!("Malli: {}", e))?;
|
||||
tracing::info!("Malli ladattu ({:.1}s) — {}", start.elapsed().as_secs_f64(), device_name);
|
||||
|
||||
Ok(LlmEngine {
|
||||
tokenizer,
|
||||
model_path,
|
||||
model,
|
||||
device,
|
||||
dtype,
|
||||
config,
|
||||
eos_token: 151645,
|
||||
})
|
||||
}
|
||||
|
||||
/// Luo tuore malliinstanssi (nollaa KV-cachen)
|
||||
fn fresh_model(&self) -> Result<QwenModel, String> {
|
||||
let vb = unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&[self.model_path.clone()], self.dtype, &self.device)
|
||||
.map_err(|e| format!("VarBuilder: {}", e))?
|
||||
};
|
||||
QwenModel::new(&self.config, vb).map_err(|e| format!("Malli: {}", e))
|
||||
}
|
||||
|
||||
pub fn generate(&mut self, prompt: &str, max_tokens: usize) -> Result<GenerateResult, String> {
|
||||
let formatted = format!("<|im_start|>system\nYou are a coding assistant. Respond with ONLY code. No explanations, no markdown, no comments unless asked.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", prompt);
|
||||
|
||||
@@ -146,8 +131,8 @@ impl LlmEngine {
|
||||
let input_ids: Vec<u32> = encoding.get_ids().to_vec();
|
||||
let input_len = input_ids.len();
|
||||
|
||||
// Tuore malli joka promptille (nollaa KV-cachen)
|
||||
let mut model = self.fresh_model()?;
|
||||
// Nollataan KV-cache edellisestä promptista
|
||||
self.model.clear_kv_cache();
|
||||
|
||||
// Sampling-parametrit
|
||||
let temperature = 0.7;
|
||||
@@ -165,7 +150,7 @@ impl LlmEngine {
|
||||
.and_then(|t| t.unsqueeze(0))
|
||||
.map_err(|e| format!("Tensor: {}", e))?;
|
||||
|
||||
let logits = model.forward(&input, 0)
|
||||
let logits = self.model.forward(&input, 0)
|
||||
.map_err(|e| format!("Forward prefill: {}", e))?;
|
||||
|
||||
let logits = logits.squeeze(0).map_err(|e| format!("Squeeze: {}", e))?;
|
||||
@@ -200,7 +185,7 @@ impl LlmEngine {
|
||||
.and_then(|t| t.unsqueeze(0))
|
||||
.map_err(|e| format!("Tensor: {}", e))?;
|
||||
|
||||
let logits = model.forward(&input, pos)
|
||||
let logits = self.model.forward(&input, pos)
|
||||
.map_err(|e| format!("Forward pos {}: {}", pos, e))?;
|
||||
|
||||
let logits = logits.squeeze(0).map_err(|e| format!("Squeeze: {}", e))?;
|
||||
|
||||
Reference in New Issue
Block a user