use burn::module::Module; use burn::tensor::{backend::Backend, Tensor}; #[derive(Module, Debug)] pub struct RoPE { cos_cache: Tensor, sin_cache: Tensor, } impl RoPE { pub fn new(head_dim: usize, max_seq_len: usize, theta: f32, device: &B::Device) -> Self { // (head_dim / 2) values let half_dim = head_dim / 2; let inv_freq: Vec = (0..half_dim) .map(|i| 1.0 / theta.powf((2 * i) as f32 / head_dim as f32)) .collect(); let inv_freq = Tensor::::from_floats(inv_freq.as_slice(), device).unsqueeze::<2>(); let t_floats: Vec = (0..max_seq_len).map(|v| v as f32).collect(); let t = Tensor::::from_floats(t_floats.as_slice(), device).unsqueeze::<2>().transpose(); // t shape: [max_seq_len, 1] // inv_freq shape: [1, half_dim] // freqs shape: [max_seq_len, half_dim] let freqs = t.matmul(inv_freq); let cos_cache = freqs.clone().cos(); let sin_cache = freqs.sin(); Self { cos_cache, sin_cache, } } pub fn forward(&self, x: Tensor, offset: usize) -> Tensor { let [batch, heads, seq_len, head_dim] = x.dims(); let half_dim = head_dim / 2; // x shape: [batch, heads, seq_len, head_dim] // valitaan viipaleet (x1 ja x2) jotta saadaan pyƶritettyƤ rotaatiot let x1 = x.clone().slice([0..batch, 0..heads, 0..seq_len, 0..half_dim]); let x2 = x.clone().slice([0..batch, 0..heads, 0..seq_len, half_dim..head_dim]); // haetaan vastaava seq offsetista alkaen let cos = self.cos_cache.clone().slice([offset..offset+seq_len, 0..half_dim]) .unsqueeze::<4>() // [seq, half_dim, 1] .reshape([1, 1, seq_len, half_dim]); let sin = self.sin_cache.clone().slice([offset..offset+seq_len, 0..half_dim]) .reshape([1, 1, seq_len, half_dim]); // x1 * cos - x2 * sin let o1 = x1.clone().mul(cos.clone()) - x2.clone().mul(sin.clone()); // x2 * cos + x1 * sin let o2 = x2.mul(cos) + x1.mul(sin); Tensor::cat(vec![o1, o2], 3) } }