1use candle::{IndexOp, Result, Tensor, D};
4use candle_nn::{Conv2dConfig, Module, VarBuilder};
5
6const MBCONV_EXPAND_RATIO: usize = 4;
7const MLP_RATIO: usize = 4;
8const LOCAL_CONV_SIZE: usize = 3;
9const IMG_SIZE: usize = 1024;
10const IN_CHANNELS: usize = 3;
11
12#[derive(Debug)]
13struct Conv2dBN {
14 c: candle_nn::Conv2d,
15 bn: candle_nn::BatchNorm,
16 span: tracing::Span,
17}
18
19impl Conv2dBN {
20 fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result<Self> {
21 let c = candle_nn::conv2d_no_bias(in_, out, ks, cfg, vb.pp("c"))?;
22 let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?;
23 let span = tracing::span!(tracing::Level::TRACE, "conv2d-bn");
24 Ok(Self { c, bn, span })
25 }
26}
27
28impl Module for Conv2dBN {
29 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
30 let _enter = self.span.enter();
31 xs.apply(&self.c)?.apply_t(&self.bn, false)
32 }
33}
34
35#[derive(Debug)]
36struct PatchEmbed {
37 conv1: Conv2dBN,
38 conv2: Conv2dBN,
39 span: tracing::Span,
40}
41
42impl PatchEmbed {
43 fn new(in_chans: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
44 let cfg = candle_nn::Conv2dConfig {
45 stride: 2,
46 padding: 1,
47 ..Default::default()
48 };
49 let conv1 = Conv2dBN::new(in_chans, embed_dim / 2, 3, cfg, vb.pp("seq.0"))?;
50 let conv2 = Conv2dBN::new(embed_dim / 2, embed_dim, 3, cfg, vb.pp("seq.2"))?;
51 let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
52 Ok(Self { conv1, conv2, span })
53 }
54}
55
56impl Module for PatchEmbed {
57 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
58 let _enter = self.span.enter();
59 xs.apply(&self.conv1)?.gelu()?.apply(&self.conv2)
60 }
61}
62
63#[derive(Debug)]
64struct MBConv {
65 conv1: Conv2dBN,
66 conv2: Conv2dBN,
67 conv3: Conv2dBN,
68 span: tracing::Span,
69}
70
71impl MBConv {
72 fn new(in_: usize, out: usize, expand_ratio: usize, vb: VarBuilder) -> Result<Self> {
73 let hidden = in_ * expand_ratio;
74 let cfg2 = candle_nn::Conv2dConfig {
75 padding: 1,
76 groups: hidden,
77 ..Default::default()
78 };
79 let conv1 = Conv2dBN::new(in_, hidden, 1, Default::default(), vb.pp("conv1"))?;
80 let conv2 = Conv2dBN::new(hidden, hidden, 3, cfg2, vb.pp("conv2"))?;
81 let conv3 = Conv2dBN::new(hidden, out, 1, Default::default(), vb.pp("conv3"))?;
82 let span = tracing::span!(tracing::Level::TRACE, "mb-conv");
83 Ok(Self {
84 conv1,
85 conv2,
86 conv3,
87 span,
88 })
89 }
90}
91
92impl Module for MBConv {
93 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
94 let _enter = self.span.enter();
95 let shortcut = xs;
96 let xs = xs
97 .apply(&self.conv1)?
98 .gelu()?
99 .apply(&self.conv2)?
100 .gelu()?
101 .apply(&self.conv3)?;
102 (xs + shortcut)?.gelu()
103 }
104}
105
106#[derive(Debug)]
107struct PatchMerging {
108 conv1: Conv2dBN,
109 conv2: Conv2dBN,
110 conv3: Conv2dBN,
111 input_resolution: (usize, usize),
112 span: tracing::Span,
113}
114
115impl PatchMerging {
116 fn new(
117 input_resolution: (usize, usize),
118 dim: usize,
119 out: usize,
120 vb: VarBuilder,
121 ) -> Result<Self> {
122 let stride = if [320, 448, 576].contains(&out) { 1 } else { 2 };
123 let cfg2 = candle_nn::Conv2dConfig {
124 padding: 1,
125 stride,
126 groups: out,
127 ..Default::default()
128 };
129 let conv1 = Conv2dBN::new(dim, out, 1, Default::default(), vb.pp("conv1"))?;
130 let conv2 = Conv2dBN::new(out, out, 3, cfg2, vb.pp("conv2"))?;
131 let conv3 = Conv2dBN::new(out, out, 1, Default::default(), vb.pp("conv3"))?;
132 let span = tracing::span!(tracing::Level::TRACE, "patch-merging");
133 Ok(Self {
134 conv1,
135 conv2,
136 conv3,
137 input_resolution,
138 span,
139 })
140 }
141}
142
143impl Module for PatchMerging {
144 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
145 let _enter = self.span.enter();
146 let xs = if xs.rank() == 3 {
147 let (h, w) = self.input_resolution;
148 let b = xs.dim(0)?;
149 xs.reshape((b, h, w, ()))?.permute((0, 3, 1, 2))?
150 } else {
151 xs.clone()
152 };
153 xs.apply(&self.conv1)?
154 .gelu()?
155 .apply(&self.conv2)?
156 .gelu()?
157 .apply(&self.conv3)?
158 .flatten_from(2)?
159 .transpose(1, 2)
160 }
161}
162
163#[derive(Debug)]
164struct ConvLayer {
165 blocks: Vec<MBConv>,
166 downsample: Option<PatchMerging>,
167 span: tracing::Span,
168}
169
170impl ConvLayer {
171 fn new(
172 dim: usize,
173 out: usize,
174 input_resolution: (usize, usize),
175 depth: usize,
176 downsample: bool,
177 conv_expand_ratio: usize,
178 vb: VarBuilder,
179 ) -> Result<Self> {
180 let vb_b = vb.pp("blocks");
181 let mut blocks = Vec::with_capacity(depth);
182 for index in 0..depth {
183 let block = MBConv::new(dim, dim, conv_expand_ratio, vb_b.pp(index))?;
184 blocks.push(block)
185 }
186 let downsample = if downsample {
187 let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?;
188 Some(downsample)
189 } else {
190 None
191 };
192 let span = tracing::span!(tracing::Level::TRACE, "conv-layer");
193 Ok(Self {
194 blocks,
195 downsample,
196 span,
197 })
198 }
199}
200
201impl Module for ConvLayer {
202 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
203 let _enter = self.span.enter();
204 let mut xs = xs.clone();
205 for block in self.blocks.iter() {
206 xs = block.forward(&xs)?
207 }
208 match &self.downsample {
209 None => Ok(xs),
210 Some(downsample) => downsample.forward(&xs),
211 }
212 }
213}
214
215#[derive(Debug)]
216struct Mlp {
217 norm: candle_nn::LayerNorm,
218 fc1: super::Linear,
219 fc2: super::Linear,
220 span: tracing::Span,
221}
222
223impl Mlp {
224 fn new(in_: usize, hidden: usize, vb: VarBuilder) -> Result<Self> {
225 let norm = candle_nn::layer_norm(in_, 1e-5, vb.pp("norm"))?;
226 let fc1 = super::linear(vb.pp("fc1"), in_, hidden, true)?;
227 let fc2 = super::linear(vb.pp("fc2"), hidden, in_, true)?;
228 let span = tracing::span!(tracing::Level::TRACE, "mlp");
229 Ok(Self {
230 norm,
231 fc1,
232 fc2,
233 span,
234 })
235 }
236}
237
238impl Module for Mlp {
239 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
240 let _enter = self.span.enter();
241 xs.apply(&self.norm)?
242 .apply(&self.fc1)?
243 .gelu()?
244 .apply(&self.fc2)
245 }
246}
247
248#[derive(Debug)]
249struct Attention {
250 norm: candle_nn::LayerNorm,
251 qkv: super::Linear,
252 proj: super::Linear,
253 ab: Tensor,
254 key_dim: usize,
255 num_heads: usize,
256 d: usize,
257 dh: usize,
258 scale: f64,
259 span: tracing::Span,
260 span_matmul: tracing::Span,
261 span_softmax: tracing::Span,
262}
263
264impl Attention {
265 fn new(
266 dim: usize,
267 key_dim: usize,
268 num_heads: usize,
269 attn_ratio: usize,
270 resolution: (usize, usize),
271 vb: VarBuilder,
272 ) -> Result<Self> {
273 let d = attn_ratio * key_dim;
274 let dh = d * num_heads;
275 let nh_kd = key_dim * num_heads;
276 let h = dh + nh_kd * 2;
277 let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?;
278 let qkv = super::linear(vb.pp("qkv"), dim, h, true)?;
279 let proj = super::linear(vb.pp("proj"), dh, dim, true)?;
280
281 let points = (0..resolution.0)
282 .flat_map(|x| (0..resolution.1).map(move |y| (x as i64, y as i64)))
283 .collect::<Vec<_>>();
284 let mut idxs = Vec::with_capacity(points.len() * points.len());
285 let mut attention_offsets = std::collections::HashMap::new();
286 for &(x1, y1) in points.iter() {
287 for &(x2, y2) in points.iter() {
288 let offset = ((x2 - x1).abs(), (y2 - y1).abs());
289 let l = attention_offsets.len();
290 let idx = attention_offsets.entry(offset).or_insert(l);
291 idxs.push(*idx as u32)
292 }
293 }
294 let attention_biases = vb.get((num_heads, attention_offsets.len()), "attention_biases")?;
295 let idxs = Tensor::new(idxs, attention_biases.device())?;
296 let ab =
297 attention_biases
298 .index_select(&idxs, 1)?
299 .reshape(((), points.len(), points.len()))?;
300 let span = tracing::span!(tracing::Level::TRACE, "attention");
301 let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul");
302 let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
303 Ok(Self {
304 norm,
305 qkv,
306 proj,
307 ab,
308 key_dim,
309 num_heads,
310 d,
311 dh,
312 scale: 1f64 / (key_dim as f64).sqrt(),
313 span,
314 span_matmul,
315 span_softmax,
316 })
317 }
318}
319
320impl Module for Attention {
321 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
322 let _enter = self.span.enter();
323 let (b, n, _) = xs.dims3()?;
324 let xs = xs.apply(&self.norm)?;
325 let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?;
326 let q = qkv
327 .narrow(D::Minus1, 0, self.key_dim)?
328 .permute((0, 2, 1, 3))?
329 .contiguous()?;
330 let k = qkv
331 .narrow(D::Minus1, self.key_dim, self.key_dim)?
332 .permute((0, 2, 1, 3))?
333 .contiguous()?;
334 let v = qkv
335 .narrow(D::Minus1, 2 * self.key_dim, self.d)?
336 .permute((0, 2, 1, 3))?
337 .contiguous()?;
338 let attn = {
339 let _enter = self.span_matmul.enter();
340 (q.matmul(&k.t()?)? * self.scale)?
341 };
342 let attn = attn.broadcast_add(&self.ab)?;
343 let attn = {
344 let _enter = self.span_softmax.enter();
345 candle_nn::ops::softmax_last_dim(&attn)?
346 };
347 let attn = {
348 let _enter = self.span_matmul.enter();
349 attn.matmul(&v)?
350 };
351 attn.transpose(1, 2)?
352 .reshape((b, n, self.dh))?
353 .apply(&self.proj)
354 }
355}
356
357#[derive(Debug)]
358struct TinyViTBlock {
359 attn: Attention,
360 local_conv: Conv2dBN,
361 mlp: Mlp,
362 window_size: usize,
363 input_resolution: (usize, usize),
364 span: tracing::Span,
365}
366
367impl TinyViTBlock {
368 fn new(
369 dim: usize,
370 input_resolution: (usize, usize),
371 num_heads: usize,
372 window_size: usize,
373 vb: VarBuilder,
374 ) -> Result<Self> {
375 let head_dim = dim / num_heads;
376 let attn = Attention::new(
377 dim,
378 head_dim,
379 num_heads,
380 1,
381 (window_size, window_size),
382 vb.pp("attn"),
383 )?;
384 let mlp = Mlp::new(dim, dim * MLP_RATIO, vb.pp("mlp"))?;
385 let cfg = candle_nn::Conv2dConfig {
386 padding: LOCAL_CONV_SIZE / 2,
387 groups: dim,
388 ..Default::default()
389 };
390 let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?;
391 let span = tracing::span!(tracing::Level::TRACE, "attention");
392 Ok(Self {
393 attn,
394 local_conv,
395 mlp,
396 window_size,
397 input_resolution,
398 span,
399 })
400 }
401}
402
403impl Module for TinyViTBlock {
404 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
405 let _enter = self.span.enter();
406 let (h, w) = self.input_resolution;
407 let (b, l, c) = xs.dims3()?;
408 let res_x = xs;
409 let xs = if h == self.window_size && w == self.window_size {
410 self.attn.forward(xs)?
411 } else {
412 let xs = xs.reshape((b, h, w, c))?;
413 let pad_b = (self.window_size - h % self.window_size) % self.window_size;
414 let pad_r = (self.window_size - w % self.window_size) % self.window_size;
415
416 let xs = if pad_b > 0 {
417 xs.pad_with_zeros(1, 0, pad_b)?
418 } else {
419 xs
420 };
421 let xs = if pad_r > 0 {
422 xs.pad_with_zeros(2, 0, pad_r)?
423 } else {
424 xs
425 };
426 let (p_h, p_w) = (h + pad_b, w + pad_r);
427 let n_h = p_h / self.window_size;
428 let n_w = p_w / self.window_size;
429 let xs = xs
430 .reshape((b, n_h, self.window_size, n_w, self.window_size, c))?
431 .transpose(2, 3)?
432 .reshape((b * n_h * n_w, self.window_size * self.window_size, c))?;
433 let xs = self.attn.forward(&xs)?;
434 let xs = xs
435 .reshape((b, n_h, n_w, self.window_size, self.window_size, c))?
436 .transpose(2, 3)?
437 .reshape((b, p_h, p_w, c))?;
438 let xs = if pad_r > 0 {
439 xs.i((.., .., ..w))?.contiguous()?
440 } else {
441 xs
442 };
443 let xs = if pad_b > 0 {
444 xs.i((.., ..h, ..))?.contiguous()?
445 } else {
446 xs
447 };
448 xs.reshape((b, l, c))?
449 };
450 let xs = (xs + res_x)?;
451 let xs = xs
452 .transpose(1, 2)?
453 .reshape((b, c, h, w))?
454 .apply(&self.local_conv)?
455 .reshape((b, c, l))?
456 .transpose(1, 2)?;
457 &xs + self.mlp.forward(&xs)?
458 }
459}
460
461#[derive(Debug)]
462struct BasicLayer {
463 blocks: Vec<TinyViTBlock>,
464 downsample: Option<PatchMerging>,
465 span: tracing::Span,
466}
467
468impl BasicLayer {
469 #[allow(clippy::too_many_arguments)]
470 fn new(
471 dim: usize,
472 input_resolution: (usize, usize),
473 depth: usize,
474 num_heads: usize,
475 window_size: usize,
476 downsample: bool,
477 out: usize,
478 vb: VarBuilder,
479 ) -> Result<Self> {
480 let vb_b = vb.pp("blocks");
481 let mut blocks = Vec::with_capacity(depth);
482 for index in 0..depth {
483 let block = TinyViTBlock::new(
484 dim,
485 input_resolution,
486 num_heads,
487 window_size,
488 vb_b.pp(index),
489 )?;
490 blocks.push(block)
491 }
492 let downsample = if downsample {
493 let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?;
494 Some(downsample)
495 } else {
496 None
497 };
498 let span = tracing::span!(tracing::Level::TRACE, "basic-layer");
499 Ok(Self {
500 blocks,
501 downsample,
502 span,
503 })
504 }
505}
506
507impl Module for BasicLayer {
508 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
509 let _enter = self.span.enter();
510 let mut xs = xs.clone();
511 for block in self.blocks.iter() {
512 xs = block.forward(&xs)?
513 }
514 match &self.downsample {
515 None => Ok(xs),
516 Some(downsample) => downsample.forward(&xs),
517 }
518 }
519}
520
521#[derive(Debug)]
522pub struct TinyViT {
523 patch_embed: PatchEmbed,
524 layer0: ConvLayer,
525 layers: Vec<BasicLayer>,
526 neck_conv1: candle_nn::Conv2d,
529 neck_ln1: super::LayerNorm2d,
530 neck_conv2: candle_nn::Conv2d,
531 neck_ln2: super::LayerNorm2d,
532 span: tracing::Span,
533 span_neck: tracing::Span,
534}
535
536impl TinyViT {
537 pub fn new(
538 embed_dims: &[usize],
539 depths: &[usize],
540 num_heads: &[usize],
541 window_sizes: &[usize],
542 _num_classes: usize,
543 vb: VarBuilder,
544 ) -> Result<Self> {
545 let patch_embed = PatchEmbed::new(IN_CHANNELS, embed_dims[0], vb.pp("patch_embed"))?;
546 let patches_resolution = IMG_SIZE / 4;
547
548 let vb_l = vb.pp("layers");
549 let layer0 = ConvLayer::new(
550 embed_dims[0],
551 embed_dims[1],
552 (patches_resolution, patches_resolution),
553 depths[0],
554 true,
555 MBCONV_EXPAND_RATIO,
556 vb_l.pp(0),
557 )?;
558
559 let num_layers = embed_dims.len();
560 let mut layers = Vec::with_capacity(num_layers - 1);
561 for i_layer in 1..num_layers {
562 let patches_resolution = patches_resolution / (1 << usize::min(i_layer, 2));
563 let layer = BasicLayer::new(
564 embed_dims[i_layer],
565 (patches_resolution, patches_resolution),
566 depths[i_layer],
567 num_heads[i_layer],
568 window_sizes[i_layer],
569 i_layer < num_layers - 1,
570 embed_dims[usize::min(i_layer + 1, num_layers - 1)],
571 vb_l.pp(i_layer),
572 )?;
573 layers.push(layer)
574 }
575
576 let last_embed_dim = embed_dims[embed_dims.len() - 1];
577 let neck_conv1 =
580 candle_nn::conv2d_no_bias(last_embed_dim, 256, 1, Default::default(), vb.pp("neck.0"))?;
581 let neck_ln1 = super::LayerNorm2d::new(256, 1e-6, vb.pp("neck.1"))?;
582 let cfg = candle_nn::Conv2dConfig {
583 padding: 1,
584 ..Default::default()
585 };
586 let neck_conv2 = candle_nn::conv2d_no_bias(256, 256, 3, cfg, vb.pp("neck.2"))?;
587 let neck_ln2 = super::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?;
588
589 let span = tracing::span!(tracing::Level::TRACE, "tiny-vit");
590 let span_neck = tracing::span!(tracing::Level::TRACE, "neck");
591 Ok(Self {
592 patch_embed,
593 layer0,
594 layers,
595 neck_conv1,
596 neck_ln1,
597 neck_conv2,
598 neck_ln2,
599 span,
600 span_neck,
601 })
602 }
603}
604
605impl Module for TinyViT {
606 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
607 let _enter = self.span.enter();
608 let xs = self.patch_embed.forward(xs)?;
609 let mut xs = self.layer0.forward(&xs)?;
610 for layer in self.layers.iter() {
611 xs = layer.forward(&xs)?
612 }
613 let (b, _, c) = xs.dims3()?;
614 let _enter = self.span_neck.enter();
615 xs.reshape((b, 64, 64, c))?
616 .permute((0, 3, 1, 2))?
617 .apply(&self.neck_conv1)?
618 .apply(&self.neck_ln1)?
619 .apply(&self.neck_conv2)?
620 .apply(&self.neck_ln2)
621 }
622}
623
624pub fn tiny_vit_5m(vb: VarBuilder) -> Result<TinyViT> {
625 TinyViT::new(
626 &[64, 128, 160, 320],
627 &[2, 2, 6, 2],
628 &[2, 4, 5, 10],
629 &[7, 7, 14, 7],
630 1000,
631 vb,
632 )
633}