Files
agentic-studio/network-poc/node/src/burn_smollm/loader.rs
2026-04-02 16:19:57 +03:00

91 lines
4.6 KiB
Rust

use burn::tensor::{backend::Backend, Tensor, TensorData};
use candle_core::safetensors;
use candle_core::Device as CandleDevice;
use burn::module::Param;
use super::model::LlamaModel;
use super::config::SmolLMConfig;
fn load_tensor_2d<B: Backend>(
tensors_map: &std::collections::HashMap<String, candle_core::Tensor>,
name: &str,
device: &B::Device,
shape_out_in: [usize; 2]
) -> Result<Param<Tensor<B, 2>>, String> {
let t = tensors_map.get(name).ok_or_else(|| format!("Puuttuu: {}", name))?;
let t = t.to_dtype(candle_core::DType::F32).unwrap();
let vec = t.flatten_all().unwrap().to_vec1::<f32>().unwrap();
let t_burn = Tensor::<B, 2>::from_data(burn::tensor::TensorData::new(vec, shape_out_in), device);
// transpose from [out, in] to [in, out]
Ok(Param::from_tensor(t_burn.transpose()))
}
fn load_tensor_1d<B: Backend>(
tensors_map: &std::collections::HashMap<String, candle_core::Tensor>,
name: &str,
device: &B::Device,
_shape: [usize; 1]
) -> Result<Param<Tensor<B, 1>>, String> {
let t = tensors_map.get(name).ok_or_else(|| format!("Puuttuu: {}", name))?;
let t = t.to_dtype(candle_core::DType::F32).unwrap();
let vec = t.flatten_all().unwrap().to_vec1::<f32>().unwrap();
Ok(Param::from_tensor(Tensor::<B, 1>::from_floats(vec.as_slice(), device)))
}
fn load_embed<B: Backend>(
tensors_map: &std::collections::HashMap<String, candle_core::Tensor>,
name: &str,
device: &B::Device,
shape: [usize; 2]
) -> Result<Param<Tensor<B, 2>>, String> {
let t = tensors_map.get(name).ok_or_else(|| format!("Puuttuu: {}", name))?;
let t = t.to_dtype(candle_core::DType::F32).unwrap();
let vec = t.flatten_all().unwrap().to_vec1::<f32>().unwrap();
// Embed ei transponoi samalla tavalla, se pysyy [vocab, hidden]
Ok(Param::from_tensor(Tensor::<B, 2>::from_data(burn::tensor::TensorData::new(vec, shape), device)))
}
pub fn load_safetensors_to_model<B: Backend>(
buffer: &[u8],
config: &SmolLMConfig,
device: &B::Device
) -> Result<LlamaModel<B>, String> {
let mut model = LlamaModel::new(config, device);
let tensors_map = safetensors::load_buffer(buffer, &CandleDevice::Cpu)
.map_err(|e| format!("Virhe Safetensors luennassa: {}", e))?;
// Embeddings
model.embed_tokens = load_embed(&tensors_map, "model.embed_tokens.weight", device, [config.vocab_size, config.hidden_size])?;
model.norm.weight = load_tensor_1d(&tensors_map, "model.norm.weight", device, [config.hidden_size])?;
model.lm_head = load_embed(&tensors_map, "lm_head.weight", device, [config.vocab_size, config.hidden_size]).or_else(|_| {
load_embed(&tensors_map, "model.embed_tokens.weight", device, [config.vocab_size, config.hidden_size])
})?;
let head_dim = config.hidden_size / config.num_attention_heads;
for i in 0..config.num_hidden_layers {
let prefix = format!("model.layers.{}", i);
let layer = &mut model.layers[i];
// Norms
layer.input_layernorm.weight = load_tensor_1d(&tensors_map, &format!("{}.input_layernorm.weight", prefix), device, [config.hidden_size])?;
layer.post_attention_layernorm.weight = load_tensor_1d(&tensors_map, &format!("{}.post_attention_layernorm.weight", prefix), device, [config.hidden_size])?;
// Attention
let num_heads = config.num_attention_heads;
let num_kv_heads = config.num_key_value_heads;
layer.self_attn.q_proj = load_tensor_2d(&tensors_map, &format!("{}.self_attn.q_proj.weight", prefix), device, [num_heads * head_dim, config.hidden_size])?;
layer.self_attn.k_proj = load_tensor_2d(&tensors_map, &format!("{}.self_attn.k_proj.weight", prefix), device, [num_kv_heads * head_dim, config.hidden_size])?;
layer.self_attn.v_proj = load_tensor_2d(&tensors_map, &format!("{}.self_attn.v_proj.weight", prefix), device, [num_kv_heads * head_dim, config.hidden_size])?;
layer.self_attn.o_proj = load_tensor_2d(&tensors_map, &format!("{}.self_attn.o_proj.weight", prefix), device, [config.hidden_size, num_heads * head_dim])?;
// MLP
layer.mlp.gate_proj = load_tensor_2d(&tensors_map, &format!("{}.mlp.gate_proj.weight", prefix), device, [config.intermediate_size, config.hidden_size])?;
layer.mlp.up_proj = load_tensor_2d(&tensors_map, &format!("{}.mlp.up_proj.weight", prefix), device, [config.intermediate_size, config.hidden_size])?;
layer.mlp.down_proj = load_tensor_2d(&tensors_map, &format!("{}.mlp.down_proj.weight", prefix), device, [config.hidden_size, config.intermediate_size])?;
}
Ok(model)
}