1use candle::{IndexOp, Result, Tensor, D};
41use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
42
43const IMG_SIZE: usize = 518;
44const PATCH_SIZE: usize = 14;
45const NUM_CLASSES: usize = 1000;
46
47fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
48 if bias {
49 candle_nn::linear(in_dim, out_dim, vb)
50 } else {
51 candle_nn::linear_no_bias(in_dim, out_dim, vb)
52 }
53}
54
55#[derive(Debug)]
56struct Attention {
57 qkv: Linear,
58 proj: Linear,
59 num_heads: usize,
60 scale: f64,
61}
62
63impl Attention {
64 fn new(
65 vb: VarBuilder,
66 dim: usize,
67 num_heads: usize,
68 qkv_bias: bool,
69 proj_bias: bool,
70 ) -> Result<Self> {
71 let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
72 let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
73 let scale = 1. / ((dim / num_heads) as f64).sqrt();
74 Ok(Self {
75 qkv,
76 proj,
77 num_heads,
78 scale,
79 })
80 }
81}
82
83impl Module for Attention {
84 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
85 let (b, n, c) = xs.dims3()?;
86 let qkv = self
87 .qkv
88 .forward(xs)?
89 .reshape((b, n, 3, self.num_heads, c / self.num_heads))?
90 .transpose(1, 2)? .transpose(0, 1)? .transpose(2, 3)?; let q = (qkv.i(0)? * self.scale)?;
94 let k = qkv.i(1)?.contiguous()?;
95 let v = qkv.i(2)?.contiguous()?;
96 let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?;
97 let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
98 self.proj.forward(&attn)
99 }
100}
101
102#[derive(Debug)]
103struct LayerScale {
104 gamma: Tensor,
105}
106
107impl LayerScale {
108 fn new(vb: VarBuilder, dim: usize) -> Result<Self> {
109 let gamma = vb.get(dim, "gamma")?;
110 Ok(Self { gamma })
111 }
112}
113
114impl Module for LayerScale {
115 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
116 xs.broadcast_mul(&self.gamma)
117 }
118}
119
120#[derive(Debug)]
121struct Mlp {
122 fc1: Linear,
123 fc2: Linear,
124}
125
126impl Mlp {
127 fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> {
128 let out_features = in_features;
129 let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?;
130 let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?;
131 Ok(Self { fc1, fc2 })
132 }
133}
134
135impl Module for Mlp {
136 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
137 let xs = self.fc1.forward(xs)?.gelu()?;
138 self.fc2.forward(&xs)
139 }
140}
141
142#[derive(Debug)]
143struct Block {
144 norm1: LayerNorm,
145 attn: Attention,
146 ls1: LayerScale,
147 norm2: LayerNorm,
148 mlp: Mlp,
149 ls2: LayerScale,
150}
151
152impl Block {
153 fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
154 let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
155 let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?;
156 let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
157 let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
158 let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
159 let ls2 = LayerScale::new(vb.pp("ls2"), dim)?;
160 Ok(Self {
161 norm1,
162 attn,
163 ls1,
164 norm2,
165 mlp,
166 ls2,
167 })
168 }
169}
170
171impl Module for Block {
172 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
173 let residual = xs;
174 let xs = self
175 .ls1
176 .forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?;
177 let xs = (xs + residual)?;
178 let residual = &xs;
179 let xs = self
180 .ls2
181 .forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?;
182 xs + residual
183 }
184}
185
186#[derive(Debug)]
187struct PatchEmbed {
188 proj: candle_nn::Conv2d,
189 patch_size: (usize, usize),
190 num_patches: usize,
191}
192
193impl PatchEmbed {
194 fn new(
195 vb: VarBuilder,
196 img_size: usize,
197 patch_size: usize,
198 in_chans: usize,
199 embed_dim: usize,
200 ) -> Result<Self> {
201 let config = candle_nn::Conv2dConfig {
202 stride: patch_size,
203 ..Default::default()
204 };
205 let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?;
206 let num_patches = (img_size / patch_size) * (img_size / patch_size);
207 Ok(Self {
208 proj,
209 patch_size: (patch_size, patch_size),
210 num_patches,
211 })
212 }
213}
214
215impl Module for PatchEmbed {
216 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
217 let (_b, _c, h, w) = xs.dims4()?;
218 let (patch_h, patch_w) = self.patch_size;
219 if (h % patch_h) != 0 {
220 candle::bail!("image height {h} is not a multiple of patch height {patch_h}")
221 }
222 if (w % patch_w) != 0 {
223 candle::bail!("image width {w} is not a multiple of patch width {patch_w}")
224 }
225 let xs = self.proj.forward(xs)?;
226 let (b, c, h, w) = xs.dims4()?;
227 xs.reshape((b, c, h * w))?.transpose(1, 2)
229 }
230}
231
232#[derive(Debug)]
233pub struct DinoVisionTransformer {
234 patch_embed: PatchEmbed,
235 cls_token: Tensor,
236 pos_embed: Tensor,
237 blocks: Vec<Block>,
238 norm: LayerNorm,
239 head: Linear,
240}
241
242impl DinoVisionTransformer {
243 pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> {
244 let patch_embed =
245 PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?;
246 let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
247 let num_tokens = 1;
248 let pos_embed = vb.get(
249 (1, patch_embed.num_patches + num_tokens, embed_dim),
250 "pos_embed",
251 )?;
252 let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?;
253 let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?;
254 let vb_b = vb.pp("blocks");
255 let blocks = (0..depth)
256 .map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads))
257 .collect::<Result<Vec<_>>>()?;
258 Ok(Self {
259 patch_embed,
260 cls_token,
261 pos_embed,
262 blocks,
263 norm,
264 head,
265 })
266 }
267
268 fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
269 let npatch = xs.dim(1)? - 1;
270 let n = self.pos_embed.dim(1)? - 1;
271 let sqrt_n = (n as f64).sqrt();
272 if npatch == n && w == h {
273 return Ok(xs.clone());
274 }
275 let class_pos_embed = self.pos_embed.i((.., ..1))?;
276 let patch_pos_embed = self.pos_embed.i((.., 1..))?;
277 let dim = xs.dim(D::Minus1)?;
278 let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
279 let patch_pos_embed = patch_pos_embed
280 .reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
281 .transpose(2, 3)?
282 .transpose(1, 2)?;
283 let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
285 let el_count = patch_pos_embed.shape().elem_count();
286 let patch_pos_embed =
287 patch_pos_embed
288 .transpose(1, 2)?
289 .transpose(2, 3)?
290 .reshape((1, el_count / dim, dim))?;
291 Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1)
292 }
293
294 fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
295 let (_b, _nc, w, h) = xs.dims4()?;
296 let xs = self.patch_embed.forward(xs)?;
297 let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
298 &xs + &self.interpolate_pos_encoding(&xs, w, h)?
299 }
300
301 fn get_intermediate_layers_not_chunked(
302 &self,
303 xs: &Tensor,
304 blocks_to_take: &[usize],
305 ) -> Result<Vec<Tensor>> {
306 let mut xs = self.prepare_tokens_with_mask(xs)?;
307 let mut output = Vec::new();
308 for (i, blk) in self.blocks.iter().enumerate() {
309 xs = blk.forward(&xs)?;
310 if blocks_to_take.contains(&i) {
311 output.push(xs.clone());
312 }
313 }
314 if output.len() != blocks_to_take.len() {
315 candle::bail!(
316 "only {} / {} blocks found",
317 output.len(),
318 blocks_to_take.len()
319 );
320 }
321 Ok(output)
322 }
323
324 pub fn get_intermediate_layers(
325 &self,
326 xs: &Tensor,
327 blocks_to_take: &[usize],
328 reshape: bool,
329 return_class_token: bool,
330 norm: bool,
331 ) -> Result<Tensor> {
332 let outputs = self.get_intermediate_layers_not_chunked(xs, blocks_to_take)?;
333 let outputs = if norm {
334 outputs
335 .iter()
336 .map(|out| self.norm.forward(out))
337 .collect::<Result<Vec<_>>>()?
338 } else {
339 outputs
340 };
341 let class_tokens = outputs
342 .iter()
343 .map(|out| out.i((.., 0)))
344 .collect::<Result<Vec<_>>>()?;
345 let outputs = outputs
346 .iter()
347 .map(|out| out.i((.., 1..)))
348 .collect::<Result<Vec<_>>>()?;
349
350 let outputs = if reshape {
351 let (b, _c, w, h) = xs.dims4()?;
352 let patch_size = self.patch_embed.patch_size.0;
353 let num_channels = outputs[0].elem_count() / (b * (w / patch_size) * (h / patch_size));
354 outputs
355 .iter()
356 .map(|out| {
357 out.reshape((b, w / patch_size, h / patch_size, num_channels))?
358 .transpose(2, 3)?
359 .transpose(1, 2)
360 })
361 .collect::<Result<Vec<_>>>()?
362 } else {
363 outputs
364 };
365
366 let outputs = if return_class_token {
367 outputs
368 .iter()
369 .zip(class_tokens.iter())
370 .map(|(out, class_token)| Tensor::cat(&[out, class_token], D::Minus1))
371 .collect::<Result<Vec<_>>>()?
372 } else {
373 outputs
374 };
375
376 Tensor::stack(&outputs[..], 0)
377 }
378}
379
380impl Module for DinoVisionTransformer {
381 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
382 let mut xs = self.prepare_tokens_with_mask(xs)?;
383 for blk in self.blocks.iter() {
384 xs = blk.forward(&xs)?
385 }
386 let xs = self.norm.forward(&xs)?;
387 let xs_norm_clstoken = xs.i((.., 0))?;
388 let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?;
389 let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?;
390 self.head.forward(&xs)
391 }
392}
393
394pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> {
395 DinoVisionTransformer::new(vb, 12, 384, 6)
396}