use burn::tensor::{backend::Backend, Tensor, TensorData}; use candle_core::safetensors; use candle_core::Device as CandleDevice; use burn::module::Param; use super::model::LlamaModel; use super::config::SmolLMConfig; fn load_tensor_2d( tensors_map: &std::collections::HashMap, name: &str, device: &B::Device, shape_out_in: [usize; 2] ) -> Result>, String> { let t = tensors_map.get(name).ok_or_else(|| format!("Puuttuu: {}", name))?; let t = t.to_dtype(candle_core::DType::F32).unwrap(); let vec = t.flatten_all().unwrap().to_vec1::().unwrap(); let t_burn = Tensor::::from_data(burn::tensor::TensorData::new(vec, shape_out_in), device); // transpose from [out, in] to [in, out] Ok(Param::from_tensor(t_burn.transpose())) } fn load_tensor_1d( tensors_map: &std::collections::HashMap, name: &str, device: &B::Device, _shape: [usize; 1] ) -> Result>, String> { let t = tensors_map.get(name).ok_or_else(|| format!("Puuttuu: {}", name))?; let t = t.to_dtype(candle_core::DType::F32).unwrap(); let vec = t.flatten_all().unwrap().to_vec1::().unwrap(); Ok(Param::from_tensor(Tensor::::from_floats(vec.as_slice(), device))) } fn load_embed( tensors_map: &std::collections::HashMap, name: &str, device: &B::Device, shape: [usize; 2] ) -> Result>, String> { let t = tensors_map.get(name).ok_or_else(|| format!("Puuttuu: {}", name))?; let t = t.to_dtype(candle_core::DType::F32).unwrap(); let vec = t.flatten_all().unwrap().to_vec1::().unwrap(); // Embed ei transponoi samalla tavalla, se pysyy [vocab, hidden] Ok(Param::from_tensor(Tensor::::from_data(burn::tensor::TensorData::new(vec, shape), device))) } pub fn load_safetensors_to_model( buffer: &[u8], config: &SmolLMConfig, device: &B::Device ) -> Result, String> { let mut model = LlamaModel::new(config, device); let tensors_map = safetensors::load_buffer(buffer, &CandleDevice::Cpu) .map_err(|e| format!("Virhe Safetensors luennassa: {}", e))?; // Embeddings model.embed_tokens = load_embed(&tensors_map, "model.embed_tokens.weight", device, [config.vocab_size, config.hidden_size])?; model.norm.weight = load_tensor_1d(&tensors_map, "model.norm.weight", device, [config.hidden_size])?; model.lm_head = load_embed(&tensors_map, "lm_head.weight", device, [config.vocab_size, config.hidden_size]).or_else(|_| { load_embed(&tensors_map, "model.embed_tokens.weight", device, [config.vocab_size, config.hidden_size]) })?; let head_dim = config.hidden_size / config.num_attention_heads; for i in 0..config.num_hidden_layers { let prefix = format!("model.layers.{}", i); let layer = &mut model.layers[i]; // Norms layer.input_layernorm.weight = load_tensor_1d(&tensors_map, &format!("{}.input_layernorm.weight", prefix), device, [config.hidden_size])?; layer.post_attention_layernorm.weight = load_tensor_1d(&tensors_map, &format!("{}.post_attention_layernorm.weight", prefix), device, [config.hidden_size])?; // Attention let num_heads = config.num_attention_heads; let num_kv_heads = config.num_key_value_heads; layer.self_attn.q_proj = load_tensor_2d(&tensors_map, &format!("{}.self_attn.q_proj.weight", prefix), device, [num_heads * head_dim, config.hidden_size])?; layer.self_attn.k_proj = load_tensor_2d(&tensors_map, &format!("{}.self_attn.k_proj.weight", prefix), device, [num_kv_heads * head_dim, config.hidden_size])?; layer.self_attn.v_proj = load_tensor_2d(&tensors_map, &format!("{}.self_attn.v_proj.weight", prefix), device, [num_kv_heads * head_dim, config.hidden_size])?; layer.self_attn.o_proj = load_tensor_2d(&tensors_map, &format!("{}.self_attn.o_proj.weight", prefix), device, [config.hidden_size, num_heads * head_dim])?; // MLP layer.mlp.gate_proj = load_tensor_2d(&tensors_map, &format!("{}.mlp.gate_proj.weight", prefix), device, [config.intermediate_size, config.hidden_size])?; layer.mlp.up_proj = load_tensor_2d(&tensors_map, &format!("{}.mlp.up_proj.weight", prefix), device, [config.intermediate_size, config.hidden_size])?; layer.mlp.down_proj = load_tensor_2d(&tensors_map, &format!("{}.mlp.down_proj.weight", prefix), device, [config.hidden_size, config.intermediate_size])?; } Ok(model) }