1use candle::{DType, Device, Module, Result, Tensor, D};
2use candle_nn::{linear_b, rms_norm, Linear, RmsNorm, VarBuilder};
3
4fn default_act() -> candle_nn::Activation {
5 candle_nn::Activation::Silu
6}
7
8fn default_hidden_size() -> usize {
9 1024
10}
11
12fn default_intermediate_size() -> usize {
13 4096
14}
15
16fn default_num_channels() -> usize {
17 3
18}
19
20fn default_num_hidden_layers() -> usize {
21 24
22}
23
24fn default_num_attention_heads() -> usize {
25 16
26}
27
28#[derive(serde::Deserialize, Debug, Clone)]
29pub struct Config {
30 #[serde(default = "default_hidden_size")]
31 pub hidden_size: usize,
32 #[serde(default = "default_num_channels")]
33 pub num_channels: usize,
34 pub image_size: usize,
35 pub patch_size: usize,
36 pub rope_theta: f64,
37 #[serde(default = "default_intermediate_size")]
38 pub intermediate_size: usize,
39 #[serde(default = "default_num_hidden_layers")]
40 pub num_hidden_layers: usize,
41 pub head_dim: Option<usize>,
42 #[serde(default = "default_num_attention_heads")]
43 pub num_attention_heads: usize,
44 #[serde(default = "default_act")]
45 pub hidden_act: candle_nn::Activation,
46}
47
48impl Config {
49 pub fn pixtral_12b_2409() -> Self {
50 Self {
51 hidden_size: 1024,
52 num_channels: 3,
53 image_size: 1024,
54 patch_size: 16,
55 rope_theta: 10000.0,
56 intermediate_size: 4096,
57 num_hidden_layers: 24,
58 num_attention_heads: 16,
59 head_dim: None,
60 hidden_act: candle_nn::Activation::Silu,
62 }
63 }
64
65 fn head_dim(&self) -> usize {
66 self.head_dim
67 .unwrap_or(self.hidden_size / self.num_attention_heads)
68 }
69}
70
71#[derive(Debug, Clone)]
72struct Attention {
73 q_proj: Linear,
74 k_proj: Linear,
75 v_proj: Linear,
76 o_proj: Linear,
77 scale: f64,
78 num_heads: usize,
79 head_dim: usize,
80}
81
82impl Attention {
83 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
84 let h = cfg.hidden_size;
85 let num_heads = cfg.num_attention_heads;
86 let head_dim = cfg.head_dim();
87 let q_proj = linear_b(h, h, false, vb.pp("q_proj"))?;
88 let k_proj = linear_b(h, h, false, vb.pp("k_proj"))?;
89 let v_proj = linear_b(h, h, false, vb.pp("v_proj"))?;
90 let o_proj = linear_b(h, h, false, vb.pp("o_proj"))?;
91 let scale = (head_dim as f64).powf(-0.5);
92 Ok(Self {
93 q_proj,
94 k_proj,
95 v_proj,
96 o_proj,
97 scale,
98 num_heads,
99 head_dim,
100 })
101 }
102
103 fn forward(
104 &self,
105 xs: &Tensor,
106 emb: &RotaryEmbedding,
107 subsampled_positions: Option<&Tensor>,
108 attention_mask: Option<&Tensor>,
109 ) -> Result<Tensor> {
110 let (b, patches, _) = xs.dims3()?;
111 let query_states = xs.apply(&self.q_proj)?;
112 let key_states = xs.apply(&self.k_proj)?;
113 let value_states = xs.apply(&self.v_proj)?;
114
115 let shape = (b, patches, self.num_heads, self.head_dim);
116 let query_states = query_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
117 let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
118 let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
119
120 let (query_states, key_states) =
121 emb.apply_rotary_emb_qkv(&query_states, &key_states, subsampled_positions)?;
122 let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?;
123
124 let attn_weights = match attention_mask {
125 None => attn_weights,
126 Some(mask) => attn_weights.broadcast_add(mask)?,
127 };
128
129 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
130 attn_weights
131 .matmul(&value_states)?
132 .transpose(1, 2)?
133 .reshape((b, patches, ()))?
134 .apply(&self.o_proj)
135 }
136}
137
138#[derive(Debug, Clone)]
139struct Mlp {
140 gate_proj: Linear,
141 up_proj: Linear,
142 down_proj: Linear,
143 act_fn: candle_nn::Activation,
144}
145
146impl Mlp {
147 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
148 let (h, i) = (cfg.hidden_size, cfg.intermediate_size);
149 let gate_proj = linear_b(h, i, false, vb.pp("gate_proj"))?;
150 let up_proj = linear_b(h, i, false, vb.pp("up_proj"))?;
151 let down_proj = linear_b(i, h, false, vb.pp("down_proj"))?;
152 Ok(Self {
153 gate_proj,
154 up_proj,
155 down_proj,
156 act_fn: cfg.hidden_act,
157 })
158 }
159}
160
161impl Module for Mlp {
162 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
163 (xs.apply(&self.gate_proj)?.apply(&self.act_fn)? * xs.apply(&self.up_proj))?
164 .apply(&self.down_proj)
165 }
166}
167
168#[derive(Debug, Clone)]
169struct AttentionLayer {
170 attention_norm: RmsNorm,
171 feed_forward: Mlp,
172 attention: Attention,
173 ffn_norm: RmsNorm,
174}
175
176impl AttentionLayer {
177 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
178 let attention_norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("attention_norm"))?;
179 let feed_forward = Mlp::new(cfg, vb.pp("feed_forward"))?;
180 let attention = Attention::new(cfg, vb.pp("attention"))?;
181 let ffn_norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("ffn_norm"))?;
182 Ok(Self {
183 attention_norm,
184 feed_forward,
185 attention,
186 ffn_norm,
187 })
188 }
189
190 fn forward(
191 &self,
192 xs: &Tensor,
193 emb: &RotaryEmbedding,
194 subsampled_positions: Option<&Tensor>,
195 attention_mask: Option<&Tensor>,
196 ) -> Result<Tensor> {
197 let residual = xs;
198 let xs = self.attention.forward(
199 &xs.apply(&self.attention_norm)?,
200 emb,
201 subsampled_positions,
202 attention_mask,
203 )?;
204 let xs = (residual + xs)?;
205 let residual = &xs;
206 let xs = xs.apply(&self.ffn_norm)?.apply(&self.feed_forward)?;
207 xs + residual
208 }
209}
210
211#[derive(Debug, Clone)]
212struct Transformer {
213 layers: Vec<AttentionLayer>,
214}
215
216impl Transformer {
217 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
218 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
219 let vb = vb.pp("layers");
220 for layer_idx in 0..cfg.num_hidden_layers {
221 let layer = AttentionLayer::new(cfg, vb.pp(layer_idx))?;
222 layers.push(layer)
223 }
224 Ok(Self { layers })
225 }
226
227 fn forward(
228 &self,
229 xs: &Tensor,
230 emb: &RotaryEmbedding,
231 subsampled_positions: Option<&Tensor>,
232 attention_mask: Option<&Tensor>,
233 ) -> Result<Tensor> {
234 let mut xs = xs.clone();
235 for layer in self.layers.iter() {
236 xs = layer.forward(&xs, emb, subsampled_positions, attention_mask)?
237 }
238 Ok(xs)
239 }
240}
241
242#[derive(Debug, Clone)]
243struct RotaryEmbedding {
244 cos: Tensor,
245 sin: Tensor,
246}
247
248impl RotaryEmbedding {
249 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
250 let dtype = vb.dtype();
251 let dev = vb.device();
252 let dim = cfg.head_dim();
253 let rope_theta = cfg.rope_theta as f32;
254 let max_patches_per_side = cfg.image_size / cfg.patch_size;
255 let freqs: Vec<_> = (0..dim)
256 .step_by(2)
257 .map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
258 .collect();
259 let freqs_h = freqs.iter().step_by(2).copied().collect::<Vec<_>>();
260 let freqs_h = Tensor::new(freqs_h, dev)?;
261 let freqs_w = freqs.iter().skip(1).step_by(2).copied().collect::<Vec<_>>();
262 let freqs_w = Tensor::new(freqs_w, dev)?;
263 let h = Tensor::arange(0u32, max_patches_per_side as u32, dev)?.to_dtype(DType::F32)?;
264 let w = Tensor::arange(0u32, max_patches_per_side as u32, dev)?.to_dtype(DType::F32)?;
265 let freqs_h = h.unsqueeze(1)?.matmul(&freqs_h.unsqueeze(0)?)?;
266 let freqs_w = w.unsqueeze(1)?.matmul(&freqs_w.unsqueeze(0)?)?;
267 let inv_freq = Tensor::cat(
268 &[
269 freqs_h.unsqueeze(1)?.repeat((1, max_patches_per_side, 1))?,
270 freqs_w.unsqueeze(0)?.repeat((max_patches_per_side, 1, 1))?,
271 ],
272 D::Minus1,
273 )?
274 .reshape(((), dim / 2))?;
275 let cos = inv_freq.cos()?.to_dtype(dtype)?;
276 let sin = inv_freq.sin()?.to_dtype(dtype)?;
277 Ok(Self { cos, sin })
278 }
279
280 fn apply_rotary_emb_qkv(
281 &self,
282 q: &Tensor,
283 k: &Tensor,
284 subsampled_positions: Option<&Tensor>,
285 ) -> Result<(Tensor, Tensor)> {
286 let (_b_sz, _h, _seq_len, _n_embd) = q.dims4()?;
287 let (cos, sin) = match subsampled_positions {
288 None => (&self.cos, &self.sin),
289 Some(pos) => (
290 &self.cos.index_select(pos, 0)?,
291 &self.sin.index_select(pos, 0)?,
292 ),
293 };
294 let q_embed = candle_nn::rotary_emb::rope(q, cos, sin)?;
295 let k_embed = candle_nn::rotary_emb::rope(k, cos, sin)?;
296 Ok((q_embed, k_embed))
297 }
298}
299
300#[derive(Debug, Clone)]
301pub struct Model {
302 patch_conv: candle_nn::Conv2d,
303 ln_pre: RmsNorm,
304 transformer: Transformer,
305 patch_positional_embedding: RotaryEmbedding,
306 max_image_width: u32,
307}
308
309impl Model {
310 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
311 let conv2d_cfg = candle_nn::Conv2dConfig {
312 stride: cfg.patch_size,
313 ..Default::default()
314 };
315 let patch_conv = candle_nn::conv2d_no_bias(
316 cfg.num_channels,
317 cfg.hidden_size,
318 cfg.patch_size,
319 conv2d_cfg,
320 vb.pp("patch_conv"),
321 )?;
322 let ln_pre = candle_nn::rms_norm(cfg.hidden_size, 1e-5, vb.pp("ln_pre"))?;
323 let transformer = Transformer::new(cfg, vb.pp("transformer"))?;
324 let patch_positional_embedding =
325 RotaryEmbedding::new(cfg, vb.pp("patch_positional_embedding"))?;
326 let max_image_width = (cfg.image_size / cfg.patch_size) as u32;
327 Ok(Self {
328 patch_conv,
329 ln_pre,
330 transformer,
331 patch_positional_embedding,
332 max_image_width,
333 })
334 }
335
336 pub fn position_ids_in_meshgrid(
337 &self,
338 num_patches_h: usize,
339 num_patches_w: usize,
340 device: &Device,
341 ) -> Result<Tensor> {
342 let idx = Tensor::arange(0, num_patches_h as u32, device)?;
343 let idy = Tensor::arange(0, num_patches_w as u32, device)?;
344 let mesh = Tensor::meshgrid(&[idx, idy], false)?;
345 let ids = (&mesh[0] * (self.max_image_width as f64) + &mesh[1])?.flatten_all()?;
346 Ok(ids)
347 }
348}
349
350impl Module for Model {
351 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
352 let patch_embeds = xs.apply(&self.patch_conv)?;
353 let subsampled_positions = Some(self.position_ids_in_meshgrid(
354 patch_embeds.dim(2)?,
355 patch_embeds.dim(3)?,
356 patch_embeds.device(),
357 )?);
358 let patch_embeds = patch_embeds.flatten_from(2)?.t()?.apply(&self.ln_pre)?;
359 self.transformer.forward(
360 &patch_embeds,
361 &self.patch_positional_embedding,
362 subsampled_positions.as_ref(),
363 None,
364 )
365 }
366}