From 92c952c07a3093d67e8b98166f83eb4834b25964 Mon Sep 17 00:00:00 2001 From: jaakko Date: Thu, 2 Apr 2026 15:47:48 +0300 Subject: [PATCH] =?UTF-8?q?kyl=C3=A4=20l=C3=A4htee!?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- network-poc/cargo-errors.log | 4 + network-poc/docker-compose.yml | 2 +- network-poc/hub/nodes.db | Bin 28672 -> 32768 bytes network-poc/node/src/burn_smollm/attention.rs | 118 ++++++++++++++ network-poc/node/src/burn_smollm/config.rs | 28 ++++ network-poc/node/src/burn_smollm/loader.rs | 90 +++++++++++ network-poc/node/src/burn_smollm/mod.rs | 6 + network-poc/node/src/burn_smollm/model.rs | 96 ++++++++++++ network-poc/node/src/burn_smollm/modules.rs | 59 +++++++ network-poc/node/src/burn_smollm/rope.rs | 59 +++++++ network-poc/node/src/lib.rs | 1 + network-poc/node/src/smollm.rs | 145 ++++++++---------- 12 files changed, 524 insertions(+), 84 deletions(-) create mode 100644 network-poc/cargo-errors.log create mode 100644 network-poc/node/src/burn_smollm/attention.rs create mode 100644 network-poc/node/src/burn_smollm/config.rs create mode 100644 network-poc/node/src/burn_smollm/loader.rs create mode 100644 network-poc/node/src/burn_smollm/mod.rs create mode 100644 network-poc/node/src/burn_smollm/model.rs create mode 100644 network-poc/node/src/burn_smollm/modules.rs create mode 100644 network-poc/node/src/burn_smollm/rope.rs diff --git a/network-poc/cargo-errors.log b/network-poc/cargo-errors.log new file mode 100644 index 0000000..d4edf5f --- /dev/null +++ b/network-poc/cargo-errors.log @@ -0,0 +1,4 @@ +error: failed to write `/home/jaakko/code/kipinä/digikipinae/agentic-office/network-poc/target/wasm32-unknown-unknown/debug/.fingerprint/num-traits-0a015ef9fd3732e0/run-build-script-build-script-build` + +Caused by: + Permission denied (os error 13) diff --git a/network-poc/docker-compose.yml b/network-poc/docker-compose.yml index f85ae0b..46d4c4b 100644 --- a/network-poc/docker-compose.yml +++ b/network-poc/docker-compose.yml @@ -9,7 +9,7 @@ services: volumes: - .:/app # Käännetään aina käynnistyksen yhteydessä varmuuden vuoksi Wasm uusimmista koodeista, ja päälle pyöräytetään Hub! - command: bash -c "cd node && wasm-pack build --dev --target web --out-dir ../static/pkg && cd ../hub && cargo run" + command: bash -c "cd node && wasm-pack build --target web --out-dir ../static/pkg && cd ../hub && cargo run" # Valinnainen natiivi-solmu — kerää oikeat laitteistotiedot (nvidia-smi-taso) native-node: diff --git a/network-poc/hub/nodes.db b/network-poc/hub/nodes.db index 9e5d54a746774e3fcd8f2714d2e7b7e2f6b0aa3d..b30b85ca6cce9b85de0811eb6cc705dad30cf06c 100644 GIT binary patch literal 32768 zcmeHQTWlQHdEQ;_E|*+VR9xAVtXLV>kts&xcJ7yT94!^mimaO}QnFnKC5Fo(ceIzG zc4jHlPo5Y-VzfZozBCEiKD109>$VR?TeN^%q6J#CNzo!fFE4%x47BJ&(Z(o}0_}I^ zLe9>fy@;@4N1c&PF1crCzVrS6`OkmO|Nn<8KYPt|3}(}6H+6>*QpZ#2bgIHIsZ{1> zDwQhW|L_{ei){DCQ*c)$4e#8RKe$*&Uz5(W|m5(W|m z5(W|mK1>E)$>wLzpHG{PzTPm_>}t(u>T5ej+cvFM?8n^aZ!IsaE;Fl37q2Zdu{W7> z!H1YDH&&M~FW;Jf%g&F_o;{oX%Di{)ZQX3IwGF$|aO~*oOxInb?=a^knb>7+gtup2 zx^?CH(ydpStIMx4OSe~VUb%rAU0=Sjy3q5fmbGE5MLv4t<|=dJ_O)xhEvs!qcZ`iS zczmnNub@@mmPTvMF<$qZ_1vtW`H!cGwyE4TdzqU8*&`ZLe(`ZU24@L<=f(UDdzOlAlb^6?Eqtn(Mq8iNu@LhDzzm)Ox=~YkW_Pqvt=r!B8*D+dEBlhE1 zm*it_o@?n%V*#I>|EiN4pPida|MHx7PfwV(VG{+lyMD-a-LLDdk3f~o?O`h9A#`@P zjqZlbrR5ivZeLqv=GNQRecNcy^{PkJYPF2&XqviV+C9w|CcC6YM%m&b^Wy$q`>tU- zMx(jE=a|NZ<@7Y$Zs-nK6M2|>fz>SAA(HF!wq^Huqq^N$t6DG*{f7i;(!@w^^tnY_ zZ>}}h4?o`n9X-7K((RsKZ{fZJAd|OujMj$L?(5Sk#xZKn{@%|1UJpTf zqhVD&v$z$ZH~vi9X!siG=yu&B;!u%2YiGV;K$Cm!4yn9stV1x1PN#16yj|1nwfn|8 zdd*zCdGp%x(hd4ORzvvcwr+#$o;wj-Hk7Qn-7qk6LiP?&vR$8;Cu{TL>G9Oe*QdWf z#TPG4u1&mMaP!~JJNc>X|BPM9yq5mY^e>>nqkXyQ<5Q{h%5%@B)8n|ZJj=_AthmVX zs~lGr*fJ}YI89(BUKQ2zEPiH}Ux$$}TY*@a^|onjHg-!Z1}203z3|tK-P$f}jakYU zD#w27d;g28GCukIo1n8&`Rz?|RZ@RLuD{)!ArR)~354``Aqs&l^GcZ$N`fGZBFk&M z58>q+W^mJHZgp(u!gaHb$677_r)<75>5HtA{^C4&#aCtp@7g>z_&Kal;#5HrR8=|B zb1wYRH^}u(w?2Ig9dbE$)@L=|CAlnlk5V{6lsLaj7k3$_W-z;k-e!1KU>BI}wzUPz z(uA|d+&3Eyre!e=tF>jcnfq%2V9XXQyS_p0vg@_K)4?pFdpBdO z+Fe+E%e)KT3=>uW-&*^7;H_@9?^8vc<9-9~b%J1!4>=fzkmoC3m?AKe7^p=-ZPLl&-uQc^7vt=+kbKhz=HkkW0qirw- z9F$#jtGV5Aj5ZjkSuJCMLH8KXH*m;%>-e0_n2t>dg3|%CTg*~64W_B< zW}~SuFn2B6*73Ct#Wz|ECQsNr58QUA2E}iACYJcvj$<;rHB8DF3!oPmi8EsXnz>_{ zO}&Nt;C-WRn}FGH^c|hNbd0*Szt`9W19c0(0Ob8WOw)|hX&5`YX?I}8q2rDHJ-G1O z`+JZcQ}5ux9lLHhbx3Q0v7sgVduLv(ocOmdGtm6X*zcTplYC#l;|;hA@8JhrUz#p@ z3aY#3`-w~-BGY)8lciwTFSQ+W)2y0$gYaxL8s-*p`Uss)8{Hya0}%z$Vqd5LhyFqt z^<5j}b{qz-Y)#+B2hgSoR|1H)TI8;J3$leYw#fZf11^|mr{g%f9*O}kKzXBq$qcOF z`vO{^0fV_260sQ!l{zGcX^uP^ABI0-wI2IqW%g@dWYF(Qp7b8yulk1P!asPLENo5= zG7NH=_eLwLaGWF@#Q=qd;gt#UobzYX{Al&%`> zw#7g@p&+^ZSms1{=eO9xutDfqH&K8s=Vbrp$sqhl2Vri!YYxhB{V@~RY4Ga zkv)Qg&|S%UJj}XeP~ghP8~~|I<`0eI;Q>g&{Ba5PwQa(r7#qQSK}-)uJ+Umox{6pP zo!Qjws@cM#7Vl%7R5iAp(9VBBzpLX8j?-b_QxYShGjMZtvk8l>6Kh*H;dWX(WO8iT z_@udjS9V=5&Q?yHc;h`{O0wsE@GklO+6%sk{g1#0SkHOAI_;kJd)4Jx!L5}ymx9i@t2T!-%++|wkR?P_}LUho=jAQ`NT};)$&rTV%2(mRN5X z<(mzYv3FZ_;ugX4Z|=73`mQ%`?e6~p<}w`vi(L(dq3ZF`XL6a=a`BR_7 zd+Ql+MT7sQDv6E#2pye&3VqYeFp*(edi>@5yZIM$zngnL`>pKc*uB)RW-h0#%%40&LfqT4CsXOCPpzz^&d;Tn z$DTWNDwDy+c&bns;}jn8p%Po-$bR;Yaj-YSt@4P~BYwn|SdNo5kq=|_u_jts6ib@I zD~Rf8!FE^O+{T9QVr$1l_;-IIxm)wh=aIZ~ZuR;+V`__> zsIlzT7g9g0-MDk*(v>CVvhjk|Mkw&s>MM-MYOMe3JDo;dZ>{~;esLQ zkbWs;fkVFqiRTqD*jA18ODT(NNzo)kD<$#A)vx05vkO`KB++nwZfJ6*5vH;kUEOIihNpcYHyJogJuQRW!@|rB77R48epFfk`u9?CjLagS` zx%`tpnW-Kn89#g~O-V$qEJ`I25eJcFW!B%Le&`xdy!7l?)_vj$pQ@D}sxpV5=~PK& zRw;2T&tjie(|puDY>H}0R1x3Q#IR!{v3bTl`8XBWGZX1yz<3NO=qAUrqJl676&PQ} z7Bj0!n!xdV5Ca(m%x+qZMstx9q-NfI`~(%+QzJkV5Mx5#j3TS56o$=WoINPCd-sjj zA}g(r-}@;aR!r-knkZz4h>~Z^f?84)1#xITjHVBwBYdVi`xupqXC?~6q{L}uNiAtS zbb}LBs+0&U83qrbR#^@;Y!Ix7l$=jc$pElnWN>O(luMc_aU!x~g4BW$pX#1CPUQnK z7|I7S8rYHqNvSHb0ccuFkRkC)7Fgs%X+d~$5VBBA_hzVM0M+CW&EUi`CzdcsBt*MM z8NuNr-4n<9F;N)yOn;DYstCKH1QSR1AR(~0k|ZPlOp&z_WON8?x*x2_j7PP19A6fA zP(kzpCXXs0Acr@=Y8>Lna$pii5s*7OMTIqrx*}}Nmqh5Q!tzR>6fw5=GB1}TUPGW* z;YQG#$T)Imi~W!d(`Z&L^BCbU#40O?NrEwCY#F|j3S_J($$=X&h_DDUXVM22Gk2hD zdWeFtGLa;wvB;&;!la%NpX$y|^y8yA?6JrlkxDEQiQo?L)DaG3oCHaMC57ik(CY{v z?rfnSvWelyM82d6JaRcOG0{{T$S~W(DdN0Vev^?o8}ElK;u%o~9V?YN9#a4mM9?&< zpNS?C50_2QR8Cb!7<65n<@+HU<|(pbnb%4hk|;ok8l>ZOYd4U=0~V3Gs_~NcU|xnh zo9l;c=(vM%6iPBKO=8s`Hz$S+;AQyciYAJ(5P%<9n%N^GgKwjh1SBotbo|C9KCcy5!I3rYfrN&G*Oj|ne4;_D{y|7>)9ki`FwkeQIg z|9=wW|A)D?Wc`0KwUe4LkA3HuH2oXX`KfyG?Z8V{zJsFgn9;6RagZ%NU>zH5C?&ZQ+7lR=lTt% zBjKnrY3aE>TY6)#qL{!V1QUVo<$SU+4eNih564KbP zi-jW!1KBUg0cx5;sW~->gdDg`@LsTgmUJ{ zbbBXup}GkL6$i&V9rSCo0p6GH3+~EO6uKv%-;s=3D%}#vuVh(Ol93rlt&Jfg;3J^> zyK`CllRn$A#A#GV4ciVN8{RPO7`2L~@f5q=m&%jw>ghhFB5JazTymr&OXOj%u*ptM zn??Z@%seY{iXy>Q__p(*t8VeNv#6kZRJZXT&B{RciJ&wv77#jWY2bmW@lB{75W>-A z_`O7N_4LXiXniLp5G6jb2iJ-zTHh7M@*H0ZhqeoK@YPvqQMTj5c@bkyBP;q+f z+o{v3ce2yJf2=$+IrV2#!uWf|GsSO>zcTqt6F*=0PX6yEbGbjxU(E5@y+SRkja|@#NY3gQ|jy~MYW(nSrGE@4vABvq2yT8I4M~F z#Ig&5#ro96F8?{`9BMcQ)k?!T_HbZN8nucNOaApP)(>%@s!^50QvVo^=+5P{ZmvRu zGkgyNW$-B7W>Lw)Nuq)xp4pyh;g)Tog!ph(+3_Js;0Fb}1 zAh9y35AR)zV;+xo`AK)`d0Ll8kpLrNP-L(m)d1i1)|Ufg{L@5AuObfb~y&8 z=w?5AM7;zW0f#oo8Y>Ed6tBJYR0iY_G^p7oYXCYqzFtmzKQ)ff@cQRyau~jSo z82ti$!57$|oZJt`;EhY7aAj77&f!w|J~*(`ila{yfm0-{Pla26qo~kukb1Vkd$}GC z6sxINM#6vZse%nD^Kc-q3LCz32S2JTZce138AbPTiUl^36a`ywd~Zc!h=ve{l3XZ@ zGFUq7k#S%P6~%>ugcG!JQUlFC1p*ui$KsK3px}_yK}sTON5jq@c#sds$9ma^5p013 zQv(r8WMik_5XOi=Y901RwQ$2fW{U%p^oHS13AFBxGA;1lSr0NM1r3LVP^X1JLwFpG z!E$UMnHZF)GG^T@f5d44Xs{QH^HMPHid=j@%OgKTBZ-)!`jlbuu7#gJ(s7#ck? za0?QZHSDtwo4AQ4V)`OB0tYU|OH#X(1QWZ8Ch1zcx z`}tH?CLA{dp+sZC+0s$v=fxjnOe-23D~g*3JMWKk9;;Qw2&}=^0|f z92M&FB-ZUg{-dg25`~^K+C04tai9&h^n(rNC?8mK92kwInO{p?iljt_3y=9idhqN5DBj;y zL#$YFXFfyi&jZ~AVVs@~^RE6#66mL?HR+GHze4_;S(_X$bspWQl91yAD-(6{h?0lp Ji!a3X{{fSjj#2;s delta 522 zcmZo@U}|{4I6;byQHOznfqkNa9TTI@#)KvO91H?L0R{$^&4LOu`Is9y87D7D5}$mj zNXFbS#K^$P*xbs*T+h(R!ot+j*u+fRz`)ADz{tSJOxM6f*8r@<#MH{bT+h_l+}yy> zz|e5AV5t~yNq%-}US?G)8_;N%$$>?}leZ?x2|<*bSeaVr0acnbaxeicGlp1X0v9$j zva~cJ(jpOWq2$DXH@w4XU=j7yW{%wCwoQ1hjgm3b~ zT#3z3!(}+MSw$Iym}D6FCh`8@jpI4aqt88s>o=D&dn(5@_J91EY#*3Uv7TW1!m@x# zW@FwlGZ%$TBy*7D5+BHUn$@ZqAljo;cOcuA3nPkGrt;J@`u$h5VQ8JievR10g zWRsvoHl7D+EX8tAl`4~s(nMgY?K#)MRnJTVtDc;gUctybIWcSY=FJ)F7+HXxD4FbN dD$Fd;$vIgtO>y$!6tMQqAF>2kC*MlH4ge@uk;(u7 diff --git a/network-poc/node/src/burn_smollm/attention.rs b/network-poc/node/src/burn_smollm/attention.rs new file mode 100644 index 0000000..b69acfb --- /dev/null +++ b/network-poc/node/src/burn_smollm/attention.rs @@ -0,0 +1,118 @@ +use burn::module::{Module, Param}; +use burn::tensor::{backend::Backend, Tensor}; +use super::rope::RoPE; +use super::config::SmolLMConfig; + +#[derive(Clone, Debug)] +pub struct KVCache { + pub k: Tensor, + pub v: Tensor, +} + +#[derive(Module, Debug)] +pub struct Attention { + pub q_proj: Param>, // [hidden, num_heads * head_dim] + pub k_proj: Param>, // [hidden, num_kv_heads * head_dim] + pub v_proj: Param>, // [hidden, num_kv_heads * head_dim] + pub o_proj: Param>, // [num_heads * head_dim, hidden] + + num_heads: usize, + num_kv_heads: usize, + head_dim: usize, + + rope: RoPE, +} + +impl Attention { + pub fn new(config: &SmolLMConfig, device: &B::Device) -> Self { + let head_dim = config.hidden_size / config.num_attention_heads; + + Self { + q_proj: Param::from_tensor(Tensor::zeros([config.hidden_size, config.num_attention_heads * head_dim], device)), + k_proj: Param::from_tensor(Tensor::zeros([config.hidden_size, config.num_key_value_heads * head_dim], device)), + v_proj: Param::from_tensor(Tensor::zeros([config.hidden_size, config.num_key_value_heads * head_dim], device)), + o_proj: Param::from_tensor(Tensor::zeros([config.num_attention_heads * head_dim, config.hidden_size], device)), + + num_heads: config.num_attention_heads, + num_kv_heads: config.num_key_value_heads, + head_dim, + + rope: RoPE::new(head_dim, config.max_position_embeddings, config.rope_theta, device), + } + } + + pub fn forward( + &self, + x: Tensor, + offset: usize, + cache: Option> + ) -> (Tensor, KVCache) { + let [batch, seq_len, hidden_dim] = x.dims(); + + // Project Q, K, V: x @ W -> [batch, seq, proj_dim] + let q = x.clone().matmul(self.q_proj.val().unsqueeze()); + let k = x.clone().matmul(self.k_proj.val().unsqueeze()); + let v = x.matmul(self.v_proj.val().unsqueeze()); + + // Reshape: [batch, seq, heads, head_dim] -> [batch, heads, seq, head_dim] + let q = q.reshape([batch, seq_len, self.num_heads, self.head_dim]).swap_dims(1, 2); + let k = k.reshape([batch, seq_len, self.num_kv_heads, self.head_dim]).swap_dims(1, 2); + let v = v.reshape([batch, seq_len, self.num_kv_heads, self.head_dim]).swap_dims(1, 2); + + // Apply RoPE + let q = self.rope.forward(q, offset); + let k = self.rope.forward(k, offset); + + // KV cache + let (k, v) = if let Some(c) = cache { + (Tensor::cat(vec![c.k, k], 2), Tensor::cat(vec![c.v, v], 2)) + } else { + (k, v) + }; + + let new_cache = KVCache { k: k.clone(), v: v.clone() }; + let kv_len = k.dims()[2]; + + // GQA: repeat K,V heads — [batch, kv_heads, kv_len, hd] -> [batch, num_heads, kv_len, hd] + let num_reps = self.num_heads / self.num_kv_heads; + let k = if num_reps > 1 { + let [b, kv_h, s, hd] = k.dims(); + k.reshape([b, kv_h, 1, s, hd]).repeat_dim(2, num_reps).reshape([b, self.num_heads, s, hd]) + } else { k }; + let v = if num_reps > 1 { + let [b, kv_h, s, hd] = v.dims(); + v.reshape([b, kv_h, 1, s, hd]).repeat_dim(2, num_reps).reshape([b, self.num_heads, s, hd]) + } else { v }; + + // Attention: Q @ K^T / sqrt(d) + let scale = 1.0 / (self.head_dim as f64).sqrt(); + let scores = q.matmul(k.swap_dims(2, 3)).mul_scalar(scale); + // scores: [batch, heads, seq_len, kv_len] + + // Causal mask for prefill (seq_len > 1) + let scores = if seq_len > 1 { + let mask_data: Vec = (0..seq_len).flat_map(|i| { + (0..kv_len).map(move |j| { + if j > offset + i { f32::NEG_INFINITY } else { 0.0 } + }) + }).collect(); + let mask = Tensor::::from_data( + burn::tensor::TensorData::new(mask_data, [seq_len, kv_len]), + &scores.device() + ).reshape([1, 1, seq_len, kv_len]); + scores + mask + } else { + scores + }; + + let attn_weights = burn::tensor::activation::softmax(scores, 3); + + let context = attn_weights.matmul(v); + // [batch, heads, seq, hd] -> [batch, seq, heads*hd] + let context = context.swap_dims(1, 2).reshape([batch, seq_len, self.num_heads * self.head_dim]); + + let output = context.matmul(self.o_proj.val().unsqueeze()); + + (output, new_cache) + } +} diff --git a/network-poc/node/src/burn_smollm/config.rs b/network-poc/node/src/burn_smollm/config.rs new file mode 100644 index 0000000..ac0b263 --- /dev/null +++ b/network-poc/node/src/burn_smollm/config.rs @@ -0,0 +1,28 @@ +#[derive(Clone, Debug)] +pub struct SmolLMConfig { + pub hidden_size: usize, + pub intermediate_size: usize, + pub vocab_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: usize, + pub rms_norm_eps: f64, + pub rope_theta: f32, + pub max_position_embeddings: usize, +} + +impl Default for SmolLMConfig { + fn default() -> Self { + Self { + hidden_size: 576, + intermediate_size: 1536, + vocab_size: 49152, + num_hidden_layers: 30, + num_attention_heads: 9, + num_key_value_heads: 3, + rms_norm_eps: 1e-5, + rope_theta: 10000.0, + max_position_embeddings: 2048, + } + } +} diff --git a/network-poc/node/src/burn_smollm/loader.rs b/network-poc/node/src/burn_smollm/loader.rs new file mode 100644 index 0000000..d5cf29b --- /dev/null +++ b/network-poc/node/src/burn_smollm/loader.rs @@ -0,0 +1,90 @@ +use burn::tensor::{backend::Backend, Tensor, Data}; +use candle_core::safetensors; +use candle_core::Device as CandleDevice; +use burn::module::Param; +use super::model::LlamaModel; +use super::config::SmolLMConfig; + +fn load_tensor_2d( + tensors_map: &std::collections::HashMap, + name: &str, + device: &B::Device, + shape_out_in: [usize; 2] +) -> Result>, String> { + let t = tensors_map.get(name).ok_or_else(|| format!("Puuttuu: {}", name))?; + let t = t.to_dtype(candle_core::DType::F32).unwrap(); + let vec = t.flatten_all().unwrap().to_vec1::().unwrap(); + let t_burn = Tensor::::from_data(burn::tensor::TensorData::new(vec, shape_out_in), device); + // transpose from [out, in] to [in, out] + Ok(Param::from_tensor(t_burn.transpose())) +} + +fn load_tensor_1d( + tensors_map: &std::collections::HashMap, + name: &str, + device: &B::Device, + _shape: [usize; 1] +) -> Result>, String> { + let t = tensors_map.get(name).ok_or_else(|| format!("Puuttuu: {}", name))?; + let t = t.to_dtype(candle_core::DType::F32).unwrap(); + let vec = t.flatten_all().unwrap().to_vec1::().unwrap(); + Ok(Param::from_tensor(Tensor::::from_floats(vec.as_slice(), device))) +} + +fn load_embed( + tensors_map: &std::collections::HashMap, + name: &str, + device: &B::Device, + shape: [usize; 2] +) -> Result>, String> { + let t = tensors_map.get(name).ok_or_else(|| format!("Puuttuu: {}", name))?; + let t = t.to_dtype(candle_core::DType::F32).unwrap(); + let vec = t.flatten_all().unwrap().to_vec1::().unwrap(); + // Embed ei transponoi samalla tavalla, se pysyy [vocab, hidden] + Ok(Param::from_tensor(Tensor::::from_data(burn::tensor::TensorData::new(vec, shape), device))) +} + +pub fn load_safetensors_to_model( + buffer: &[u8], + config: &SmolLMConfig, + device: &B::Device +) -> Result, String> { + + let mut model = LlamaModel::new(config, device); + let tensors_map = safetensors::load_buffer(buffer, &CandleDevice::Cpu) + .map_err(|e| format!("Virhe Safetensors luennassa: {}", e))?; + + // Embeddings + model.embed_tokens = load_embed(&tensors_map, "model.embed_tokens.weight", device, [config.vocab_size, config.hidden_size])?; + model.norm.weight = load_tensor_1d(&tensors_map, "model.norm.weight", device, [config.hidden_size])?; + model.lm_head = load_embed(&tensors_map, "lm_head.weight", device, [config.vocab_size, config.hidden_size]).or_else(|_| { + load_embed(&tensors_map, "model.embed_tokens.weight", device, [config.vocab_size, config.hidden_size]) + })?; + + let head_dim = config.hidden_size / config.num_attention_heads; + + for i in 0..config.num_hidden_layers { + let prefix = format!("model.layers.{}", i); + + let layer = &mut model.layers[i]; + + // Norms + layer.input_layernorm.weight = load_tensor_1d(&tensors_map, &format!("{}.input_layernorm.weight", prefix), device, [config.hidden_size])?; + layer.post_attention_layernorm.weight = load_tensor_1d(&tensors_map, &format!("{}.post_attention_layernorm.weight", prefix), device, [config.hidden_size])?; + + // Attention + let num_heads = config.num_attention_heads; + let num_kv_heads = config.num_key_value_heads; + layer.self_attn.q_proj = load_tensor_2d(&tensors_map, &format!("{}.self_attn.q_proj.weight", prefix), device, [num_heads * head_dim, config.hidden_size])?; + layer.self_attn.k_proj = load_tensor_2d(&tensors_map, &format!("{}.self_attn.k_proj.weight", prefix), device, [num_kv_heads * head_dim, config.hidden_size])?; + layer.self_attn.v_proj = load_tensor_2d(&tensors_map, &format!("{}.self_attn.v_proj.weight", prefix), device, [num_kv_heads * head_dim, config.hidden_size])?; + layer.self_attn.o_proj = load_tensor_2d(&tensors_map, &format!("{}.self_attn.o_proj.weight", prefix), device, [config.hidden_size, num_heads * head_dim])?; + + // MLP + layer.mlp.gate_proj = load_tensor_2d(&tensors_map, &format!("{}.mlp.gate_proj.weight", prefix), device, [config.intermediate_size, config.hidden_size])?; + layer.mlp.up_proj = load_tensor_2d(&tensors_map, &format!("{}.mlp.up_proj.weight", prefix), device, [config.intermediate_size, config.hidden_size])?; + layer.mlp.down_proj = load_tensor_2d(&tensors_map, &format!("{}.mlp.down_proj.weight", prefix), device, [config.hidden_size, config.intermediate_size])?; + } + + Ok(model) +} diff --git a/network-poc/node/src/burn_smollm/mod.rs b/network-poc/node/src/burn_smollm/mod.rs new file mode 100644 index 0000000..3664e61 --- /dev/null +++ b/network-poc/node/src/burn_smollm/mod.rs @@ -0,0 +1,6 @@ +pub mod attention; +pub mod config; +pub mod loader; +pub mod model; +pub mod modules; +pub mod rope; diff --git a/network-poc/node/src/burn_smollm/model.rs b/network-poc/node/src/burn_smollm/model.rs new file mode 100644 index 0000000..9a4f485 --- /dev/null +++ b/network-poc/node/src/burn_smollm/model.rs @@ -0,0 +1,96 @@ +use burn::module::{Module, Param}; +use burn::tensor::{backend::Backend, Tensor, Int}; +use super::modules::{RmsNorm, Mlp}; +use super::attention::{Attention, KVCache}; +use super::config::SmolLMConfig; + +#[derive(Module, Debug)] +pub struct LlamaBlock { + pub self_attn: Attention, + pub mlp: Mlp, + pub input_layernorm: RmsNorm, + pub post_attention_layernorm: RmsNorm, +} + +impl LlamaBlock { + pub fn new(config: &SmolLMConfig, device: &B::Device) -> Self { + Self { + self_attn: Attention::new(config, device), + mlp: Mlp::new(config.hidden_size, config.intermediate_size, device), + input_layernorm: RmsNorm::new(config.hidden_size, config.rms_norm_eps, device), + post_attention_layernorm: RmsNorm::new(config.hidden_size, config.rms_norm_eps, device), + } + } + + pub fn forward( + &self, + x: Tensor, + offset: usize, + cache: Option> + ) -> (Tensor, KVCache) { + let residual = x.clone(); + let x_norm = self.input_layernorm.forward(x); + + let (attn_out, new_cache) = self.self_attn.forward(x_norm, offset, cache); + + let x = residual + attn_out; + + let residual = x.clone(); + let x_norm = self.post_attention_layernorm.forward(x); + let mlp_out = self.mlp.forward(x_norm); + + let x = residual + mlp_out; + (x, new_cache) + } +} + +#[derive(Module, Debug)] +pub struct LlamaModel { + pub embed_tokens: Param>, + pub layers: Vec>, + pub norm: RmsNorm, + pub lm_head: Param>, // For tie_word_embeddings this can point to embed_tokens +} + +impl LlamaModel { + pub fn new(config: &SmolLMConfig, device: &B::Device) -> Self { + let embed = Tensor::zeros([config.vocab_size, config.hidden_size], device); + let lm_head = Tensor::zeros([config.vocab_size, config.hidden_size], device); + + let mut layers = Vec::new(); + for _ in 0..config.num_hidden_layers { + layers.push(LlamaBlock::new(config, device)); + } + + Self { + embed_tokens: Param::from_tensor(embed), + layers, + norm: RmsNorm::new(config.hidden_size, config.rms_norm_eps, device), + lm_head: Param::from_tensor(lm_head), + } + } + + pub fn forward( + &self, + input_ids: Tensor, + offset: usize, + caches: &mut Vec>> + ) -> Tensor { + let [_batch, _seq_len] = input_ids.dims(); + + let mut x = burn::tensor::module::embedding(self.embed_tokens.val(), input_ids); + + for (i, layer) in self.layers.iter().enumerate() { + let cache = caches[i].take(); + let (out, new_cache) = layer.forward(x, offset, cache); + x = out; + caches[i] = Some(new_cache); + } + + x = self.norm.forward(x); + + // Matmul with lm_head (or embed_tokens if tied) to get logits + // Notice: lm_head is typically [vocab_size, hidden_size] in HF, so we swap dims + x.matmul(self.lm_head.val().swap_dims(0, 1).unsqueeze()) + } +} diff --git a/network-poc/node/src/burn_smollm/modules.rs b/network-poc/node/src/burn_smollm/modules.rs new file mode 100644 index 0000000..b1dc9cb --- /dev/null +++ b/network-poc/node/src/burn_smollm/modules.rs @@ -0,0 +1,59 @@ +use burn::module::{Module, Param}; +use burn::tensor::{backend::Backend, Tensor}; + +#[derive(Module, Debug)] +pub struct RmsNorm { + pub weight: Param>, + epsilon: f64, +} + +impl RmsNorm { + pub fn new(size: usize, epsilon: f64, device: &B::Device) -> Self { + let weight = Param::from_tensor(Tensor::ones([size], device)); + Self { weight, epsilon } + } + + pub fn forward(&self, x: Tensor) -> Tensor { + // x: [batch, seq_len, dim] + // RMSNorm: x * weight / sqrt(mean(x^2) + eps) + let x_sq = x.clone().powf_scalar(2.0); + // mean over last dim, keeping dims for broadcast + let [b, s, d] = x_sq.dims(); + let variance = x_sq.sum_dim(2).div_scalar(d as f32); + let norm = x.div(variance.add_scalar(self.epsilon).sqrt()); + + let w = self.weight.val().unsqueeze::<2>().unsqueeze::<3>().reshape([1, 1, d]); + norm * w + } +} + +#[derive(Module, Debug)] +pub struct Mlp { + pub gate_proj: Param>, // [in, intermediate] + pub up_proj: Param>, // [in, intermediate] + pub down_proj: Param>, // [intermediate, out] +} + +impl Mlp { + pub fn new(hidden_size: usize, intermediate_size: usize, device: &B::Device) -> Self { + Self { + gate_proj: Param::from_tensor(Tensor::zeros([hidden_size, intermediate_size], device)), + up_proj: Param::from_tensor(Tensor::zeros([hidden_size, intermediate_size], device)), + down_proj: Param::from_tensor(Tensor::zeros([intermediate_size, hidden_size], device)), + } + } + + pub fn forward(&self, x: Tensor) -> Tensor { + // x: [batch, seq, hidden] + // gate = x @ gate_proj -> [batch, seq, intermediate] + let gate = x.clone().matmul(self.gate_proj.val().unsqueeze()); + let up = x.matmul(self.up_proj.val().unsqueeze()); + + // SiLU(gate) * up + let silu = gate.clone() * burn::tensor::activation::sigmoid(gate); + let intermediate = silu * up; + + // intermediate @ down_proj -> [batch, seq, hidden] + intermediate.matmul(self.down_proj.val().unsqueeze()) + } +} diff --git a/network-poc/node/src/burn_smollm/rope.rs b/network-poc/node/src/burn_smollm/rope.rs new file mode 100644 index 0000000..2ed2993 --- /dev/null +++ b/network-poc/node/src/burn_smollm/rope.rs @@ -0,0 +1,59 @@ +use burn::module::Module; +use burn::tensor::{backend::Backend, Tensor}; + +#[derive(Module, Debug)] +pub struct RoPE { + cos_cache: Tensor, + sin_cache: Tensor, +} + +impl RoPE { + pub fn new(head_dim: usize, max_seq_len: usize, theta: f32, device: &B::Device) -> Self { + // (head_dim / 2) values + let half_dim = head_dim / 2; + let inv_freq: Vec = (0..half_dim) + .map(|i| 1.0 / theta.powf((2 * i) as f32 / head_dim as f32)) + .collect(); + + let inv_freq = Tensor::::from_floats(inv_freq.as_slice(), device).unsqueeze::<2>(); + let t_floats: Vec = (0..max_seq_len).map(|v| v as f32).collect(); + let t = Tensor::::from_floats(t_floats.as_slice(), device).unsqueeze::<2>().transpose(); + // t shape: [max_seq_len, 1] + // inv_freq shape: [1, half_dim] + + // freqs shape: [max_seq_len, half_dim] + let freqs = t.matmul(inv_freq); + + let cos_cache = freqs.clone().cos(); + let sin_cache = freqs.sin(); + + Self { + cos_cache, + sin_cache, + } + } + + pub fn forward(&self, x: Tensor, offset: usize) -> Tensor { + let [batch, heads, seq_len, head_dim] = x.dims(); + let half_dim = head_dim / 2; + + // x shape: [batch, heads, seq_len, head_dim] + // valitaan viipaleet (x1 ja x2) jotta saadaan pyöritettyä rotaatiot + let x1 = x.clone().slice([0..batch, 0..heads, 0..seq_len, 0..half_dim]); + let x2 = x.clone().slice([0..batch, 0..heads, 0..seq_len, half_dim..head_dim]); + + // haetaan vastaava seq offsetista alkaen + let cos = self.cos_cache.clone().slice([offset..offset+seq_len, 0..half_dim]) + .unsqueeze::<4>() // [seq, half_dim, 1] + .reshape([1, 1, seq_len, half_dim]); + let sin = self.sin_cache.clone().slice([offset..offset+seq_len, 0..half_dim]) + .reshape([1, 1, seq_len, half_dim]); + + // x1 * cos - x2 * sin + let o1 = x1.clone().mul(cos.clone()) - x2.clone().mul(sin.clone()); + // x2 * cos + x1 * sin + let o2 = x2.mul(cos) + x1.mul(sin); + + Tensor::cat(vec![o1, o2], 3) + } +} diff --git a/network-poc/node/src/lib.rs b/network-poc/node/src/lib.rs index 03df2db..7df3a50 100644 --- a/network-poc/node/src/lib.rs +++ b/network-poc/node/src/lib.rs @@ -12,6 +12,7 @@ pub mod smollm; pub mod qwen; pub mod qwen_coder; pub mod phi3; +pub mod burn_smollm; #[macro_export] macro_rules! console_log { diff --git a/network-poc/node/src/smollm.rs b/network-poc/node/src/smollm.rs index 2176467..0a622d4 100644 --- a/network-poc/node/src/smollm.rs +++ b/network-poc/node/src/smollm.rs @@ -118,125 +118,106 @@ pub async fn run_smollm_inference(prompt: String, ws: Rc>) { Err(e) => { console_log!("[SmolLM] Malli-virhe: {}", e); return; } }; - console_log!("[SmolLM] Rakennetaan mallia..."); + let use_gpu = crate::HAS_WEBGPU.load(std::sync::atomic::Ordering::SeqCst); + if use_gpu { + console_log!("[SmolLM] Alustetaan Burn WebGPU..."); + burn_wgpu::init_async::(&Default::default(), Default::default()).await; + run_burn_inference::(prompt, model_bytes, tokenizer, ws, perf.clone()).await; + } else { + console_log!("[SmolLM] Käytetään CPU NdArrayta (vanha tapa)..."); + run_burn_inference::(prompt, model_bytes, tokenizer, ws, perf.clone()).await; + } +} + +async fn run_burn_inference( + prompt: String, + model_bytes: Vec, + tokenizer: tokenizers::Tokenizer, + ws: Rc>, + perf: web_sys::Performance, // Korjattu Wasm-performanssi välitettäväksi +) { let start_load = perf.now(); - let device = Device::Cpu; - let dtype = DType::F32; - - // Parsitaan safetensors - let tensors = match candle_core::safetensors::load_buffer(&model_bytes, &device) { - Ok(t) => t, - Err(e) => { console_log!("[SmolLM] Safetensors-parsinta epäonnistui: {}", e); return; } - }; - - let vb = VarBuilder::from_tensors(tensors, dtype, &device); - - // SmolLM-135M config (Llama-arkkitehtuuri) - let config = LlamaConfig { - hidden_size: 576, - intermediate_size: 1536, - vocab_size: 49152, - num_hidden_layers: 30, - num_attention_heads: 9, - num_key_value_heads: Some(3), - rms_norm_eps: 1e-5, - rope_theta: 10000.0, - max_position_embeddings: 2048, - tie_word_embeddings: Some(true), - bos_token_id: Some(1u32), - eos_token_id: Some(LlamaEosToks::Single(2)), - rope_scaling: None, - }; - - let llama_config = config.into_config(false); // false = ei flash attention - let mut cache = Cache::new(true, dtype, &llama_config, &device).unwrap(); - - let model = match Llama::load(vb, &llama_config) { + let device = Default::default(); + let config = crate::burn_smollm::config::SmolLMConfig::default(); + + console_log!("[SmolLM] Injektoidaan Safetensors -> Burn Params..."); + let model = match crate::burn_smollm::loader::load_safetensors_to_model::(&model_bytes, &config, &device) { Ok(m) => m, - Err(e) => { console_log!("[SmolLM] Mallin lataus epäonnistui: {}", e); return; } + Err(e) => { console_log!("[SmolLM] Lataus epäonnistui: {}", e); return; } }; let load_time = perf.now() - start_load; - console_log!("[SmolLM] Malli ladattu ({:.0}ms). Generoidaan...", load_time); + console_log!("[SmolLM] Burn-malli ladattu ({:.0}ms). Generoidaan...", load_time); - // 3. Tokenisoi syöte (Käytetään ChatML-formaattia SmolLM-Instructille) let formatted_prompt = format!("<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", prompt); let encoding = match tokenizer.encode(formatted_prompt.as_str(), true) { Ok(e) => e, Err(e) => { console_log!("[SmolLM] Tokenisointivirhe: {}", e); return; } }; - let input_ids: Vec = encoding.get_ids().to_vec(); + let mut input_ids: Vec = encoding.get_ids().to_vec(); let input_len = input_ids.len(); console_log!("[SmolLM] Syöte: {} tokenia", input_len); - // 4. Generoi tokeneita let start_gen = perf.now(); let max_new_tokens = 32; let mut generated_text = String::new(); let mut tokens_generated: usize = 0; - let mut pos: usize = 0; + + // KV-välimuistin taulukko kerroksittain + let mut caches: Vec>> = vec![None; config.num_hidden_layers]; + let mut current_offset = 0; - // Ensimmäinen forward: koko syöte kerralla - let input = match Tensor::new(input_ids.as_slice(), &device).and_then(|t| t.unsqueeze(0)) { - Ok(t) => t, - Err(e) => { console_log!("[SmolLM] Tensor-virhe: {}", e); return; } - }; + // Prefill: yksitellen, vältetään future token leakage koska ei causal maskia + let input_ids_i32: Vec = input_ids.iter().map(|&x| x as i32).collect(); + let mut last_logits = None; + + for &id in &input_ids_i32 { + let input_tensor = burn::tensor::Tensor::::from_data( + burn::tensor::TensorData::from([id]), + &device + ).unsqueeze::<2>(); // [1, 1] + + last_logits = Some(model.forward(input_tensor, current_offset, &mut caches)); + current_offset += 1; + } - let logits = match model.forward(&input, 0, &mut cache) { - Ok(l) => l, - Err(e) => { console_log!("[SmolLM] Forward-virhe (prefill): {}", e); return; } - }; + let mut logits = last_logits.unwrap(); - // Llama forward voi palauttaa [batch, vocab] tai [batch, seq_len, vocab] - let logits = logits.squeeze(0).unwrap(); - let logits = if logits.dims().len() == 2 { - logits.get(logits.dim(0).unwrap() - 1).unwrap() - } else { - logits - }; - let mut next_token = crate::sampling::sample_top_k(&logits, 10, 5.0); - console_log!("[SmolLM] Ensimmäinen generoitu token: {}", next_token); - pos = input_len; + // Argmax sämpläys + let next_token_tensor = logits.clone().argmax(2); + let mut next_token: u32 = next_token_tensor.into_scalar().to_string().parse().unwrap_or(2); // Yksinkertainen cast koska int scalar if next_token != 2 { if let Ok(text) = tokenizer.decode(&[next_token], true) { generated_text.push_str(&text); - let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "SmolLM-135M" }); + let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "SmolLM-135M (WebGPU)" }); let _ = ws.borrow().send_with_str(&chunk.to_string()); } tokens_generated += 1; } - // Autoregressiivinen generointi: yksi token kerrallaan + // Autoregressiivinen luuppi for _ in 1..max_new_tokens { if next_token == 2 { break; } - - let input = match Tensor::new(&[next_token], &device).and_then(|t| t.unsqueeze(0)) { - Ok(t) => t, - Err(e) => { console_log!("[SmolLM] Tensor-virhe: {}", e); break; } - }; - - let logits = match model.forward(&input, pos, &mut cache) { - Ok(l) => l, - Err(e) => { console_log!("[SmolLM] Forward-virhe pos {}: {}", pos, e); break; } - }; - - let logits = logits.squeeze(0).unwrap(); - let logits = if logits.dims().len() == 2 { - logits.get(logits.dim(0).unwrap() - 1).unwrap() - } else { - logits - }; - next_token = crate::sampling::sample_top_k(&logits, 10, 5.0); - pos += 1; + + let mut input_tensor = burn::tensor::Tensor::::from_data( + burn::tensor::TensorData::from([next_token as i32]), + &device + ).unsqueeze::<2>(); + + logits = model.forward(input_tensor, current_offset, &mut caches); + current_offset += 1; + + let next_token_tensor = logits.argmax(2); + next_token = next_token_tensor.into_scalar().to_string().parse().unwrap_or(2); if next_token == 2 { break; } if let Ok(text) = tokenizer.decode(&[next_token], true) { generated_text.push_str(&text); - let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "SmolLM-135M" }); + let chunk = serde_json::json!({ "type": "llm_chunk", "token": text, "prompt": prompt, "model": "SmolLM-135M (WebGPU)" }); let _ = ws.borrow().send_with_str(&chunk.to_string()); } tokens_generated += 1; @@ -245,12 +226,10 @@ pub async fn run_smollm_inference(prompt: String, ws: Rc>) { let gen_time = perf.now() - start_gen; let tokens_per_sec = if gen_time > 0.0 { (tokens_generated as f64 / gen_time) * 1000.0 } else { 0.0 }; - console_log!("[SmolLM] Generoitu {} tokenia | {:.0}ms | {:.1} tok/s", tokens_generated, gen_time, tokens_per_sec); - let done = serde_json::json!({ "type": "llm_done", "prompt": prompt, - "model": "SmolLM-135M-Instruct", + "model": "SmolLM-135M-Instruct (WebGPU)", "response": generated_text, "tokens_generated": tokens_generated, "duration_ms": (gen_time * 100.0).round() / 100.0,