on tämä työmaa
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
use candle_core::{Device, Tensor, DType};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::llama::{Llama, LlamaConfig, LlamaEosToks, Cache};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
// LogitsProcessor poistettu — käytetään greedy samplingia (argmax) Wasm-yhteensopivuuden vuoksi
|
||||
use wasm_bindgen::JsCast;
|
||||
use std::cell::RefCell;
|
||||
use std::rc::Rc;
|
||||
@@ -160,8 +160,9 @@ pub async fn run_smollm_inference(prompt: String, ws: Rc<RefCell<WebSocket>>) {
|
||||
let load_time = perf.now() - start_load;
|
||||
console_log!("[SmolLM] Malli ladattu ({:.0}ms). Generoidaan...", load_time);
|
||||
|
||||
// 3. Tokenisoi syöte
|
||||
let encoding = match tokenizer.encode(prompt.as_str(), true) {
|
||||
// 3. Tokenisoi syöte (Käytetään ChatML-formaattia SmolLM-Instructille)
|
||||
let formatted_prompt = format!("<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", prompt);
|
||||
let encoding = match tokenizer.encode(formatted_prompt.as_str(), true) {
|
||||
Ok(e) => e,
|
||||
Err(e) => { console_log!("[SmolLM] Tokenisointivirhe: {}", e); return; }
|
||||
};
|
||||
@@ -172,62 +173,76 @@ pub async fn run_smollm_inference(prompt: String, ws: Rc<RefCell<WebSocket>>) {
|
||||
|
||||
// 4. Generoi tokeneita
|
||||
let start_gen = perf.now();
|
||||
let mut logits_processor = LogitsProcessor::new(42, Some(0.8), Some(0.95));
|
||||
let mut all_tokens = input_ids.clone();
|
||||
let max_new_tokens = 64;
|
||||
let max_new_tokens = 32;
|
||||
let mut generated_text = String::new();
|
||||
let mut tokens_generated: usize = 0;
|
||||
let mut pos: usize = 0;
|
||||
|
||||
for i in 0..max_new_tokens {
|
||||
let context_tokens = if i == 0 {
|
||||
all_tokens.as_slice()
|
||||
} else {
|
||||
std::slice::from_ref(all_tokens.last().unwrap())
|
||||
};
|
||||
// Ensimmäinen forward: koko syöte kerralla
|
||||
let input = match Tensor::new(input_ids.as_slice(), &device).and_then(|t| t.unsqueeze(0)) {
|
||||
Ok(t) => t,
|
||||
Err(e) => { console_log!("[SmolLM] Tensor-virhe: {}", e); return; }
|
||||
};
|
||||
|
||||
let input = Tensor::new(context_tokens, &device).unwrap().unsqueeze(0).unwrap();
|
||||
let seq_len = input.dim(1).unwrap();
|
||||
let logits = match model.forward(&input, 0, &mut cache) {
|
||||
Ok(l) => l,
|
||||
Err(e) => { console_log!("[SmolLM] Forward-virhe (prefill): {}", e); return; }
|
||||
};
|
||||
|
||||
let logits = match model.forward(&input, input_len + i - seq_len, &mut cache) {
|
||||
Ok(l) => l,
|
||||
Err(e) => { console_log!("[SmolLM] Forward-virhe stepissä {}: {}", i, e); break; }
|
||||
};
|
||||
// Llama forward voi palauttaa [batch, vocab] tai [batch, seq_len, vocab]
|
||||
let logits = logits.squeeze(0).unwrap();
|
||||
let logits = if logits.dims().len() == 2 {
|
||||
logits.get(logits.dim(0).unwrap() - 1).unwrap()
|
||||
} else {
|
||||
logits
|
||||
};
|
||||
let mut next_token = logits.argmax(0).unwrap().to_vec0::<u32>().unwrap();
|
||||
console_log!("[SmolLM] Ensimmäinen generoitu token: {}", next_token);
|
||||
pos = input_len;
|
||||
|
||||
// Viimeisen tokenin logitit
|
||||
let logits = logits.squeeze(0).unwrap();
|
||||
let last_dim = logits.dim(0).unwrap();
|
||||
let logits = if last_dim > 1 {
|
||||
logits.get(last_dim - 1).unwrap()
|
||||
} else {
|
||||
logits.get(0).unwrap()
|
||||
};
|
||||
|
||||
let next_token = logits_processor.sample(&logits).unwrap();
|
||||
|
||||
// EOS-tarkistus
|
||||
if next_token == 2 {
|
||||
break;
|
||||
}
|
||||
|
||||
all_tokens.push(next_token);
|
||||
|
||||
// Dekoodaa token tekstiksi
|
||||
if next_token != 2 {
|
||||
if let Ok(text) = tokenizer.decode(&[next_token], true) {
|
||||
generated_text.push_str(&text);
|
||||
|
||||
// Streamaa token hubille
|
||||
let chunk = serde_json::json!({
|
||||
"type": "llm_chunk",
|
||||
"token": text,
|
||||
"is_last": false,
|
||||
"prompt": prompt,
|
||||
"model": "SmolLM-135M"
|
||||
});
|
||||
let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "SmolLM-135M" });
|
||||
let _ = ws.borrow().send_with_str(&chunk.to_string());
|
||||
}
|
||||
tokens_generated += 1;
|
||||
}
|
||||
|
||||
// Autoregressiivinen generointi: yksi token kerrallaan
|
||||
for _ in 1..max_new_tokens {
|
||||
if next_token == 2 { break; }
|
||||
|
||||
let input = match Tensor::new(&[next_token], &device).and_then(|t| t.unsqueeze(0)) {
|
||||
Ok(t) => t,
|
||||
Err(e) => { console_log!("[SmolLM] Tensor-virhe: {}", e); break; }
|
||||
};
|
||||
|
||||
let logits = match model.forward(&input, pos, &mut cache) {
|
||||
Ok(l) => l,
|
||||
Err(e) => { console_log!("[SmolLM] Forward-virhe pos {}: {}", pos, e); break; }
|
||||
};
|
||||
|
||||
let logits = logits.squeeze(0).unwrap();
|
||||
let logits = if logits.dims().len() == 2 {
|
||||
logits.get(logits.dim(0).unwrap() - 1).unwrap()
|
||||
} else {
|
||||
logits
|
||||
};
|
||||
next_token = logits.argmax(0).unwrap().to_vec0::<u32>().unwrap();
|
||||
pos += 1;
|
||||
|
||||
if next_token == 2 { break; }
|
||||
|
||||
if let Ok(text) = tokenizer.decode(&[next_token], true) {
|
||||
generated_text.push_str(&text);
|
||||
let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "SmolLM-135M" });
|
||||
let _ = ws.borrow().send_with_str(&chunk.to_string());
|
||||
}
|
||||
tokens_generated += 1;
|
||||
}
|
||||
|
||||
let gen_time = perf.now() - start_gen;
|
||||
let tokens_generated = all_tokens.len() - input_len;
|
||||
let tokens_per_sec = if gen_time > 0.0 { (tokens_generated as f64 / gen_time) * 1000.0 } else { 0.0 };
|
||||
|
||||
console_log!("[SmolLM] Generoitu {} tokenia | {:.0}ms | {:.1} tok/s", tokens_generated, gen_time, tokens_per_sec);
|
||||
|
||||
Reference in New Issue
Block a user