candle_transformers/models/
convmixer.rs1use candle::Result;
9use candle_nn::{batch_norm, Conv2dConfig, Module, VarBuilder};
10
11#[allow(clippy::many_single_char_names)]
12fn conv2d_same(
13 i: usize,
14 o: usize,
15 k: usize,
16 c: Conv2dConfig,
17 vb: VarBuilder,
18) -> Result<impl Module> {
19 let conv2d = candle_nn::conv2d(i, o, k, c, vb)?;
20 let s = c.stride;
21 let module = candle_nn::func(move |xs| {
22 let ih = xs.dim(2)?;
23 let iw = xs.dim(3)?;
24 let oh = ih.div_ceil(s);
25 let ow = iw.div_ceil(s);
26 let pad_h = usize::max((oh - 1) * s + k - ih, 0);
27 let pad_w = usize::max((ow - 1) * s + k - iw, 0);
28 if pad_h > 0 || pad_w > 0 {
29 xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?
30 .pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?
31 .apply(&conv2d)
32 } else {
33 xs.apply(&conv2d)
34 }
35 });
36 Ok(module)
37}
38
39fn block(dim: usize, kernel_size: usize, vb: VarBuilder) -> Result<impl Module> {
40 let conv2d_cfg = Conv2dConfig {
41 groups: dim,
42 ..Default::default()
43 };
44 let vb_fn = vb.pp(0).pp("fn");
45 let conv1 = conv2d_same(dim, dim, kernel_size, conv2d_cfg, vb_fn.pp(0))?;
46 let bn1 = batch_norm(dim, 1e-5, vb_fn.pp(2))?;
47 let conv2 = candle_nn::conv2d(dim, dim, 1, Default::default(), vb.pp(1))?;
48 let bn2 = batch_norm(dim, 1e-5, vb.pp(3))?;
49 Ok(candle_nn::func(move |xs| {
50 let ys = xs.apply(&conv1)?.gelu_erf()?.apply_t(&bn1, false)?;
51 (xs + ys)?.apply(&conv2)?.gelu_erf()?.apply_t(&bn2, false)
52 }))
53}
54
55fn convmixer(
56 nclasses: usize,
57 dim: usize,
58 depth: usize,
59 kernel_size: usize,
60 patch_size: usize,
61 vb: VarBuilder,
62) -> Result<candle_nn::Func<'static>> {
63 let conv2d_cfg = Conv2dConfig {
64 stride: patch_size,
65 ..Default::default()
66 };
67 let conv1 = candle_nn::conv2d(3, dim, patch_size, conv2d_cfg, vb.pp(0))?;
68 let bn1 = batch_norm(dim, 1e-5, vb.pp(2))?;
69 let blocks: Vec<_> = (0..depth)
70 .map(|index| block(dim, kernel_size, vb.pp(3 + index)))
71 .collect::<Result<Vec<_>>>()?;
72 let fc = candle_nn::linear(dim, nclasses, vb.pp(25))?;
73 Ok(candle_nn::func(move |xs| {
74 let mut xs = xs.apply(&conv1)?.gelu_erf()?.apply_t(&bn1, false)?;
75 for block in blocks.iter() {
76 xs = xs.apply(block)?
77 }
78 xs.mean(3)?.mean(2)?.apply(&fc)
80 }))
81}
82
83pub fn c1536_20(nclasses: usize, vb: VarBuilder) -> Result<candle_nn::Func<'static>> {
84 convmixer(nclasses, 1536, 20, 9, 7, vb)
85}
86
87pub fn c1024_20(nclasses: usize, vb: VarBuilder) -> Result<candle_nn::Func<'static>> {
88 convmixer(nclasses, 1024, 20, 9, 14, vb)
89}