on tämä työmaa
This commit is contained in:
@@ -12,5 +12,10 @@ serde_json = "1.0"
|
||||
sysinfo = "0.30"
|
||||
nvml-wrapper = "0.10"
|
||||
wgpu = "24"
|
||||
candle-core = { version = "0.8" }
|
||||
candle-nn = "0.8"
|
||||
candle-transformers = "0.8"
|
||||
hf-hub = "0.4"
|
||||
tokenizers = "0.19"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
||||
181
network-poc/native-node/src/inference.rs
Normal file
181
network-poc/native-node/src/inference.rs
Normal file
@@ -0,0 +1,181 @@
|
||||
use candle_core::{Device, Tensor, DType};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::qwen2::{Config as QwenConfig, ModelForCausalLM as QwenModel};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::path::PathBuf;
|
||||
use std::time::Instant;
|
||||
|
||||
pub struct LlmEngine {
|
||||
tokenizer: tokenizers::Tokenizer,
|
||||
model_path: PathBuf,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
config: QwenConfig,
|
||||
eos_token: u32,
|
||||
}
|
||||
|
||||
impl LlmEngine {
|
||||
pub fn load() -> Result<Self, String> {
|
||||
let device = Device::cuda_if_available(0).map_err(|e| format!("Device: {}", e))?;
|
||||
let device_name = if device.is_cuda() { "CUDA" } else { "CPU" };
|
||||
tracing::info!("LLM device: {}", device_name);
|
||||
|
||||
let dtype = if device.is_cuda() { DType::F16 } else { DType::F32 };
|
||||
|
||||
tracing::info!("Ladataan Qwen2.5-0.5B-Instruct...");
|
||||
let api = Api::new().map_err(|e| format!("HF API: {}", e))?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
"Qwen/Qwen2.5-0.5B-Instruct".to_string(),
|
||||
RepoType::Model,
|
||||
"main".to_string(),
|
||||
));
|
||||
|
||||
let tokenizer_path = repo.get("tokenizer.json").map_err(|e| format!("Tokenizer lataus: {}", e))?;
|
||||
let model_path = repo.get("model.safetensors").map_err(|e| format!("Malli lataus: {}", e))?;
|
||||
|
||||
tracing::info!("Ladataan tokenizer: {:?}", tokenizer_path);
|
||||
let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
|
||||
.map_err(|e| format!("Tokenizer: {}", e))?;
|
||||
|
||||
let config = QwenConfig {
|
||||
vocab_size: 151936,
|
||||
hidden_size: 896,
|
||||
intermediate_size: 4864,
|
||||
num_hidden_layers: 24,
|
||||
num_attention_heads: 14,
|
||||
num_key_value_heads: 2,
|
||||
max_position_embeddings: 32768,
|
||||
sliding_window: 32768,
|
||||
max_window_layers: 21,
|
||||
tie_word_embeddings: true,
|
||||
rope_theta: 1000000.0,
|
||||
rms_norm_eps: 1e-6,
|
||||
use_sliding_window: false,
|
||||
hidden_act: candle_nn::Activation::Silu,
|
||||
};
|
||||
|
||||
// Testi-lataus varmistaa, että painot toimivat
|
||||
let start = Instant::now();
|
||||
let vb = unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&[model_path.clone()], dtype, &device)
|
||||
.map_err(|e| format!("VarBuilder: {}", e))?
|
||||
};
|
||||
let _model = QwenModel::new(&config, vb).map_err(|e| format!("Malli: {}", e))?;
|
||||
tracing::info!("Malli ladattu ({:.1}s) — {}", start.elapsed().as_secs_f64(), device_name);
|
||||
|
||||
Ok(LlmEngine {
|
||||
tokenizer,
|
||||
model_path,
|
||||
device,
|
||||
dtype,
|
||||
config,
|
||||
eos_token: 151645,
|
||||
})
|
||||
}
|
||||
|
||||
/// Luo tuore malliinstanssi (nollaa KV-cachen)
|
||||
fn fresh_model(&self) -> Result<QwenModel, String> {
|
||||
let vb = unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&[self.model_path.clone()], self.dtype, &self.device)
|
||||
.map_err(|e| format!("VarBuilder: {}", e))?
|
||||
};
|
||||
QwenModel::new(&self.config, vb).map_err(|e| format!("Malli: {}", e))
|
||||
}
|
||||
|
||||
pub fn generate(&mut self, prompt: &str, max_tokens: usize) -> Result<GenerateResult, String> {
|
||||
let formatted = format!("<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", prompt);
|
||||
|
||||
let encoding = self.tokenizer.encode(formatted.as_str(), true)
|
||||
.map_err(|e| format!("Encode: {}", e))?;
|
||||
let input_ids: Vec<u32> = encoding.get_ids().to_vec();
|
||||
let input_len = input_ids.len();
|
||||
|
||||
// Tuore malli joka promptille (nollaa KV-cachen)
|
||||
let mut model = self.fresh_model()?;
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
// Prefill
|
||||
let input = Tensor::new(input_ids.as_slice(), &self.device)
|
||||
.and_then(|t| t.unsqueeze(0))
|
||||
.map_err(|e| format!("Tensor: {}", e))?;
|
||||
|
||||
let logits = model.forward(&input, 0)
|
||||
.map_err(|e| format!("Forward prefill: {}", e))?;
|
||||
|
||||
let logits = logits.squeeze(0).map_err(|e| format!("Squeeze: {}", e))?;
|
||||
let logits = if logits.dims().len() == 2 {
|
||||
logits.get(logits.dim(0).unwrap() - 1).map_err(|e| format!("Get: {}", e))?
|
||||
} else {
|
||||
logits
|
||||
};
|
||||
let mut next_token = logits.argmax(0)
|
||||
.map_err(|e| format!("Argmax: {}", e))?
|
||||
.to_vec0::<u32>()
|
||||
.map_err(|e| format!("to_vec0: {}", e))?;
|
||||
|
||||
let mut generated_text = String::new();
|
||||
let mut tokens_generated: usize = 0;
|
||||
let mut all_tokens: Vec<u32> = Vec::new();
|
||||
|
||||
if next_token != self.eos_token {
|
||||
if let Ok(text) = self.tokenizer.decode(&[next_token], true) {
|
||||
generated_text.push_str(&text);
|
||||
}
|
||||
all_tokens.push(next_token);
|
||||
tokens_generated += 1;
|
||||
}
|
||||
|
||||
// Autoregressive
|
||||
let mut pos = input_len;
|
||||
for _ in 1..max_tokens {
|
||||
if next_token == self.eos_token { break; }
|
||||
|
||||
let input = Tensor::new(&[next_token], &self.device)
|
||||
.and_then(|t| t.unsqueeze(0))
|
||||
.map_err(|e| format!("Tensor: {}", e))?;
|
||||
|
||||
let logits = model.forward(&input, pos)
|
||||
.map_err(|e| format!("Forward pos {}: {}", pos, e))?;
|
||||
|
||||
let logits = logits.squeeze(0).map_err(|e| format!("Squeeze: {}", e))?;
|
||||
let logits = if logits.dims().len() == 2 {
|
||||
logits.get(logits.dim(0).unwrap() - 1).map_err(|e| format!("Get: {}", e))?
|
||||
} else {
|
||||
logits
|
||||
};
|
||||
next_token = logits.argmax(0)
|
||||
.map_err(|e| format!("Argmax: {}", e))?
|
||||
.to_vec0::<u32>()
|
||||
.map_err(|e| format!("to_vec0: {}", e))?;
|
||||
pos += 1;
|
||||
|
||||
if next_token == self.eos_token { break; }
|
||||
|
||||
if let Ok(text) = self.tokenizer.decode(&[next_token], true) {
|
||||
generated_text.push_str(&text);
|
||||
}
|
||||
all_tokens.push(next_token);
|
||||
tokens_generated += 1;
|
||||
}
|
||||
|
||||
let gen_time = start.elapsed();
|
||||
let tokens_per_sec = if gen_time.as_secs_f64() > 0.0 {
|
||||
tokens_generated as f64 / gen_time.as_secs_f64()
|
||||
} else { 0.0 };
|
||||
|
||||
Ok(GenerateResult {
|
||||
text: generated_text,
|
||||
tokens_generated,
|
||||
duration_ms: gen_time.as_millis() as f64,
|
||||
tokens_per_sec,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GenerateResult {
|
||||
pub text: String,
|
||||
pub tokens_generated: usize,
|
||||
pub duration_ms: f64,
|
||||
pub tokens_per_sec: f64,
|
||||
}
|
||||
@@ -4,6 +4,8 @@ use sysinfo::System;
|
||||
use tokio_tungstenite::connect_async;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
|
||||
mod inference;
|
||||
|
||||
/// GPU-tietorakenne — yhtenäinen kaikille valmistajille
|
||||
struct GpuInfo {
|
||||
name: String,
|
||||
@@ -282,7 +284,20 @@ async fn main() {
|
||||
}
|
||||
}
|
||||
|
||||
// Yhdistetään hubiin — yritetään uudelleen katkon sattuessa
|
||||
// Ladataan LLM-malli
|
||||
tracing::info!("Ladataan LLM-mallia...");
|
||||
let mut llm = match inference::LlmEngine::load() {
|
||||
Ok(engine) => {
|
||||
tracing::info!("LLM valmis inferenssiin!");
|
||||
Some(engine)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("LLM-lataus epäonnistui: {} — toimitaan ilman inferenssiä", e);
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
// Yhdistetään hubiin
|
||||
loop {
|
||||
match connect_async(&hub_url).await {
|
||||
Ok((ws_stream, _)) => {
|
||||
@@ -295,17 +310,51 @@ async fn main() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut busy = false;
|
||||
|
||||
while let Some(Ok(msg)) = read.next().await {
|
||||
if let Message::Text(text) = msg {
|
||||
if text.contains("pair_task") || text.contains("ai_task") {
|
||||
tracing::debug!("Tehtävä vastaanotettu: {}", &text[..text.len().min(80)]);
|
||||
let reply = json!({
|
||||
"type": "result",
|
||||
"status": "success",
|
||||
"data": "native-node: ei vielä laskentaa"
|
||||
});
|
||||
let _ = write.send(Message::Text(reply.to_string())).await;
|
||||
// LLM-promptit
|
||||
if text.contains("llm_prompt") && !busy {
|
||||
if let Ok(task) = serde_json::from_str::<serde_json::Value>(&text) {
|
||||
let prompt = task.get("prompt").and_then(|v| v.as_str()).unwrap_or("");
|
||||
if !prompt.is_empty() {
|
||||
if let Some(ref mut engine) = llm {
|
||||
busy = true;
|
||||
tracing::info!("Generoidaan: \"{}\"", prompt);
|
||||
|
||||
match engine.generate(prompt, 64) {
|
||||
Ok(result) => {
|
||||
tracing::info!(
|
||||
"Tulos: {} tokenia | {:.0}ms | {:.1} tok/s | \"{}\"",
|
||||
result.tokens_generated,
|
||||
result.duration_ms,
|
||||
result.tokens_per_sec,
|
||||
&result.text[..result.text.len().min(80)]
|
||||
);
|
||||
|
||||
let done = json!({
|
||||
"type": "llm_done",
|
||||
"prompt": prompt,
|
||||
"model": "Qwen2.5-0.5B-Instruct (native/GPU)",
|
||||
"response": result.text,
|
||||
"tokens_generated": result.tokens_generated,
|
||||
"duration_ms": result.duration_ms,
|
||||
"tokens_per_sec": (result.tokens_per_sec * 10.0).round() / 10.0,
|
||||
"load_time_ms": 0,
|
||||
});
|
||||
let _ = write.send(Message::Text(done.to_string())).await;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Inferenssivirhe: {}", e);
|
||||
}
|
||||
}
|
||||
busy = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Ohitetaan pair_task, stats jne.
|
||||
}
|
||||
}
|
||||
tracing::warn!("Yhteys hubiin katkesi — yritetään uudelleen 5s...");
|
||||
|
||||
Reference in New Issue
Block a user