Compare commits
31 Commits
2e7ddf6f1e
...
agentic-co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ac15336c9f | ||
|
|
7a15cacebf | ||
|
|
27135a8f14 | ||
|
|
e28a715f32 | ||
|
|
24d29d9ba9 | ||
|
|
7eca426e77 | ||
|
|
7a1352ead7 | ||
|
|
b9017448d8 | ||
|
|
3d1b406e8d | ||
|
|
aa6c4739dd | ||
|
|
cbbf427a93 | ||
|
|
0a216f19e2 | ||
|
|
a2e7ed53ff | ||
|
|
950cae9d96 | ||
|
|
ff3a720b8d | ||
|
|
6f14614af8 | ||
|
|
518c6dc5cb | ||
|
|
b48eeb6f5f | ||
|
|
6bc7d03676 | ||
|
|
13b2911d38 | ||
|
|
38054452e2 | ||
|
|
50ff34cb09 | ||
|
|
949f34833f | ||
|
|
88fd31ca8c | ||
|
|
57c6506f91 | ||
|
|
12fd3c4eae | ||
|
|
1845ddf8c2 | ||
|
|
0ef1a3f7cd | ||
|
|
4e49cfbbfa | ||
|
|
133ff38fa4 | ||
|
|
3ada8949d0 |
42
TODO.md
42
TODO.md
@@ -1 +1,41 @@
|
|||||||
Lisää viesteihin tietoturvallinen kryptaus - mitään selkokielistä ei ole hyvä lähettää.
|
# Kipinä Agentic Network: TODO-lista
|
||||||
|
|
||||||
|
- [x] **Tietoturva & yksityisyys:** Lisää viesteihin tietoturvallinen kryptaus (E2E-salaus / Blind Orchestrator). Mitään selkokielistä ei ole hyvä lähettää vieraalle solmulle.
|
||||||
|
- [x] **Reititysarkkitehtuuri:** Hubin kohdennettu reititys. Broadcastin sijaan tehtävät ohjataan vain parhaalle vapana olevalle solmulle (Node Registry & Matchmaking) tehtävän tyypin ja resurssien perusteella.
|
||||||
|
- [x] **P2P-jakelu:** WebRTC Data Channels mallipainojen jakamiseen suoraan solmujen välillä kaistan ja latausaikojen säästämiseksi.
|
||||||
|
- [x] **Tulosten varmentaminen:** Proof of Compute / Konsensus-mekanismi, jossa sama tehtävä annetaan kahdelle solmulle, ja tila hyväksytään vasta kun ristiintarkastus täsmää.
|
||||||
|
- [x] **Optimaalinen laitekiihdytys:** Selainpuolen laajennus tulevaa WebNN-standardia (NPU API) varten WebGPU:n rinnalle.
|
||||||
|
- [x] **Insentiivit:** Gamifikaatio, pistetaulukko tai token-talous (esim. Kipinä Tokens), joka motivoi käyttäjiä tarjoamaan laitteensa laskentatehoa verkoston käyttöön pidemmäksi aikaa.
|
||||||
|
- [x] **Pelimerkkien UI-synkkaus:** Pelimerkkien saldon synkronointi reaaliajassa Hubista takaisin valikossa olevalle selainsolmulle ja luvun visuaalinen näyttäminen.
|
||||||
|
- [x] **XSS-suojaus:** HTML-escape kaikelle backend-datalle joka renderöidään DOM:iin (prompt, response, tokenisaatiotekstit).
|
||||||
|
- [x] **System prompt -vuoto:** Agents-pipelinen system prompt ei enää näy käyttäjälle vastauksissa.
|
||||||
|
- [x] **Token-saldon data race:** Korjattu atomiseksi operaatioksi.
|
||||||
|
- [x] **UTF-8 slicing panic:** Korjattu kaikki `&text[..n]` → `text.chars().take(n)`.
|
||||||
|
- [x] **Tensor dim unwrap:** Lisätty virheenkäsittely tyhjälle tensorille natiivisolmussa.
|
||||||
|
- [x] **llm_error-viestien tuki:** Lisätty hubiin ja frontendiin, streaming-kortti siivoutuu virhetilanteessa.
|
||||||
|
- [x] **Malli-cache (selain):** QwenModel pidetään muistissa `thread_local! MODEL_CACHE`:ssa, `clear_kv_cache()` promptien välillä.
|
||||||
|
- [x] **Malli-cache (natiivi):** `LlmEngine` pitää mallin muistissa, `fresh_model()` poistettu.
|
||||||
|
- [x] **Sampling:** Greedy argmax korvattu temperature + top-k + repetition penalty -samplingillä (sekä selain että natiivi).
|
||||||
|
- [x] **Stop-sekvenssit:** Generointi katkaistaan kun malli alkaa tuottaa selityksiä.
|
||||||
|
- [x] **Codelab/Agents-reititys:** `llm_done` ja `llm_chunk` reitittyy `task_id`:n perusteella oikeaan näkymään.
|
||||||
|
- [x] **Broadcast Lag:** `RecvError::Lagged` käsitellään gracefully sekä sender-taskissa että API-endpointissa — solmu ei enää tipu verkosta.
|
||||||
|
- [x] **Busy-tila reititys:** Hub seuraa solmujen busy-tilaa (`node_busy`). Tehtäviä ei enää reititetä varatuille solmuille.
|
||||||
|
- [x] **Rate limiting:** `/api/v1/chat/completions` rajoittaa max 10 pyyntöä/minuutti per IP.
|
||||||
|
- [x] **Gamification-validointi:** Kipinä-merkkejä jaetaan vain tehtävistä joiden `task_id` on hubin jakama (`pending_task_ids`).
|
||||||
|
- [x] **Base64:** Oma base64-dekooderi korvattu `base64`-cratella.
|
||||||
|
- [x] **Atominen siivous:** Solmun disconnect-siivouksessa kaikki lukot otetaan kerralla.
|
||||||
|
- [x] **DOM-vuoto:** Terminaalin trim ei enää poista aktiivista streaming-riviä.
|
||||||
|
|
||||||
|
## Havaitut Bugaavat Ominaisuudet ja Arkkitehtuuriongelmat
|
||||||
|
|
||||||
|
### Keskitaso (eivät estä käyttöä)
|
||||||
|
|
||||||
|
- [ ] **Origin-headerin validoinnin ohitus:** Natiivisolmut eivät lähetä Origin-headeria, joten tarkistus ohitetaan. Hyökkääjä voi esiintyä natiivisolmuna. Korjaus: vaadi autentikaatio natiivisolmuilta (API-avain tai token).
|
||||||
|
- [ ] **Kovakoodattu oletussalasana:** Admin-paneelin oletussalasana on `"kipina"` jos `ADMIN_PASSWORD`-ympäristömuuttujaa ei aseta. Tuotannossa pitää asettaa pakollisesti. Varoitus logitetaan.
|
||||||
|
|
||||||
|
### Arkkitehtuuriparannukset (tulevaisuus)
|
||||||
|
|
||||||
|
- [ ] **E2E-salaus:** Promptit ja vastaukset kulkevat selkokielisinä WebSocketin yli. Placeholder-kommentti koodissa, mutta ei toteutusta.
|
||||||
|
- [ ] **Proof of Work / konsensus:** Solmu voi lähettää väärennettyjä tuloksia. Merkitty TODO:ksi, mutta ei toteutusta.
|
||||||
|
- [ ] **WebGPU-inferenssi Candle-mallille:** Selainsolmu käyttää aina CPU:ta Candle-inferenssiin. Candle ei vielä tue WebGPU:ta.
|
||||||
|
- [ ] **Streaming yield -optimointi:** Pitkillä generoinneilla (>128 tok) selaimen event loop voi jäätyä hetkeksi koska generointilooppi ajetaan synkronisessa closuressa. Korjaus: pilko generointilooppi eriin ja yield joka N:s token.
|
||||||
|
|||||||
525
network-poc/BUILDING_BLOCKS.md
Normal file
525
network-poc/BUILDING_BLOCKS.md
Normal file
@@ -0,0 +1,525 @@
|
|||||||
|
# Kipinä Agentic Studio — Rakennuspalaset
|
||||||
|
|
||||||
|
Tämä dokumentti kuvaa projektin UI-komponentit, arkkitehtuuripatternit ja työnkulut niin, että vastaavan hajautetun AI-laskentaverkon ja agenttipohjaisen käyttöliittymän voi rakentaa alusta asti.
|
||||||
|
|
||||||
|
## Yleiskuva
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────┐
|
||||||
|
│ Selain (käyttäjä) │
|
||||||
|
│ ┌──────────┐ ┌──────────┐ ┌───────────────────┐ │
|
||||||
|
│ │ Verkko- │ │ Koodi- │ │ Agents-näkymä │ │
|
||||||
|
│ │ näkymä │ │ labra │ │ ┌───────────────┐ │ │
|
||||||
|
│ │ │ │ │ │ │ Terminaali │ │ │
|
||||||
|
│ │ Stats │ │ Editor │ │ │ Tab-complete │ │ │
|
||||||
|
│ │ Chat │ │ Pipeline │ │ │ Dropdown │ │ │
|
||||||
|
│ │ Tokenit │ │ Tulokset │ │ │ Historia │ │ │
|
||||||
|
│ └────┬─────┘ └────┬─────┘ │ └───────────────┘ │ │
|
||||||
|
│ │ │ └────────┬──────────┘ │
|
||||||
|
│ └──────────┬───┘ │ │
|
||||||
|
│ UI WebSocket HTTP API │
|
||||||
|
│ │ /api/v1/chat │
|
||||||
|
│ ┌───────────────┴──────────────┐ │ │
|
||||||
|
│ │ Wasm Compute Node │ │ │
|
||||||
|
│ │ (Candle + Burn) │ │ │
|
||||||
|
│ │ ┌─────────┐ ┌────────────┐ │ │ │
|
||||||
|
│ │ │ RAM │ │ IndexedDB │ │ │ │
|
||||||
|
│ │ │ Cache │ │ Cache │ │ │ │
|
||||||
|
│ │ └─────────┘ └────────────┘ │ │ │
|
||||||
|
│ │ ┌─────────────────────────┐ │ │ │
|
||||||
|
│ │ │ Model Cache (QwenModel) │ │ │ │
|
||||||
|
│ │ └─────────────────────────┘ │ │ │
|
||||||
|
│ └──────────────┬───────────────┘ │ │
|
||||||
|
│ │ WS │ │
|
||||||
|
└─────────────────┼──────────────────────┼─────────────┘
|
||||||
|
│ │
|
||||||
|
┌────────┴──────────────────────┴──┐
|
||||||
|
│ Hub (Axum + Tokio) │
|
||||||
|
│ ┌────────────┐ ┌─────────────┐ │
|
||||||
|
│ │ Broadcast │ │ Node │ │
|
||||||
|
│ │ Channel │ │ Registry │ │
|
||||||
|
│ └────────────┘ └─────────────┘ │
|
||||||
|
│ ┌────────────┐ ┌─────────────┐ │
|
||||||
|
│ │ Busy-State │ │ Rate Limit │ │
|
||||||
|
│ │ Tracker │ │ + Auth │ │
|
||||||
|
│ └────────────┘ └─────────────┘ │
|
||||||
|
│ ┌─────────────────────────────┐ │
|
||||||
|
│ │ SQLite (sessiot, tulokset) │ │
|
||||||
|
│ └─────────────────────────────┘ │
|
||||||
|
└──────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. WebSocket-reaaliaikakommunikaatio
|
||||||
|
|
||||||
|
### 1.1 Hub ↔ Node broadcast-kanava
|
||||||
|
|
||||||
|
**Tarkoitus:** Jakaa tehtäviä ja vastaanottaa tuloksia kaikilta laskentasolmuilta.
|
||||||
|
|
||||||
|
**Työnkulku:**
|
||||||
|
1. Hub luo `tokio::sync::broadcast::channel(100)`
|
||||||
|
2. Jokainen solmu saa oman `rx = stats_tx.subscribe()`
|
||||||
|
3. Hub broadcastaa tehtävät: `stats_tx.send(json)`
|
||||||
|
4. Solmut suodattavat viestin tyypin ja `selected_task`:n perusteella
|
||||||
|
|
||||||
|
**Viestityupit:**
|
||||||
|
|
||||||
|
| Tyyppi | Suunta | Sisältö |
|
||||||
|
|--------|--------|---------|
|
||||||
|
| `stats` | Hub → kaikki | nodes, vram_gb, tasks |
|
||||||
|
| `pair_task` | Hub → tokenize-solmut | en, fi tekstiparit |
|
||||||
|
| `llm_prompt` | Hub → valittu solmu | prompt, model, task_id |
|
||||||
|
| `llm_chunk` | Solmu → Hub → UI | token (1 kerrallaan) |
|
||||||
|
| `llm_done` | Solmu → Hub → UI | response, tokens_generated, duration_ms |
|
||||||
|
| `llm_error` | Solmu → Hub → UI | error, task_id |
|
||||||
|
| `task_routed` | Hub → UI | status (routed/queued), node_id, message |
|
||||||
|
|
||||||
|
**Lagged-viestien käsittely:**
|
||||||
|
```rust
|
||||||
|
match rx.recv().await {
|
||||||
|
Ok(msg) => { /* käsittele */ }
|
||||||
|
Err(broadcast::error::RecvError::Lagged(n)) => {
|
||||||
|
// Ohitetaan vanhat viestit, ei katkaista yhteyttä
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Err(_) => break, // Kanava suljettu
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 1.2 Kohdennettu reititys (Direct Channel)
|
||||||
|
|
||||||
|
**Tarkoitus:** Lähetä tehtävä yhdelle tietylle solmulle broadcastin sijaan.
|
||||||
|
|
||||||
|
**Työnkulku:**
|
||||||
|
1. Jokainen solmu saa `mpsc::unbounded_channel` yhdistyessään
|
||||||
|
2. Hub tallentaa `node_channels: HashMap<u64, UnboundedSender>`
|
||||||
|
3. API-pyyntö → valitaan vapaa solmu → lähetetään suoraan kanavaan
|
||||||
|
4. Broadcast-kanavaa käytetään vain tuloksen välittämiseen UI:lle
|
||||||
|
|
||||||
|
```rust
|
||||||
|
let channels = state.node_channels.read().await;
|
||||||
|
if let Some(tx) = channels.get(&target_node_id) {
|
||||||
|
tx.send(msg.to_string());
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 1.3 Busy-state ja työjono
|
||||||
|
|
||||||
|
**Tarkoitus:** Estä tehtävien reititys varatuille solmuille.
|
||||||
|
|
||||||
|
**Rakenne:**
|
||||||
|
- `node_busy: HashSet<u64>` — solmut joilla on aktiivinen tehtävä
|
||||||
|
- Asetetaan kun tehtävä reititetään, vapautetaan `llm_done`/`llm_error`:ssa
|
||||||
|
- Jos kaikki solmut varattuja → pollaa 500ms välein, max 30s
|
||||||
|
|
||||||
|
**UI-palaute:**
|
||||||
|
```json
|
||||||
|
{"type": "task_routed", "status": "queued", "message": "Kaikki 2 solmua varattuja — odotetaan..."}
|
||||||
|
{"type": "task_routed", "status": "routed", "node_id": 3, "message": "Solmu #3 vapautui (2.5s jonossa)"}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Wasm-laskentasolmu
|
||||||
|
|
||||||
|
### 2.1 Elinkaari
|
||||||
|
|
||||||
|
```
|
||||||
|
init() → start_agent_node(ws_url, has_webgpu, device_info, task_id)
|
||||||
|
│
|
||||||
|
├─ Avaa WebSocket hubiin
|
||||||
|
├─ Lähettää auth-viestin (laitetiedot, selected_task)
|
||||||
|
├─ Rekisteröityy onmessage-käsittelijä
|
||||||
|
│ ├─ pair_task → tokenize
|
||||||
|
│ ├─ llm_prompt → inference
|
||||||
|
│ └─ ai_task → tensor matmul
|
||||||
|
└─ Odottaa tehtäviä loopissa
|
||||||
|
```
|
||||||
|
|
||||||
|
**Globaali tila (atominen, lukitsematon):**
|
||||||
|
```rust
|
||||||
|
static GPU_LOAD_PERCENT: AtomicU32 = AtomicU32::new(50);
|
||||||
|
static LLM_BUSY: AtomicBool = AtomicBool::new(false);
|
||||||
|
static SELECTED_TASK: AtomicU32 = AtomicU32::new(0);
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2.2 Kolmitasoinen cache
|
||||||
|
|
||||||
|
```
|
||||||
|
Pyyntö → [1] RAM-cache (thread_local HashMap)
|
||||||
|
│ miss
|
||||||
|
▼
|
||||||
|
[2] IndexedDB (selaimen pysyvä tallennus)
|
||||||
|
│ miss
|
||||||
|
▼
|
||||||
|
[3] Verkko (HuggingFace CDN, streaming + 5% progressi)
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
Tallenna → IndexedDB → RAM-cache
|
||||||
|
```
|
||||||
|
|
||||||
|
| Taso | Nopeus | Koko | Pysyvyys |
|
||||||
|
|------|--------|------|----------|
|
||||||
|
| RAM | ~0ms | Rajaton | Sivulataus |
|
||||||
|
| IndexedDB | ~50ms | ~50GB | Pysyvä |
|
||||||
|
| Verkko | ~10s/100MB | ∞ | — |
|
||||||
|
|
||||||
|
**Malliinstanssin cache (neljäs taso):**
|
||||||
|
```rust
|
||||||
|
thread_local! {
|
||||||
|
static MODEL_CACHE: RefCell<Option<CachedModel>> = RefCell::new(None);
|
||||||
|
}
|
||||||
|
// clear_kv_cache() promptien välillä — ei tarvitse rakentaa mallia uusiksi
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2.3 Warmup-esilataus
|
||||||
|
|
||||||
|
**Tarkoitus:** Lataa malli valmiiksi ennen ensimmäistä oikeaa promptia.
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
// Lähetetään 1 tokenin warmup heti kun WS on auki
|
||||||
|
uiSocket.send(JSON.stringify({
|
||||||
|
type: 'user_text',
|
||||||
|
text: '{"prompt":"warmup","max_tokens":1}',
|
||||||
|
task_type: 'qwen-coder'
|
||||||
|
}));
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. LLM-inferenssipipeline
|
||||||
|
|
||||||
|
### 3.1 Prompt-formaatti (ChatML + prefill)
|
||||||
|
|
||||||
|
```
|
||||||
|
<|im_start|>system
|
||||||
|
You are a coding assistant. Respond with ONLY code.<|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
hello world in python<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
``` ← PREFILL: pakottaa mallin aloittamaan koodilla
|
||||||
|
```
|
||||||
|
|
||||||
|
**Prefill-tekniikka:** Lisäämällä ` ``` ` assistantin vastauksen alkuun malli jatkaa suoraan koodilla eikä tuota "Sure! Here is..." -johdantoa. Säästää 10-20 tokenia per vastaus.
|
||||||
|
|
||||||
|
### 3.2 Sampling-parametrit
|
||||||
|
|
||||||
|
| Parametri | Arvo | Tarkoitus |
|
||||||
|
|-----------|------|-----------|
|
||||||
|
| `temperature` | 0.7 | Pehmentää jakaumaa, vähentää toistoa |
|
||||||
|
| `top_k` | 40 | Rajaa valinnan 40 todennäköisimpään tokeniin |
|
||||||
|
| `repetition_penalty` | 1.15 | Rankaisee jo generoitujen tokenien uudelleenvalintaa |
|
||||||
|
| `max_tokens` | 128 | Oletusraja, JSON-promptilla konfiguroitavissa |
|
||||||
|
|
||||||
|
**Sampling-funktio (top-k + temperature + repetition penalty):**
|
||||||
|
```rust
|
||||||
|
fn sample_top_k_with_penalty(logits, k, temperature, generated_tokens, penalty) -> u32 {
|
||||||
|
// 1. Repetition penalty: vähennä aiempien tokenien logitteja
|
||||||
|
// 2. Temperature scaling: jaa logitit temperaturella
|
||||||
|
// 3. Top-k: ota k suurinta
|
||||||
|
// 4. Softmax top-k:lle
|
||||||
|
// 5. Satunnaisvalinta kumulatiivisella todennäköisyydellä (XorShift RNG)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.3 Stop-sekvenssit
|
||||||
|
|
||||||
|
Generointi katkaistaan ja teksti trimmataan kun malli alkaa selittää:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
let stop_patterns = ["\n###", "\nExplanation", "\nNote:", "\nOutput:", "\n```\n\n"];
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.4 Vastauksen siivous
|
||||||
|
|
||||||
|
```
|
||||||
|
Raakavastaus: "Sure! Here is...\n```python\n# This is a simple program\nprint('hi')\n```"
|
||||||
|
│
|
||||||
|
strip_markdown: "# This is a simple program\nprint('hi')"
|
||||||
|
│
|
||||||
|
strip_preamble: "print('hi')"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Tunnistettavat selityskommentit:** `# This is`, `# simple`, `# program that`, `# here is`, `# the following`, `# below`
|
||||||
|
|
||||||
|
### 3.5 Streaming
|
||||||
|
|
||||||
|
Jokainen generoitu token lähetetään heti `llm_chunk`-viestinä:
|
||||||
|
```json
|
||||||
|
{"type": "llm_chunk", "token": "print", "prompt": "...", "model": "Qwen2.5-Coder", "task_id": "uuid"}
|
||||||
|
```
|
||||||
|
|
||||||
|
UI päivittää streaming-korttia reaaliaikaisesti appendaamalla tokeneita.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. Terminaaliemulaattori
|
||||||
|
|
||||||
|
### 4.1 Rakenne
|
||||||
|
|
||||||
|
```html
|
||||||
|
<div id="agent-hub-status"> <!-- Status-palkki (Hub + Laskenta) -->
|
||||||
|
<div id="agent-terminal"> <!-- Scrollaava tulosalue, max 100 riviä -->
|
||||||
|
<div> <!-- Input-rivi -->
|
||||||
|
<span>$</span>
|
||||||
|
<input id="term-input">
|
||||||
|
<div id="term-dropdown"> <!-- Autocompletion-valikko -->
|
||||||
|
</div>
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.2 Komentojen käsittely
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
function termExec(cmd) {
|
||||||
|
// Parsitaan: "kpn" + alikomento + argumentit
|
||||||
|
// Tuetut: help, run, pipeline, load, status, models, hello, clear
|
||||||
|
// Agenttinimi → malli-mapping: "coder" → "qwen-coder"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.3 Tab-completion (kolmitasoinen)
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
const kpnCommands = {
|
||||||
|
'kpn': ['help', 'run', 'pipeline', 'load', ...],
|
||||||
|
'kpn run': ['coder', 'manager', 'qwen-coder', ...],
|
||||||
|
};
|
||||||
|
const kpnExamples = {
|
||||||
|
'kpn run coder': ['"hello world in python"', ...],
|
||||||
|
};
|
||||||
|
```
|
||||||
|
|
||||||
|
**Käyttö:**
|
||||||
|
|
||||||
|
| Näppäin | Toiminto |
|
||||||
|
|---------|----------|
|
||||||
|
| TAB | Täydennä seuraava sana tai avaa dropdown |
|
||||||
|
| Shift-TAB | Poista viimeinen sana (lainausmerkit kokonaisuutena) |
|
||||||
|
| ↑ / ↓ | Navigoi dropdownissa (tai komentohistoriassa) |
|
||||||
|
| Enter | Valitse dropdownista tai suorita komento |
|
||||||
|
| Esc | Sulje dropdown |
|
||||||
|
|
||||||
|
### 4.4 Dropdown-valikko
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
function showDropdown(items, prefix) {
|
||||||
|
// Luo div.term-dd-item per vaihtoehto
|
||||||
|
// Positio: absolute, bottom: 100% (inputin yläpuolella)
|
||||||
|
// Mouseenter → highlight, click → valinta
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.5 Komentohistoria
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
const termHistory = []; // Kaikki ajetut komennot (viimeisin ensin)
|
||||||
|
let termHistIdx = -1; // Nykyinen positio historiassa
|
||||||
|
// ArrowUp: termHistIdx++, ArrowDown: termHistIdx--
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. Status-palkit ja tilaindikaattorit
|
||||||
|
|
||||||
|
### 5.1 Hub-yhteyden tila
|
||||||
|
|
||||||
|
| Tila | Väri | Teksti | Tooltip |
|
||||||
|
|------|------|--------|---------|
|
||||||
|
| Yhdistetään | 🟡 | "Yhdistetään..." | WebSocket-yhteys Kipinä Hubiin |
|
||||||
|
| Yhdistetty | 🟢 | "Yhdistetty" | Tehtävien jakelu aktiivinen |
|
||||||
|
| Katkennut | 🔴 | "Yhteys katkennut" | Tarkista verkko, lataa uudelleen |
|
||||||
|
|
||||||
|
### 5.2 Laskentasolmun tila
|
||||||
|
|
||||||
|
| Tila | Väri | Teksti | Nappi |
|
||||||
|
|------|------|--------|-------|
|
||||||
|
| Ei käynnissä | ⚫ | "—" | `[Alusta laskentasolmu]` sininen |
|
||||||
|
| Lataa | 🟡 | "Ladataan..." | `[Peruuta]` punainen |
|
||||||
|
| Valmis | 🟢 | "Qwen2.5-Coder" | `[✓ Valmis]` vihreä |
|
||||||
|
|
||||||
|
### 5.3 Pipeline-tilakone (Codelab)
|
||||||
|
|
||||||
|
```
|
||||||
|
Step 1: WebAssembly-ytimen lataus [◯ → ◷ → ✓]
|
||||||
|
Step 2: Tokenizer (7 MB) [◯ → ◷ → ✓]
|
||||||
|
Step 3: Mallipainot (990 MB) [◯ → ◷ 45% → ✓ cache]
|
||||||
|
Step 4: Mallin rakentaminen [◯ → ◷ → ✓]
|
||||||
|
Step 5: Valmis generoimaan [◯ → ✓]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Seuranta console.log-viesteistä:**
|
||||||
|
```javascript
|
||||||
|
if (msg.includes('[Coder]') && msg.includes('Malli ladattu')) {
|
||||||
|
// Merkkaa kaikki vaiheet valmiiksi (myös cache-hitillä)
|
||||||
|
setStep('step-wasm', 'done');
|
||||||
|
setStep('step-tokenizer', 'done');
|
||||||
|
setStep('step-model', 'done', 'cache');
|
||||||
|
setStep('step-build', 'done');
|
||||||
|
setStep('step-ready', 'done');
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. Tietoturva
|
||||||
|
|
||||||
|
### 6.1 XSS-suojaus
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
function esc(str) {
|
||||||
|
return String(str).replace(/&/g,'&').replace(/</g,'<')
|
||||||
|
.replace(/>/g,'>').replace(/"/g,'"');
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Käyttöpaikat:** Kaikki `innerHTML`-insertoinnit joissa on käyttäjä- tai backend-dataa.
|
||||||
|
|
||||||
|
### 6.2 System prompt -piilotus
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
function stripSystemPrompt(prompt) {
|
||||||
|
const parts = prompt.split('\n\n');
|
||||||
|
return parts[parts.length - 1] || prompt;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.3 Viestityyppivalidointi (backend)
|
||||||
|
|
||||||
|
```rust
|
||||||
|
const ALLOWED_MSG_TYPES: &[&str] = &[
|
||||||
|
"auth", "result", "pair_done", "llm_chunk", "llm_done",
|
||||||
|
"llm_error", "download_progress", "user_text", "single_tokenize_done"
|
||||||
|
];
|
||||||
|
|
||||||
|
fn validate_message(text: &str) -> Result<Value, &'static str> {
|
||||||
|
// 1. JSON-parsinta
|
||||||
|
// 2. "type"-kenttä pakollinen
|
||||||
|
// 3. Tyyppi sallittujen listalla
|
||||||
|
// 4. Tyyppikohtainen validointi (esim. pair_done: token_count <= 10000)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.4 Rate limiting
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// Per-IP liukuva ikkuna: max 10 pyyntöä per 60s
|
||||||
|
let entry = limits.entry(addr.ip()).or_insert((now, 0));
|
||||||
|
if now.duration_since(entry.0).as_secs() >= 60 {
|
||||||
|
*entry = (now, 1);
|
||||||
|
} else {
|
||||||
|
entry.1 += 1;
|
||||||
|
if entry.1 > 10 { return 429 Too Many Requests; }
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.5 Gamification-huijauksen esto
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// Hub jakaa task_id:n → tallentaa pending_task_ids:hen
|
||||||
|
// Merkkejä jaetaan VAIN jos llm_done sisältää validin task_id:n
|
||||||
|
let valid_task = state.pending_task_ids.lock().unwrap().remove(tid);
|
||||||
|
if active_incentives && valid_task {
|
||||||
|
*balance += 20;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. Syntaksikorostus
|
||||||
|
|
||||||
|
### 7.1 Highlight.js-integraatio
|
||||||
|
|
||||||
|
```html
|
||||||
|
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.11.1/styles/github-dark.min.css">
|
||||||
|
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.11.1/highlight.min.js"></script>
|
||||||
|
```
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
function highlightCode(code) {
|
||||||
|
if (typeof hljs !== 'undefined') {
|
||||||
|
return hljs.highlightAuto(code).value; // Automaattinen kielentunnistus
|
||||||
|
}
|
||||||
|
return esc(code); // Fallback
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Käyttöpaikat:** Codelab-tulokset, agents-terminaalin vastaukset, network-chat.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. Agenttien orkestrointi
|
||||||
|
|
||||||
|
### 8.1 Multi-agent pipeline
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────┐ ┌──────────┐ ┌──────────┐
|
||||||
|
│ Manageri │ ──→ │ Koodari │ ──→ │ Testaaja │
|
||||||
|
│ Analysoi │ │ Koodaa │ │ Arvioi │
|
||||||
|
│ tehtävä │ │ ratkaisu │ │ koodi │
|
||||||
|
└──────────┘ └──────────┘ └──────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
async function kpnPipeline(task) {
|
||||||
|
const plan = await kpnRun('qwen-coder', `Analysoi: ${task}`);
|
||||||
|
if (!plan) return;
|
||||||
|
const code = await kpnRun('qwen-coder', `Koodaa: ${plan}`);
|
||||||
|
if (!code) return;
|
||||||
|
await kpnRun('smollm-135m', `Arvioi: ${code}`);
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 8.2 Agenttien promptien hallinta
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
const agentPrompts = {
|
||||||
|
manager: { model: 'qwen-coder', prompt: 'Olet projektipäällikkö...' },
|
||||||
|
coder: { model: 'qwen-coder', prompt: 'Olet ohjelmistokehittäjä...' },
|
||||||
|
// ...
|
||||||
|
};
|
||||||
|
// Tallennetaan localStorage:en per agentti
|
||||||
|
localStorage.setItem('kpn-agent-prompt-coder', customPrompt);
|
||||||
|
```
|
||||||
|
|
||||||
|
### 8.3 Yhteinen promptikonteksti
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
async function kpnRun(model, prompt) {
|
||||||
|
const parts = [];
|
||||||
|
if (sharedPrompt) parts.push(sharedPrompt); // Kaikille yhteinen
|
||||||
|
if (agent.prompt) parts.push(agent.prompt); // Agenttikohtainen
|
||||||
|
parts.push(prompt); // Käyttäjän pyyntö
|
||||||
|
const fullPrompt = parts.join('\n\n');
|
||||||
|
// → HTTP POST /api/v1/chat/completions
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. Teknologiapino
|
||||||
|
|
||||||
|
| Kerros | Teknologia | Tarkoitus |
|
||||||
|
|--------|------------|-----------|
|
||||||
|
| Frontend | Vanilla JS + HTML + CSS | Ei build-steppiä, toimii suoraan |
|
||||||
|
| Wasm | Rust + wasm-bindgen | Inferenssi selaimessa |
|
||||||
|
| LLM | Candle (Rust) | Transformer-inferenssi CPU:lla |
|
||||||
|
| Tensorit | Burn (Rust) | GPU-tensorilaskenta (WebGPU/NdArray) |
|
||||||
|
| Backend | Axum + Tokio (Rust) | Async WebSocket + HTTP -palvelin |
|
||||||
|
| Tietokanta | SQLite (rusqlite) | Sessiot ja tulokset |
|
||||||
|
| Cache | IndexedDB | Mallipainot selaimen pysyvässä muistissa |
|
||||||
|
| Korostus | Highlight.js (CDN) | Syntaksikorostus, automaattinen kielentunnistus |
|
||||||
|
| Tokenizer | HuggingFace tokenizers | BPE-tokenisaatio Wasmissa |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 10. Jatkokehitysideoita
|
||||||
|
|
||||||
|
Näiden rakennuspalasten pohjalta voi rakentaa:
|
||||||
|
|
||||||
|
- **Oma chat-UI:** WebSocket + streaming + syntaksikorostus
|
||||||
|
- **Hajautettu laskentaverkko:** Hub + node-rekisteri + busy-state + työjono
|
||||||
|
- **Selain-LLM:** Wasm + Candle + IndexedDB-cache + warmup
|
||||||
|
- **Agenttipohjainen työnkulku:** Pipeline + prompt-orkestrointi + reititys
|
||||||
|
- **Terminaaliemulasttori:** Input + historia + tab-completion + dropdown
|
||||||
|
- **Reaaliaikadashboard:** WebSocket broadcast + tilaindikaattorit + metriikat
|
||||||
@@ -1,6 +1,13 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
|
if [ "$1" == "local" ]; then
|
||||||
|
echo "=== Kipinä Studio Local Development ==="
|
||||||
|
echo "Käynnistetään kokonaisuus puhtaasti Docker-kontissa..."
|
||||||
|
docker compose up agentic-poc
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
SERVER="ubuntu@86.50.252.98"
|
SERVER="ubuntu@86.50.252.98"
|
||||||
REMOTE_DIR="~/code/agentic-studio/network-poc"
|
REMOTE_DIR="~/code/agentic-studio/network-poc"
|
||||||
KEY="$HOME/.ssh/id_rsa"
|
KEY="$HOME/.ssh/id_rsa"
|
||||||
@@ -14,9 +21,23 @@ fi
|
|||||||
|
|
||||||
echo "=== Kipinä Studio Deploy ==="
|
echo "=== Kipinä Studio Deploy ==="
|
||||||
|
|
||||||
|
# 0. Commitoidaan uncommitted muutokset ennen deployta
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||||
|
if ! git -C "$SCRIPT_DIR" diff --quiet HEAD 2>/dev/null || \
|
||||||
|
[ -n "$(git -C "$SCRIPT_DIR" ls-files --others --exclude-standard 2>/dev/null)" ]; then
|
||||||
|
echo "[0] Uncommitted muutoksia havaittu — commitoidaan..."
|
||||||
|
read -rp " Commit-viesti: " DEPLOY_MSG
|
||||||
|
if [ -z "$DEPLOY_MSG" ]; then
|
||||||
|
DEPLOY_MSG="Deploy $(date +%Y-%m-%d\ %H:%M)"
|
||||||
|
fi
|
||||||
|
git -C "$SCRIPT_DIR" add -A
|
||||||
|
git -C "$SCRIPT_DIR" commit -m "$DEPLOY_MSG"
|
||||||
|
echo " Commitoitu: $DEPLOY_MSG"
|
||||||
|
fi
|
||||||
|
|
||||||
# 1. Rakennetaan Docker-image lokaalisti
|
# 1. Rakennetaan Docker-image lokaalisti
|
||||||
echo "[1/4] Rakennetaan image lokaalisti..."
|
echo "[1/4] Rakennetaan image lokaalisti..."
|
||||||
docker build -f Dockerfile.prod -t kipina-agentic:latest .
|
docker build --platform linux/amd64 -f Dockerfile.prod -t kipina-agentic:latest .
|
||||||
|
|
||||||
# 2. Tallennetaan tiedostoon
|
# 2. Tallennetaan tiedostoon
|
||||||
echo "[2/5] Pakataan image..."
|
echo "[2/5] Pakataan image..."
|
||||||
@@ -39,7 +60,11 @@ echo "=== Valmis! https://kipina.studio ==="
|
|||||||
|
|
||||||
# Discord-notifikaatio
|
# Discord-notifikaatio
|
||||||
DISCORD_WEBHOOK="https://discord.com/api/webhooks/1489504066898755687/8U02d0wug-3MkVax0xMmRoj0s_-V1psnNLPWdSOjnGnKRBUpPjaU6XiX9Iu8DgJI69AP"
|
DISCORD_WEBHOOK="https://discord.com/api/webhooks/1489504066898755687/8U02d0wug-3MkVax0xMmRoj0s_-V1psnNLPWdSOjnGnKRBUpPjaU6XiX9Iu8DgJI69AP"
|
||||||
COMMIT_MSG=$(git log -1 --pretty=format:"%s" 2>/dev/null || echo "?")
|
COMMIT_HASH=$(git -C "$SCRIPT_DIR" log -1 --pretty=format:"%h" 2>/dev/null || echo "?")
|
||||||
curl -s -H "Content-Type: application/json" \
|
COMMIT_MSG=$(git -C "$SCRIPT_DIR" log -1 --pretty=format:"%s" 2>/dev/null || echo "?")
|
||||||
-d "{\"content\":\"🚀 **Kipinä Studio julkaistu!**\n> ${COMMIT_MSG}\n> https://kipina.studio\n> Admin: https://kipina.studio/admin (salasana: kipina)\"}" \
|
# python3 escapettaa erikoismerkit JSON-turvallisesti
|
||||||
"$DISCORD_WEBHOOK" > /dev/null
|
PAYLOAD=$(python3 -c "import json,sys; print(json.dumps({'content': sys.argv[1]}))" \
|
||||||
|
"🚀 **Kipinä Studio julkaistu!**
|
||||||
|
> \`${COMMIT_HASH}\` ${COMMIT_MSG}
|
||||||
|
> https://kipina.studio")
|
||||||
|
curl -s -H "Content-Type: application/json" -d "$PAYLOAD" "$DISCORD_WEBHOOK" > /dev/null
|
||||||
|
|||||||
@@ -15,3 +15,4 @@ uuid = { version = "1.7.0", features = ["v4", "serde"] }
|
|||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
rusqlite = { version = "0.31", features = ["bundled"] }
|
rusqlite = { version = "0.31", features = ["bundled"] }
|
||||||
chrono = "0.4"
|
chrono = "0.4"
|
||||||
|
base64 = "0.22"
|
||||||
|
|||||||
Binary file not shown.
@@ -25,16 +25,23 @@ const ALLOWED_ORIGINS: &[&str] = &[
|
|||||||
];
|
];
|
||||||
|
|
||||||
// Sallitut viestityyypit clientilta
|
// Sallitut viestityyypit clientilta
|
||||||
const ALLOWED_MSG_TYPES: &[&str] = &["auth", "result", "pair_done", "llm_chunk", "llm_done", "download_progress", "user_text", "single_tokenize_done"];
|
const ALLOWED_MSG_TYPES: &[&str] = &["auth", "result", "pair_done", "llm_chunk", "llm_done", "llm_error", "download_progress", "user_text", "single_tokenize_done"];
|
||||||
|
|
||||||
struct AppState {
|
struct AppState {
|
||||||
next_node_id: Mutex<u64>,
|
next_node_id: Mutex<u64>,
|
||||||
nodes_vram: Mutex<HashMap<u64, u32>>,
|
nodes_vram: Mutex<HashMap<u64, u32>>,
|
||||||
|
nodes_tokens: Mutex<HashMap<u64, u32>>, // Gamification: Kipinä Tokens
|
||||||
total_tasks: Mutex<u64>,
|
total_tasks: Mutex<u64>,
|
||||||
stats_tx: broadcast::Sender<String>,
|
stats_tx: broadcast::Sender<String>,
|
||||||
|
node_channels: tokio::sync::RwLock<HashMap<u64, tokio::sync::mpsc::UnboundedSender<String>>>, // Kohdennettu reititys
|
||||||
|
pending_consensus: tokio::sync::RwLock<HashMap<String, Vec<serde_json::Value>>>, // Proof of Compute -konsensus
|
||||||
|
feature_flags: tokio::sync::RwLock<HashMap<String, bool>>, // Tuntee TODO.md:n ruksit lennosta
|
||||||
ip_connections: Mutex<HashMap<IpAddr, u32>>,
|
ip_connections: Mutex<HashMap<IpAddr, u32>>,
|
||||||
node_ips: Mutex<HashMap<u64, IpAddr>>,
|
node_ips: Mutex<HashMap<u64, IpAddr>>,
|
||||||
node_tasks: Mutex<HashMap<u64, String>>, // node_id → selected_task
|
node_tasks: Mutex<HashMap<u64, String>>, // node_id → selected_task
|
||||||
|
node_busy: Mutex<std::collections::HashSet<u64>>, // Solmut joilla on aktiivinen tehtävä
|
||||||
|
pending_task_ids: Mutex<std::collections::HashSet<String>>, // Hubin jakamat task_id:t (gamification-validointi)
|
||||||
|
api_rate_limits: Mutex<HashMap<IpAddr, (std::time::Instant, u32)>>, // IP → (ikkuna-alku, pyyntömäärä)
|
||||||
db: db::NodeDb,
|
db: db::NodeDb,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -244,16 +251,51 @@ async fn main() {
|
|||||||
let state = Arc::new(AppState {
|
let state = Arc::new(AppState {
|
||||||
next_node_id: Mutex::new(1),
|
next_node_id: Mutex::new(1),
|
||||||
nodes_vram: Mutex::new(HashMap::new()),
|
nodes_vram: Mutex::new(HashMap::new()),
|
||||||
|
nodes_tokens: Mutex::new(HashMap::new()),
|
||||||
total_tasks: Mutex::new(0),
|
total_tasks: Mutex::new(0),
|
||||||
stats_tx: stats_tx.clone(),
|
stats_tx: stats_tx.clone(),
|
||||||
|
node_channels: tokio::sync::RwLock::new(HashMap::new()),
|
||||||
|
pending_consensus: tokio::sync::RwLock::new(HashMap::new()),
|
||||||
|
feature_flags: tokio::sync::RwLock::new(HashMap::new()),
|
||||||
ip_connections: Mutex::new(HashMap::new()),
|
ip_connections: Mutex::new(HashMap::new()),
|
||||||
node_ips: Mutex::new(HashMap::new()),
|
node_ips: Mutex::new(HashMap::new()),
|
||||||
node_tasks: Mutex::new(HashMap::new()),
|
node_tasks: Mutex::new(HashMap::new()),
|
||||||
|
node_busy: Mutex::new(std::collections::HashSet::new()),
|
||||||
|
pending_task_ids: Mutex::new(std::collections::HashSet::new()),
|
||||||
|
api_rate_limits: Mutex::new(HashMap::new()),
|
||||||
db: db::NodeDb::new(&std::env::var("DATABASE_PATH").unwrap_or_else(|_| "nodes.db".to_string())),
|
db: db::NodeDb::new(&std::env::var("DATABASE_PATH").unwrap_or_else(|_| "nodes.db".to_string())),
|
||||||
});
|
});
|
||||||
|
|
||||||
tracing::info!("Tietokanta alustettu");
|
tracing::info!("Tietokanta alustettu");
|
||||||
|
|
||||||
|
let state_for_watcher = state.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
// Ensimmäinen luku heti, sitten 3s välein
|
||||||
|
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(3));
|
||||||
|
let file_path = std::env::var("FEATURE_FLAGS_FILE").unwrap_or_else(|_| "../TODO.md".to_string());
|
||||||
|
|
||||||
|
loop {
|
||||||
|
interval.tick().await;
|
||||||
|
if let Ok(content) = tokio::fs::read_to_string(&file_path).await {
|
||||||
|
let mut flags = HashMap::new();
|
||||||
|
for line in content.lines() {
|
||||||
|
if line.starts_with("- [ ] **") || line.starts_with("- [x] **") {
|
||||||
|
let is_active = line.starts_with("- [x]");
|
||||||
|
if let Some(start_idx) = line.find("**") {
|
||||||
|
let start = start_idx + 2;
|
||||||
|
if let Some(end_idx) = line[start..].find("**") {
|
||||||
|
let end = end_idx + start;
|
||||||
|
let feature_name = line[start..end].trim_end_matches(':').trim().to_string();
|
||||||
|
flags.insert(feature_name, is_active);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*state_for_watcher.feature_flags.write().await = flags;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
let state_for_task = state.clone();
|
let state_for_task = state.clone();
|
||||||
|
|
||||||
// Ajastin, joka jakaa satunnaisia tekoälytehtäviä eri pituuksilla
|
// Ajastin, joka jakaa satunnaisia tekoälytehtäviä eri pituuksilla
|
||||||
@@ -376,20 +418,30 @@ async fn api_stats(
|
|||||||
) -> axum::response::Response {
|
) -> axum::response::Response {
|
||||||
if !check_admin_auth(&headers) { return admin_unauthorized(); }
|
if !check_admin_auth(&headers) { return admin_unauthorized(); }
|
||||||
let mut stats = state.db.get_stats();
|
let mut stats = state.db.get_stats();
|
||||||
stats.as_object_mut().unwrap().insert("version".to_string(), serde_json::json!(env!("CARGO_PKG_VERSION")));
|
if let Some(obj) = stats.as_object_mut() {
|
||||||
|
obj.insert("version".to_string(), serde_json::json!(env!("CARGO_PKG_VERSION")));
|
||||||
|
}
|
||||||
axum::Json(stats).into_response()
|
axum::Json(stats).into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_admin_auth(headers: &axum::http::HeaderMap) -> bool {
|
fn check_admin_auth(headers: &axum::http::HeaderMap) -> bool {
|
||||||
let password = std::env::var("ADMIN_PASSWORD").unwrap_or_else(|_| "kipina".to_string());
|
let password = match std::env::var("ADMIN_PASSWORD") {
|
||||||
|
Ok(p) if !p.is_empty() => p,
|
||||||
|
_ => {
|
||||||
|
tracing::warn!("ADMIN_PASSWORD ei ole asetettu — käytetään oletusta 'kipina' (ÄLÄ käytä tuotannossa!)");
|
||||||
|
"kipina".to_string()
|
||||||
|
}
|
||||||
|
};
|
||||||
if let Some(auth) = headers.get("authorization").and_then(|v| v.to_str().ok()) {
|
if let Some(auth) = headers.get("authorization").and_then(|v| v.to_str().ok()) {
|
||||||
if auth.starts_with("Basic ") {
|
if auth.starts_with("Basic ") {
|
||||||
if let Ok(decoded) = String::from_utf8(
|
use base64::Engine;
|
||||||
base64_decode(auth.trim_start_matches("Basic ").trim())
|
if let Ok(decoded_bytes) = base64::engine::general_purpose::STANDARD
|
||||||
) {
|
.decode(auth.trim_start_matches("Basic ").trim())
|
||||||
// Tarkistetaan "user:password" — käyttäjänimi ei väliä
|
{
|
||||||
if let Some(pass) = decoded.split(':').nth(1) {
|
if let Ok(decoded) = String::from_utf8(decoded_bytes) {
|
||||||
return pass == password;
|
if let Some(pass) = decoded.split(':').nth(1) {
|
||||||
|
return pass == password;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -397,20 +449,6 @@ fn check_admin_auth(headers: &axum::http::HeaderMap) -> bool {
|
|||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
fn base64_decode(input: &str) -> Vec<u8> {
|
|
||||||
// Yksinkertainen base64-dekooderi
|
|
||||||
const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
|
||||||
let mut out = Vec::new();
|
|
||||||
let bytes: Vec<u8> = input.bytes().filter(|&b| b != b'=').collect();
|
|
||||||
for chunk in bytes.chunks(4) {
|
|
||||||
let vals: Vec<u8> = chunk.iter().filter_map(|&b| TABLE.iter().position(|&t| t == b).map(|p| p as u8)).collect();
|
|
||||||
if vals.len() >= 2 { out.push((vals[0] << 2) | (vals[1] >> 4)); }
|
|
||||||
if vals.len() >= 3 { out.push((vals[1] << 4) | (vals[2] >> 2)); }
|
|
||||||
if vals.len() >= 4 { out.push((vals[2] << 6) | vals[3]); }
|
|
||||||
}
|
|
||||||
out
|
|
||||||
}
|
|
||||||
|
|
||||||
fn admin_unauthorized() -> axum::response::Response {
|
fn admin_unauthorized() -> axum::response::Response {
|
||||||
axum::response::Response::builder()
|
axum::response::Response::builder()
|
||||||
.status(401)
|
.status(401)
|
||||||
@@ -555,22 +593,35 @@ async fn handle_socket(socket: WebSocket, state: Arc<AppState>, ip: IpAddr) {
|
|||||||
|
|
||||||
tracing::info!("Solmu {} yhdistyi osoitteesta {}", node_id, ip);
|
tracing::info!("Solmu {} yhdistyi osoitteesta {}", node_id, ip);
|
||||||
|
|
||||||
|
let (node_tx, mut node_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
|
||||||
|
|
||||||
|
// Tallennetaan node channel reititystä varten
|
||||||
|
{
|
||||||
|
state.node_channels.write().await.insert(node_id, node_tx);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Yksinkertaistettu broadcast tx vastaanotto
|
||||||
let mut rx = state.stats_tx.subscribe();
|
let mut rx = state.stats_tx.subscribe();
|
||||||
|
|
||||||
let sender_task = tokio::spawn(async move {
|
let sender_task = tokio::spawn(async move {
|
||||||
loop {
|
loop {
|
||||||
match rx.recv().await {
|
tokio::select! {
|
||||||
Ok(msg) => {
|
result = rx.recv() => {
|
||||||
if sender.send(Message::Text(msg)).await.is_err() {
|
match result {
|
||||||
break;
|
Ok(msg) => {
|
||||||
|
if sender.send(Message::Text(msg)).await.is_err() { break; }
|
||||||
|
}
|
||||||
|
Err(broadcast::error::RecvError::Lagged(n)) => {
|
||||||
|
tracing::debug!("Broadcast lagged {} viestiä — ohitetaan", n);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Err(_) => break, // Kanava suljettu
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
|
Some(direct_msg) = node_rx.recv() => {
|
||||||
continue;
|
if sender.send(Message::Text(direct_msg)).await.is_err() { break; }
|
||||||
}
|
|
||||||
Err(_) => {
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
else => break,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -592,7 +643,8 @@ async fn handle_socket(socket: WebSocket, state: Arc<AppState>, ip: IpAddr) {
|
|||||||
let json = match validate_message(&text) {
|
let json = match validate_message(&text) {
|
||||||
Ok(j) => j,
|
Ok(j) => j,
|
||||||
Err(reason) => {
|
Err(reason) => {
|
||||||
tracing::warn!("Solmu {} ({}) lähetti virheellisen viestin: {} — {:?}", node_id, ip, reason, &text[..text.len().min(100)]);
|
let preview: String = text.chars().take(100).collect();
|
||||||
|
tracing::warn!("Solmu {} ({}) lähetti virheellisen viestin: {} — {:?}", node_id, ip, reason, preview);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -722,10 +774,32 @@ async fn handle_socket(socket: WebSocket, state: Arc<AppState>, ip: IpAddr) {
|
|||||||
}
|
}
|
||||||
let _ = state.stats_tx.send(json.to_string());
|
let _ = state.stats_tx.send(json.to_string());
|
||||||
|
|
||||||
|
let active_incentives = state.feature_flags.read().await.get("Insentiivit").copied().unwrap_or(false);
|
||||||
|
let ui_sync = state.feature_flags.read().await.get("Pelimerkkien UI-synkkaus").copied().unwrap_or(false);
|
||||||
|
let mut current_balance = 0;
|
||||||
|
|
||||||
{
|
{
|
||||||
let mut task_count = state.total_tasks.lock().unwrap();
|
let mut task_count = state.total_tasks.lock().unwrap();
|
||||||
*task_count += 1;
|
*task_count += 1;
|
||||||
|
|
||||||
|
if active_incentives {
|
||||||
|
let mut tokens = state.nodes_tokens.lock().unwrap();
|
||||||
|
let balance = tokens.entry(node_id).or_insert(0);
|
||||||
|
*balance += 5; // Palkkio: 5 Kipinä-merkkiä
|
||||||
|
current_balance = *balance;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if active_incentives && ui_sync {
|
||||||
|
if let Some(tx) = state.node_channels.read().await.get(&node_id) {
|
||||||
|
let msg = serde_json::json!({
|
||||||
|
"type": "token_balance",
|
||||||
|
"balance": current_balance
|
||||||
|
});
|
||||||
|
let _ = tx.send(msg.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
broadcast_stats(&state).await;
|
broadcast_stats(&state).await;
|
||||||
}
|
}
|
||||||
} else if msg_type == "single_tokenize_done" {
|
} else if msg_type == "single_tokenize_done" {
|
||||||
@@ -745,6 +819,13 @@ async fn handle_socket(socket: WebSocket, state: Arc<AppState>, ip: IpAddr) {
|
|||||||
let _ = state.stats_tx.send(json.to_string());
|
let _ = state.stats_tx.send(json.to_string());
|
||||||
}
|
}
|
||||||
} else if msg_type == "llm_done" {
|
} else if msg_type == "llm_done" {
|
||||||
|
// Vapautetaan solmu ja tarkistetaan task_id:n aitous
|
||||||
|
state.node_busy.lock().unwrap().remove(&node_id);
|
||||||
|
let valid_task = if let Some(tid) = json.get("task_id").and_then(|v| v.as_str()) {
|
||||||
|
state.pending_task_ids.lock().unwrap().remove(tid)
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
};
|
||||||
{
|
{
|
||||||
let mut json = json;
|
let mut json = json;
|
||||||
if let Some(obj) = json.as_object_mut() {
|
if let Some(obj) = json.as_object_mut() {
|
||||||
@@ -766,18 +847,53 @@ async fn handle_socket(socket: WebSocket, state: Arc<AppState>, ip: IpAddr) {
|
|||||||
}
|
}
|
||||||
let _ = state.stats_tx.send(json.to_string());
|
let _ = state.stats_tx.send(json.to_string());
|
||||||
|
|
||||||
|
let active_incentives = state.feature_flags.read().await.get("Insentiivit").copied().unwrap_or(false);
|
||||||
|
let ui_sync = state.feature_flags.read().await.get("Pelimerkkien UI-synkkaus").copied().unwrap_or(false);
|
||||||
|
let mut current_balance = 0;
|
||||||
|
|
||||||
{
|
{
|
||||||
let mut task_count = state.total_tasks.lock().unwrap();
|
let mut task_count = state.total_tasks.lock().unwrap();
|
||||||
*task_count += 1;
|
*task_count += 1;
|
||||||
|
|
||||||
|
if active_incentives && valid_task {
|
||||||
|
let mut tokens = state.nodes_tokens.lock().unwrap();
|
||||||
|
let balance = tokens.entry(node_id).or_insert(0);
|
||||||
|
*balance += 20; // Palkkio: 20 Kipinä-merkkiä
|
||||||
|
current_balance = *balance;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if active_incentives && ui_sync {
|
||||||
|
if let Some(tx) = state.node_channels.read().await.get(&node_id) {
|
||||||
|
let msg = serde_json::json!({
|
||||||
|
"type": "token_balance",
|
||||||
|
"balance": current_balance
|
||||||
|
});
|
||||||
|
let _ = tx.send(msg.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
broadcast_stats(&state).await;
|
broadcast_stats(&state).await;
|
||||||
}
|
}
|
||||||
|
} else if msg_type == "llm_error" {
|
||||||
|
state.node_busy.lock().unwrap().remove(&node_id);
|
||||||
|
if let Some(tid) = json.get("task_id").and_then(|v| v.as_str()) {
|
||||||
|
state.pending_task_ids.lock().unwrap().remove(tid);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
let mut json = json;
|
||||||
|
if let Some(obj) = json.as_object_mut() {
|
||||||
|
obj.insert("node_id".to_string(), serde_json::json!(node_id));
|
||||||
|
}
|
||||||
|
let _ = state.stats_tx.send(json.to_string());
|
||||||
|
}
|
||||||
} else if msg_type == "user_text" {
|
} else if msg_type == "user_text" {
|
||||||
// Käyttäjän lähettämä teksti — broadcastataan pair_taskina ja llm_promptina
|
// Käyttäjän lähettämä teksti — broadcastataan pair_taskina ja llm_promptina
|
||||||
let text = json.get("text").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
let text = json.get("text").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||||
let task_type = json.get("task_type").and_then(|v| v.as_str()).unwrap_or("tokenize");
|
let task_type = json.get("task_type").and_then(|v| v.as_str()).unwrap_or("tokenize");
|
||||||
if !text.is_empty() {
|
if !text.is_empty() {
|
||||||
tracing::info!("Solmu {} lähetti oman tekstin ({}): \"{}\"", node_id, task_type, &text[..text.len().min(80)]);
|
let preview: String = text.chars().take(80).collect();
|
||||||
|
tracing::info!("Solmu {} lähetti oman tekstin ({}): \"{}\"", node_id, task_type, preview);
|
||||||
match task_type {
|
match task_type {
|
||||||
"tokenize" => {
|
"tokenize" => {
|
||||||
let msg = serde_json::json!({
|
let msg = serde_json::json!({
|
||||||
@@ -787,38 +903,36 @@ async fn handle_socket(socket: WebSocket, state: Arc<AppState>, ip: IpAddr) {
|
|||||||
let _ = state.stats_tx.send(msg.to_string());
|
let _ = state.stats_tx.send(msg.to_string());
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
// LLM-prompti
|
// LLM-prompti: lähetetään VAIN valitulle mallille, ei kaikille (välttää turhaa ruuhkaa ja busy-tiloja)
|
||||||
for model in &["smollm-135m", "qwen-05b", "phi3-mini", "qwen-coder"] {
|
let prompt = serde_json::json!({
|
||||||
let prompt = serde_json::json!({
|
"type": "llm_prompt",
|
||||||
"type": "llm_prompt",
|
"prompt": text,
|
||||||
"prompt": text,
|
"model": task_type,
|
||||||
"model": model,
|
});
|
||||||
});
|
let _ = state.stats_tx.send(prompt.to_string());
|
||||||
let _ = state.stats_tx.send(prompt.to_string());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Yhteys katkesi — merkitään session päättyneeksi ja siivotaan
|
// Yhteys katkesi — merkitään session päättyneeksi ja siivotaan atomisesti
|
||||||
state.db.close_session(node_id);
|
state.db.close_session(node_id);
|
||||||
state.node_tasks.lock().unwrap().remove(&node_id);
|
|
||||||
{
|
{
|
||||||
|
// Lukitaan kaikki kerralla, jotta solmu ei ole osittain siivottu
|
||||||
|
let mut tasks = state.node_tasks.lock().unwrap();
|
||||||
let mut conns = state.ip_connections.lock().unwrap();
|
let mut conns = state.ip_connections.lock().unwrap();
|
||||||
|
let mut ips = state.node_ips.lock().unwrap();
|
||||||
|
let mut vram = state.nodes_vram.lock().unwrap();
|
||||||
|
let mut busy = state.node_busy.lock().unwrap();
|
||||||
|
tasks.remove(&node_id);
|
||||||
|
busy.remove(&node_id);
|
||||||
if let Some(count) = conns.get_mut(&ip) {
|
if let Some(count) = conns.get_mut(&ip) {
|
||||||
*count = count.saturating_sub(1);
|
*count = count.saturating_sub(1);
|
||||||
if *count == 0 {
|
if *count == 0 { conns.remove(&ip); }
|
||||||
conns.remove(&ip);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
ips.remove(&node_id);
|
||||||
{
|
vram.remove(&node_id);
|
||||||
state.node_ips.lock().unwrap().remove(&node_id);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
state.nodes_vram.lock().unwrap().remove(&node_id);
|
|
||||||
}
|
}
|
||||||
tracing::info!("Solmu {} ({}) poistui verkosta.", node_id, ip);
|
tracing::info!("Solmu {} ({}) poistui verkosta.", node_id, ip);
|
||||||
broadcast_stats(&state).await;
|
broadcast_stats(&state).await;
|
||||||
@@ -840,8 +954,111 @@ struct ChatCompletionResponse {
|
|||||||
|
|
||||||
async fn api_chat_completions(
|
async fn api_chat_completions(
|
||||||
axum::extract::State(state): axum::extract::State<Arc<AppState>>,
|
axum::extract::State(state): axum::extract::State<Arc<AppState>>,
|
||||||
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||||
axum::Json(payload): axum::Json<ChatCompletionRequest>,
|
axum::Json(payload): axum::Json<ChatCompletionRequest>,
|
||||||
) -> axum::response::Response {
|
) -> axum::response::Response {
|
||||||
|
// Rate limiting: max 10 pyyntöä per IP per minuutti
|
||||||
|
{
|
||||||
|
let mut limits = state.api_rate_limits.lock().unwrap();
|
||||||
|
let now = std::time::Instant::now();
|
||||||
|
let entry = limits.entry(addr.ip()).or_insert((now, 0));
|
||||||
|
if now.duration_since(entry.0).as_secs() >= 60 {
|
||||||
|
*entry = (now, 1); // Uusi ikkuna
|
||||||
|
} else {
|
||||||
|
entry.1 += 1;
|
||||||
|
if entry.1 > 10 {
|
||||||
|
return (axum::http::StatusCode::TOO_MANY_REQUESTS, "Liian monta pyyntöä — yritä minuutin kuluttua").into_response();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Etsitään vapaa tai varattu solmu, joka vastaa pyydettyä mallia
|
||||||
|
let (target_node_free, target_node_any, total_matching) = {
|
||||||
|
let tasks = state.node_tasks.lock().unwrap();
|
||||||
|
let busy = state.node_busy.lock().unwrap();
|
||||||
|
let matching: Vec<u64> = tasks.iter().filter(|(_, task)| {
|
||||||
|
if payload.model == "qwen-coder" {
|
||||||
|
*task == "qwen-coder-05b" || *task == "qwen-coder"
|
||||||
|
} else {
|
||||||
|
**task == payload.model
|
||||||
|
}
|
||||||
|
}).map(|(k, _)| *k).collect();
|
||||||
|
let free = matching.iter().find(|id| !busy.contains(id)).copied();
|
||||||
|
let any = matching.first().copied();
|
||||||
|
(free, any, matching.len())
|
||||||
|
};
|
||||||
|
|
||||||
|
// Broadcastataan reititystila UI:lle
|
||||||
|
let task_id = payload.task_id.clone();
|
||||||
|
|
||||||
|
if target_node_any.is_none() {
|
||||||
|
// Ei yhtään solmua tälle mallille
|
||||||
|
return (axum::http::StatusCode::SERVICE_UNAVAILABLE, "Ei solmua tälle mallille (käynnistä malli selaimessa)").into_response();
|
||||||
|
}
|
||||||
|
|
||||||
|
let target_node_id;
|
||||||
|
if let Some(free_id) = target_node_free {
|
||||||
|
// Vapaa solmu löytyi — reititetään suoraan
|
||||||
|
target_node_id = free_id;
|
||||||
|
let node_type = if state.node_tasks.lock().unwrap().get(&free_id).map(|t| t.contains("native")).unwrap_or(false) { "natiivi" } else { "selain" };
|
||||||
|
let routing_msg = serde_json::json!({
|
||||||
|
"type": "task_routed",
|
||||||
|
"task_id": task_id,
|
||||||
|
"node_id": free_id,
|
||||||
|
"node_type": node_type,
|
||||||
|
"status": "routed",
|
||||||
|
"message": format!("Reititetty solmulle #{}", free_id),
|
||||||
|
});
|
||||||
|
let _ = state.stats_tx.send(routing_msg.to_string());
|
||||||
|
} else {
|
||||||
|
// Kaikki solmut varattuja — odotetaan vapautumista (max 30s)
|
||||||
|
let queue_msg = serde_json::json!({
|
||||||
|
"type": "task_routed",
|
||||||
|
"task_id": task_id,
|
||||||
|
"status": "queued",
|
||||||
|
"message": format!("Kaikki {} solmua varattuja — odotetaan vapautumista...", total_matching),
|
||||||
|
});
|
||||||
|
let _ = state.stats_tx.send(queue_msg.to_string());
|
||||||
|
|
||||||
|
// Pollaa busy-tilaa 500ms välein, max 30s
|
||||||
|
let mut waited = 0u32;
|
||||||
|
loop {
|
||||||
|
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
|
||||||
|
waited += 500;
|
||||||
|
let free = {
|
||||||
|
let tasks = state.node_tasks.lock().unwrap();
|
||||||
|
let busy = state.node_busy.lock().unwrap();
|
||||||
|
tasks.iter().find(|(node_id, task)| {
|
||||||
|
let model_match = if payload.model == "qwen-coder" {
|
||||||
|
*task == "qwen-coder-05b" || *task == "qwen-coder"
|
||||||
|
} else {
|
||||||
|
**task == payload.model
|
||||||
|
};
|
||||||
|
model_match && !busy.contains(node_id)
|
||||||
|
}).map(|(k, _)| *k)
|
||||||
|
};
|
||||||
|
if let Some(id) = free {
|
||||||
|
target_node_id = id;
|
||||||
|
let routing_msg = serde_json::json!({
|
||||||
|
"type": "task_routed",
|
||||||
|
"task_id": task_id,
|
||||||
|
"node_id": id,
|
||||||
|
"status": "routed",
|
||||||
|
"message": format!("Solmu #{} vapautui — reititetään ({:.1}s jonossa)", id, waited as f64 / 1000.0),
|
||||||
|
});
|
||||||
|
let _ = state.stats_tx.send(routing_msg.to_string());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if waited >= 30000 {
|
||||||
|
return (axum::http::StatusCode::SERVICE_UNAVAILABLE, "Aikakatkaisu: kaikki solmut varattuja 30s ajan").into_response();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Merkitään solmu varatuksi ja task_id jaetuksi
|
||||||
|
state.node_busy.lock().unwrap().insert(target_node_id);
|
||||||
|
state.pending_task_ids.lock().unwrap().insert(payload.task_id.clone());
|
||||||
|
|
||||||
let msg = serde_json::json!({
|
let msg = serde_json::json!({
|
||||||
"type": "llm_prompt",
|
"type": "llm_prompt",
|
||||||
"prompt": payload.prompt,
|
"prompt": payload.prompt,
|
||||||
@@ -849,31 +1066,58 @@ async fn api_chat_completions(
|
|||||||
"task_id": payload.task_id,
|
"task_id": payload.task_id,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Odotuskanava valmiiksi (solmu palauttaa tuloksen stats_tx kautta)
|
||||||
let mut rx = state.stats_tx.subscribe();
|
let mut rx = state.stats_tx.subscribe();
|
||||||
let _ = state.stats_tx.send(msg.to_string());
|
|
||||||
|
// Kohdennettu reititys: lähetetään AI-tehtävä suoraan VAIN valitulle solmulle
|
||||||
|
{
|
||||||
|
let channels = state.node_channels.read().await;
|
||||||
|
if let Some(tx) = channels.get(&target_node_id) {
|
||||||
|
let _ = tx.send(msg.to_string());
|
||||||
|
tracing::info!("Reititettiin API-pyyntö solmulle {} (Malli: {})", target_node_id, payload.model);
|
||||||
|
} else {
|
||||||
|
return (axum::http::StatusCode::SERVICE_UNAVAILABLE, "Verkkovirhe: solmun yhteys katkesi reitityksen aikana").into_response();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let timeout = tokio::time::timeout(std::time::Duration::from_secs(120), async move {
|
let timeout = tokio::time::timeout(std::time::Duration::from_secs(120), async move {
|
||||||
while let Ok(msg_str) = rx.recv().await {
|
loop {
|
||||||
|
let msg_str = match rx.recv().await {
|
||||||
|
Ok(msg) => msg,
|
||||||
|
Err(broadcast::error::RecvError::Lagged(n)) => {
|
||||||
|
tracing::debug!("API-kanava lagged {} viestiä", n);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Err(_) => return Ok(None), // Kanava suljettu
|
||||||
|
};
|
||||||
if let Ok(v) = serde_json::from_str::<serde_json::Value>(&msg_str) {
|
if let Ok(v) = serde_json::from_str::<serde_json::Value>(&msg_str) {
|
||||||
if v["type"].as_str() == Some("llm_done") {
|
if v["type"].as_str() == Some("llm_done") {
|
||||||
if let Some(tid) = v["task_id"].as_str() {
|
if let Some(tid) = v["task_id"].as_str() {
|
||||||
if tid == payload.task_id {
|
if tid == payload.task_id {
|
||||||
return Some(ChatCompletionResponse {
|
return Ok(Some(ChatCompletionResponse {
|
||||||
response: v["response"].as_str().unwrap_or("").to_string(),
|
response: v["response"].as_str().unwrap_or("").to_string(),
|
||||||
model: v["model"].as_str().unwrap_or("").to_string(),
|
model: v["model"].as_str().unwrap_or("").to_string(),
|
||||||
tokens_generated: v["tokens_generated"].as_u64().unwrap_or(0),
|
tokens_generated: v["tokens_generated"].as_u64().unwrap_or(0),
|
||||||
});
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if v["type"].as_str() == Some("llm_error") {
|
||||||
|
if let Some(tid) = v["task_id"].as_str() {
|
||||||
|
if tid == payload.task_id {
|
||||||
|
return Err(v["error"].as_str().unwrap_or("Määrittelemätön virhe solmussa").to_string());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None
|
#[allow(unreachable_code)]
|
||||||
|
Ok(None)
|
||||||
}).await;
|
}).await;
|
||||||
|
|
||||||
match timeout {
|
match timeout {
|
||||||
Ok(Some(res)) => axum::Json(res).into_response(),
|
Ok(Ok(Some(res))) => axum::Json(res).into_response(),
|
||||||
Ok(None) => (axum::http::StatusCode::INTERNAL_SERVER_ERROR, "Verkkovirhe: yhteys katkesi").into_response(),
|
Ok(Ok(None)) => (axum::http::StatusCode::INTERNAL_SERVER_ERROR, "Verkkovirhe: yhteys katkesi").into_response(),
|
||||||
Err(_) => (axum::http::StatusCode::GATEWAY_TIMEOUT, "Aikakatkaisu: yksikään solmu ei vastannut 120s sisällä").into_response(),
|
Ok(Err(err)) => (axum::http::StatusCode::CONFLICT, err).into_response(),
|
||||||
|
Err(_) => (axum::http::StatusCode::GATEWAY_TIMEOUT, "Aikakatkaisu: solmu ei saanut tehtävää ajoissa valmiiksi").into_response(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,15 +2,68 @@ use candle_core::{Device, Tensor, DType};
|
|||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use candle_transformers::models::qwen2::{Config as QwenConfig, ModelForCausalLM as QwenModel};
|
use candle_transformers::models::qwen2::{Config as QwenConfig, ModelForCausalLM as QwenModel};
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use std::path::PathBuf;
|
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
|
/// Top-k sampling with temperature and repetition penalty
|
||||||
|
fn sample_top_k(logits: &Tensor, k: usize, temperature: f64, generated_tokens: &[u32], repetition_penalty: f64, rng_state: &mut u64) -> Result<u32, String> {
|
||||||
|
let mut logits_vec: Vec<f32> = logits.to_vec1::<f32>().map_err(|e| format!("to_vec1: {}", e))?;
|
||||||
|
if logits_vec.is_empty() { return Err("Tyhjä logits".to_string()); }
|
||||||
|
|
||||||
|
// Repetition penalty: rankaisee jo generoituja tokeneita
|
||||||
|
for &token_id in generated_tokens {
|
||||||
|
if (token_id as usize) < logits_vec.len() {
|
||||||
|
let logit = &mut logits_vec[token_id as usize];
|
||||||
|
if *logit > 0.0 {
|
||||||
|
*logit /= repetition_penalty as f32;
|
||||||
|
} else {
|
||||||
|
*logit *= repetition_penalty as f32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Temperature scaling
|
||||||
|
if temperature > 0.0 && temperature != 1.0 {
|
||||||
|
for logit in logits_vec.iter_mut() {
|
||||||
|
*logit /= temperature as f32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Top-k: etsitään k suurinta
|
||||||
|
let mut indexed: Vec<(usize, f32)> = logits_vec.iter().enumerate().map(|(i, &v)| (i, v)).collect();
|
||||||
|
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||||
|
indexed.truncate(k);
|
||||||
|
|
||||||
|
if k == 1 || temperature == 0.0 {
|
||||||
|
return Ok(indexed[0].0 as u32);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Softmax top-k:lle
|
||||||
|
let max_logit = indexed[0].1;
|
||||||
|
let exps: Vec<f32> = indexed.iter().map(|x| (x.1 - max_logit).exp()).collect();
|
||||||
|
let sum: f32 = exps.iter().sum();
|
||||||
|
let probs: Vec<f32> = exps.iter().map(|e| e / sum).collect();
|
||||||
|
|
||||||
|
// XorShift64 RNG
|
||||||
|
*rng_state ^= *rng_state << 13;
|
||||||
|
*rng_state ^= *rng_state >> 7;
|
||||||
|
*rng_state ^= *rng_state << 17;
|
||||||
|
let rand_val = (*rng_state % 10000) as f32 / 10000.0;
|
||||||
|
|
||||||
|
let mut cumulative = 0.0;
|
||||||
|
for (i, p) in probs.iter().enumerate() {
|
||||||
|
cumulative += p;
|
||||||
|
if rand_val < cumulative {
|
||||||
|
return Ok(indexed[i].0 as u32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(indexed[0].0 as u32)
|
||||||
|
}
|
||||||
|
|
||||||
pub struct LlmEngine {
|
pub struct LlmEngine {
|
||||||
tokenizer: tokenizers::Tokenizer,
|
tokenizer: tokenizers::Tokenizer,
|
||||||
model_path: PathBuf,
|
model: QwenModel,
|
||||||
device: Device,
|
device: Device,
|
||||||
dtype: DType,
|
|
||||||
config: QwenConfig,
|
|
||||||
eos_token: u32,
|
eos_token: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -22,10 +75,10 @@ impl LlmEngine {
|
|||||||
|
|
||||||
let dtype = if device.is_cuda() { DType::F16 } else { DType::F32 };
|
let dtype = if device.is_cuda() { DType::F16 } else { DType::F32 };
|
||||||
|
|
||||||
tracing::info!("Ladataan Qwen2.5-0.5B-Instruct...");
|
tracing::info!("Ladataan Qwen2.5-Coder-0.5B-Instruct...");
|
||||||
let api = Api::new().map_err(|e| format!("HF API: {}", e))?;
|
let api = Api::new().map_err(|e| format!("HF API: {}", e))?;
|
||||||
let repo = api.repo(Repo::with_revision(
|
let repo = api.repo(Repo::with_revision(
|
||||||
"Qwen/Qwen2.5-0.5B-Instruct".to_string(),
|
"Qwen/Qwen2.5-Coder-0.5B-Instruct".to_string(),
|
||||||
RepoType::Model,
|
RepoType::Model,
|
||||||
"main".to_string(),
|
"main".to_string(),
|
||||||
));
|
));
|
||||||
@@ -54,44 +107,42 @@ impl LlmEngine {
|
|||||||
hidden_act: candle_nn::Activation::Silu,
|
hidden_act: candle_nn::Activation::Silu,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Testi-lataus varmistaa, että painot toimivat
|
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let vb = unsafe {
|
let vb = unsafe {
|
||||||
VarBuilder::from_mmaped_safetensors(&[model_path.clone()], dtype, &device)
|
VarBuilder::from_mmaped_safetensors(&[model_path.clone()], dtype, &device)
|
||||||
.map_err(|e| format!("VarBuilder: {}", e))?
|
.map_err(|e| format!("VarBuilder: {}", e))?
|
||||||
};
|
};
|
||||||
let _model = QwenModel::new(&config, vb).map_err(|e| format!("Malli: {}", e))?;
|
let model = QwenModel::new(&config, vb).map_err(|e| format!("Malli: {}", e))?;
|
||||||
tracing::info!("Malli ladattu ({:.1}s) — {}", start.elapsed().as_secs_f64(), device_name);
|
tracing::info!("Malli ladattu ({:.1}s) — {}", start.elapsed().as_secs_f64(), device_name);
|
||||||
|
|
||||||
Ok(LlmEngine {
|
Ok(LlmEngine {
|
||||||
tokenizer,
|
tokenizer,
|
||||||
model_path,
|
model,
|
||||||
device,
|
device,
|
||||||
dtype,
|
|
||||||
config,
|
|
||||||
eos_token: 151645,
|
eos_token: 151645,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Luo tuore malliinstanssi (nollaa KV-cachen)
|
|
||||||
fn fresh_model(&self) -> Result<QwenModel, String> {
|
|
||||||
let vb = unsafe {
|
|
||||||
VarBuilder::from_mmaped_safetensors(&[self.model_path.clone()], self.dtype, &self.device)
|
|
||||||
.map_err(|e| format!("VarBuilder: {}", e))?
|
|
||||||
};
|
|
||||||
QwenModel::new(&self.config, vb).map_err(|e| format!("Malli: {}", e))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn generate(&mut self, prompt: &str, max_tokens: usize) -> Result<GenerateResult, String> {
|
pub fn generate(&mut self, prompt: &str, max_tokens: usize) -> Result<GenerateResult, String> {
|
||||||
let formatted = format!("<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", prompt);
|
// Prefill: aloitetaan vastaus ```-koodiblokkilla → malli jatkaa suoraan koodilla
|
||||||
|
let formatted = format!("<|im_start|>system\nYou are a coding assistant. Respond with ONLY code. No explanations, no markdown, no comments unless asked.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n```\n", prompt);
|
||||||
|
|
||||||
let encoding = self.tokenizer.encode(formatted.as_str(), true)
|
let encoding = self.tokenizer.encode(formatted.as_str(), true)
|
||||||
.map_err(|e| format!("Encode: {}", e))?;
|
.map_err(|e| format!("Encode: {}", e))?;
|
||||||
let input_ids: Vec<u32> = encoding.get_ids().to_vec();
|
let input_ids: Vec<u32> = encoding.get_ids().to_vec();
|
||||||
let input_len = input_ids.len();
|
let input_len = input_ids.len();
|
||||||
|
|
||||||
// Tuore malli joka promptille (nollaa KV-cachen)
|
// Nollataan KV-cache edellisestä promptista
|
||||||
let mut model = self.fresh_model()?;
|
self.model.clear_kv_cache();
|
||||||
|
|
||||||
|
// Sampling-parametrit
|
||||||
|
let temperature = 0.7;
|
||||||
|
let top_k = 40;
|
||||||
|
let repetition_penalty = 1.15;
|
||||||
|
let mut rng_state: u64 = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_nanos() as u64;
|
||||||
|
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
||||||
@@ -100,24 +151,24 @@ impl LlmEngine {
|
|||||||
.and_then(|t| t.unsqueeze(0))
|
.and_then(|t| t.unsqueeze(0))
|
||||||
.map_err(|e| format!("Tensor: {}", e))?;
|
.map_err(|e| format!("Tensor: {}", e))?;
|
||||||
|
|
||||||
let logits = model.forward(&input, 0)
|
let logits = self.model.forward(&input, 0)
|
||||||
.map_err(|e| format!("Forward prefill: {}", e))?;
|
.map_err(|e| format!("Forward prefill: {}", e))?;
|
||||||
|
|
||||||
let logits = logits.squeeze(0).map_err(|e| format!("Squeeze: {}", e))?;
|
let logits = logits.squeeze(0).map_err(|e| format!("Squeeze: {}", e))?;
|
||||||
let logits = if logits.dims().len() == 2 {
|
let logits = if logits.dims().len() == 2 {
|
||||||
logits.get(logits.dim(0).unwrap() - 1).map_err(|e| format!("Get: {}", e))?
|
let seq_len = logits.dim(0).map_err(|e| format!("Dim: {}", e))?;
|
||||||
|
if seq_len == 0 { return Err("Tyhjä tensori".to_string()); }
|
||||||
|
logits.get(seq_len - 1).map_err(|e| format!("Get: {}", e))?
|
||||||
} else {
|
} else {
|
||||||
logits
|
logits
|
||||||
};
|
};
|
||||||
let mut next_token = logits.argmax(0)
|
|
||||||
.map_err(|e| format!("Argmax: {}", e))?
|
|
||||||
.to_vec0::<u32>()
|
|
||||||
.map_err(|e| format!("to_vec0: {}", e))?;
|
|
||||||
|
|
||||||
let mut generated_text = String::new();
|
let mut generated_text = String::new();
|
||||||
let mut tokens_generated: usize = 0;
|
let mut tokens_generated: usize = 0;
|
||||||
let mut all_tokens: Vec<u32> = Vec::new();
|
let mut all_tokens: Vec<u32> = Vec::new();
|
||||||
|
|
||||||
|
let mut next_token = sample_top_k(&logits, top_k, temperature, &all_tokens, repetition_penalty, &mut rng_state)?;
|
||||||
|
|
||||||
if next_token != self.eos_token {
|
if next_token != self.eos_token {
|
||||||
if let Ok(text) = self.tokenizer.decode(&[next_token], true) {
|
if let Ok(text) = self.tokenizer.decode(&[next_token], true) {
|
||||||
generated_text.push_str(&text);
|
generated_text.push_str(&text);
|
||||||
@@ -135,25 +186,35 @@ impl LlmEngine {
|
|||||||
.and_then(|t| t.unsqueeze(0))
|
.and_then(|t| t.unsqueeze(0))
|
||||||
.map_err(|e| format!("Tensor: {}", e))?;
|
.map_err(|e| format!("Tensor: {}", e))?;
|
||||||
|
|
||||||
let logits = model.forward(&input, pos)
|
let logits = self.model.forward(&input, pos)
|
||||||
.map_err(|e| format!("Forward pos {}: {}", pos, e))?;
|
.map_err(|e| format!("Forward pos {}: {}", pos, e))?;
|
||||||
|
|
||||||
let logits = logits.squeeze(0).map_err(|e| format!("Squeeze: {}", e))?;
|
let logits = logits.squeeze(0).map_err(|e| format!("Squeeze: {}", e))?;
|
||||||
let logits = if logits.dims().len() == 2 {
|
let logits = if logits.dims().len() == 2 {
|
||||||
logits.get(logits.dim(0).unwrap() - 1).map_err(|e| format!("Get: {}", e))?
|
let seq_len = logits.dim(0).map_err(|e| format!("Dim: {}", e))?;
|
||||||
|
if seq_len == 0 { break; }
|
||||||
|
logits.get(seq_len - 1).map_err(|e| format!("Get: {}", e))?
|
||||||
} else {
|
} else {
|
||||||
logits
|
logits
|
||||||
};
|
};
|
||||||
next_token = logits.argmax(0)
|
next_token = sample_top_k(&logits, top_k, temperature, &all_tokens, repetition_penalty, &mut rng_state)?;
|
||||||
.map_err(|e| format!("Argmax: {}", e))?
|
|
||||||
.to_vec0::<u32>()
|
|
||||||
.map_err(|e| format!("to_vec0: {}", e))?;
|
|
||||||
pos += 1;
|
pos += 1;
|
||||||
|
|
||||||
if next_token == self.eos_token { break; }
|
if next_token == self.eos_token { break; }
|
||||||
|
|
||||||
if let Ok(text) = self.tokenizer.decode(&[next_token], true) {
|
if let Ok(text) = self.tokenizer.decode(&[next_token], true) {
|
||||||
generated_text.push_str(&text);
|
generated_text.push_str(&text);
|
||||||
|
|
||||||
|
// Stop-sekvenssit: katkaistaan kun malli alkaa selittää
|
||||||
|
let lower = generated_text.to_lowercase();
|
||||||
|
if lower.contains("\n###") || lower.contains("\nexplanation") || lower.contains("\nnote:") || lower.contains("\noutput:") || lower.contains("\n```\n\n") || lower.contains("\n// example") || lower.contains("\n# example") {
|
||||||
|
for stop in &["\n###", "\nExplanation", "\nNote:", "\nOutput:", "\n```\n\n", "\n// Example", "\n// example", "\n# Example", "\n# example"] {
|
||||||
|
if let Some(pos) = generated_text.find(stop) {
|
||||||
|
generated_text.truncate(pos);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
all_tokens.push(next_token);
|
all_tokens.push(next_token);
|
||||||
tokens_generated += 1;
|
tokens_generated += 1;
|
||||||
@@ -165,7 +226,7 @@ impl LlmEngine {
|
|||||||
} else { 0.0 };
|
} else { 0.0 };
|
||||||
|
|
||||||
Ok(GenerateResult {
|
Ok(GenerateResult {
|
||||||
text: generated_text,
|
text: strip_markdown_wrapper(&generated_text),
|
||||||
tokens_generated,
|
tokens_generated,
|
||||||
duration_ms: gen_time.as_millis() as f64,
|
duration_ms: gen_time.as_millis() as f64,
|
||||||
tokens_per_sec,
|
tokens_per_sec,
|
||||||
@@ -173,6 +234,61 @@ impl LlmEngine {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const LANG_TAGS: &[&str] = &[
|
||||||
|
"python", "py", "rust", "rs", "javascript", "js", "typescript", "ts",
|
||||||
|
"java", "kotlin", "scala", "go", "ruby", "rb", "php", "swift",
|
||||||
|
"c", "cpp", "c++", "c#", "csharp", "r", "sql", "bash", "sh", "zsh",
|
||||||
|
"html", "css", "json", "yaml", "yml", "toml", "xml", "markdown", "md",
|
||||||
|
"lua", "perl", "dart", "elixir", "haskell", "hs", "ocaml", "zig",
|
||||||
|
"plaintext", "text", "txt",
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Siivoa mallin tuottama vastaus (prefill-yhteensopiva).
|
||||||
|
fn strip_markdown_wrapper(text: &str) -> String {
|
||||||
|
let mut result = text.trim().to_string();
|
||||||
|
|
||||||
|
// 1. Kielitunniste — VAIN tunnettu kieli
|
||||||
|
if let Some(nl) = result.find('\n') {
|
||||||
|
let first = result[..nl].trim().to_lowercase();
|
||||||
|
if LANG_TAGS.contains(&first.as_str()) {
|
||||||
|
result = result[nl + 1..].to_string();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Sulkeva ``` — VAIN omalla rivillään lopussa
|
||||||
|
let trimmed = result.trim_end();
|
||||||
|
if trimmed.ends_with("```") {
|
||||||
|
let before = &trimmed[..trimmed.len() - 3];
|
||||||
|
if before.is_empty() || before.ends_with('\n') {
|
||||||
|
result = before.trim_end().to_string();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Johdantolauseet
|
||||||
|
let lower = result.trim().to_lowercase();
|
||||||
|
for prefix in &["sure!", "here is", "here's", "certainly!", "below is"] {
|
||||||
|
if lower.starts_with(prefix) {
|
||||||
|
if let Some(nl) = result.find('\n') { result = result[nl + 1..].to_string(); }
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Selityskommentit alusta
|
||||||
|
let mut lines: Vec<&str> = result.trim().lines().collect();
|
||||||
|
while !lines.is_empty() {
|
||||||
|
let first = lines[0].trim();
|
||||||
|
let is_preamble = first.starts_with("# ") && !first.starts_with("#!")
|
||||||
|
&& (first.to_lowercase().contains("this is")
|
||||||
|
|| first.to_lowercase().contains("simple")
|
||||||
|
|| first.to_lowercase().contains("program that")
|
||||||
|
|| first.to_lowercase().contains("here is")
|
||||||
|
|| first.to_lowercase().contains("the following")
|
||||||
|
|| first.to_lowercase().contains("below"));
|
||||||
|
if is_preamble { lines.remove(0); } else { break; }
|
||||||
|
}
|
||||||
|
lines.join("\n").trim().to_string()
|
||||||
|
}
|
||||||
|
|
||||||
pub struct GenerateResult {
|
pub struct GenerateResult {
|
||||||
pub text: String,
|
pub text: String,
|
||||||
pub tokens_generated: usize,
|
pub tokens_generated: usize,
|
||||||
|
|||||||
@@ -227,6 +227,7 @@ fn build_auth_message(allocated_gb: u32) -> String {
|
|||||||
"status": "agent_ready",
|
"status": "agent_ready",
|
||||||
"node_type": "native",
|
"node_type": "native",
|
||||||
"allocated_gb": allocated_gb,
|
"allocated_gb": allocated_gb,
|
||||||
|
"selected_task": "qwen-coder-05b",
|
||||||
"system": sys,
|
"system": sys,
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -318,10 +319,14 @@ async fn main() {
|
|||||||
if text.contains("llm_prompt") && !busy {
|
if text.contains("llm_prompt") && !busy {
|
||||||
if let Ok(task) = serde_json::from_str::<serde_json::Value>(&text) {
|
if let Ok(task) = serde_json::from_str::<serde_json::Value>(&text) {
|
||||||
let prompt = task.get("prompt").and_then(|v| v.as_str()).unwrap_or("");
|
let prompt = task.get("prompt").and_then(|v| v.as_str()).unwrap_or("");
|
||||||
if !prompt.is_empty() {
|
let task_id = task.get("task_id").and_then(|v| v.as_str()).unwrap_or("?");
|
||||||
|
let msg_model = task.get("model").and_then(|v| v.as_str()).unwrap_or("");
|
||||||
|
|
||||||
|
if !prompt.is_empty() && msg_model.starts_with("qwen-coder") {
|
||||||
|
|
||||||
if let Some(ref mut engine) = llm {
|
if let Some(ref mut engine) = llm {
|
||||||
busy = true;
|
busy = true;
|
||||||
tracing::info!("Generoidaan: \"{}\"", prompt);
|
tracing::info!("Generoidaan (task_id: {}): \"{}\"", task_id, prompt);
|
||||||
|
|
||||||
match engine.generate(prompt, 64) {
|
match engine.generate(prompt, 64) {
|
||||||
Ok(result) => {
|
Ok(result) => {
|
||||||
@@ -336,12 +341,13 @@ async fn main() {
|
|||||||
let done = json!({
|
let done = json!({
|
||||||
"type": "llm_done",
|
"type": "llm_done",
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"model": "Qwen2.5-0.5B-Instruct (native/GPU)",
|
"model": "Qwen2.5-Coder-0.5B (native/GPU)",
|
||||||
"response": result.text,
|
"response": result.text,
|
||||||
"tokens_generated": result.tokens_generated,
|
"tokens_generated": result.tokens_generated,
|
||||||
"duration_ms": result.duration_ms,
|
"duration_ms": result.duration_ms,
|
||||||
"tokens_per_sec": (result.tokens_per_sec * 10.0).round() / 10.0,
|
"tokens_per_sec": (result.tokens_per_sec * 10.0).round() / 10.0,
|
||||||
"load_time_ms": 0,
|
"load_time_ms": 0,
|
||||||
|
"task_id": task_id,
|
||||||
});
|
});
|
||||||
let _ = write.send(Message::Text(done.to_string())).await;
|
let _ = write.send(Message::Text(done.to_string())).await;
|
||||||
}
|
}
|
||||||
|
|||||||
BIN
network-poc/node/nodes.db
Normal file
BIN
network-poc/node/nodes.db
Normal file
Binary file not shown.
@@ -130,8 +130,9 @@ async fn run_single_tokenize(text: String, ws: Rc<RefCell<WebSocket>>) {
|
|||||||
|
|
||||||
let token_count = result["token_count"].as_u64().unwrap_or(0);
|
let token_count = result["token_count"].as_u64().unwrap_or(0);
|
||||||
let cpt = result["chars_per_token"].as_f64().unwrap_or(0.0);
|
let cpt = result["chars_per_token"].as_f64().unwrap_or(0.0);
|
||||||
|
let preview: String = text.chars().take(50).collect();
|
||||||
console_log!("Tokenisaatio: \"{}\" → {} tokenia | {:.2} m/t | {:.2}ms",
|
console_log!("Tokenisaatio: \"{}\" → {} tokenia | {:.2} m/t | {:.2}ms",
|
||||||
&text[..text.len().min(50)], token_count, cpt, duration_ms);
|
preview, token_count, cpt, duration_ms);
|
||||||
|
|
||||||
let msg = serde_json::json!({
|
let msg = serde_json::json!({
|
||||||
"type": "single_tokenize_done",
|
"type": "single_tokenize_done",
|
||||||
@@ -270,7 +271,8 @@ pub async fn start_agent_node(hub_url: String, has_webgpu: bool, device_info_jso
|
|||||||
if LLM_BUSY.load(Ordering::SeqCst) {
|
if LLM_BUSY.load(Ordering::SeqCst) {
|
||||||
} else if let Ok(task) = serde_json::from_str::<serde_json::Value>(&msg) {
|
} else if let Ok(task) = serde_json::from_str::<serde_json::Value>(&msg) {
|
||||||
let prompt = task.get("prompt").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
let prompt = task.get("prompt").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||||
if !prompt.is_empty() {
|
let model = task.get("model").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||||
|
if !prompt.is_empty() && model == "qwen-05b" {
|
||||||
LLM_BUSY.store(true, Ordering::SeqCst);
|
LLM_BUSY.store(true, Ordering::SeqCst);
|
||||||
let ws_for_async = ws_clone.clone();
|
let ws_for_async = ws_clone.clone();
|
||||||
wasm_bindgen_futures::spawn_local(async move {
|
wasm_bindgen_futures::spawn_local(async move {
|
||||||
@@ -284,7 +286,8 @@ pub async fn start_agent_node(hub_url: String, has_webgpu: bool, device_info_jso
|
|||||||
if LLM_BUSY.load(Ordering::SeqCst) {
|
if LLM_BUSY.load(Ordering::SeqCst) {
|
||||||
} else if let Ok(task) = serde_json::from_str::<serde_json::Value>(&msg) {
|
} else if let Ok(task) = serde_json::from_str::<serde_json::Value>(&msg) {
|
||||||
let prompt = task.get("prompt").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
let prompt = task.get("prompt").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||||
if !prompt.is_empty() {
|
let model = task.get("model").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||||
|
if !prompt.is_empty() && model.starts_with("phi3-mini") {
|
||||||
LLM_BUSY.store(true, Ordering::SeqCst);
|
LLM_BUSY.store(true, Ordering::SeqCst);
|
||||||
let ws_for_async = ws_clone.clone();
|
let ws_for_async = ws_clone.clone();
|
||||||
wasm_bindgen_futures::spawn_local(async move {
|
wasm_bindgen_futures::spawn_local(async move {
|
||||||
@@ -295,18 +298,30 @@ pub async fn start_agent_node(hub_url: String, has_webgpu: bool, device_info_jso
|
|||||||
}
|
}
|
||||||
} else if msg.contains("llm_prompt") && (current_task == 4 || current_task == 5) {
|
} else if msg.contains("llm_prompt") && (current_task == 4 || current_task == 5) {
|
||||||
// Qwen2.5-Coder: 4 = 0.5B, 5 = 3B
|
// Qwen2.5-Coder: 4 = 0.5B, 5 = 3B
|
||||||
if LLM_BUSY.load(Ordering::SeqCst) {
|
if let Ok(task) = serde_json::from_str::<serde_json::Value>(&msg) {
|
||||||
} else if let Ok(task) = serde_json::from_str::<serde_json::Value>(&msg) {
|
|
||||||
let prompt = task.get("prompt").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
let prompt = task.get("prompt").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||||
|
let model = task.get("model").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||||
let task_id = task.get("task_id").and_then(|v| v.as_str()).map(|s| s.to_string());
|
let task_id = task.get("task_id").and_then(|v| v.as_str()).map(|s| s.to_string());
|
||||||
if !prompt.is_empty() {
|
|
||||||
let use_3b = current_task == 5;
|
if !prompt.is_empty() && model.starts_with("qwen-coder") {
|
||||||
LLM_BUSY.store(true, Ordering::SeqCst);
|
if LLM_BUSY.load(Ordering::SeqCst) {
|
||||||
let ws_for_async = ws_clone.clone();
|
if let Some(tid) = task_id {
|
||||||
wasm_bindgen_futures::spawn_local(async move {
|
let err_msg = serde_json::json!({
|
||||||
qwen_coder::run_coder_inference(prompt, ws_for_async, use_3b, task_id).await;
|
"type": "llm_error",
|
||||||
LLM_BUSY.store(false, Ordering::SeqCst);
|
"task_id": tid,
|
||||||
});
|
"error": "Solmu on paraikaa varattuna toisen tehtävän suorittamiseen"
|
||||||
|
});
|
||||||
|
let _ = ws_clone.borrow().send_with_str(&err_msg.to_string());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let use_3b = current_task == 5;
|
||||||
|
LLM_BUSY.store(true, Ordering::SeqCst);
|
||||||
|
let ws_for_async = ws_clone.clone();
|
||||||
|
wasm_bindgen_futures::spawn_local(async move {
|
||||||
|
qwen_coder::run_coder_inference(prompt, ws_for_async, use_3b, task_id).await;
|
||||||
|
LLM_BUSY.store(false, Ordering::SeqCst);
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if msg.contains("ai_task") {
|
} else if msg.contains("ai_task") {
|
||||||
|
|||||||
@@ -21,12 +21,98 @@ const MODEL_3B_PART1_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-Coder-3B-I
|
|||||||
const MODEL_3B_PART2_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-Coder-3B-Instruct/resolve/main/model-00002-of-00002.safetensors";
|
const MODEL_3B_PART2_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-Coder-3B-Instruct/resolve/main/model-00002-of-00002.safetensors";
|
||||||
const TOKENIZER_3B_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-Coder-3B-Instruct/resolve/main/tokenizer.json";
|
const TOKENIZER_3B_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-Coder-3B-Instruct/resolve/main/tokenizer.json";
|
||||||
|
|
||||||
async fn ensure_cached(key: &str, url: &str, ws: &Rc<RefCell<WebSocket>>) -> Result<Vec<u8>, String> {
|
struct CachedModel {
|
||||||
if let Ok(Some(bytes)) = storage::load_from_idb(key).await {
|
model: QwenModel,
|
||||||
console_log!("[Coder] {} löytyi välimuistista ({} MB)", key, bytes.len() / 1024 / 1024);
|
tokenizer: tokenizers::Tokenizer,
|
||||||
|
is_3b: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tunnetut kielitunnisteet joita malli voi tuottaa prefill-backtickien jälkeen.
|
||||||
|
const LANG_TAGS: &[&str] = &[
|
||||||
|
"python", "py", "rust", "rs", "javascript", "js", "typescript", "ts",
|
||||||
|
"java", "kotlin", "scala", "go", "ruby", "rb", "php", "swift",
|
||||||
|
"c", "cpp", "c++", "c#", "csharp", "r", "sql", "bash", "sh", "zsh",
|
||||||
|
"html", "css", "json", "yaml", "yml", "toml", "xml", "markdown", "md",
|
||||||
|
"lua", "perl", "dart", "elixir", "haskell", "hs", "ocaml", "zig",
|
||||||
|
"plaintext", "text", "txt",
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Siivoa mallin tuottama vastaus.
|
||||||
|
/// Prefill-tekniikan vuoksi malli tuottaa: "rust\nfn main() {...}\n```"
|
||||||
|
/// eli kielitunniste alussa + sulkeva ``` lopussa. Molemmat poistetaan.
|
||||||
|
fn strip_markdown_wrapper(text: &str) -> String {
|
||||||
|
let mut result = text.trim().to_string();
|
||||||
|
|
||||||
|
// 1. Poistetaan kielitunniste ensimmäiseltä riviltä — VAIN jos se on tunnettu kieli
|
||||||
|
if let Some(first_newline) = result.find('\n') {
|
||||||
|
let first_line = result[..first_newline].trim().to_lowercase();
|
||||||
|
if LANG_TAGS.contains(&first_line.as_str()) {
|
||||||
|
result = result[first_newline + 1..].to_string();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Poistetaan sulkeva ``` VAIN jos se on omalla rivillään lopussa
|
||||||
|
let trimmed = result.trim_end();
|
||||||
|
if trimmed.ends_with("```") {
|
||||||
|
let before = &trimmed[..trimmed.len() - 3];
|
||||||
|
// Varmistetaan: edellinen merkki on rivinvaihto tai alku (eli ``` on oma rivinsä)
|
||||||
|
if before.is_empty() || before.ends_with('\n') {
|
||||||
|
result = before.trim_end().to_string();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Poistetaan johdantolauseet: "Sure! Here is...", "Certainly!" jne.
|
||||||
|
let lower = result.trim().to_lowercase();
|
||||||
|
for prefix in &["sure!", "here is", "here's", "certainly!", "below is"] {
|
||||||
|
if lower.starts_with(prefix) {
|
||||||
|
if let Some(newline) = result.find('\n') {
|
||||||
|
result = result[newline + 1..].to_string();
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Poistetaan selityskommentit alusta: "# This is a simple program..."
|
||||||
|
let mut lines: Vec<&str> = result.trim().lines().collect();
|
||||||
|
while !lines.is_empty() {
|
||||||
|
let first = lines[0].trim();
|
||||||
|
let is_preamble = first.starts_with("# ")
|
||||||
|
&& !first.starts_with("#!")
|
||||||
|
&& (first.to_lowercase().contains("this is")
|
||||||
|
|| first.to_lowercase().contains("simple")
|
||||||
|
|| first.to_lowercase().contains("program that")
|
||||||
|
|| first.to_lowercase().contains("here is")
|
||||||
|
|| first.to_lowercase().contains("the following")
|
||||||
|
|| first.to_lowercase().contains("below"));
|
||||||
|
if is_preamble { lines.remove(0); } else { break; }
|
||||||
|
}
|
||||||
|
|
||||||
|
lines.join("\n").trim().to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
thread_local! {
|
||||||
|
static RAM_CACHE: RefCell<std::collections::HashMap<String, Rc<Vec<u8>>>> = RefCell::new(std::collections::HashMap::new());
|
||||||
|
static MODEL_CACHE: RefCell<Option<CachedModel>> = RefCell::new(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn ensure_cached(key: &str, url: &str, ws: &Rc<RefCell<WebSocket>>) -> Result<Rc<Vec<u8>>, String> {
|
||||||
|
// 1. Tarkistetaan RAM välimuisti (estää OOM ja levy-I/O pullonkaulat)
|
||||||
|
let ram_hit = RAM_CACHE.with(|cache| {
|
||||||
|
cache.borrow().get(key).cloned()
|
||||||
|
});
|
||||||
|
if let Some(bytes) = ram_hit {
|
||||||
|
console_log!("[Coder] {} löytyi nopeasta RAM-välimuistista!", key);
|
||||||
return Ok(bytes);
|
return Ok(bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 2. Tarkistetaan IndexedDB (jos selain on suljettu aikaisemmin)
|
||||||
|
if let Ok(Some(bytes)) = storage::load_from_idb(key).await {
|
||||||
|
console_log!("[Coder] {} löytyi IndexedDB-välimuistista ({} MB)", key, bytes.len() / 1024 / 1024);
|
||||||
|
let rc_bytes = Rc::new(bytes);
|
||||||
|
RAM_CACHE.with(|cache| cache.borrow_mut().insert(key.to_string(), rc_bytes.clone()));
|
||||||
|
return Ok(rc_bytes);
|
||||||
|
}
|
||||||
|
|
||||||
console_log!("[Coder] Ladataan {}...", key);
|
console_log!("[Coder] Ladataan {}...", key);
|
||||||
|
|
||||||
let window = web_sys::window().unwrap();
|
let window = web_sys::window().unwrap();
|
||||||
@@ -68,11 +154,85 @@ async fn ensure_cached(key: &str, url: &str, ws: &Rc<RefCell<WebSocket>>) -> Res
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
console_log!("[Coder] Tallennetaan {} ({} MB)...", key, data.len() / 1024 / 1024);
|
console_log!("[Coder] Tallennetaan {} ({} MB) IndexedDB:hen...", key, data.len() / 1024 / 1024);
|
||||||
let _ = storage::save_to_idb(key, &data).await;
|
let _ = storage::save_to_idb(key, &data).await;
|
||||||
console_log!("[Coder] {} tallennettu!", key);
|
console_log!("[Coder] {} tallennettu!", key);
|
||||||
|
|
||||||
Ok(data)
|
let rc_data = Rc::new(data);
|
||||||
|
RAM_CACHE.with(|cache| cache.borrow_mut().insert(key.to_string(), rc_data.clone()));
|
||||||
|
|
||||||
|
Ok(rc_data)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Lataa tai palauttaa välimuistista valmiin mallin + tokenizerin
|
||||||
|
async fn get_or_build_model(use_3b: bool, ws: &Rc<RefCell<WebSocket>>) -> Result<(), String> {
|
||||||
|
// Tarkistetaan onko oikea malli jo muistissa
|
||||||
|
let cache_hit = MODEL_CACHE.with(|c| {
|
||||||
|
c.borrow().as_ref().map(|m| m.is_3b == use_3b).unwrap_or(false)
|
||||||
|
});
|
||||||
|
if cache_hit {
|
||||||
|
// Logitetaan kaikki välivaiheet valmiiksi, jotta pipeline-UI päivittyy
|
||||||
|
console_log!("[Coder] tokenizer löytyi (cache)");
|
||||||
|
console_log!("[Coder] model löytyi (cache)");
|
||||||
|
console_log!("[Coder] Malli ladattu (välimuistista)");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let device = Device::Cpu;
|
||||||
|
let dtype = DType::F32;
|
||||||
|
|
||||||
|
// Tokenizer
|
||||||
|
let tok_url = if use_3b { TOKENIZER_3B_URL } else { TOKENIZER_05B_URL };
|
||||||
|
let tok_key = if use_3b { "coder3b-tokenizer.json" } else { "coder05b-tokenizer.json" };
|
||||||
|
let tok_bytes = ensure_cached(tok_key, tok_url, ws).await?;
|
||||||
|
let tokenizer = tokenizers::Tokenizer::from_bytes(&tok_bytes[..])
|
||||||
|
.map_err(|e| format!("Tokenizer: {}", e))?;
|
||||||
|
|
||||||
|
// Painot
|
||||||
|
let tensors = if use_3b {
|
||||||
|
let part1 = ensure_cached("coder3b-model-part1.safetensors", MODEL_3B_PART1_URL, ws).await?;
|
||||||
|
let part2 = ensure_cached("coder3b-model-part2.safetensors", MODEL_3B_PART2_URL, ws).await?;
|
||||||
|
console_log!("[Coder] Rakennetaan 3B-mallia...");
|
||||||
|
let mut all_tensors = candle_core::safetensors::load_buffer(&part1[..], &device)
|
||||||
|
.map_err(|e| format!("Part1: {}", e))?;
|
||||||
|
let tensors2 = candle_core::safetensors::load_buffer(&part2[..], &device)
|
||||||
|
.map_err(|e| format!("Part2: {}", e))?;
|
||||||
|
all_tensors.extend(tensors2);
|
||||||
|
all_tensors
|
||||||
|
} else {
|
||||||
|
let model_bytes = ensure_cached("coder05b-model.safetensors", MODEL_05B_URL, ws).await?;
|
||||||
|
console_log!("[Coder] Rakennetaan 0.5B-mallia...");
|
||||||
|
candle_core::safetensors::load_buffer(&model_bytes[..], &device)
|
||||||
|
.map_err(|e| format!("Safetensors: {}", e))?
|
||||||
|
};
|
||||||
|
|
||||||
|
let vb = VarBuilder::from_tensors(tensors, dtype, &device);
|
||||||
|
let config = if use_3b {
|
||||||
|
QwenConfig {
|
||||||
|
vocab_size: 151936, hidden_size: 2048, intermediate_size: 11008,
|
||||||
|
num_hidden_layers: 36, num_attention_heads: 16, num_key_value_heads: 2,
|
||||||
|
max_position_embeddings: 32768, sliding_window: 32768, max_window_layers: 36,
|
||||||
|
tie_word_embeddings: true, rope_theta: 1000000.0, rms_norm_eps: 1e-6,
|
||||||
|
use_sliding_window: false, hidden_act: candle_nn::Activation::Silu,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
QwenConfig {
|
||||||
|
vocab_size: 151936, hidden_size: 896, intermediate_size: 4864,
|
||||||
|
num_hidden_layers: 24, num_attention_heads: 14, num_key_value_heads: 2,
|
||||||
|
max_position_embeddings: 32768, sliding_window: 32768, max_window_layers: 21,
|
||||||
|
tie_word_embeddings: true, rope_theta: 1000000.0, rms_norm_eps: 1e-6,
|
||||||
|
use_sliding_window: false, hidden_act: candle_nn::Activation::Silu,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let model = QwenModel::new(&config, vb).map_err(|e| format!("Malli: {}", e))?;
|
||||||
|
console_log!("[Coder] Malli ladattu ja välimuistitettu");
|
||||||
|
|
||||||
|
MODEL_CACHE.with(|c| {
|
||||||
|
*c.borrow_mut() = Some(CachedModel { model, tokenizer, is_3b: use_3b });
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// use_3b: false = 0.5B (nopea), true = 3B (laadukas)
|
/// use_3b: false = 0.5B (nopea), true = 3B (laadukas)
|
||||||
@@ -80,196 +240,133 @@ pub async fn run_coder_inference(prompt: String, ws: Rc<RefCell<WebSocket>>, use
|
|||||||
let perf = web_sys::window().unwrap().performance().unwrap();
|
let perf = web_sys::window().unwrap().performance().unwrap();
|
||||||
let size_label = if use_3b { "3B" } else { "0.5B" };
|
let size_label = if use_3b { "3B" } else { "0.5B" };
|
||||||
|
|
||||||
// Tokenizer (sama molemmille)
|
|
||||||
let tok_url = if use_3b { TOKENIZER_3B_URL } else { TOKENIZER_05B_URL };
|
|
||||||
let tok_key = if use_3b { "coder3b-tokenizer.json" } else { "coder05b-tokenizer.json" };
|
|
||||||
let tok_bytes = match ensure_cached(tok_key, tok_url, &ws).await {
|
|
||||||
Ok(b) => b,
|
|
||||||
Err(e) => { console_log!("[Coder] Tokenizer-virhe: {}", e); return; }
|
|
||||||
};
|
|
||||||
let tokenizer = match tokenizers::Tokenizer::from_bytes(&tok_bytes) {
|
|
||||||
Ok(t) => t,
|
|
||||||
Err(e) => { console_log!("[Coder] Tokenizer-parsinta: {}", e); return; }
|
|
||||||
};
|
|
||||||
|
|
||||||
// Mallin painot
|
|
||||||
let device = Device::Cpu;
|
|
||||||
let dtype = DType::F32;
|
|
||||||
|
|
||||||
let tensors = if use_3b {
|
|
||||||
// 3B: kaksi osaa
|
|
||||||
let part1 = match ensure_cached("coder3b-model-part1.safetensors", MODEL_3B_PART1_URL, &ws).await {
|
|
||||||
Ok(b) => b,
|
|
||||||
Err(e) => { console_log!("[Coder] Malli osa 1 virhe: {}", e); return; }
|
|
||||||
};
|
|
||||||
let part2 = match ensure_cached("coder3b-model-part2.safetensors", MODEL_3B_PART2_URL, &ws).await {
|
|
||||||
Ok(b) => b,
|
|
||||||
Err(e) => { console_log!("[Coder] Malli osa 2 virhe: {}", e); return; }
|
|
||||||
};
|
|
||||||
console_log!("[Coder] Rakennetaan 3B-mallia...");
|
|
||||||
let mut all_tensors = candle_core::safetensors::load_buffer(&part1, &device)
|
|
||||||
.map_err(|e| format!("Part1: {}", e)).unwrap();
|
|
||||||
let tensors2 = candle_core::safetensors::load_buffer(&part2, &device)
|
|
||||||
.map_err(|e| format!("Part2: {}", e)).unwrap();
|
|
||||||
all_tensors.extend(tensors2);
|
|
||||||
all_tensors
|
|
||||||
} else {
|
|
||||||
// 0.5B: yksi osa
|
|
||||||
let model_bytes = match ensure_cached("coder05b-model.safetensors", MODEL_05B_URL, &ws).await {
|
|
||||||
Ok(b) => b,
|
|
||||||
Err(e) => { console_log!("[Coder] Malli-virhe: {}", e); return; }
|
|
||||||
};
|
|
||||||
console_log!("[Coder] Rakennetaan 0.5B-mallia...");
|
|
||||||
match candle_core::safetensors::load_buffer(&model_bytes, &device) {
|
|
||||||
Ok(t) => t,
|
|
||||||
Err(e) => { console_log!("[Coder] Safetensors: {}", e); return; }
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let start_load = perf.now();
|
let start_load = perf.now();
|
||||||
let vb = VarBuilder::from_tensors(tensors, dtype, &device);
|
|
||||||
|
|
||||||
let config = if use_3b {
|
if let Err(e) = get_or_build_model(use_3b, &ws).await {
|
||||||
QwenConfig {
|
console_log!("[Coder] Mallin lataus: {}", e);
|
||||||
vocab_size: 151936,
|
return;
|
||||||
hidden_size: 2048,
|
}
|
||||||
intermediate_size: 11008,
|
|
||||||
num_hidden_layers: 36,
|
|
||||||
num_attention_heads: 16,
|
|
||||||
num_key_value_heads: 2,
|
|
||||||
max_position_embeddings: 32768,
|
|
||||||
sliding_window: 32768,
|
|
||||||
max_window_layers: 36,
|
|
||||||
tie_word_embeddings: true,
|
|
||||||
rope_theta: 1000000.0,
|
|
||||||
rms_norm_eps: 1e-6,
|
|
||||||
use_sliding_window: false,
|
|
||||||
hidden_act: candle_nn::Activation::Silu,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
QwenConfig {
|
|
||||||
vocab_size: 151936,
|
|
||||||
hidden_size: 896,
|
|
||||||
intermediate_size: 4864,
|
|
||||||
num_hidden_layers: 24,
|
|
||||||
num_attention_heads: 14,
|
|
||||||
num_key_value_heads: 2,
|
|
||||||
max_position_embeddings: 32768,
|
|
||||||
sliding_window: 32768,
|
|
||||||
max_window_layers: 21,
|
|
||||||
tie_word_embeddings: true,
|
|
||||||
rope_theta: 1000000.0,
|
|
||||||
rms_norm_eps: 1e-6,
|
|
||||||
use_sliding_window: false,
|
|
||||||
hidden_act: candle_nn::Activation::Silu,
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut model = match QwenModel::new(&config, vb) {
|
|
||||||
Ok(m) => m,
|
|
||||||
Err(e) => { console_log!("[Coder] Mallin lataus: {}", e); return; }
|
|
||||||
};
|
|
||||||
|
|
||||||
let load_time = perf.now() - start_load;
|
let load_time = perf.now() - start_load;
|
||||||
console_log!("[Coder] Malli ladattu ({:.0}ms). Generoidaan...", load_time);
|
if load_time > 100.0 {
|
||||||
|
console_log!("[Coder] Malli ladattu ({:.0}ms). Generoidaan...", load_time);
|
||||||
|
}
|
||||||
|
|
||||||
// Parsitaan JSON-prompti tai käytetään teksti sellaisenaan
|
// Parsitaan JSON-prompti tai käytetään teksti sellaisenaan
|
||||||
|
let default_system = "You are a coding assistant. Respond with ONLY code. No explanations, no markdown, no comments unless asked.";
|
||||||
let (actual_prompt, system_msg, max_new_tokens) = if prompt.starts_with('{') {
|
let (actual_prompt, system_msg, max_new_tokens) = if prompt.starts_with('{') {
|
||||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&prompt) {
|
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&prompt) {
|
||||||
let p = json.get("prompt").and_then(|v| v.as_str()).unwrap_or(&prompt).to_string();
|
let p = json.get("prompt").and_then(|v| v.as_str()).unwrap_or(&prompt).to_string();
|
||||||
let s = json.get("system").and_then(|v| v.as_str())
|
let s = json.get("system").and_then(|v| v.as_str()).unwrap_or(default_system).to_string();
|
||||||
.unwrap_or("You are a Python coding assistant. Write only code, no explanations.").to_string();
|
let m = json.get("max_tokens").and_then(|v| v.as_u64()).unwrap_or(256) as usize;
|
||||||
let m = json.get("max_tokens").and_then(|v| v.as_u64()).unwrap_or(128) as usize;
|
|
||||||
(p, s, m)
|
(p, s, m)
|
||||||
} else {
|
} else {
|
||||||
(prompt.clone(), "You are a Python coding assistant. Write only code, no explanations.".to_string(), 128)
|
(prompt.clone(), default_system.to_string(), 256)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
(prompt.clone(), "You are a Python coding assistant. Write only code, no explanations.".to_string(), 128)
|
(prompt.clone(), default_system.to_string(), 256)
|
||||||
};
|
};
|
||||||
|
|
||||||
let formatted = format!("<|im_start|>system\n{}<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", system_msg, actual_prompt);
|
// Prefill: aloitetaan vastaus ```-koodiblokkilla, jolloin malli jatkaa suoraan koodilla
|
||||||
|
// eikä tuota "Sure! Here is..." -johdantoa. strip_markdown_wrapper poistaa ``` jälkikäteen.
|
||||||
|
let formatted = format!("<|im_start|>system\n{}<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n```\n", system_msg, actual_prompt);
|
||||||
|
|
||||||
let encoding = match tokenizer.encode(formatted.as_str(), true) {
|
// Inferenssi: käytetään välimuistissa olevaa mallia
|
||||||
Ok(e) => e,
|
let (generated_text, tokens_generated, gen_time) = MODEL_CACHE.with(|cache| {
|
||||||
Err(e) => { console_log!("[Coder] Tokenisointivirhe: {}", e); return; }
|
let mut cache = cache.borrow_mut();
|
||||||
};
|
let cached = cache.as_mut().expect("Malli pitää olla ladattu");
|
||||||
let input_ids: Vec<u32> = encoding.get_ids().to_vec();
|
|
||||||
let input_len = input_ids.len();
|
|
||||||
console_log!("[Coder] Syöte: {} tokenia", input_len);
|
|
||||||
|
|
||||||
let start_gen = perf.now();
|
let encoding = cached.tokenizer.encode(formatted.as_str(), true)
|
||||||
// max_new_tokens tulee JSON-promptista tai oletuksena 128
|
.map_err(|e| format!("Encode: {}", e)).unwrap();
|
||||||
let mut generated_text = String::new();
|
let input_ids: Vec<u32> = encoding.get_ids().to_vec();
|
||||||
let mut tokens_generated: usize = 0;
|
let input_len = input_ids.len();
|
||||||
let eos_token = 151645u32;
|
console_log!("[Coder] Syöte: {} tokenia", input_len);
|
||||||
|
|
||||||
// Prefill
|
let device = Device::Cpu;
|
||||||
let input = match Tensor::new(input_ids.as_slice(), &device).and_then(|t| t.unsqueeze(0)) {
|
let start_gen = perf.now();
|
||||||
Ok(t) => t,
|
let eos_token = 151645u32;
|
||||||
Err(e) => { console_log!("[Coder] Tensor: {}", e); return; }
|
let temperature: f32 = 0.7;
|
||||||
};
|
let top_k: usize = 40;
|
||||||
let logits = match model.forward(&input, 0) {
|
let repetition_penalty: f32 = 1.15;
|
||||||
Ok(l) => l,
|
|
||||||
Err(e) => { console_log!("[Coder] Forward (prefill): {}", e); return; }
|
|
||||||
};
|
|
||||||
|
|
||||||
let logits = logits.squeeze(0).unwrap();
|
// Nollataan KV-cache edellisestä promptista
|
||||||
let logits = if logits.dims().len() == 2 {
|
cached.model.clear_kv_cache();
|
||||||
logits.get(logits.dim(0).unwrap() - 1).unwrap()
|
|
||||||
} else {
|
|
||||||
logits
|
|
||||||
};
|
|
||||||
let mut next_token = crate::sampling::sample_top_k(&logits, 10, 5.0);
|
|
||||||
|
|
||||||
if next_token != eos_token {
|
let mut generated_text = String::new();
|
||||||
if let Ok(text) = tokenizer.decode(&[next_token], true) {
|
let mut tokens_generated: usize = 0;
|
||||||
generated_text.push_str(&text);
|
let mut all_generated: Vec<u32> = Vec::new();
|
||||||
let mut chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "Qwen2.5-Coder" });
|
|
||||||
if let Some(ref tid) = task_id { chunk.as_object_mut().unwrap().insert("task_id".to_string(), serde_json::json!(tid)); }
|
|
||||||
let _ = ws.borrow().send_with_str(&chunk.to_string());
|
|
||||||
}
|
|
||||||
tokens_generated += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Autoregressive
|
|
||||||
let mut pos = input_len;
|
|
||||||
for _ in 1..max_new_tokens {
|
|
||||||
if next_token == eos_token { break; }
|
|
||||||
|
|
||||||
let input = match Tensor::new(&[next_token], &device).and_then(|t| t.unsqueeze(0)) {
|
|
||||||
Ok(t) => t,
|
|
||||||
Err(e) => { console_log!("[Coder] Tensor: {}", e); break; }
|
|
||||||
};
|
|
||||||
let logits = match model.forward(&input, pos) {
|
|
||||||
Ok(l) => l,
|
|
||||||
Err(e) => { console_log!("[Coder] Forward pos {}: {}", pos, e); break; }
|
|
||||||
};
|
|
||||||
|
|
||||||
|
// Prefill
|
||||||
|
let input = Tensor::new(input_ids.as_slice(), &device).and_then(|t| t.unsqueeze(0)).unwrap();
|
||||||
|
let logits = cached.model.forward(&input, 0).unwrap();
|
||||||
let logits = logits.squeeze(0).unwrap();
|
let logits = logits.squeeze(0).unwrap();
|
||||||
let logits = if logits.dims().len() == 2 {
|
let logits = if logits.dims().len() == 2 {
|
||||||
logits.get(logits.dim(0).unwrap() - 1).unwrap()
|
logits.get(logits.dim(0).unwrap() - 1).unwrap()
|
||||||
} else {
|
} else { logits };
|
||||||
logits
|
|
||||||
};
|
|
||||||
next_token = crate::sampling::sample_top_k(&logits, 10, 5.0);
|
|
||||||
pos += 1;
|
|
||||||
|
|
||||||
if next_token == eos_token { break; }
|
let mut next_token = crate::sampling::sample_top_k_with_penalty(&logits, top_k, temperature, &all_generated, repetition_penalty);
|
||||||
|
|
||||||
if let Ok(text) = tokenizer.decode(&[next_token], true) {
|
if next_token != eos_token {
|
||||||
generated_text.push_str(&text);
|
if let Ok(text) = cached.tokenizer.decode(&[next_token], true) {
|
||||||
let mut chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "Qwen2.5-Coder" });
|
generated_text.push_str(&text);
|
||||||
|
let mut chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "Qwen2.5-Coder" });
|
||||||
if let Some(ref tid) = task_id { chunk.as_object_mut().unwrap().insert("task_id".to_string(), serde_json::json!(tid)); }
|
if let Some(ref tid) = task_id { chunk.as_object_mut().unwrap().insert("task_id".to_string(), serde_json::json!(tid)); }
|
||||||
let _ = ws.borrow().send_with_str(&chunk.to_string());
|
let _ = ws.borrow().send_with_str(&chunk.to_string());
|
||||||
|
}
|
||||||
|
all_generated.push(next_token);
|
||||||
|
tokens_generated += 1;
|
||||||
}
|
}
|
||||||
tokens_generated += 1;
|
|
||||||
|
|
||||||
// Yield — vapautetaan selaimen event loop joka tokenin jälkeen
|
// Autoregressive
|
||||||
crate::sleep_ms(0).await;
|
let mut pos = input_len;
|
||||||
}
|
for _ in 1..max_new_tokens {
|
||||||
|
if next_token == eos_token { break; }
|
||||||
|
|
||||||
|
let input = Tensor::new(&[next_token], &device).and_then(|t| t.unsqueeze(0)).unwrap();
|
||||||
|
let logits = match cached.model.forward(&input, pos) {
|
||||||
|
Ok(l) => l,
|
||||||
|
Err(e) => { console_log!("[Coder] Forward pos {}: {}", pos, e); break; }
|
||||||
|
};
|
||||||
|
|
||||||
|
let logits = logits.squeeze(0).unwrap();
|
||||||
|
let logits = if logits.dims().len() == 2 {
|
||||||
|
logits.get(logits.dim(0).unwrap() - 1).unwrap()
|
||||||
|
} else { logits };
|
||||||
|
next_token = crate::sampling::sample_top_k_with_penalty(&logits, top_k, temperature, &all_generated, repetition_penalty);
|
||||||
|
pos += 1;
|
||||||
|
|
||||||
|
if next_token == eos_token { break; }
|
||||||
|
|
||||||
|
if let Ok(text) = cached.tokenizer.decode(&[next_token], true) {
|
||||||
|
generated_text.push_str(&text);
|
||||||
|
|
||||||
|
// Stop-sekvenssit: katkaistaan kun malli alkaa selittää
|
||||||
|
let lower = generated_text.to_lowercase();
|
||||||
|
if lower.contains("\n###") || lower.contains("\nexplanation") || lower.contains("\nnote:") || lower.contains("\noutput:") || lower.contains("\n```\n\n") || lower.contains("\n// example") || lower.contains("\n# example") {
|
||||||
|
for stop in &["\n###", "\nExplanation", "\nNote:", "\nOutput:", "\n```\n\n", "\n// Example", "\n// example", "\n# Example", "\n# example"] {
|
||||||
|
if let Some(pos) = generated_text.find(stop) {
|
||||||
|
generated_text.truncate(pos);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "Qwen2.5-Coder" });
|
||||||
|
if let Some(ref tid) = task_id { chunk.as_object_mut().unwrap().insert("task_id".to_string(), serde_json::json!(tid)); }
|
||||||
|
let _ = ws.borrow().send_with_str(&chunk.to_string());
|
||||||
|
}
|
||||||
|
all_generated.push(next_token);
|
||||||
|
tokens_generated += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
let gen_time = perf.now() - start_gen;
|
||||||
|
|
||||||
|
// Siivotaan vastaus: poista markdown-koodiblokit ja johdantotekstit
|
||||||
|
let cleaned = strip_markdown_wrapper(&generated_text);
|
||||||
|
|
||||||
|
(cleaned, tokens_generated, gen_time)
|
||||||
|
});
|
||||||
|
|
||||||
let gen_time = perf.now() - start_gen;
|
|
||||||
let tokens_per_sec = if gen_time > 0.0 { (tokens_generated as f64 / gen_time) * 1000.0 } else { 0.0 };
|
let tokens_per_sec = if gen_time > 0.0 { (tokens_generated as f64 / gen_time) * 1000.0 } else { 0.0 };
|
||||||
console_log!("[Coder] {} tokenia | {:.0}ms | {:.1} tok/s", tokens_generated, gen_time, tokens_per_sec);
|
console_log!("[Coder] {} tokenia | {:.0}ms | {:.1} tok/s", tokens_generated, gen_time, tokens_per_sec);
|
||||||
|
|
||||||
|
|||||||
@@ -1,39 +1,105 @@
|
|||||||
use candle_core::Tensor;
|
use candle_core::Tensor;
|
||||||
|
use std::cell::Cell;
|
||||||
|
|
||||||
/// Top-k sampling ilman softmaxia — kiertää Candlen SoftmaxLastDim Wasm-bugin.
|
thread_local! {
|
||||||
/// Valitsee top-k logiteista ja poimii satunnaisen (painotettu).
|
static RNG_STATE: Cell<u64> = Cell::new(0);
|
||||||
/// Jos k=1, toimii kuten argmax (greedy).
|
}
|
||||||
pub fn sample_top_k(logits: &Tensor, k: usize, eos_penalty: f32) -> u32 {
|
|
||||||
// Muunnetaan Vec<f32>:ksi
|
fn next_rand() -> f32 {
|
||||||
let logits_vec: Vec<f32> = logits.to_vec1::<f32>().unwrap_or_default();
|
RNG_STATE.with(|state| {
|
||||||
|
let mut s = state.get();
|
||||||
|
if s == 0 {
|
||||||
|
s = (js_sys::Date::now() * 1000.0) as u64 | 1;
|
||||||
|
}
|
||||||
|
s ^= s << 13;
|
||||||
|
s ^= s >> 7;
|
||||||
|
s ^= s << 17;
|
||||||
|
state.set(s);
|
||||||
|
(s % 10000) as f32 / 10000.0
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Top-k sampling with temperature and repetition penalty.
|
||||||
|
/// `generated_tokens` sisältää aiemmin generoidut token-id:t toiston estämiseksi.
|
||||||
|
pub fn sample_top_k_with_penalty(logits: &Tensor, k: usize, temperature: f32, generated_tokens: &[u32], repetition_penalty: f32) -> u32 {
|
||||||
|
let mut logits_vec: Vec<f32> = logits.to_vec1::<f32>().unwrap_or_default();
|
||||||
if logits_vec.is_empty() { return 0; }
|
if logits_vec.is_empty() { return 0; }
|
||||||
|
|
||||||
// Rangotaan ja otetaan top-k indeksit
|
// Repetition penalty
|
||||||
|
if repetition_penalty != 1.0 {
|
||||||
|
for &token_id in generated_tokens {
|
||||||
|
if (token_id as usize) < logits_vec.len() {
|
||||||
|
let logit = &mut logits_vec[token_id as usize];
|
||||||
|
if *logit > 0.0 {
|
||||||
|
*logit /= repetition_penalty;
|
||||||
|
} else {
|
||||||
|
*logit *= repetition_penalty;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Temperature scaling
|
||||||
|
if temperature > 0.0 && temperature != 1.0 {
|
||||||
|
for logit in logits_vec.iter_mut() {
|
||||||
|
*logit /= temperature;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Top-k
|
||||||
let mut indexed: Vec<(usize, f32)> = logits_vec.iter().enumerate().map(|(i, &v)| (i, v)).collect();
|
let mut indexed: Vec<(usize, f32)> = logits_vec.iter().enumerate().map(|(i, &v)| (i, v)).collect();
|
||||||
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||||
indexed.truncate(k);
|
indexed.truncate(k);
|
||||||
|
|
||||||
// EOS-penaltti: vähennetään EOS-tokenin logitia
|
if k == 1 || temperature == 0.0 {
|
||||||
for item in indexed.iter_mut() {
|
|
||||||
if item.0 == 2 || item.0 == 151645 { // SmolLM EOS=2, Qwen EOS=151645
|
|
||||||
item.1 -= eos_penalty;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if k == 1 {
|
|
||||||
return indexed[0].0 as u32;
|
return indexed[0].0 as u32;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Yksinkertainen "softmax" top-k:lle CPU:lla
|
// Softmax top-k:lle
|
||||||
let max_logit = indexed.iter().map(|x| x.1).fold(f32::NEG_INFINITY, f32::max);
|
let max_logit = indexed[0].1;
|
||||||
let exps: Vec<f32> = indexed.iter().map(|x| (x.1 - max_logit).exp()).collect();
|
let exps: Vec<f32> = indexed.iter().map(|x| (x.1 - max_logit).exp()).collect();
|
||||||
let sum: f32 = exps.iter().sum();
|
let sum: f32 = exps.iter().sum();
|
||||||
let probs: Vec<f32> = exps.iter().map(|e| e / sum).collect();
|
let probs: Vec<f32> = exps.iter().map(|e| e / sum).collect();
|
||||||
|
|
||||||
// Satunnainen valinta kumulatiivisella todennäköisyydellä
|
let rand_val = next_rand();
|
||||||
// Käytetään yksinkertaista XorShift-satunnaislukugeneraattoria (ei tarvita getrandom)
|
|
||||||
let seed = (js_sys::Date::now() * 1000.0) as u64;
|
let mut cumulative = 0.0;
|
||||||
let rand_val = ((seed ^ (seed >> 13) ^ (seed << 7)) % 10000) as f32 / 10000.0;
|
for (i, p) in probs.iter().enumerate() {
|
||||||
|
cumulative += p;
|
||||||
|
if rand_val < cumulative {
|
||||||
|
return indexed[i].0 as u32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
indexed[0].0 as u32
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Alkuperäinen API yhteensopivuudeksi SmolLM/Qwen-moduulien kanssa
|
||||||
|
pub fn sample_top_k(logits: &Tensor, k: usize, eos_penalty: f32) -> u32 {
|
||||||
|
let mut logits_vec: Vec<f32> = logits.to_vec1::<f32>().unwrap_or_default();
|
||||||
|
if logits_vec.is_empty() { return 0; }
|
||||||
|
|
||||||
|
// EOS-penaltti
|
||||||
|
for &eos_id in &[2u32, 151645] {
|
||||||
|
if (eos_id as usize) < logits_vec.len() {
|
||||||
|
logits_vec[eos_id as usize] -= eos_penalty;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut indexed: Vec<(usize, f32)> = logits_vec.iter().enumerate().map(|(i, &v)| (i, v)).collect();
|
||||||
|
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||||
|
indexed.truncate(k);
|
||||||
|
|
||||||
|
if k == 1 {
|
||||||
|
return indexed[0].0 as u32;
|
||||||
|
}
|
||||||
|
|
||||||
|
let max_logit = indexed[0].1;
|
||||||
|
let exps: Vec<f32> = indexed.iter().map(|x| (x.1 - max_logit).exp()).collect();
|
||||||
|
let sum: f32 = exps.iter().sum();
|
||||||
|
let probs: Vec<f32> = exps.iter().map(|e| e / sum).collect();
|
||||||
|
|
||||||
|
let rand_val = next_rand();
|
||||||
|
|
||||||
let mut cumulative = 0.0;
|
let mut cumulative = 0.0;
|
||||||
for (i, p) in probs.iter().enumerate() {
|
for (i, p) in probs.iter().enumerate() {
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user