on tämä työmaa

This commit is contained in:
2026-04-02 00:50:29 +03:00
parent e55ff64565
commit b693542116
11 changed files with 648 additions and 70 deletions

View File

@@ -12,5 +12,10 @@ serde_json = "1.0"
sysinfo = "0.30"
nvml-wrapper = "0.10"
wgpu = "24"
candle-core = { version = "0.8" }
candle-nn = "0.8"
candle-transformers = "0.8"
hf-hub = "0.4"
tokenizers = "0.19"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

View File

@@ -0,0 +1,181 @@
use candle_core::{Device, Tensor, DType};
use candle_nn::VarBuilder;
use candle_transformers::models::qwen2::{Config as QwenConfig, ModelForCausalLM as QwenModel};
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::path::PathBuf;
use std::time::Instant;
pub struct LlmEngine {
tokenizer: tokenizers::Tokenizer,
model_path: PathBuf,
device: Device,
dtype: DType,
config: QwenConfig,
eos_token: u32,
}
impl LlmEngine {
pub fn load() -> Result<Self, String> {
let device = Device::cuda_if_available(0).map_err(|e| format!("Device: {}", e))?;
let device_name = if device.is_cuda() { "CUDA" } else { "CPU" };
tracing::info!("LLM device: {}", device_name);
let dtype = if device.is_cuda() { DType::F16 } else { DType::F32 };
tracing::info!("Ladataan Qwen2.5-0.5B-Instruct...");
let api = Api::new().map_err(|e| format!("HF API: {}", e))?;
let repo = api.repo(Repo::with_revision(
"Qwen/Qwen2.5-0.5B-Instruct".to_string(),
RepoType::Model,
"main".to_string(),
));
let tokenizer_path = repo.get("tokenizer.json").map_err(|e| format!("Tokenizer lataus: {}", e))?;
let model_path = repo.get("model.safetensors").map_err(|e| format!("Malli lataus: {}", e))?;
tracing::info!("Ladataan tokenizer: {:?}", tokenizer_path);
let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
.map_err(|e| format!("Tokenizer: {}", e))?;
let config = QwenConfig {
vocab_size: 151936,
hidden_size: 896,
intermediate_size: 4864,
num_hidden_layers: 24,
num_attention_heads: 14,
num_key_value_heads: 2,
max_position_embeddings: 32768,
sliding_window: 32768,
max_window_layers: 21,
tie_word_embeddings: true,
rope_theta: 1000000.0,
rms_norm_eps: 1e-6,
use_sliding_window: false,
hidden_act: candle_nn::Activation::Silu,
};
// Testi-lataus varmistaa, että painot toimivat
let start = Instant::now();
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[model_path.clone()], dtype, &device)
.map_err(|e| format!("VarBuilder: {}", e))?
};
let _model = QwenModel::new(&config, vb).map_err(|e| format!("Malli: {}", e))?;
tracing::info!("Malli ladattu ({:.1}s) — {}", start.elapsed().as_secs_f64(), device_name);
Ok(LlmEngine {
tokenizer,
model_path,
device,
dtype,
config,
eos_token: 151645,
})
}
/// Luo tuore malliinstanssi (nollaa KV-cachen)
fn fresh_model(&self) -> Result<QwenModel, String> {
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[self.model_path.clone()], self.dtype, &self.device)
.map_err(|e| format!("VarBuilder: {}", e))?
};
QwenModel::new(&self.config, vb).map_err(|e| format!("Malli: {}", e))
}
pub fn generate(&mut self, prompt: &str, max_tokens: usize) -> Result<GenerateResult, String> {
let formatted = format!("<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", prompt);
let encoding = self.tokenizer.encode(formatted.as_str(), true)
.map_err(|e| format!("Encode: {}", e))?;
let input_ids: Vec<u32> = encoding.get_ids().to_vec();
let input_len = input_ids.len();
// Tuore malli joka promptille (nollaa KV-cachen)
let mut model = self.fresh_model()?;
let start = Instant::now();
// Prefill
let input = Tensor::new(input_ids.as_slice(), &self.device)
.and_then(|t| t.unsqueeze(0))
.map_err(|e| format!("Tensor: {}", e))?;
let logits = model.forward(&input, 0)
.map_err(|e| format!("Forward prefill: {}", e))?;
let logits = logits.squeeze(0).map_err(|e| format!("Squeeze: {}", e))?;
let logits = if logits.dims().len() == 2 {
logits.get(logits.dim(0).unwrap() - 1).map_err(|e| format!("Get: {}", e))?
} else {
logits
};
let mut next_token = logits.argmax(0)
.map_err(|e| format!("Argmax: {}", e))?
.to_vec0::<u32>()
.map_err(|e| format!("to_vec0: {}", e))?;
let mut generated_text = String::new();
let mut tokens_generated: usize = 0;
let mut all_tokens: Vec<u32> = Vec::new();
if next_token != self.eos_token {
if let Ok(text) = self.tokenizer.decode(&[next_token], true) {
generated_text.push_str(&text);
}
all_tokens.push(next_token);
tokens_generated += 1;
}
// Autoregressive
let mut pos = input_len;
for _ in 1..max_tokens {
if next_token == self.eos_token { break; }
let input = Tensor::new(&[next_token], &self.device)
.and_then(|t| t.unsqueeze(0))
.map_err(|e| format!("Tensor: {}", e))?;
let logits = model.forward(&input, pos)
.map_err(|e| format!("Forward pos {}: {}", pos, e))?;
let logits = logits.squeeze(0).map_err(|e| format!("Squeeze: {}", e))?;
let logits = if logits.dims().len() == 2 {
logits.get(logits.dim(0).unwrap() - 1).map_err(|e| format!("Get: {}", e))?
} else {
logits
};
next_token = logits.argmax(0)
.map_err(|e| format!("Argmax: {}", e))?
.to_vec0::<u32>()
.map_err(|e| format!("to_vec0: {}", e))?;
pos += 1;
if next_token == self.eos_token { break; }
if let Ok(text) = self.tokenizer.decode(&[next_token], true) {
generated_text.push_str(&text);
}
all_tokens.push(next_token);
tokens_generated += 1;
}
let gen_time = start.elapsed();
let tokens_per_sec = if gen_time.as_secs_f64() > 0.0 {
tokens_generated as f64 / gen_time.as_secs_f64()
} else { 0.0 };
Ok(GenerateResult {
text: generated_text,
tokens_generated,
duration_ms: gen_time.as_millis() as f64,
tokens_per_sec,
})
}
}
pub struct GenerateResult {
pub text: String,
pub tokens_generated: usize,
pub duration_ms: f64,
pub tokens_per_sec: f64,
}

