candle_transformers/models/
llama2_c_weights.rs1use byteorder::{LittleEndian, ReadBytesExt};
8use candle::{DType, Device, IndexOp, Result, Shape, Tensor};
9use candle_nn::VarBuilder;
10
11use super::llama2_c::Config;
12
13pub struct TransformerWeights {
14 token_embedding_table: Tensor, rms_att_weight: Tensor, rms_ffn_weight: Tensor, wq: Tensor, wk: Tensor, wv: Tensor, wo: Tensor, w1: Tensor, w2: Tensor, w3: Tensor, rms_final_weight: Tensor, freq_cis_real: Tensor, freq_cis_imag: Tensor, }
34
35fn read_i32<R: std::io::Read>(r: &mut R) -> Result<i32> {
36 let mut buf = [0u8; 4];
37 r.read_exact(&mut buf)?;
38 Ok(i32::from_le_bytes(buf))
39}
40
41fn read_tensor<R: std::io::Read, S: Into<Shape>>(
42 r: &mut R,
43 shape: S,
44 dev: &Device,
45) -> Result<Tensor> {
46 let shape = shape.into();
47 let mut data_t = vec![0f32; shape.elem_count()];
48 r.read_f32_into::<LittleEndian>(&mut data_t)?;
49 let tensor = Tensor::from_vec(data_t, shape, dev)?;
50 Ok(tensor)
51}
52
53impl Config {
54 pub fn from_reader<R: std::io::Read>(r: &mut R) -> Result<Self> {
55 let dim = read_i32(r)? as usize;
56 let hidden_dim = read_i32(r)? as usize;
57 let n_layers = read_i32(r)? as usize;
58 let n_heads = read_i32(r)? as usize;
59 let n_kv_heads = read_i32(r)? as usize;
60 let vocab_size = read_i32(r)? as usize;
61 let seq_len = read_i32(r)? as usize;
62 Ok(Self {
63 dim,
64 hidden_dim,
65 n_layers,
66 n_heads,
67 n_kv_heads,
68 vocab_size,
69 seq_len,
70 norm_eps: 1e-5,
71 })
72 }
73
74 pub fn head_size(&self) -> usize {
75 self.dim / self.n_heads
76 }
77}
78
79impl TransformerWeights {
80 pub fn from_reader<R: std::io::Read>(r: &mut R, c: &Config, dev: &Device) -> Result<Self> {
81 let token_embedding_table = read_tensor(r, (c.vocab_size, c.dim), dev)?;
82 let rms_att_weight = read_tensor(r, (c.n_layers, c.dim), dev)?;
83 let wq = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
84 let wk = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
85 let wv = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
86 let wo = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
87 let rms_ffn_weight = read_tensor(r, (c.n_layers, c.dim), dev)?;
88 let w1 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?;
89 let w2 = read_tensor(r, (c.n_layers, c.dim, c.hidden_dim), dev)?;
90 let w3 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?;
91 let rms_final_weight = read_tensor(r, c.dim, dev)?;
92 let head_size = c.head_size();
93 let freq_cis_real = read_tensor(r, (c.seq_len, head_size / 2), dev)?;
94 let freq_cis_imag = read_tensor(r, (c.seq_len, head_size / 2), dev)?;
95 Ok(Self {
96 token_embedding_table,
97 rms_att_weight,
98 wq,
99 wk,
100 wv,
101 wo,
102 rms_ffn_weight,
103 w1,
104 w2,
105 w3,
106 rms_final_weight,
107 freq_cis_real,
108 freq_cis_imag,
109 })
110 }
111
112 pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder<'static>> {
113 let tr = device.is_cpu() && !candle::utils::has_mkl();
119 let tr = |x: Tensor| if tr { x.t()?.contiguous()?.t() } else { Ok(x) };
120 let mut ws = std::collections::HashMap::new();
121 let mut insert = |name: &str, t: Tensor| {
122 ws.insert(name.to_string(), t);
123 };
124 insert("rot.freq_cis_real", self.freq_cis_real.clone());
125 insert("rot.freq_cis_imag", self.freq_cis_imag.clone());
126 insert(
127 "model.embed_tokens.weight",
128 self.token_embedding_table.clone(),
129 );
130 insert("lm_head.weight", tr(self.token_embedding_table.clone())?);
131 insert("model.norm.weight", self.rms_final_weight.clone());
132 for layer in 0..cfg.n_layers {
133 ws.insert(
134 format!("model.layers.{layer}.self_attn.q_proj.weight"),
135 tr(self.wq.i(layer)?)?,
136 );
137 ws.insert(
138 format!("model.layers.{layer}.self_attn.k_proj.weight"),
139 tr(self.wk.i(layer)?)?,
140 );
141 ws.insert(
142 format!("model.layers.{layer}.self_attn.v_proj.weight"),
143 tr(self.wv.i(layer)?)?,
144 );
145 ws.insert(
146 format!("model.layers.{layer}.self_attn.o_proj.weight"),
147 tr(self.wo.i(layer)?)?,
148 );
149 ws.insert(
150 format!("model.layers.{layer}.mlp.gate_proj.weight"),
151 tr(self.w1.i(layer)?)?,
152 );
153 ws.insert(
154 format!("model.layers.{layer}.mlp.down_proj.weight"),
155 tr(self.w2.i(layer)?)?,
156 );
157 ws.insert(
158 format!("model.layers.{layer}.mlp.up_proj.weight"),
159 tr(self.w3.i(layer)?)?,
160 );
161 ws.insert(
162 format!("model.layers.{layer}.input_layernorm.weight"),
163 self.rms_att_weight.i(layer)?,
164 );
165 ws.insert(
166 format!("model.layers.{layer}.post_attention_layernorm.weight"),
167 self.rms_ffn_weight.i(layer)?,
168 );
169 }
170 let vb = VarBuilder::from_tensors(ws, DType::F32, device);
171 Ok(vb)
172 }
173}