candle_transformers/models/
convmixer.rs

1//! ConvMixer implementation.
2//!
3//! See "Patches Are All You Need?" by Trockman et al. 2022
4//!
5//! - 📝 [Arxiv](https://arxiv.org/abs/2201.09792)
6//! - 💻 [Github](https://github.com/locuslab/convmixer)
7//!
8use candle::Result;
9use candle_nn::{batch_norm, Conv2dConfig, Module, VarBuilder};
10
11#[allow(clippy::many_single_char_names)]
12fn conv2d_same(
13    i: usize,
14    o: usize,
15    k: usize,
16    c: Conv2dConfig,
17    vb: VarBuilder,
18) -> Result<impl Module> {
19    let conv2d = candle_nn::conv2d(i, o, k, c, vb)?;
20    let s = c.stride;
21    let module = candle_nn::func(move |xs| {
22        let ih = xs.dim(2)?;
23        let iw = xs.dim(3)?;
24        let oh = ih.div_ceil(s);
25        let ow = iw.div_ceil(s);
26        let pad_h = usize::max((oh - 1) * s + k - ih, 0);
27        let pad_w = usize::max((ow - 1) * s + k - iw, 0);
28        if pad_h > 0 || pad_w > 0 {
29            xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?
30                .pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?
31                .apply(&conv2d)
32        } else {
33            xs.apply(&conv2d)
34        }
35    });
36    Ok(module)
37}
38
39fn block(dim: usize, kernel_size: usize, vb: VarBuilder) -> Result<impl Module> {
40    let conv2d_cfg = Conv2dConfig {
41        groups: dim,
42        ..Default::default()
43    };
44    let vb_fn = vb.pp(0).pp("fn");
45    let conv1 = conv2d_same(dim, dim, kernel_size, conv2d_cfg, vb_fn.pp(0))?;
46    let bn1 = batch_norm(dim, 1e-5, vb_fn.pp(2))?;
47    let conv2 = candle_nn::conv2d(dim, dim, 1, Default::default(), vb.pp(1))?;
48    let bn2 = batch_norm(dim, 1e-5, vb.pp(3))?;
49    Ok(candle_nn::func(move |xs| {
50        let ys = xs.apply(&conv1)?.gelu_erf()?.apply_t(&bn1, false)?;
51        (xs + ys)?.apply(&conv2)?.gelu_erf()?.apply_t(&bn2, false)
52    }))
53}
54
55fn convmixer(
56    nclasses: usize,
57    dim: usize,
58    depth: usize,
59    kernel_size: usize,
60    patch_size: usize,
61    vb: VarBuilder,
62) -> Result<candle_nn::Func<'static>> {
63    let conv2d_cfg = Conv2dConfig {
64        stride: patch_size,
65        ..Default::default()
66    };
67    let conv1 = candle_nn::conv2d(3, dim, patch_size, conv2d_cfg, vb.pp(0))?;
68    let bn1 = batch_norm(dim, 1e-5, vb.pp(2))?;
69    let blocks: Vec<_> = (0..depth)
70        .map(|index| block(dim, kernel_size, vb.pp(3 + index)))
71        .collect::<Result<Vec<_>>>()?;
72    let fc = candle_nn::linear(dim, nclasses, vb.pp(25))?;
73    Ok(candle_nn::func(move |xs| {
74        let mut xs = xs.apply(&conv1)?.gelu_erf()?.apply_t(&bn1, false)?;
75        for block in blocks.iter() {
76            xs = xs.apply(block)?
77        }
78        // This performs the adaptive average pooling with a target size of (1, 1).
79        xs.mean(3)?.mean(2)?.apply(&fc)
80    }))
81}
82
83pub fn c1536_20(nclasses: usize, vb: VarBuilder) -> Result<candle_nn::Func<'static>> {
84    convmixer(nclasses, 1536, 20, 9, 7, vb)
85}
86
87pub fn c1024_20(nclasses: usize, vb: VarBuilder) -> Result<candle_nn::Func<'static>> {
88    convmixer(nclasses, 1024, 20, 9, 14, vb)
89}