View File

@@ -4,6 +4,8 @@ use sysinfo::System;
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::Message;
mod inference;
/// GPU-tietorakenne — yhtenäinen kaikille valmistajille
struct GpuInfo {
name: String,
@@ -282,7 +284,20 @@ async fn main() {
}
}
// Yhdistetään hubiin — yritetään uudelleen katkon sattuessa
// Ladataan LLM-malli
tracing::info!("Ladataan LLM-mallia...");
let mut llm = match inference::LlmEngine::load() {
Ok(engine) => {
tracing::info!("LLM valmis inferenssiin!");
Some(engine)
}
Err(e) => {
tracing::warn!("LLM-lataus epäonnistui: {} — toimitaan ilman inferenssiä", e);
None
}
};
// Yhdistetään hubiin
loop {
match connect_async(&hub_url).await {
Ok((ws_stream, _)) => {
@@ -295,17 +310,51 @@ async fn main() {
continue;
}
let mut busy = false;
while let Some(Ok(msg)) = read.next().await {
if let Message::Text(text) = msg {
if text.contains("pair_task") || text.contains("ai_task") {
tracing::debug!("Tehtävä vastaanotettu: {}", &text[..text.len().min(80)]);
let reply = json!({
"type": "result",
"status": "success",
"data": "native-node: ei vielä laskentaa"
});
let _ = write.send(Message::Text(reply.to_string())).await;
// LLM-promptit
if text.contains("llm_prompt") && !busy {
if let Ok(task) = serde_json::from_str::<serde_json::Value>(&text) {
let prompt = task.get("prompt").and_then(|v| v.as_str()).unwrap_or("");
if !prompt.is_empty() {
if let Some(ref mut engine) = llm {
busy = true;
tracing::info!("Generoidaan: \"{}\"", prompt);
match engine.generate(prompt, 64) {
Ok(result) => {
tracing::info!(
"Tulos: {} tokenia | {:.0}ms | {:.1} tok/s | \"{}\"",
result.tokens_generated,
result.duration_ms,
result.tokens_per_sec,
&result.text[..result.text.len().min(80)]
);
let done = json!({
"type": "llm_done",
"prompt": prompt,
"model": "Qwen2.5-0.5B-Instruct (native/GPU)",
"response": result.text,
"tokens_generated": result.tokens_generated,
"duration_ms": result.duration_ms,
"tokens_per_sec": (result.tokens_per_sec * 10.0).round() / 10.0,
"load_time_ms": 0,
});
let _ = write.send(Message::Text(done.to_string())).await;
}
Err(e) => {
tracing::error!("Inferenssivirhe: {}", e);
}
}
busy = false;
}
}
}
}
// Ohitetaan pair_task, stats jne.
}
}
tracing::warn!("Yhteys hubiin katkesi — yritetään uudelleen 5s...");