91 lines
4.6 KiB
Rust
91 lines
4.6 KiB
Rust
use burn::tensor::{backend::Backend, Tensor, TensorData};
|
|
use candle_core::safetensors;
|
|
use candle_core::Device as CandleDevice;
|
|
use burn::module::Param;
|
|
use super::model::LlamaModel;
|
|
use super::config::SmolLMConfig;
|
|
|
|
fn load_tensor_2d<B: Backend>(
|
|
tensors_map: &std::collections::HashMap<String, candle_core::Tensor>,
|
|
name: &str,
|
|
device: &B::Device,
|
|
shape_out_in: [usize; 2]
|
|
) -> Result<Param<Tensor<B, 2>>, String> {
|
|
let t = tensors_map.get(name).ok_or_else(|| format!("Puuttuu: {}", name))?;
|
|
let t = t.to_dtype(candle_core::DType::F32).unwrap();
|
|
let vec = t.flatten_all().unwrap().to_vec1::<f32>().unwrap();
|
|
let t_burn = Tensor::<B, 2>::from_data(burn::tensor::TensorData::new(vec, shape_out_in), device);
|
|
// transpose from [out, in] to [in, out]
|
|
Ok(Param::from_tensor(t_burn.transpose()))
|
|
}
|
|
|
|
fn load_tensor_1d<B: Backend>(
|
|
tensors_map: &std::collections::HashMap<String, candle_core::Tensor>,
|
|
name: &str,
|
|
device: &B::Device,
|
|
_shape: [usize; 1]
|
|
) -> Result<Param<Tensor<B, 1>>, String> {
|
|
let t = tensors_map.get(name).ok_or_else(|| format!("Puuttuu: {}", name))?;
|
|
let t = t.to_dtype(candle_core::DType::F32).unwrap();
|
|
let vec = t.flatten_all().unwrap().to_vec1::<f32>().unwrap();
|
|
Ok(Param::from_tensor(Tensor::<B, 1>::from_floats(vec.as_slice(), device)))
|
|
}
|
|
|
|
fn load_embed<B: Backend>(
|
|
tensors_map: &std::collections::HashMap<String, candle_core::Tensor>,
|
|
name: &str,
|
|
device: &B::Device,
|
|
shape: [usize; 2]
|
|
) -> Result<Param<Tensor<B, 2>>, String> {
|
|
let t = tensors_map.get(name).ok_or_else(|| format!("Puuttuu: {}", name))?;
|
|
let t = t.to_dtype(candle_core::DType::F32).unwrap();
|
|
let vec = t.flatten_all().unwrap().to_vec1::<f32>().unwrap();
|
|
// Embed ei transponoi samalla tavalla, se pysyy [vocab, hidden]
|
|
Ok(Param::from_tensor(Tensor::<B, 2>::from_data(burn::tensor::TensorData::new(vec, shape), device)))
|
|
}
|
|
|
|
pub fn load_safetensors_to_model<B: Backend>(
|
|
buffer: &[u8],
|
|
config: &SmolLMConfig,
|
|
device: &B::Device
|
|
) -> Result<LlamaModel<B>, String> {
|
|
|
|
let mut model = LlamaModel::new(config, device);
|
|
let tensors_map = safetensors::load_buffer(buffer, &CandleDevice::Cpu)
|
|
.map_err(|e| format!("Virhe Safetensors luennassa: {}", e))?;
|
|
|
|
// Embeddings
|
|
model.embed_tokens = load_embed(&tensors_map, "model.embed_tokens.weight", device, [config.vocab_size, config.hidden_size])?;
|
|
model.norm.weight = load_tensor_1d(&tensors_map, "model.norm.weight", device, [config.hidden_size])?;
|
|
model.lm_head = load_embed(&tensors_map, "lm_head.weight", device, [config.vocab_size, config.hidden_size]).or_else(|_| {
|
|
load_embed(&tensors_map, "model.embed_tokens.weight", device, [config.vocab_size, config.hidden_size])
|
|
})?;
|
|
|
|
let head_dim = config.hidden_size / config.num_attention_heads;
|
|
|
|
for i in 0..config.num_hidden_layers {
|
|
let prefix = format!("model.layers.{}", i);
|
|
|
|
let layer = &mut model.layers[i];
|
|
|
|
// Norms
|
|
layer.input_layernorm.weight = load_tensor_1d(&tensors_map, &format!("{}.input_layernorm.weight", prefix), device, [config.hidden_size])?;
|
|
layer.post_attention_layernorm.weight = load_tensor_1d(&tensors_map, &format!("{}.post_attention_layernorm.weight", prefix), device, [config.hidden_size])?;
|
|
|
|
// Attention
|
|
let num_heads = config.num_attention_heads;
|
|
let num_kv_heads = config.num_key_value_heads;
|
|
layer.self_attn.q_proj = load_tensor_2d(&tensors_map, &format!("{}.self_attn.q_proj.weight", prefix), device, [num_heads * head_dim, config.hidden_size])?;
|
|
layer.self_attn.k_proj = load_tensor_2d(&tensors_map, &format!("{}.self_attn.k_proj.weight", prefix), device, [num_kv_heads * head_dim, config.hidden_size])?;
|
|
layer.self_attn.v_proj = load_tensor_2d(&tensors_map, &format!("{}.self_attn.v_proj.weight", prefix), device, [num_kv_heads * head_dim, config.hidden_size])?;
|
|
layer.self_attn.o_proj = load_tensor_2d(&tensors_map, &format!("{}.self_attn.o_proj.weight", prefix), device, [config.hidden_size, num_heads * head_dim])?;
|
|
|
|
// MLP
|
|
layer.mlp.gate_proj = load_tensor_2d(&tensors_map, &format!("{}.mlp.gate_proj.weight", prefix), device, [config.intermediate_size, config.hidden_size])?;
|
|
layer.mlp.up_proj = load_tensor_2d(&tensors_map, &format!("{}.mlp.up_proj.weight", prefix), device, [config.intermediate_size, config.hidden_size])?;
|
|
layer.mlp.down_proj = load_tensor_2d(&tensors_map, &format!("{}.mlp.down_proj.weight", prefix), device, [config.hidden_size, config.intermediate_size])?;
|
|
}
|
|
|
|
Ok(model)
|
|
}
|