candle_transformers/models/
llama2_c_weights.rs

1//! Llama2 inference implementation.
2//!
3//! See ["LLaMA 2: Open Foundation and Fine-Tuned Chat Models"](https://arxiv.org/abs/2307.09288)
4//!
5//! Based on the [llama2.c](https://github.com/karpathy/llama2.c) implementation
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle::{DType, Device, IndexOp, Result, Shape, Tensor};
9use candle_nn::VarBuilder;
10
11use super::llama2_c::Config;
12
13pub struct TransformerWeights {
14    // token embedding table
15    token_embedding_table: Tensor, // (vocab_size, dim)
16    // weights for rmsnorms
17    rms_att_weight: Tensor, // (layer, dim) rmsnorm weights
18    rms_ffn_weight: Tensor, // (layer, dim)
19    // weights for matmuls
20    wq: Tensor, // (layer, dim, dim)
21    wk: Tensor, // (layer, dim, dim)
22    wv: Tensor, // (layer, dim, dim)
23    wo: Tensor, // (layer, dim, dim)
24    // weights for ffn
25    w1: Tensor, // (layer, hidden_dim, dim)
26    w2: Tensor, // (layer, dim, hidden_dim)
27    w3: Tensor, // (layer, hidden_dim, dim)
28    // final rmsnorm
29    rms_final_weight: Tensor, // (dim,)
30    // freq_cis for RoPE relatively positional embeddings
31    freq_cis_real: Tensor, // (seq_len, head_size/2)
32    freq_cis_imag: Tensor, // (seq_len, head_size/2)
33}
34
35fn read_i32<R: std::io::Read>(r: &mut R) -> Result<i32> {
36    let mut buf = [0u8; 4];
37    r.read_exact(&mut buf)?;
38    Ok(i32::from_le_bytes(buf))
39}
40
41fn read_tensor<R: std::io::Read, S: Into<Shape>>(
42    r: &mut R,
43    shape: S,
44    dev: &Device,
45) -> Result<Tensor> {
46    let shape = shape.into();
47    let mut data_t = vec![0f32; shape.elem_count()];
48    r.read_f32_into::<LittleEndian>(&mut data_t)?;
49    let tensor = Tensor::from_vec(data_t, shape, dev)?;
50    Ok(tensor)
51}
52
53impl Config {
54    pub fn from_reader<R: std::io::Read>(r: &mut R) -> Result<Self> {
55        let dim = read_i32(r)? as usize;
56        let hidden_dim = read_i32(r)? as usize;
57        let n_layers = read_i32(r)? as usize;
58        let n_heads = read_i32(r)? as usize;
59        let n_kv_heads = read_i32(r)? as usize;
60        let vocab_size = read_i32(r)? as usize;
61        let seq_len = read_i32(r)? as usize;
62        Ok(Self {
63            dim,
64            hidden_dim,
65            n_layers,
66            n_heads,
67            n_kv_heads,
68            vocab_size,
69            seq_len,
70            norm_eps: 1e-5,
71        })
72    }
73
74    pub fn head_size(&self) -> usize {
75        self.dim / self.n_heads
76    }
77}
78
79impl TransformerWeights {
80    pub fn from_reader<R: std::io::Read>(r: &mut R, c: &Config, dev: &Device) -> Result<Self> {
81        let token_embedding_table = read_tensor(r, (c.vocab_size, c.dim), dev)?;
82        let rms_att_weight = read_tensor(r, (c.n_layers, c.dim), dev)?;
83        let wq = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
84        let wk = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
85        let wv = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
86        let wo = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
87        let rms_ffn_weight = read_tensor(r, (c.n_layers, c.dim), dev)?;
88        let w1 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?;
89        let w2 = read_tensor(r, (c.n_layers, c.dim, c.hidden_dim), dev)?;
90        let w3 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?;
91        let rms_final_weight = read_tensor(r, c.dim, dev)?;
92        let head_size = c.head_size();
93        let freq_cis_real = read_tensor(r, (c.seq_len, head_size / 2), dev)?;
94        let freq_cis_imag = read_tensor(r, (c.seq_len, head_size / 2), dev)?;
95        Ok(Self {
96            token_embedding_table,
97            rms_att_weight,
98            wq,
99            wk,
100            wv,
101            wo,
102            rms_ffn_weight,
103            w1,
104            w2,
105            w3,
106            rms_final_weight,
107            freq_cis_real,
108            freq_cis_imag,
109        })
110    }
111
112    pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder<'static>> {
113        // TODO: As of 2023-08-04, gemm is slower than expected when multiplying a matrix of
114        // size (1, k) with the transpose of a matrix of size (k, n) as it ends up transposing the
115        // second matrix back. We detect this case here and as a temporary hack make the weight
116        // matrix column major rather than row major. This ends up speeding up text generation from
117        // 120 token/s to 220 token/s on a Ryzen 2600X.
118        let tr = device.is_cpu() && !candle::utils::has_mkl();
119        let tr = |x: Tensor| if tr { x.t()?.contiguous()?.t() } else { Ok(x) };
120        let mut ws = std::collections::HashMap::new();
121        let mut insert = |name: &str, t: Tensor| {
122            ws.insert(name.to_string(), t);
123        };
124        insert("rot.freq_cis_real", self.freq_cis_real.clone());
125        insert("rot.freq_cis_imag", self.freq_cis_imag.clone());
126        insert(
127            "model.embed_tokens.weight",
128            self.token_embedding_table.clone(),
129        );
130        insert("lm_head.weight", tr(self.token_embedding_table.clone())?);
131        insert("model.norm.weight", self.rms_final_weight.clone());
132        for layer in 0..cfg.n_layers {
133            ws.insert(
134                format!("model.layers.{layer}.self_attn.q_proj.weight"),
135                tr(self.wq.i(layer)?)?,
136            );
137            ws.insert(
138                format!("model.layers.{layer}.self_attn.k_proj.weight"),
139                tr(self.wk.i(layer)?)?,
140            );
141            ws.insert(
142                format!("model.layers.{layer}.self_attn.v_proj.weight"),
143                tr(self.wv.i(layer)?)?,
144            );
145            ws.insert(
146                format!("model.layers.{layer}.self_attn.o_proj.weight"),
147                tr(self.wo.i(layer)?)?,
148            );
149            ws.insert(
150                format!("model.layers.{layer}.mlp.gate_proj.weight"),
151                tr(self.w1.i(layer)?)?,
152            );
153            ws.insert(
154                format!("model.layers.{layer}.mlp.down_proj.weight"),
155                tr(self.w2.i(layer)?)?,
156            );
157            ws.insert(
158                format!("model.layers.{layer}.mlp.up_proj.weight"),
159                tr(self.w3.i(layer)?)?,
160            );
161            ws.insert(
162                format!("model.layers.{layer}.input_layernorm.weight"),
163                self.rms_att_weight.i(layer)?,
164            );
165            ws.insert(
166                format!("model.layers.{layer}.post_attention_layernorm.weight"),
167                self.rms_ffn_weight.i(layer)?,
168            );
169        }
170        let vb = VarBuilder::from_tensors(ws, DType::F32, device);
171        Ok(vb)
172    }
173}