candle_transformers/models/
mobileone.rs

1//! # MobileOne
2//!
3//! MobileOne inference implementation based on timm and candle-repvgg
4//!
5//! See ["MobileOne: An Improved One millisecond Mobile Backbone"](https://arxiv.org/abs/2206.04040)
6
7use candle::{DType, Result, Tensor, D};
8use candle_nn::{
9    batch_norm, conv2d, conv2d_no_bias, linear, ops::sigmoid, BatchNorm, Conv2d, Conv2dConfig,
10    Func, VarBuilder,
11};
12
13struct StageConfig {
14    blocks: usize,
15    channels: usize,
16}
17
18// The architecture in the paper has 6 stages. The timm implementation uses an equivalent form
19// by concatenating the 5th stage (starts with stride 1) to the previous one.
20const STAGES: [StageConfig; 5] = [
21    StageConfig {
22        blocks: 1,
23        channels: 64,
24    },
25    StageConfig {
26        blocks: 2,
27        channels: 64,
28    },
29    StageConfig {
30        blocks: 8,
31        channels: 128,
32    },
33    StageConfig {
34        blocks: 10,
35        channels: 256,
36    },
37    StageConfig {
38        blocks: 1,
39        channels: 512,
40    },
41];
42
43#[derive(Clone)]
44pub struct Config {
45    /// overparameterization factor
46    k: usize,
47    /// per-stage channel number multipliers
48    alphas: [f32; 5],
49}
50
51impl Config {
52    pub fn s0() -> Self {
53        Self {
54            k: 4,
55            alphas: [0.75, 0.75, 1.0, 1.0, 2.0],
56        }
57    }
58    pub fn s1() -> Self {
59        Self {
60            k: 1,
61            alphas: [1.5, 1.5, 1.5, 2.0, 2.5],
62        }
63    }
64    pub fn s2() -> Self {
65        Self {
66            k: 1,
67            alphas: [1.5, 1.5, 2.0, 2.5, 4.0],
68        }
69    }
70    pub fn s3() -> Self {
71        Self {
72            k: 1,
73            alphas: [2.0, 2.0, 2.5, 3.0, 4.0],
74        }
75    }
76    pub fn s4() -> Self {
77        Self {
78            k: 1,
79            alphas: [3.0, 3.0, 3.5, 3.5, 4.0],
80        }
81    }
82}
83
84// SE blocks are used in the last stages of the s4 variant.
85fn squeeze_and_excitation(
86    in_channels: usize,
87    squeeze_channels: usize,
88    vb: VarBuilder,
89) -> Result<Func<'static>> {
90    let conv2d_cfg = Conv2dConfig {
91        ..Default::default()
92    };
93    let fc1 = conv2d(in_channels, squeeze_channels, 1, conv2d_cfg, vb.pp("fc1"))?;
94    let fc2 = conv2d(squeeze_channels, in_channels, 1, conv2d_cfg, vb.pp("fc2"))?;
95
96    Ok(Func::new(move |xs| {
97        let residual = xs;
98        let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
99        let xs = sigmoid(&xs.apply(&fc1)?.relu()?.apply(&fc2)?)?;
100
101        residual.broadcast_mul(&xs)
102    }))
103}
104
105// fuses a convolutional kernel and a batchnorm layer into a convolutional layer
106// based on the _fuse_bn_tensor method in timm
107// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
108fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
109    let (gamma, beta) = bn.weight_and_bias().unwrap();
110    let mu = bn.running_mean();
111    let sigma = (bn.running_var() + bn.eps())?.sqrt();
112    let gps = (gamma / sigma)?;
113    let bias = (beta - mu * &gps)?;
114    let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?;
115
116    Ok((weights, bias))
117}
118
119// A mobileone block has a different training time and inference time architecture.
120// The latter is a simple and efficient equivalent transformation of the former
121// realized by a structural reparameterization technique, where convolutions
122// along with identity branches and batchnorm layers are fused into a single convolution.
123#[allow(clippy::too_many_arguments)]
124fn mobileone_block(
125    has_identity: bool,
126    k: usize,
127    dim: usize,
128    stride: usize,
129    padding: usize,
130    groups: usize,
131    kernel: usize,
132    in_channels: usize,
133    out_channels: usize,
134    vb: VarBuilder,
135) -> Result<Func<'static>> {
136    let conv2d_cfg = Conv2dConfig {
137        stride,
138        padding,
139        groups,
140        ..Default::default()
141    };
142
143    let mut w = Tensor::zeros(
144        (out_channels, in_channels / groups, kernel, kernel),
145        DType::F32,
146        vb.device(),
147    )?;
148    let mut b = Tensor::zeros(dim, DType::F32, vb.device())?;
149
150    // k is the training-time overparameterization factor, larger than 1 only in the s0 variant
151    for i in 0..k {
152        let conv_kxk_bn = batch_norm(dim, 1e-5, vb.pp(format!("conv_kxk.{i}.bn")))?;
153        let conv_kxk = conv2d_no_bias(
154            in_channels,
155            out_channels,
156            kernel,
157            conv2d_cfg,
158            vb.pp(format!("conv_kxk.{i}.conv")),
159        )?;
160        let (wk, bk) = fuse_conv_bn(conv_kxk.weight(), conv_kxk_bn)?;
161        w = (w + wk)?;
162        b = (b + bk)?;
163    }
164
165    if kernel > 1 {
166        let conv_scale_bn = batch_norm(dim, 1e-5, vb.pp("conv_scale.bn"))?;
167        let conv_scale = conv2d_no_bias(
168            in_channels,
169            out_channels,
170            1,
171            conv2d_cfg,
172            vb.pp("conv_scale.conv"),
173        )?;
174
175        let (mut ws, bs) = fuse_conv_bn(conv_scale.weight(), conv_scale_bn)?;
176        // resize to 3x3
177        ws = ws.pad_with_zeros(D::Minus1, 1, 1)?;
178        ws = ws.pad_with_zeros(D::Minus2, 1, 1)?;
179
180        w = (w + ws)?;
181        b = (b + bs)?;
182    }
183
184    // Use SE blocks if present (last layers of the s4 variant)
185    let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("attn"));
186
187    // read and reparameterize the identity bn into wi and bi
188    if has_identity {
189        let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"))?;
190
191        let mut weights: Vec<f32> = vec![0.0; w.elem_count()];
192
193        let id = in_channels / groups;
194        // See https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L809
195        for i in 0..in_channels {
196            if kernel > 1 {
197                weights[i * kernel * kernel + 4] = 1.0;
198            } else {
199                weights[i * (id + 1)] = 1.0;
200            }
201        }
202
203        let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
204        let (wi, bi) = fuse_conv_bn(weights, identity_bn)?;
205
206        w = (w + wi)?;
207        b = (b + bi)?;
208    }
209
210    let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);
211
212    Ok(Func::new(move |xs| {
213        let mut xs = xs.apply(&reparam_conv)?;
214        if let Ok(f) = &se {
215            xs = xs.apply(f)?;
216        }
217        xs = xs.relu()?;
218        Ok(xs)
219    }))
220}
221
222// Get the number of output channels per stage taking into account the multipliers
223fn output_channels_per_stage(cfg: &Config, stage: usize) -> usize {
224    let channels = STAGES[stage].channels as f32;
225    let alpha = cfg.alphas[stage];
226
227    match stage {
228        0 => std::cmp::min(64, (channels * alpha) as usize),
229        _ => (channels * alpha) as usize,
230    }
231}
232
233// Each stage is made of blocks. The first layer always downsamples with stride 2.
234// All but the first block have a residual connection.
235fn mobileone_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
236    let nblocks = STAGES[idx].blocks;
237    let mut blocks = Vec::with_capacity(nblocks);
238
239    let mut in_channels = output_channels_per_stage(cfg, idx - 1);
240
241    for block_idx in 0..nblocks {
242        let out_channels = output_channels_per_stage(cfg, idx);
243        let (has_identity, stride) = if block_idx == 0 {
244            (false, 2)
245        } else {
246            (true, 1)
247        };
248
249        // depthwise convolution layer
250        blocks.push(mobileone_block(
251            has_identity,
252            cfg.k,
253            in_channels,
254            stride,
255            1,
256            in_channels,
257            3,
258            in_channels,
259            in_channels,
260            vb.pp(block_idx * 2),
261        )?);
262
263        // pointwise convolution layer
264        blocks.push(mobileone_block(
265            has_identity,
266            cfg.k,
267            out_channels,
268            1, // stride
269            0, // padding
270            1, // groups
271            1, // kernel
272            in_channels,
273            out_channels,
274            vb.pp(block_idx * 2 + 1),
275        )?);
276
277        in_channels = out_channels;
278    }
279
280    Ok(Func::new(move |xs| {
281        let mut xs = xs.clone();
282        for block in blocks.iter() {
283            xs = xs.apply(block)?
284        }
285        Ok(xs)
286    }))
287}
288
289// Build a mobileone model for a given configuration.
290fn mobileone_model(
291    config: &Config,
292    nclasses: Option<usize>,
293    vb: VarBuilder,
294) -> Result<Func<'static>> {
295    let cls = match nclasses {
296        None => None,
297        Some(nclasses) => {
298            let outputs = output_channels_per_stage(config, 4);
299            let linear = linear(outputs, nclasses, vb.pp("head.fc"))?;
300            Some(linear)
301        }
302    };
303
304    let stem_dim = output_channels_per_stage(config, 0);
305    let stem = mobileone_block(false, 1, stem_dim, 2, 1, 1, 3, 3, stem_dim, vb.pp("stem"))?;
306    let vb = vb.pp("stages");
307    let stage1 = mobileone_stage(config, 1, vb.pp(0))?;
308    let stage2 = mobileone_stage(config, 2, vb.pp(1))?;
309    let stage3 = mobileone_stage(config, 3, vb.pp(2))?;
310    let stage4 = mobileone_stage(config, 4, vb.pp(3))?;
311
312    Ok(Func::new(move |xs| {
313        let xs = xs
314            .apply(&stem)?
315            .apply(&stage1)?
316            .apply(&stage2)?
317            .apply(&stage3)?
318            .apply(&stage4)?
319            .mean(D::Minus2)?
320            .mean(D::Minus1)?;
321        match &cls {
322            None => Ok(xs),
323            Some(cls) => xs.apply(cls),
324        }
325    }))
326}
327
328pub fn mobileone(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
329    mobileone_model(cfg, Some(nclasses), vb)
330}
331
332pub fn mobileone_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
333    mobileone_model(cfg, None, vb)
334}