kylä lähtee!

This commit is contained in:
2026-04-02 15:47:48 +03:00
parent e1326b145e
commit 92c952c07a
12 changed files with 524 additions and 84 deletions

View File

@@ -118,125 +118,106 @@ pub async fn run_smollm_inference(prompt: String, ws: Rc<RefCell<WebSocket>>) {
Err(e) => { console_log!("[SmolLM] Malli-virhe: {}", e); return; }
};
console_log!("[SmolLM] Rakennetaan mallia...");
let use_gpu = crate::HAS_WEBGPU.load(std::sync::atomic::Ordering::SeqCst);
if use_gpu {
console_log!("[SmolLM] Alustetaan Burn WebGPU...");
burn_wgpu::init_async::<burn_wgpu::AutoGraphicsApi>(&Default::default(), Default::default()).await;
run_burn_inference::<burn::backend::Wgpu>(prompt, model_bytes, tokenizer, ws, perf.clone()).await;
} else {
console_log!("[SmolLM] Käytetään CPU NdArrayta (vanha tapa)...");
run_burn_inference::<burn::backend::NdArray>(prompt, model_bytes, tokenizer, ws, perf.clone()).await;
}
}
async fn run_burn_inference<B: burn::tensor::backend::Backend>(
prompt: String,
model_bytes: Vec<u8>,
tokenizer: tokenizers::Tokenizer,
ws: Rc<RefCell<WebSocket>>,
perf: web_sys::Performance, // Korjattu Wasm-performanssi välitettäväksi
) {
let start_load = perf.now();
let device = Device::Cpu;
let dtype = DType::F32;
// Parsitaan safetensors
let tensors = match candle_core::safetensors::load_buffer(&model_bytes, &device) {
Ok(t) => t,
Err(e) => { console_log!("[SmolLM] Safetensors-parsinta epäonnistui: {}", e); return; }
};
let vb = VarBuilder::from_tensors(tensors, dtype, &device);
// SmolLM-135M config (Llama-arkkitehtuuri)
let config = LlamaConfig {
hidden_size: 576,
intermediate_size: 1536,
vocab_size: 49152,
num_hidden_layers: 30,
num_attention_heads: 9,
num_key_value_heads: Some(3),
rms_norm_eps: 1e-5,
rope_theta: 10000.0,
max_position_embeddings: 2048,
tie_word_embeddings: Some(true),
bos_token_id: Some(1u32),
eos_token_id: Some(LlamaEosToks::Single(2)),
rope_scaling: None,
};
let llama_config = config.into_config(false); // false = ei flash attention
let mut cache = Cache::new(true, dtype, &llama_config, &device).unwrap();
let model = match Llama::load(vb, &llama_config) {
let device = Default::default();
let config = crate::burn_smollm::config::SmolLMConfig::default();
console_log!("[SmolLM] Injektoidaan Safetensors -> Burn Params...");
let model = match crate::burn_smollm::loader::load_safetensors_to_model::<B>(&model_bytes, &config, &device) {
Ok(m) => m,
Err(e) => { console_log!("[SmolLM] Mallin lataus epäonnistui: {}", e); return; }
Err(e) => { console_log!("[SmolLM] Lataus epäonnistui: {}", e); return; }
};
let load_time = perf.now() - start_load;
console_log!("[SmolLM] Malli ladattu ({:.0}ms). Generoidaan...", load_time);
console_log!("[SmolLM] Burn-malli ladattu ({:.0}ms). Generoidaan...", load_time);
// 3. Tokenisoi syöte (Käytetään ChatML-formaattia SmolLM-Instructille)
let formatted_prompt = format!("<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", prompt);
let encoding = match tokenizer.encode(formatted_prompt.as_str(), true) {
Ok(e) => e,
Err(e) => { console_log!("[SmolLM] Tokenisointivirhe: {}", e); return; }
};
let input_ids: Vec<u32> = encoding.get_ids().to_vec();
let mut input_ids: Vec<u32> = encoding.get_ids().to_vec();
let input_len = input_ids.len();
console_log!("[SmolLM] Syöte: {} tokenia", input_len);
// 4. Generoi tokeneita
let start_gen = perf.now();
let max_new_tokens = 32;
let mut generated_text = String::new();
let mut tokens_generated: usize = 0;
let mut pos: usize = 0;
// KV-välimuistin taulukko kerroksittain
let mut caches: Vec<Option<crate::burn_smollm::attention::KVCache<B>>> = vec![None; config.num_hidden_layers];
let mut current_offset = 0;
// Ensimmäinen forward: koko syöte kerralla
let input = match Tensor::new(input_ids.as_slice(), &device).and_then(|t| t.unsqueeze(0)) {
Ok(t) => t,
Err(e) => { console_log!("[SmolLM] Tensor-virhe: {}", e); return; }
};
// Prefill: yksitellen, vältetään future token leakage koska ei causal maskia
let input_ids_i32: Vec<i32> = input_ids.iter().map(|&x| x as i32).collect();
let mut last_logits = None;
for &id in &input_ids_i32 {
let input_tensor = burn::tensor::Tensor::<B, 1, burn::tensor::Int>::from_data(
burn::tensor::TensorData::from([id]),
&device
).unsqueeze::<2>(); // [1, 1]
last_logits = Some(model.forward(input_tensor, current_offset, &mut caches));
current_offset += 1;
}
let logits = match model.forward(&input, 0, &mut cache) {
Ok(l) => l,
Err(e) => { console_log!("[SmolLM] Forward-virhe (prefill): {}", e); return; }
};
let mut logits = last_logits.unwrap();
// Llama forward voi palauttaa [batch, vocab] tai [batch, seq_len, vocab]
let logits = logits.squeeze(0).unwrap();
let logits = if logits.dims().len() == 2 {
logits.get(logits.dim(0).unwrap() - 1).unwrap()
} else {
logits
};
let mut next_token = crate::sampling::sample_top_k(&logits, 10, 5.0);
console_log!("[SmolLM] Ensimmäinen generoitu token: {}", next_token);
pos = input_len;
// Argmax sämpläys
let next_token_tensor = logits.clone().argmax(2);
let mut next_token: u32 = next_token_tensor.into_scalar().to_string().parse().unwrap_or(2); // Yksinkertainen cast koska int scalar
if next_token != 2 {
if let Ok(text) = tokenizer.decode(&[next_token], true) {
generated_text.push_str(&text);
let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "SmolLM-135M" });
let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "SmolLM-135M (WebGPU)" });
let _ = ws.borrow().send_with_str(&chunk.to_string());
}
tokens_generated += 1;
}
// Autoregressiivinen generointi: yksi token kerrallaan
// Autoregressiivinen luuppi
for _ in 1..max_new_tokens {
if next_token == 2 { break; }
let input = match Tensor::new(&[next_token], &device).and_then(|t| t.unsqueeze(0)) {
Ok(t) => t,
Err(e) => { console_log!("[SmolLM] Tensor-virhe: {}", e); break; }
};
let logits = match model.forward(&input, pos, &mut cache) {
Ok(l) => l,
Err(e) => { console_log!("[SmolLM] Forward-virhe pos {}: {}", pos, e); break; }
};
let logits = logits.squeeze(0).unwrap();
let logits = if logits.dims().len() == 2 {
logits.get(logits.dim(0).unwrap() - 1).unwrap()
} else {
logits
};
next_token = crate::sampling::sample_top_k(&logits, 10, 5.0);
pos += 1;
let mut input_tensor = burn::tensor::Tensor::<B, 1, burn::tensor::Int>::from_data(
burn::tensor::TensorData::from([next_token as i32]),
&device
).unsqueeze::<2>();
logits = model.forward(input_tensor, current_offset, &mut caches);
current_offset += 1;
let next_token_tensor = logits.argmax(2);
next_token = next_token_tensor.into_scalar().to_string().parse().unwrap_or(2);
if next_token == 2 { break; }
if let Ok(text) = tokenizer.decode(&[next_token], true) {
generated_text.push_str(&text);
let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "SmolLM-135M" });
let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "SmolLM-135M (WebGPU)" });
let _ = ws.borrow().send_with_str(&chunk.to_string());
}
tokens_generated += 1;
@@ -245,12 +226,10 @@ pub async fn run_smollm_inference(prompt: String, ws: Rc<RefCell<WebSocket>>) {
let gen_time = perf.now() - start_gen;
let tokens_per_sec = if gen_time > 0.0 { (tokens_generated as f64 / gen_time) * 1000.0 } else { 0.0 };
console_log!("[SmolLM] Generoitu {} tokenia | {:.0}ms | {:.1} tok/s", tokens_generated, gen_time, tokens_per_sec);
let done = serde_json::json!({
"type": "llm_done",
"prompt": prompt,
"model": "SmolLM-135M-Instruct",
"model": "SmolLM-135M-Instruct (WebGPU)",
"response": generated_text,
"tokens_generated": tokens_generated,
"duration_ms": (gen_time * 100.0).round() / 100.0,