60 lines
2.2 KiB
Rust
60 lines
2.2 KiB
Rust
use burn::module::Module;
|
|
use burn::tensor::{backend::Backend, Tensor};
|
|
|
|
#[derive(Module, Debug)]
|
|
pub struct RoPE<B: Backend> {
|
|
cos_cache: Tensor<B, 2>,
|
|
sin_cache: Tensor<B, 2>,
|
|
}
|
|
|
|
impl<B: Backend> RoPE<B> {
|
|
pub fn new(head_dim: usize, max_seq_len: usize, theta: f32, device: &B::Device) -> Self {
|
|
// (head_dim / 2) values
|
|
let half_dim = head_dim / 2;
|
|
let inv_freq: Vec<f32> = (0..half_dim)
|
|
.map(|i| 1.0 / theta.powf((2 * i) as f32 / head_dim as f32))
|
|
.collect();
|
|
|
|
let inv_freq = Tensor::<B, 1>::from_floats(inv_freq.as_slice(), device).unsqueeze::<2>();
|
|
let t_floats: Vec<f32> = (0..max_seq_len).map(|v| v as f32).collect();
|
|
let t = Tensor::<B, 1>::from_floats(t_floats.as_slice(), device).unsqueeze::<2>().transpose();
|
|
// t shape: [max_seq_len, 1]
|
|
// inv_freq shape: [1, half_dim]
|
|
|
|
// freqs shape: [max_seq_len, half_dim]
|
|
let freqs = t.matmul(inv_freq);
|
|
|
|
let cos_cache = freqs.clone().cos();
|
|
let sin_cache = freqs.sin();
|
|
|
|
Self {
|
|
cos_cache,
|
|
sin_cache,
|
|
}
|
|
}
|
|
|
|
pub fn forward(&self, x: Tensor<B, 4>, offset: usize) -> Tensor<B, 4> {
|
|
let [batch, heads, seq_len, head_dim] = x.dims();
|
|
let half_dim = head_dim / 2;
|
|
|
|
// x shape: [batch, heads, seq_len, head_dim]
|
|
// valitaan viipaleet (x1 ja x2) jotta saadaan pyöritettyä rotaatiot
|
|
let x1 = x.clone().slice([0..batch, 0..heads, 0..seq_len, 0..half_dim]);
|
|
let x2 = x.clone().slice([0..batch, 0..heads, 0..seq_len, half_dim..head_dim]);
|
|
|
|
// haetaan vastaava seq offsetista alkaen
|
|
let cos = self.cos_cache.clone().slice([offset..offset+seq_len, 0..half_dim])
|
|
.unsqueeze::<4>() // [seq, half_dim, 1]
|
|
.reshape([1, 1, seq_len, half_dim]);
|
|
let sin = self.sin_cache.clone().slice([offset..offset+seq_len, 0..half_dim])
|
|
.reshape([1, 1, seq_len, half_dim]);
|
|
|
|
// x1 * cos - x2 * sin
|
|
let o1 = x1.clone().mul(cos.clone()) - x2.clone().mul(sin.clone());
|
|
// x2 * cos + x1 * sin
|
|
let o2 = x2.mul(cos) + x1.mul(sin);
|
|
|
|
Tensor::cat(vec![o1, o2], 3)
|
|
}
|
|
}
|