candle_transformers/models/
phi3.rs

1//! Microsoft Phi-3 model implementation
2//!
3//! See Phi model details at:
4//! - [Phi-3 Model](https://huggingface.co/microsoft/phi-3)
5//!
6//! The Phi series are decoder-only transformers designed for code and language tasks.
7//! Key characteristics:
8//! - Decoder-only transformer architecture
9//! - RoPE embeddings
10//! - Layer normalization
11//! - QK normalization
12//! - Mixed activation functions
13//! - Improved context window handling
14//!
15//! References:
16//! - [Hugging Face Implementation](https://huggingface.co/microsoft/phi-3)
17//! - [Alternative Implementation](https://huggingface.co/microsoft/phi-3/tree/main)
18//!
19
20// This implementation is based on:
21// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
22use crate::models::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
23use candle::{DType, Device, Module, Result, Tensor, D};
24use candle_nn::VarBuilder;
25use std::sync::Arc;
26
27// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json
28#[derive(Debug, Clone, serde::Deserialize)]
29pub struct Config {
30    pub vocab_size: usize,
31    pub hidden_act: candle_nn::Activation,
32    pub hidden_size: usize,
33    pub intermediate_size: usize,
34    pub num_hidden_layers: usize,
35    pub num_attention_heads: usize,
36    pub num_key_value_heads: usize,
37    pub rms_norm_eps: f64,
38    pub rope_theta: f64,
39    pub bos_token_id: Option<u32>,
40    pub eos_token_id: Option<u32>,
41    pub rope_scaling: Option<String>,
42    pub max_position_embeddings: usize,
43}
44
45impl Config {
46    pub fn head_dim(&self) -> usize {
47        self.hidden_size / self.num_attention_heads
48    }
49}
50
51#[derive(Debug, Clone)]
52pub struct RotaryEmbedding {
53    sin: Tensor,
54    cos: Tensor,
55}
56
57impl RotaryEmbedding {
58    pub fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
59        let dim = cfg.head_dim();
60        let max_seq_len = cfg.max_position_embeddings;
61        let inv_freq: Vec<_> = (0..dim)
62            .step_by(2)
63            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
64            .collect();
65        let inv_freq_len = inv_freq.len();
66        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
67        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
68            .to_dtype(dtype)?
69            .reshape((max_seq_len, 1))?;
70        let freqs = t.matmul(&inv_freq)?;
71        Ok(Self {
72            sin: freqs.sin()?,
73            cos: freqs.cos()?,
74        })
75    }
76
77    pub fn apply_rotary_emb_qkv(
78        &self,
79        q: &Tensor,
80        k: &Tensor,
81        seqlen_offset: usize,
82    ) -> Result<(Tensor, Tensor)> {
83        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
84        let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
85        let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
86        let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
87        let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
88        Ok((q_embed, k_embed))
89    }
90}
91
92#[derive(Debug, Clone)]
93struct Attention {
94    qkv_proj: Linear,
95    o_proj: Linear,
96    num_heads: usize,
97    num_kv_heads: usize,
98    num_kv_groups: usize,
99    head_dim: usize,
100    rotary_emb: Arc<RotaryEmbedding>,
101    kv_cache: Option<(Tensor, Tensor)>,
102}
103
104impl Attention {
105    fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
106        let num_heads = cfg.num_attention_heads;
107        let num_kv_heads = cfg.num_key_value_heads;
108        let head_dim = cfg.head_dim();
109        let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim;
110        let qkv_proj = linear(cfg.hidden_size, op_size, vb.pp("qkv_proj"))?;
111        let o_proj = linear(num_heads * head_dim, cfg.hidden_size, vb.pp("o_proj"))?;
112        Ok(Self {
113            qkv_proj,
114            o_proj,
115            rotary_emb,
116            kv_cache: None,
117            num_heads,
118            num_kv_heads,
119            num_kv_groups: num_heads / num_kv_heads,
120            head_dim,
121        })
122    }
123
124    fn forward(
125        &mut self,
126        xs: &Tensor,
127        attention_mask: Option<&Tensor>,
128        seqlen_offset: usize,
129    ) -> Result<Tensor> {
130        let (b_sz, q_len, _) = xs.dims3()?;
131
132        let qkv = self.qkv_proj.forward(xs)?;
133        let query_pos = self.num_heads * self.head_dim;
134        let query_states = qkv.narrow(D::Minus1, 0, query_pos)?;
135        let key_states = qkv.narrow(D::Minus1, query_pos, self.num_kv_heads * self.head_dim)?;
136        let value_states = qkv.narrow(
137            D::Minus1,
138            query_pos + self.num_kv_heads * self.head_dim,
139            self.num_kv_heads * self.head_dim,
140        )?;
141
142        let query_states = query_states
143            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
144            .transpose(1, 2)?;
145        let key_states = key_states
146            .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
147            .transpose(1, 2)?;
148        let value_states = value_states
149            .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
150            .transpose(1, 2)?;
151
152        let (query_states, key_states) =
153            self.rotary_emb
154                .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
155
156        let (key_states, value_states) = match &self.kv_cache {
157            None => (key_states, value_states),
158            Some((prev_k, prev_v)) => {
159                let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
160                let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
161                (key_states, value_states)
162            }
163        };
164        self.kv_cache = Some((key_states.clone(), value_states.clone()));
165
166        let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
167        let value_states =
168            crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
169
170        let attn_output = {
171            let scale = 1f64 / f64::sqrt(self.head_dim as f64);
172            let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
173
174            let attn_weights = match attention_mask {
175                None => attn_weights,
176                Some(mask) => attn_weights.broadcast_add(mask)?,
177            };
178            let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
179            attn_weights.matmul(&value_states)?
180        };
181        attn_output
182            .transpose(1, 2)?
183            .reshape((b_sz, q_len, ()))?
184            .apply(&self.o_proj)
185    }
186
187    fn clear_kv_cache(&mut self) {
188        self.kv_cache = None
189    }
190}
191
192#[derive(Debug, Clone)]
193struct Mlp {
194    gate_up_proj: Linear,
195    down_proj: Linear,
196    act_fn: candle_nn::Activation,
197    i_size: usize,
198}
199
200impl Mlp {
201    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
202        let hidden_size = cfg.hidden_size;
203        let i_size = cfg.intermediate_size;
204        let gate_up_proj = linear(hidden_size, 2 * i_size, vb.pp("gate_up_proj"))?;
205        let down_proj = linear(i_size, hidden_size, vb.pp("down_proj"))?;
206        Ok(Self {
207            gate_up_proj,
208            down_proj,
209            act_fn: cfg.hidden_act,
210            i_size,
211        })
212    }
213}
214
215impl Module for Mlp {
216    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
217        let up_states = xs.apply(&self.gate_up_proj)?;
218        let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
219        let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
220        let up_states = (up_states * gate.apply(&self.act_fn))?;
221        up_states.apply(&self.down_proj)
222    }
223}
224
225#[derive(Debug, Clone)]
226struct DecoderLayer {
227    self_attn: Attention,
228    mlp: Mlp,
229    input_layernorm: RmsNorm,
230    post_attention_layernorm: RmsNorm,
231}
232
233impl DecoderLayer {
234    fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
235        let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
236        let mlp = Mlp::new(cfg, vb.pp("mlp"))?;
237        let input_layernorm =
238            RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
239        let post_attention_layernorm = RmsNorm::new(
240            cfg.hidden_size,
241            cfg.rms_norm_eps,
242            vb.pp("post_attention_layernorm"),
243        )?;
244        Ok(Self {
245            self_attn,
246            mlp,
247            input_layernorm,
248            post_attention_layernorm,
249        })
250    }
251
252    fn forward(
253        &mut self,
254        xs: &Tensor,
255        attention_mask: Option<&Tensor>,
256        seqlen_offset: usize,
257    ) -> Result<Tensor> {
258        let residual = xs;
259        let xs = self.input_layernorm.forward(xs)?;
260        let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
261        let xs = (xs + residual)?;
262        let residual = &xs;
263        let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
264        residual + xs
265    }
266
267    fn clear_kv_cache(&mut self) {
268        self.self_attn.clear_kv_cache()
269    }
270}
271
272#[derive(Debug, Clone)]
273pub struct Model {
274    embed_tokens: candle_nn::Embedding,
275    layers: Vec<DecoderLayer>,
276    norm: RmsNorm,
277    lm_head: Linear,
278    device: Device,
279    dtype: DType,
280}
281
282impl Model {
283    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
284        let vb_m = vb.pp("model");
285        let embed_tokens =
286            candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
287        let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
288        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
289        let vb_l = vb_m.pp("layers");
290        for layer_idx in 0..cfg.num_hidden_layers {
291            let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
292            layers.push(layer)
293        }
294        let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
295        let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
296        Ok(Self {
297            embed_tokens,
298            layers,
299            norm,
300            lm_head,
301            device: vb.device().clone(),
302            dtype: vb.dtype(),
303        })
304    }
305
306    fn prepare_decoder_attention_mask(
307        &self,
308        b_size: usize,
309        tgt_len: usize,
310        seqlen_offset: usize,
311    ) -> Result<Tensor> {
312        let mask: Vec<_> = (0..tgt_len)
313            .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
314            .collect();
315        let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
316        let mask = if seqlen_offset > 0 {
317            let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
318            Tensor::cat(&[&mask0, &mask], D::Minus1)?
319        } else {
320            mask
321        };
322        mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
323            .to_dtype(self.dtype)
324    }
325
326    pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
327        let (b_size, seq_len) = input_ids.dims2()?;
328        let attention_mask = if seq_len <= 1 {
329            None
330        } else {
331            let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
332            Some(mask)
333        };
334        let mut xs = self.embed_tokens.forward(input_ids)?;
335        for layer in self.layers.iter_mut() {
336            xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
337        }
338        xs.narrow(1, seq_len - 1, 1)?
339            .apply(&self.norm)?
340            .apply(&self.lm_head)
341    }
342
343    pub fn clear_kv_cache(&mut self) {
344        for layer in self.layers.iter_mut() {
345            layer.clear_kv_cache()
346        }
347    }
348}