candle_transformers/models/segment_anything/
tiny_vit.rs

1// Adapted from:
2// https://github.com/ChaoningZhang/MobileSAM/blob/master/mobile_sam/modeling/tiny_vit_sam.py
3use candle::{IndexOp, Result, Tensor, D};
4use candle_nn::{Conv2dConfig, Module, VarBuilder};
5
6const MBCONV_EXPAND_RATIO: usize = 4;
7const MLP_RATIO: usize = 4;
8const LOCAL_CONV_SIZE: usize = 3;
9const IMG_SIZE: usize = 1024;
10const IN_CHANNELS: usize = 3;
11
12#[derive(Debug)]
13struct Conv2dBN {
14    c: candle_nn::Conv2d,
15    bn: candle_nn::BatchNorm,
16    span: tracing::Span,
17}
18
19impl Conv2dBN {
20    fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result<Self> {
21        let c = candle_nn::conv2d_no_bias(in_, out, ks, cfg, vb.pp("c"))?;
22        let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?;
23        let span = tracing::span!(tracing::Level::TRACE, "conv2d-bn");
24        Ok(Self { c, bn, span })
25    }
26}
27
28impl Module for Conv2dBN {
29    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
30        let _enter = self.span.enter();
31        xs.apply(&self.c)?.apply_t(&self.bn, false)
32    }
33}
34
35#[derive(Debug)]
36struct PatchEmbed {
37    conv1: Conv2dBN,
38    conv2: Conv2dBN,
39    span: tracing::Span,
40}
41
42impl PatchEmbed {
43    fn new(in_chans: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
44        let cfg = candle_nn::Conv2dConfig {
45            stride: 2,
46            padding: 1,
47            ..Default::default()
48        };
49        let conv1 = Conv2dBN::new(in_chans, embed_dim / 2, 3, cfg, vb.pp("seq.0"))?;
50        let conv2 = Conv2dBN::new(embed_dim / 2, embed_dim, 3, cfg, vb.pp("seq.2"))?;
51        let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
52        Ok(Self { conv1, conv2, span })
53    }
54}
55
56impl Module for PatchEmbed {
57    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
58        let _enter = self.span.enter();
59        xs.apply(&self.conv1)?.gelu()?.apply(&self.conv2)
60    }
61}
62
63#[derive(Debug)]
64struct MBConv {
65    conv1: Conv2dBN,
66    conv2: Conv2dBN,
67    conv3: Conv2dBN,
68    span: tracing::Span,
69}
70
71impl MBConv {
72    fn new(in_: usize, out: usize, expand_ratio: usize, vb: VarBuilder) -> Result<Self> {
73        let hidden = in_ * expand_ratio;
74        let cfg2 = candle_nn::Conv2dConfig {
75            padding: 1,
76            groups: hidden,
77            ..Default::default()
78        };
79        let conv1 = Conv2dBN::new(in_, hidden, 1, Default::default(), vb.pp("conv1"))?;
80        let conv2 = Conv2dBN::new(hidden, hidden, 3, cfg2, vb.pp("conv2"))?;
81        let conv3 = Conv2dBN::new(hidden, out, 1, Default::default(), vb.pp("conv3"))?;
82        let span = tracing::span!(tracing::Level::TRACE, "mb-conv");
83        Ok(Self {
84            conv1,
85            conv2,
86            conv3,
87            span,
88        })
89    }
90}
91
92impl Module for MBConv {
93    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
94        let _enter = self.span.enter();
95        let shortcut = xs;
96        let xs = xs
97            .apply(&self.conv1)?
98            .gelu()?
99            .apply(&self.conv2)?
100            .gelu()?
101            .apply(&self.conv3)?;
102        (xs + shortcut)?.gelu()
103    }
104}
105
106#[derive(Debug)]
107struct PatchMerging {
108    conv1: Conv2dBN,
109    conv2: Conv2dBN,
110    conv3: Conv2dBN,
111    input_resolution: (usize, usize),
112    span: tracing::Span,
113}
114
115impl PatchMerging {
116    fn new(
117        input_resolution: (usize, usize),
118        dim: usize,
119        out: usize,
120        vb: VarBuilder,
121    ) -> Result<Self> {
122        let stride = if [320, 448, 576].contains(&out) { 1 } else { 2 };
123        let cfg2 = candle_nn::Conv2dConfig {
124            padding: 1,
125            stride,
126            groups: out,
127            ..Default::default()
128        };
129        let conv1 = Conv2dBN::new(dim, out, 1, Default::default(), vb.pp("conv1"))?;
130        let conv2 = Conv2dBN::new(out, out, 3, cfg2, vb.pp("conv2"))?;
131        let conv3 = Conv2dBN::new(out, out, 1, Default::default(), vb.pp("conv3"))?;
132        let span = tracing::span!(tracing::Level::TRACE, "patch-merging");
133        Ok(Self {
134            conv1,
135            conv2,
136            conv3,
137            input_resolution,
138            span,
139        })
140    }
141}
142
143impl Module for PatchMerging {
144    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
145        let _enter = self.span.enter();
146        let xs = if xs.rank() == 3 {
147            let (h, w) = self.input_resolution;
148            let b = xs.dim(0)?;
149            xs.reshape((b, h, w, ()))?.permute((0, 3, 1, 2))?
150        } else {
151            xs.clone()
152        };
153        xs.apply(&self.conv1)?
154            .gelu()?
155            .apply(&self.conv2)?
156            .gelu()?
157            .apply(&self.conv3)?
158            .flatten_from(2)?
159            .transpose(1, 2)
160    }
161}
162
163#[derive(Debug)]
164struct ConvLayer {
165    blocks: Vec<MBConv>,
166    downsample: Option<PatchMerging>,
167    span: tracing::Span,
168}
169
170impl ConvLayer {
171    fn new(
172        dim: usize,
173        out: usize,
174        input_resolution: (usize, usize),
175        depth: usize,
176        downsample: bool,
177        conv_expand_ratio: usize,
178        vb: VarBuilder,
179    ) -> Result<Self> {
180        let vb_b = vb.pp("blocks");
181        let mut blocks = Vec::with_capacity(depth);
182        for index in 0..depth {
183            let block = MBConv::new(dim, dim, conv_expand_ratio, vb_b.pp(index))?;
184            blocks.push(block)
185        }
186        let downsample = if downsample {
187            let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?;
188            Some(downsample)
189        } else {
190            None
191        };
192        let span = tracing::span!(tracing::Level::TRACE, "conv-layer");
193        Ok(Self {
194            blocks,
195            downsample,
196            span,
197        })
198    }
199}
200
201impl Module for ConvLayer {
202    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
203        let _enter = self.span.enter();
204        let mut xs = xs.clone();
205        for block in self.blocks.iter() {
206            xs = block.forward(&xs)?
207        }
208        match &self.downsample {
209            None => Ok(xs),
210            Some(downsample) => downsample.forward(&xs),
211        }
212    }
213}
214
215#[derive(Debug)]
216struct Mlp {
217    norm: candle_nn::LayerNorm,
218    fc1: super::Linear,
219    fc2: super::Linear,
220    span: tracing::Span,
221}
222
223impl Mlp {
224    fn new(in_: usize, hidden: usize, vb: VarBuilder) -> Result<Self> {
225        let norm = candle_nn::layer_norm(in_, 1e-5, vb.pp("norm"))?;
226        let fc1 = super::linear(vb.pp("fc1"), in_, hidden, true)?;
227        let fc2 = super::linear(vb.pp("fc2"), hidden, in_, true)?;
228        let span = tracing::span!(tracing::Level::TRACE, "mlp");
229        Ok(Self {
230            norm,
231            fc1,
232            fc2,
233            span,
234        })
235    }
236}
237
238impl Module for Mlp {
239    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
240        let _enter = self.span.enter();
241        xs.apply(&self.norm)?
242            .apply(&self.fc1)?
243            .gelu()?
244            .apply(&self.fc2)
245    }
246}
247
248#[derive(Debug)]
249struct Attention {
250    norm: candle_nn::LayerNorm,
251    qkv: super::Linear,
252    proj: super::Linear,
253    ab: Tensor,
254    key_dim: usize,
255    num_heads: usize,
256    d: usize,
257    dh: usize,
258    scale: f64,
259    span: tracing::Span,
260    span_matmul: tracing::Span,
261    span_softmax: tracing::Span,
262}
263
264impl Attention {
265    fn new(
266        dim: usize,
267        key_dim: usize,
268        num_heads: usize,
269        attn_ratio: usize,
270        resolution: (usize, usize),
271        vb: VarBuilder,
272    ) -> Result<Self> {
273        let d = attn_ratio * key_dim;
274        let dh = d * num_heads;
275        let nh_kd = key_dim * num_heads;
276        let h = dh + nh_kd * 2;
277        let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?;
278        let qkv = super::linear(vb.pp("qkv"), dim, h, true)?;
279        let proj = super::linear(vb.pp("proj"), dh, dim, true)?;
280
281        let points = (0..resolution.0)
282            .flat_map(|x| (0..resolution.1).map(move |y| (x as i64, y as i64)))
283            .collect::<Vec<_>>();
284        let mut idxs = Vec::with_capacity(points.len() * points.len());
285        let mut attention_offsets = std::collections::HashMap::new();
286        for &(x1, y1) in points.iter() {
287            for &(x2, y2) in points.iter() {
288                let offset = ((x2 - x1).abs(), (y2 - y1).abs());
289                let l = attention_offsets.len();
290                let idx = attention_offsets.entry(offset).or_insert(l);
291                idxs.push(*idx as u32)
292            }
293        }
294        let attention_biases = vb.get((num_heads, attention_offsets.len()), "attention_biases")?;
295        let idxs = Tensor::new(idxs, attention_biases.device())?;
296        let ab =
297            attention_biases
298                .index_select(&idxs, 1)?
299                .reshape(((), points.len(), points.len()))?;
300        let span = tracing::span!(tracing::Level::TRACE, "attention");
301        let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul");
302        let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
303        Ok(Self {
304            norm,
305            qkv,
306            proj,
307            ab,
308            key_dim,
309            num_heads,
310            d,
311            dh,
312            scale: 1f64 / (key_dim as f64).sqrt(),
313            span,
314            span_matmul,
315            span_softmax,
316        })
317    }
318}
319
320impl Module for Attention {
321    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
322        let _enter = self.span.enter();
323        let (b, n, _) = xs.dims3()?;
324        let xs = xs.apply(&self.norm)?;
325        let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?;
326        let q = qkv
327            .narrow(D::Minus1, 0, self.key_dim)?
328            .permute((0, 2, 1, 3))?
329            .contiguous()?;
330        let k = qkv
331            .narrow(D::Minus1, self.key_dim, self.key_dim)?
332            .permute((0, 2, 1, 3))?
333            .contiguous()?;
334        let v = qkv
335            .narrow(D::Minus1, 2 * self.key_dim, self.d)?
336            .permute((0, 2, 1, 3))?
337            .contiguous()?;
338        let attn = {
339            let _enter = self.span_matmul.enter();
340            (q.matmul(&k.t()?)? * self.scale)?
341        };
342        let attn = attn.broadcast_add(&self.ab)?;
343        let attn = {
344            let _enter = self.span_softmax.enter();
345            candle_nn::ops::softmax_last_dim(&attn)?
346        };
347        let attn = {
348            let _enter = self.span_matmul.enter();
349            attn.matmul(&v)?
350        };
351        attn.transpose(1, 2)?
352            .reshape((b, n, self.dh))?
353            .apply(&self.proj)
354    }
355}
356
357#[derive(Debug)]
358struct TinyViTBlock {
359    attn: Attention,
360    local_conv: Conv2dBN,
361    mlp: Mlp,
362    window_size: usize,
363    input_resolution: (usize, usize),
364    span: tracing::Span,
365}
366
367impl TinyViTBlock {
368    fn new(
369        dim: usize,
370        input_resolution: (usize, usize),
371        num_heads: usize,
372        window_size: usize,
373        vb: VarBuilder,
374    ) -> Result<Self> {
375        let head_dim = dim / num_heads;
376        let attn = Attention::new(
377            dim,
378            head_dim,
379            num_heads,
380            1,
381            (window_size, window_size),
382            vb.pp("attn"),
383        )?;
384        let mlp = Mlp::new(dim, dim * MLP_RATIO, vb.pp("mlp"))?;
385        let cfg = candle_nn::Conv2dConfig {
386            padding: LOCAL_CONV_SIZE / 2,
387            groups: dim,
388            ..Default::default()
389        };
390        let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?;
391        let span = tracing::span!(tracing::Level::TRACE, "attention");
392        Ok(Self {
393            attn,
394            local_conv,
395            mlp,
396            window_size,
397            input_resolution,
398            span,
399        })
400    }
401}
402
403impl Module for TinyViTBlock {
404    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
405        let _enter = self.span.enter();
406        let (h, w) = self.input_resolution;
407        let (b, l, c) = xs.dims3()?;
408        let res_x = xs;
409        let xs = if h == self.window_size && w == self.window_size {
410            self.attn.forward(xs)?
411        } else {
412            let xs = xs.reshape((b, h, w, c))?;
413            let pad_b = (self.window_size - h % self.window_size) % self.window_size;
414            let pad_r = (self.window_size - w % self.window_size) % self.window_size;
415
416            let xs = if pad_b > 0 {
417                xs.pad_with_zeros(1, 0, pad_b)?
418            } else {
419                xs
420            };
421            let xs = if pad_r > 0 {
422                xs.pad_with_zeros(2, 0, pad_r)?
423            } else {
424                xs
425            };
426            let (p_h, p_w) = (h + pad_b, w + pad_r);
427            let n_h = p_h / self.window_size;
428            let n_w = p_w / self.window_size;
429            let xs = xs
430                .reshape((b, n_h, self.window_size, n_w, self.window_size, c))?
431                .transpose(2, 3)?
432                .reshape((b * n_h * n_w, self.window_size * self.window_size, c))?;
433            let xs = self.attn.forward(&xs)?;
434            let xs = xs
435                .reshape((b, n_h, n_w, self.window_size, self.window_size, c))?
436                .transpose(2, 3)?
437                .reshape((b, p_h, p_w, c))?;
438            let xs = if pad_r > 0 {
439                xs.i((.., .., ..w))?.contiguous()?
440            } else {
441                xs
442            };
443            let xs = if pad_b > 0 {
444                xs.i((.., ..h, ..))?.contiguous()?
445            } else {
446                xs
447            };
448            xs.reshape((b, l, c))?
449        };
450        let xs = (xs + res_x)?;
451        let xs = xs
452            .transpose(1, 2)?
453            .reshape((b, c, h, w))?
454            .apply(&self.local_conv)?
455            .reshape((b, c, l))?
456            .transpose(1, 2)?;
457        &xs + self.mlp.forward(&xs)?
458    }
459}
460
461#[derive(Debug)]
462struct BasicLayer {
463    blocks: Vec<TinyViTBlock>,
464    downsample: Option<PatchMerging>,
465    span: tracing::Span,
466}
467
468impl BasicLayer {
469    #[allow(clippy::too_many_arguments)]
470    fn new(
471        dim: usize,
472        input_resolution: (usize, usize),
473        depth: usize,
474        num_heads: usize,
475        window_size: usize,
476        downsample: bool,
477        out: usize,
478        vb: VarBuilder,
479    ) -> Result<Self> {
480        let vb_b = vb.pp("blocks");
481        let mut blocks = Vec::with_capacity(depth);
482        for index in 0..depth {
483            let block = TinyViTBlock::new(
484                dim,
485                input_resolution,
486                num_heads,
487                window_size,
488                vb_b.pp(index),
489            )?;
490            blocks.push(block)
491        }
492        let downsample = if downsample {
493            let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?;
494            Some(downsample)
495        } else {
496            None
497        };
498        let span = tracing::span!(tracing::Level::TRACE, "basic-layer");
499        Ok(Self {
500            blocks,
501            downsample,
502            span,
503        })
504    }
505}
506
507impl Module for BasicLayer {
508    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
509        let _enter = self.span.enter();
510        let mut xs = xs.clone();
511        for block in self.blocks.iter() {
512            xs = block.forward(&xs)?
513        }
514        match &self.downsample {
515            None => Ok(xs),
516            Some(downsample) => downsample.forward(&xs),
517        }
518    }
519}
520
521#[derive(Debug)]
522pub struct TinyViT {
523    patch_embed: PatchEmbed,
524    layer0: ConvLayer,
525    layers: Vec<BasicLayer>,
526    // norm_head: candle_nn::LayerNorm,
527    // head: candle_nn::Linear,
528    neck_conv1: candle_nn::Conv2d,
529    neck_ln1: super::LayerNorm2d,
530    neck_conv2: candle_nn::Conv2d,
531    neck_ln2: super::LayerNorm2d,
532    span: tracing::Span,
533    span_neck: tracing::Span,
534}
535
536impl TinyViT {
537    pub fn new(
538        embed_dims: &[usize],
539        depths: &[usize],
540        num_heads: &[usize],
541        window_sizes: &[usize],
542        _num_classes: usize,
543        vb: VarBuilder,
544    ) -> Result<Self> {
545        let patch_embed = PatchEmbed::new(IN_CHANNELS, embed_dims[0], vb.pp("patch_embed"))?;
546        let patches_resolution = IMG_SIZE / 4;
547
548        let vb_l = vb.pp("layers");
549        let layer0 = ConvLayer::new(
550            /* dim */ embed_dims[0],
551            /* out */ embed_dims[1],
552            /* input_resolution */ (patches_resolution, patches_resolution),
553            /* depth */ depths[0],
554            /* downsample */ true,
555            /* conv_expand_ratio */ MBCONV_EXPAND_RATIO,
556            vb_l.pp(0),
557        )?;
558
559        let num_layers = embed_dims.len();
560        let mut layers = Vec::with_capacity(num_layers - 1);
561        for i_layer in 1..num_layers {
562            let patches_resolution = patches_resolution / (1 << usize::min(i_layer, 2));
563            let layer = BasicLayer::new(
564                /* dim */ embed_dims[i_layer],
565                /* input_resolution */ (patches_resolution, patches_resolution),
566                /* depth */ depths[i_layer],
567                /* num_heads */ num_heads[i_layer],
568                /* window_size */ window_sizes[i_layer],
569                /* downsample */ i_layer < num_layers - 1,
570                /* out */ embed_dims[usize::min(i_layer + 1, num_layers - 1)],
571                vb_l.pp(i_layer),
572            )?;
573            layers.push(layer)
574        }
575
576        let last_embed_dim = embed_dims[embed_dims.len() - 1];
577        // let norm_head = candle_nn::layer_norm(last_embed_dim, 1e-5, vb.pp("norm_head"))?;
578        // let head = candle_nn::linear(last_embed_dim, num_classes, vb.pp("head"))?;
579        let neck_conv1 =
580            candle_nn::conv2d_no_bias(last_embed_dim, 256, 1, Default::default(), vb.pp("neck.0"))?;
581        let neck_ln1 = super::LayerNorm2d::new(256, 1e-6, vb.pp("neck.1"))?;
582        let cfg = candle_nn::Conv2dConfig {
583            padding: 1,
584            ..Default::default()
585        };
586        let neck_conv2 = candle_nn::conv2d_no_bias(256, 256, 3, cfg, vb.pp("neck.2"))?;
587        let neck_ln2 = super::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?;
588
589        let span = tracing::span!(tracing::Level::TRACE, "tiny-vit");
590        let span_neck = tracing::span!(tracing::Level::TRACE, "neck");
591        Ok(Self {
592            patch_embed,
593            layer0,
594            layers,
595            neck_conv1,
596            neck_ln1,
597            neck_conv2,
598            neck_ln2,
599            span,
600            span_neck,
601        })
602    }
603}
604
605impl Module for TinyViT {
606    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
607        let _enter = self.span.enter();
608        let xs = self.patch_embed.forward(xs)?;
609        let mut xs = self.layer0.forward(&xs)?;
610        for layer in self.layers.iter() {
611            xs = layer.forward(&xs)?
612        }
613        let (b, _, c) = xs.dims3()?;
614        let _enter = self.span_neck.enter();
615        xs.reshape((b, 64, 64, c))?
616            .permute((0, 3, 1, 2))?
617            .apply(&self.neck_conv1)?
618            .apply(&self.neck_ln1)?
619            .apply(&self.neck_conv2)?
620            .apply(&self.neck_ln2)
621    }
622}
623
624pub fn tiny_vit_5m(vb: VarBuilder) -> Result<TinyViT> {
625    TinyViT::new(
626        /* embed_dims */ &[64, 128, 160, 320],
627        /* depths */ &[2, 2, 6, 2],
628        /* num_heads */ &[2, 4, 5, 10],
629        /* window_sizes */ &[7, 7, 14, 7],
630        /* num_classes */ 1000,
631        vb,
632    )
633}