candle_transformers/models/pixtral/
vision_model.rs

1use candle::{DType, Device, Module, Result, Tensor, D};
2use candle_nn::{linear_b, rms_norm, Linear, RmsNorm, VarBuilder};
3
4fn default_act() -> candle_nn::Activation {
5    candle_nn::Activation::Silu
6}
7
8fn default_hidden_size() -> usize {
9    1024
10}
11
12fn default_intermediate_size() -> usize {
13    4096
14}
15
16fn default_num_channels() -> usize {
17    3
18}
19
20fn default_num_hidden_layers() -> usize {
21    24
22}
23
24fn default_num_attention_heads() -> usize {
25    16
26}
27
28#[derive(serde::Deserialize, Debug, Clone)]
29pub struct Config {
30    #[serde(default = "default_hidden_size")]
31    pub hidden_size: usize,
32    #[serde(default = "default_num_channels")]
33    pub num_channels: usize,
34    pub image_size: usize,
35    pub patch_size: usize,
36    pub rope_theta: f64,
37    #[serde(default = "default_intermediate_size")]
38    pub intermediate_size: usize,
39    #[serde(default = "default_num_hidden_layers")]
40    pub num_hidden_layers: usize,
41    pub head_dim: Option<usize>,
42    #[serde(default = "default_num_attention_heads")]
43    pub num_attention_heads: usize,
44    #[serde(default = "default_act")]
45    pub hidden_act: candle_nn::Activation,
46}
47
48impl Config {
49    pub fn pixtral_12b_2409() -> Self {
50        Self {
51            hidden_size: 1024,
52            num_channels: 3,
53            image_size: 1024,
54            patch_size: 16,
55            rope_theta: 10000.0,
56            intermediate_size: 4096,
57            num_hidden_layers: 24,
58            num_attention_heads: 16,
59            head_dim: None,
60            // Default
61            hidden_act: candle_nn::Activation::Silu,
62        }
63    }
64
65    fn head_dim(&self) -> usize {
66        self.head_dim
67            .unwrap_or(self.hidden_size / self.num_attention_heads)
68    }
69}
70
71#[derive(Debug, Clone)]
72struct Attention {
73    q_proj: Linear,
74    k_proj: Linear,
75    v_proj: Linear,
76    o_proj: Linear,
77    scale: f64,
78    num_heads: usize,
79    head_dim: usize,
80}
81
82impl Attention {
83    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
84        let h = cfg.hidden_size;
85        let num_heads = cfg.num_attention_heads;
86        let head_dim = cfg.head_dim();
87        let q_proj = linear_b(h, h, false, vb.pp("q_proj"))?;
88        let k_proj = linear_b(h, h, false, vb.pp("k_proj"))?;
89        let v_proj = linear_b(h, h, false, vb.pp("v_proj"))?;
90        let o_proj = linear_b(h, h, false, vb.pp("o_proj"))?;
91        let scale = (head_dim as f64).powf(-0.5);
92        Ok(Self {
93            q_proj,
94            k_proj,
95            v_proj,
96            o_proj,
97            scale,
98            num_heads,
99            head_dim,
100        })
101    }
102
103    fn forward(
104        &self,
105        xs: &Tensor,
106        emb: &RotaryEmbedding,
107        subsampled_positions: Option<&Tensor>,
108        attention_mask: Option<&Tensor>,
109    ) -> Result<Tensor> {
110        let (b, patches, _) = xs.dims3()?;
111        let query_states = xs.apply(&self.q_proj)?;
112        let key_states = xs.apply(&self.k_proj)?;
113        let value_states = xs.apply(&self.v_proj)?;
114
115        let shape = (b, patches, self.num_heads, self.head_dim);
116        let query_states = query_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
117        let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
118        let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
119
120        let (query_states, key_states) =
121            emb.apply_rotary_emb_qkv(&query_states, &key_states, subsampled_positions)?;
122        let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?;
123
124        let attn_weights = match attention_mask {
125            None => attn_weights,
126            Some(mask) => attn_weights.broadcast_add(mask)?,
127        };
128
129        let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
130        attn_weights
131            .matmul(&value_states)?
132            .transpose(1, 2)?
133            .reshape((b, patches, ()))?
134            .apply(&self.o_proj)
135    }
136}
137
138#[derive(Debug, Clone)]
139struct Mlp {
140    gate_proj: Linear,
141    up_proj: Linear,
142    down_proj: Linear,
143    act_fn: candle_nn::Activation,
144}
145
146impl Mlp {
147    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
148        let (h, i) = (cfg.hidden_size, cfg.intermediate_size);
149        let gate_proj = linear_b(h, i, false, vb.pp("gate_proj"))?;
150        let up_proj = linear_b(h, i, false, vb.pp("up_proj"))?;
151        let down_proj = linear_b(i, h, false, vb.pp("down_proj"))?;
152        Ok(Self {
153            gate_proj,
154            up_proj,
155            down_proj,
156            act_fn: cfg.hidden_act,
157        })
158    }
159}
160
161impl Module for Mlp {
162    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
163        (xs.apply(&self.gate_proj)?.apply(&self.act_fn)? * xs.apply(&self.up_proj))?
164            .apply(&self.down_proj)
165    }
166}
167
168#[derive(Debug, Clone)]
169struct AttentionLayer {
170    attention_norm: RmsNorm,
171    feed_forward: Mlp,
172    attention: Attention,
173    ffn_norm: RmsNorm,
174}
175
176impl AttentionLayer {
177    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
178        let attention_norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("attention_norm"))?;
179        let feed_forward = Mlp::new(cfg, vb.pp("feed_forward"))?;
180        let attention = Attention::new(cfg, vb.pp("attention"))?;
181        let ffn_norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("ffn_norm"))?;
182        Ok(Self {
183            attention_norm,
184            feed_forward,
185            attention,
186            ffn_norm,
187        })
188    }
189
190    fn forward(
191        &self,
192        xs: &Tensor,
193        emb: &RotaryEmbedding,
194        subsampled_positions: Option<&Tensor>,
195        attention_mask: Option<&Tensor>,
196    ) -> Result<Tensor> {
197        let residual = xs;
198        let xs = self.attention.forward(
199            &xs.apply(&self.attention_norm)?,
200            emb,
201            subsampled_positions,
202            attention_mask,
203        )?;
204        let xs = (residual + xs)?;
205        let residual = &xs;
206        let xs = xs.apply(&self.ffn_norm)?.apply(&self.feed_forward)?;
207        xs + residual
208    }
209}
210
211#[derive(Debug, Clone)]
212struct Transformer {
213    layers: Vec<AttentionLayer>,
214}
215
216impl Transformer {
217    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
218        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
219        let vb = vb.pp("layers");
220        for layer_idx in 0..cfg.num_hidden_layers {
221            let layer = AttentionLayer::new(cfg, vb.pp(layer_idx))?;
222            layers.push(layer)
223        }
224        Ok(Self { layers })
225    }
226
227    fn forward(
228        &self,
229        xs: &Tensor,
230        emb: &RotaryEmbedding,
231        subsampled_positions: Option<&Tensor>,
232        attention_mask: Option<&Tensor>,
233    ) -> Result<Tensor> {
234        let mut xs = xs.clone();
235        for layer in self.layers.iter() {
236            xs = layer.forward(&xs, emb, subsampled_positions, attention_mask)?
237        }
238        Ok(xs)
239    }
240}
241
242#[derive(Debug, Clone)]
243struct RotaryEmbedding {
244    cos: Tensor,
245    sin: Tensor,
246}
247
248impl RotaryEmbedding {
249    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
250        let dtype = vb.dtype();
251        let dev = vb.device();
252        let dim = cfg.head_dim();
253        let rope_theta = cfg.rope_theta as f32;
254        let max_patches_per_side = cfg.image_size / cfg.patch_size;
255        let freqs: Vec<_> = (0..dim)
256            .step_by(2)
257            .map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
258            .collect();
259        let freqs_h = freqs.iter().step_by(2).copied().collect::<Vec<_>>();
260        let freqs_h = Tensor::new(freqs_h, dev)?;
261        let freqs_w = freqs.iter().skip(1).step_by(2).copied().collect::<Vec<_>>();
262        let freqs_w = Tensor::new(freqs_w, dev)?;
263        let h = Tensor::arange(0u32, max_patches_per_side as u32, dev)?.to_dtype(DType::F32)?;
264        let w = Tensor::arange(0u32, max_patches_per_side as u32, dev)?.to_dtype(DType::F32)?;
265        let freqs_h = h.unsqueeze(1)?.matmul(&freqs_h.unsqueeze(0)?)?;
266        let freqs_w = w.unsqueeze(1)?.matmul(&freqs_w.unsqueeze(0)?)?;
267        let inv_freq = Tensor::cat(
268            &[
269                freqs_h.unsqueeze(1)?.repeat((1, max_patches_per_side, 1))?,
270                freqs_w.unsqueeze(0)?.repeat((max_patches_per_side, 1, 1))?,
271            ],
272            D::Minus1,
273        )?
274        .reshape(((), dim / 2))?;
275        let cos = inv_freq.cos()?.to_dtype(dtype)?;
276        let sin = inv_freq.sin()?.to_dtype(dtype)?;
277        Ok(Self { cos, sin })
278    }
279
280    fn apply_rotary_emb_qkv(
281        &self,
282        q: &Tensor,
283        k: &Tensor,
284        subsampled_positions: Option<&Tensor>,
285    ) -> Result<(Tensor, Tensor)> {
286        let (_b_sz, _h, _seq_len, _n_embd) = q.dims4()?;
287        let (cos, sin) = match subsampled_positions {
288            None => (&self.cos, &self.sin),
289            Some(pos) => (
290                &self.cos.index_select(pos, 0)?,
291                &self.sin.index_select(pos, 0)?,
292            ),
293        };
294        let q_embed = candle_nn::rotary_emb::rope(q, cos, sin)?;
295        let k_embed = candle_nn::rotary_emb::rope(k, cos, sin)?;
296        Ok((q_embed, k_embed))
297    }
298}
299
300#[derive(Debug, Clone)]
301pub struct Model {
302    patch_conv: candle_nn::Conv2d,
303    ln_pre: RmsNorm,
304    transformer: Transformer,
305    patch_positional_embedding: RotaryEmbedding,
306    max_image_width: u32,
307}
308
309impl Model {
310    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
311        let conv2d_cfg = candle_nn::Conv2dConfig {
312            stride: cfg.patch_size,
313            ..Default::default()
314        };
315        let patch_conv = candle_nn::conv2d_no_bias(
316            cfg.num_channels,
317            cfg.hidden_size,
318            cfg.patch_size,
319            conv2d_cfg,
320            vb.pp("patch_conv"),
321        )?;
322        let ln_pre = candle_nn::rms_norm(cfg.hidden_size, 1e-5, vb.pp("ln_pre"))?;
323        let transformer = Transformer::new(cfg, vb.pp("transformer"))?;
324        let patch_positional_embedding =
325            RotaryEmbedding::new(cfg, vb.pp("patch_positional_embedding"))?;
326        let max_image_width = (cfg.image_size / cfg.patch_size) as u32;
327        Ok(Self {
328            patch_conv,
329            ln_pre,
330            transformer,
331            patch_positional_embedding,
332            max_image_width,
333        })
334    }
335
336    pub fn position_ids_in_meshgrid(
337        &self,
338        num_patches_h: usize,
339        num_patches_w: usize,
340        device: &Device,
341    ) -> Result<Tensor> {
342        let idx = Tensor::arange(0, num_patches_h as u32, device)?;
343        let idy = Tensor::arange(0, num_patches_w as u32, device)?;
344        let mesh = Tensor::meshgrid(&[idx, idy], false)?;
345        let ids = (&mesh[0] * (self.max_image_width as f64) + &mesh[1])?.flatten_all()?;
346        Ok(ids)
347    }
348}
349
350impl Module for Model {
351    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
352        let patch_embeds = xs.apply(&self.patch_conv)?;
353        let subsampled_positions = Some(self.position_ids_in_meshgrid(
354            patch_embeds.dim(2)?,
355            patch_embeds.dim(3)?,
356            patch_embeds.device(),
357        )?);
358        let patch_embeds = patch_embeds.flatten_from(2)?.t()?.apply(&self.ln_pre)?;
359        self.transformer.forward(
360            &patch_embeds,
361            &self.patch_positional_embedding,
362            subsampled_positions.as_ref(),
363            None,
364        )
365    }
366}