1use candle::{IndexOp, Result, Tensor, D};
30use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
31
32const IMG_SIZE: usize = 448;
33const PATCH_SIZE: usize = 14;
34const NUM_CLASSES: usize = 1000;
35
36fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
37 if bias {
38 candle_nn::linear(in_dim, out_dim, vb)
39 } else {
40 candle_nn::linear_no_bias(in_dim, out_dim, vb)
41 }
42}
43
44#[derive(Debug)]
45struct Attention {
46 q: Linear,
47 k: Linear,
48 v: Linear,
49 proj: Linear,
50 rot_pos_embed: Tensor,
51 num_heads: usize,
52 scale: f64,
53}
54
55impl Attention {
56 fn new(
57 vb: VarBuilder,
58 dim: usize,
59 num_heads: usize,
60 qkv_bias: bool,
61 proj_bias: bool,
62 rot_pos_embed: &Tensor,
63 ) -> Result<Self> {
64 let q = linear(vb.pp("q_proj"), dim, dim, qkv_bias)?;
65 let k = linear(vb.pp("k_proj"), dim, dim, false)?; let v = linear(vb.pp("v_proj"), dim, dim, qkv_bias)?;
67 let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
68 let rot_pos_embed = rot_pos_embed.clone();
69 let scale = 1. / ((dim / num_heads) as f64).sqrt();
70 Ok(Self {
71 q,
72 k,
73 v,
74 proj,
75 rot_pos_embed,
76 num_heads,
77 scale,
78 })
79 }
80}
81
82impl Attention {
83 fn apply_rot_embed_cat(x: &Tensor, emb: &Tensor) -> Result<Tensor> {
85 let cos_emb = emb.i((0.., 64..128))?; let sin_emb = emb.i((0.., 0..64))?; let index_even: [u32; 32] = (0u32..=63)
88 .step_by(2)
89 .collect::<Vec<_>>()
90 .try_into()
91 .expect("wrong size iterator");
92 let index_odd: [u32; 32] = (1u32..=63)
93 .step_by(2)
94 .collect::<Vec<_>>()
95 .try_into()
96 .expect("wrong size iterator");
97 let t_index_even = Tensor::new(&index_even, x.device())?;
98 let t_index_odd = Tensor::new(&index_odd, x.device())?;
99 let x_c = x.contiguous()?;
100 let rot_x_even = x_c.index_select(&t_index_even, D::Minus1)?;
101 let rot_x_odd_minus = (-1.0 * x_c.index_select(&t_index_odd, D::Minus1)?)?;
102 let rot_x =
103 Tensor::stack(&[&rot_x_odd_minus, &rot_x_even], D::Minus1)?.reshape(x.shape())?;
104 x.broadcast_mul(&cos_emb)? + rot_x.broadcast_mul(&sin_emb)?
105 }
106}
107
108impl Module for Attention {
109 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
110 let (b, n, c) = xs.dims3()?;
111 let qkv = Tensor::cat(
112 &[
113 &self.q.forward(xs)?,
114 &self.k.forward(xs)?,
115 &self.v.forward(xs)?,
116 ],
117 2,
118 )?
119 .reshape((b, n, 3, self.num_heads, c / self.num_heads))?
120 .transpose(1, 2)? .transpose(0, 1)? .transpose(2, 3)?; let q = qkv.i(0)?;
124 let k = qkv.i(1)?.contiguous()?;
125 let v = qkv.i(2)?.contiguous()?;
126
127 let npt = 1; let q = Tensor::cat(
129 &[
130 &q.i((0.., 0.., ..npt, 0..))?,
131 &Self::apply_rot_embed_cat(&q.i((0.., 0.., npt.., 0..))?, &self.rot_pos_embed)?,
132 ],
133 2,
134 )?;
135 let k = Tensor::cat(
136 &[
137 &k.i((0.., 0.., ..npt, 0..))?,
138 &Self::apply_rot_embed_cat(&k.i((0.., 0.., npt.., 0..))?, &self.rot_pos_embed)?,
139 ],
140 2,
141 )?;
142
143 let q = (q * self.scale)?;
144 let attn = &q.matmul(&k.t()?)?;
145 let attn = candle_nn::ops::softmax(attn, D::Minus1)?;
146 let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
147 self.proj.forward(&attn)
148 }
149}
150
151#[derive(Debug)]
152struct Mlp {
153 fc1_g: Linear,
154 fc1_x: Linear,
155 norm: LayerNorm,
156 fc2: Linear,
157}
158
159impl Mlp {
160 fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> {
161 let out_features = in_features;
162 let fc1_g = linear(vb.pp("fc1_g"), in_features, hidden_features, bias)?;
163 let fc1_x = linear(vb.pp("fc1_x"), in_features, hidden_features, bias)?;
164 let norm = layer_norm(hidden_features, 1e-6, vb.pp("norm"))?;
165 let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?;
166 Ok(Self {
167 fc1_g,
168 fc1_x,
169 norm,
170 fc2,
171 })
172 }
173}
174
175impl Module for Mlp {
176 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
177 let xs_g = self.fc1_g.forward(xs)?.silu()?;
178 let xs = self.fc1_x.forward(xs)?;
179 let xs = self.norm.forward(&(xs_g.mul(&xs)?))?;
180 self.fc2.forward(&xs)
181 }
182}
183
184#[derive(Debug)]
185struct Block {
186 norm1: LayerNorm,
187 attn: Attention,
188 norm2: LayerNorm,
189 mlp: Mlp,
190}
191
192impl Block {
193 fn new(vb: VarBuilder, dim: usize, num_heads: usize, rot_pos_embed: &Tensor) -> Result<Self> {
194 let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?;
195 let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true, rot_pos_embed)?;
196 let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?;
197 let hidden_dim = dim * 4 * 2 / 3; let mlp = Mlp::new(vb.pp("mlp"), dim, hidden_dim, true)?;
199 Ok(Self {
200 norm1,
201 attn,
202 norm2,
203 mlp,
204 })
205 }
206}
207
208impl Module for Block {
209 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
210 let residual = xs;
211 let xs = &self.attn.forward(&self.norm1.forward(xs)?)?;
212 let xs = (xs + residual)?;
213 let residual = &xs;
214 let xs = &self.mlp.forward(&self.norm2.forward(&xs)?)?;
215 xs + residual
216 }
217}
218
219#[derive(Debug)]
220struct PatchEmbed {
221 proj: candle_nn::Conv2d,
222 patch_size: (usize, usize),
223 num_patches: usize,
224}
225
226impl PatchEmbed {
227 fn new(
228 vb: VarBuilder,
229 img_size: usize,
230 patch_size: usize,
231 in_chans: usize,
232 embed_dim: usize,
233 ) -> Result<Self> {
234 let config = candle_nn::Conv2dConfig {
235 stride: patch_size,
236 ..Default::default()
237 };
238 let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?;
239 let num_patches = (img_size / patch_size) * (img_size / patch_size);
240 Ok(Self {
241 proj,
242 patch_size: (patch_size, patch_size),
243 num_patches,
244 })
245 }
246}
247
248impl Module for PatchEmbed {
249 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
250 let (_b, _c, h, w) = xs.dims4()?;
251 let (patch_h, patch_w) = self.patch_size;
252 if (h % patch_h) != 0 {
253 candle::bail!("image height {h} is not a multiple of patch height {patch_h}")
254 }
255 if (w % patch_w) != 0 {
256 candle::bail!("image width {w} is not a multiple of patch width {patch_w}")
257 }
258 let xs = self.proj.forward(xs)?;
259 let (b, c, h, w) = xs.dims4()?;
260 xs.reshape((b, c, h * w))?.transpose(1, 2)
262 }
263}
264
265#[derive(Debug)]
266pub struct EVA2VisionTransformer {
267 patch_embed: PatchEmbed,
268 cls_token: Tensor,
269 pos_embed: Tensor,
270 blocks: Vec<Block>,
271 norm: LayerNorm,
272 head: Linear,
273}
274
275impl EVA2VisionTransformer {
276 pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> {
277 let patch_embed =
278 PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?;
279 let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
280 let pos_embed = vb.get((1, patch_embed.num_patches + 1, embed_dim), "pos_embed")?;
281 let rot_pos_embed = vb.get((patch_embed.num_patches, 128), "rot_pos_embed")?;
282 let head = linear(vb.pp("head"), embed_dim, NUM_CLASSES, true)?;
283 let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?;
284 let vb_b = vb.pp("blocks");
285 let blocks = (0..depth)
286 .map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads, &rot_pos_embed))
287 .collect::<Result<Vec<_>>>()?;
288 Ok(Self {
289 patch_embed,
290 cls_token,
291 pos_embed,
292 blocks,
293 norm,
294 head,
295 })
296 }
297
298 fn interpolate_pos_encoding(
299 &self,
300 xs: &Tensor,
301 w: usize,
302 h: usize,
303 num_prefix_tokens: usize,
304 ) -> Result<Tensor> {
305 let npatch = xs.dim(1)? - 1;
306 let n = self.pos_embed.dim(1)? - 1;
307 let sqrt_n = (n as f64).sqrt();
308 if npatch == n && w == h {
309 return Ok(self.pos_embed.clone());
310 }
311 let prefix_tokens_pos_embed = self.pos_embed.i((0.., ..num_prefix_tokens, 0..))?.clone();
313 let patch_pos_embed = &self.pos_embed.i((0.., num_prefix_tokens.., 0..))?;
314 let dim = xs.dim(D::Minus1)?;
315 let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
316 let patch_pos_embed = patch_pos_embed
317 .reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
318 .transpose(2, 3)?
319 .transpose(1, 2)?;
320 let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
322 let el_count = patch_pos_embed.shape().elem_count();
323 let patch_pos_embed =
324 patch_pos_embed
325 .transpose(1, 2)?
326 .transpose(2, 3)?
327 .reshape((1, el_count / dim, dim))?;
328 Tensor::cat(&[&prefix_tokens_pos_embed, &patch_pos_embed], 1)
329 }
330
331 fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
332 let (_b, _nc, w, h) = xs.dims4()?;
333 if (w != IMG_SIZE) || (h != IMG_SIZE) {
334 panic!("Error: The input tensor should have the shape: Bx3x518x518.");
335 }
336 let xs = self.patch_embed.forward(xs)?;
337 let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
338 let xs = (&xs + &self.interpolate_pos_encoding(&xs, w, h, 1)?)?;
339 Ok(xs)
340 }
341
342 fn get_intermediate_layers_not_chunked(
343 &self,
344 xs: &Tensor,
345 blocks_to_take: &[usize],
346 ) -> Result<Vec<Tensor>> {
347 let mut xs = self.prepare_tokens_with_mask(xs)?;
348 let mut output = Vec::new();
349 for (i, blk) in self.blocks.iter().enumerate() {
350 xs = blk.forward(&xs)?;
351 if blocks_to_take.contains(&i) {
352 output.push(xs.clone());
353 }
354 }
355 if output.len() != blocks_to_take.len() {
356 candle::bail!(
357 "only {} / {} blocks found",
358 output.len(),
359 blocks_to_take.len()
360 );
361 }
362 Ok(output)
363 }
364
365 pub fn get_intermediate_layers(
366 &self,
367 xs: &Tensor,
368 blocks_to_take: &[usize],
369 reshape: bool,
370 return_class_token: bool,
371 norm: bool,
372 ) -> Result<Tensor> {
373 let outputs = self.get_intermediate_layers_not_chunked(xs, blocks_to_take)?;
374 let outputs = if norm {
375 outputs
376 .iter()
377 .map(|out| self.norm.forward(out))
378 .collect::<Result<Vec<_>>>()?
379 } else {
380 outputs
381 };
382 let class_tokens = outputs
383 .iter()
384 .map(|out| out.i((.., 0)))
385 .collect::<Result<Vec<_>>>()?;
386 let outputs = outputs
387 .iter()
388 .map(|out| out.i((.., 1..)))
389 .collect::<Result<Vec<_>>>()?;
390
391 let outputs = if reshape {
392 let (b, _c, w, h) = xs.dims4()?;
393 let patch_size = self.patch_embed.patch_size.0;
394 let num_channels = outputs[0].elem_count() / (b * (w / patch_size) * (h / patch_size));
395 outputs
396 .iter()
397 .map(|out| {
398 out.reshape((b, w / patch_size, h / patch_size, num_channels))?
399 .transpose(2, 3)?
400 .transpose(1, 2)
401 })
402 .collect::<Result<Vec<_>>>()?
403 } else {
404 outputs
405 };
406
407 let outputs = if return_class_token {
408 outputs
409 .iter()
410 .zip(class_tokens.iter())
411 .map(|(out, class_token)| Tensor::cat(&[out, class_token], D::Minus1))
412 .collect::<Result<Vec<_>>>()?
413 } else {
414 outputs
415 };
416
417 Tensor::stack(&outputs[..], 0)
418 }
419}
420
421impl Module for EVA2VisionTransformer {
422 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
423 let mut xs = self.prepare_tokens_with_mask(xs)?;
424 for blk in self.blocks.iter() {
425 xs = blk.forward(&xs)?
426 }
427 let xs_moy_local_tokens = xs.i((.., 1..))?.mean(1)?;
428 let xs_norm = self.norm.forward(&xs_moy_local_tokens)?;
429 self.head.forward(&xs_norm)
430 }
431}
432
433pub fn vit_base(vb: VarBuilder) -> Result<EVA2VisionTransformer> {
434 EVA2VisionTransformer::new(vb, 12, 768, 12)
435}
436
437pub fn vit_large(vb: VarBuilder) -> Result<EVA2VisionTransformer> {
438 EVA2VisionTransformer::new(vb, 24, 1024, 16)
439}