Frontend lähettää agentin asetukset (system_prompt, temperature, top_k, max_tokens, repeat_penalty, stop) API:lle. Hub välittää ne solmulle. Native-node ja Wasm-coder käyttävät välitettyjä arvoja hardkoodattujen sijaan.
208 lines
7.4 KiB
Rust
208 lines
7.4 KiB
Rust
use std::time::Instant;
|
|
use std::cell::RefCell;
|
|
|
|
pub struct GenerateOptions {
|
|
pub max_tokens: usize,
|
|
pub system_prompt: Option<String>,
|
|
pub temperature: Option<f64>,
|
|
pub top_k: Option<u64>,
|
|
pub repeat_penalty: Option<f64>,
|
|
pub stop: Option<Vec<String>>,
|
|
}
|
|
|
|
pub struct LlmEngine {
|
|
ollama_url: String,
|
|
model: RefCell<String>,
|
|
client: reqwest::Client,
|
|
}
|
|
|
|
impl LlmEngine {
|
|
pub async fn load() -> Result<Self, String> {
|
|
let client = reqwest::Client::builder()
|
|
.timeout(std::time::Duration::from_secs(600))
|
|
.connect_timeout(std::time::Duration::from_secs(3))
|
|
.build()
|
|
.map_err(|e| format!("HTTP client: {}", e))?;
|
|
|
|
// Jos OLLAMA_URL on asetettu, käytetään sitä suoraan
|
|
let ollama_url = if let Ok(url) = std::env::var("OLLAMA_URL") {
|
|
tracing::info!("Ollama backend (env): {}", url);
|
|
url
|
|
} else {
|
|
// Haistellaan Ollamaa tunnetuista osoitteista
|
|
let candidates = [
|
|
"http://localhost:11434",
|
|
"http://127.0.0.1:11434",
|
|
"http://ollama:11434",
|
|
"http://host.docker.internal:11434",
|
|
];
|
|
let mut found = None;
|
|
for url in &candidates {
|
|
let probe = reqwest::Client::builder()
|
|
.connect_timeout(std::time::Duration::from_secs(2))
|
|
.build().unwrap_or(client.clone());
|
|
if let Ok(resp) = probe.get(format!("{}/api/version", url)).send().await {
|
|
if resp.status().is_success() {
|
|
tracing::info!("Ollama löytyi osoitteesta: {}", url);
|
|
found = Some(url.to_string());
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
found.unwrap_or_else(|| {
|
|
tracing::warn!("Ollamaa ei löytynyt — käytetään oletusta http://localhost:11434");
|
|
"http://localhost:11434".to_string()
|
|
})
|
|
};
|
|
|
|
// Kysytään malli TUI:lla jos ei pakotettu ympäristöstä
|
|
let model = match std::env::var("OLLAMA_MODEL") {
|
|
Ok(m) if !m.is_empty() => m,
|
|
_ => crate::tui::select_model(&ollama_url, &client).await?
|
|
};
|
|
|
|
tracing::info!("Ollama backend: {} | malli: {}", ollama_url, model);
|
|
Ok(LlmEngine { ollama_url, model: RefCell::new(model), client })
|
|
}
|
|
|
|
pub fn model_name(&self) -> String {
|
|
self.model.borrow().clone()
|
|
}
|
|
|
|
pub fn set_model(&self, new_model: String) {
|
|
*self.model.borrow_mut() = new_model;
|
|
}
|
|
|
|
/// Varmistaa että malli on ladattu Ollamaan (ollama pull)
|
|
pub async fn ensure_model(&self) -> Result<(), String> {
|
|
let model = self.model.borrow().clone();
|
|
tracing::info!("Tarkistetaan malli {}...", model);
|
|
let resp = self.client.post(format!("{}/api/pull", self.ollama_url))
|
|
.json(&serde_json::json!({ "name": model, "stream": false }))
|
|
.send()
|
|
.await
|
|
.map_err(|e| format!("Ollama pull: {}", e))?;
|
|
|
|
if resp.status().is_success() {
|
|
tracing::info!("Malli {} valmis", model);
|
|
Ok(())
|
|
} else {
|
|
Err(format!("Ollama pull epäonnistui: {}", resp.status()))
|
|
}
|
|
}
|
|
|
|
/// Hakee kaikki Ollamaan asennetut mallit
|
|
pub async fn fetch_models(&self) -> Result<serde_json::Value, String> {
|
|
let resp = self.client.get(format!("{}/api/tags", self.ollama_url))
|
|
.send()
|
|
.await
|
|
.map_err(|e| format!("Ollama tags fetch: {}", e))?;
|
|
|
|
if resp.status().is_success() {
|
|
resp.json().await.map_err(|e| format!("Ollama tags json: {}", e))
|
|
} else {
|
|
Err(format!("Ollama tags epäonnistui: {}", resp.status()))
|
|
}
|
|
}
|
|
|
|
pub async fn generate(&self, prompt: &str, opts: &GenerateOptions) -> Result<GenerateResult, String> {
|
|
let model = self.model.borrow().clone();
|
|
|
|
let default_stop: Vec<String> = vec![
|
|
"<|im_end|>".into(), "\n###".into(), "\nExplanation".into(),
|
|
"\nNote:".into(), "\nPlease note".into(), "\nThis is".into(),
|
|
"\n```\n\n".into(), "\n// Example".into(), "\n# Example".into(),
|
|
];
|
|
|
|
let mut body = serde_json::json!({
|
|
"model": model,
|
|
"prompt": prompt,
|
|
"stream": false,
|
|
"options": {
|
|
"num_predict": opts.max_tokens,
|
|
"temperature": opts.temperature.unwrap_or(0.7),
|
|
"top_k": opts.top_k.unwrap_or(40),
|
|
"repeat_penalty": opts.repeat_penalty.unwrap_or(1.15),
|
|
"stop": opts.stop.as_ref().unwrap_or(&default_stop),
|
|
}
|
|
});
|
|
if let Some(ref sp) = opts.system_prompt {
|
|
body.as_object_mut().unwrap().insert("system".to_string(), serde_json::json!(sp));
|
|
}
|
|
|
|
let start = Instant::now();
|
|
let resp = self.client.post(format!("{}/api/generate", self.ollama_url))
|
|
.json(&body)
|
|
.send()
|
|
.await
|
|
.map_err(|e| format!("Ollama generate: {}", e))?;
|
|
|
|
if !resp.status().is_success() {
|
|
return Err(format!("Ollama HTTP {}", resp.status()));
|
|
}
|
|
|
|
let body: serde_json::Value = resp.json().await
|
|
.map_err(|e| format!("Ollama JSON: {}", e))?;
|
|
|
|
let text = body["response"].as_str().unwrap_or("").to_string();
|
|
let _total_duration_ns = body["total_duration"].as_u64().unwrap_or(0);
|
|
let eval_count = body["eval_count"].as_u64().unwrap_or(0) as usize;
|
|
let eval_duration_ns = body["eval_duration"].as_u64().unwrap_or(1);
|
|
|
|
let duration_ms = start.elapsed().as_millis() as f64;
|
|
let tokens_per_sec = if eval_duration_ns > 0 {
|
|
eval_count as f64 / (eval_duration_ns as f64 / 1_000_000_000.0)
|
|
} else { 0.0 };
|
|
|
|
Ok(GenerateResult {
|
|
text: strip_code_fences(&text),
|
|
tokens_generated: eval_count,
|
|
duration_ms,
|
|
tokens_per_sec,
|
|
})
|
|
}
|
|
}
|
|
|
|
/// Siivoa markdown-koodiblokki-merkit ja selitystekstit
|
|
fn strip_code_fences(text: &str) -> String {
|
|
// Poistetaan kaikki ```-rivit ja kielitunnisteet (```python, ```rust jne.)
|
|
let lines: Vec<&str> = text.lines().collect();
|
|
let filtered: Vec<&str> = lines.into_iter().filter(|line| {
|
|
let trimmed = line.trim();
|
|
// Poista rivit jotka ovat pelkkiä ``` tai ```kielitunniste
|
|
if trimmed.starts_with("```") {
|
|
return false;
|
|
}
|
|
true
|
|
}).collect();
|
|
let mut result = filtered.join("\n").trim().to_string();
|
|
|
|
// Poista selitysteksti lopusta (kaikki rivin "\nPlease note" jälkeen jne.)
|
|
let lower = result.to_lowercase();
|
|
for stop in &["\nplease note", "\nthis is a basic", "\nthis code", "\nnote that", "\nremember to", "\nyou can", "\nto run"] {
|
|
if let Some(pos) = lower.find(stop) {
|
|
result = result[..pos].trim_end().to_string();
|
|
}
|
|
}
|
|
|
|
// Poista johdantolauseet alusta
|
|
let lower = result.to_lowercase();
|
|
for prefix in &["sure!", "here is", "here's", "certainly!", "below is"] {
|
|
if lower.starts_with(prefix) {
|
|
if let Some(nl) = result.find('\n') {
|
|
result = result[nl + 1..].to_string();
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
|
|
result.trim().to_string()
|
|
}
|
|
|
|
pub struct GenerateResult {
|
|
pub text: String,
|
|
pub tokens_generated: usize,
|
|
pub duration_ms: f64,
|
|
pub tokens_per_sec: f64,
|
|
}
|