candle_transformers/models/
mobilenetv4.rs

1//! # MobileNet-v4
2//!
3//! MobileNet-v4 inference implementation based on timm.
4//!
5//! ## Paper
6//!
7//! ["MobileNetV4 - Universal Models for the Mobile Ecosystem"](https://arxiv.org/abs/2404.10518)
8//!
9//! ## References
10//!
11//! - [PyTorch Implementation](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/mobilenetv3.py)
12
13use candle::{Result, Tensor, D};
14use candle_nn::{
15    batch_norm, conv2d_no_bias, linear, ops::softmax, Activation, Conv2dConfig, Func, VarBuilder,
16};
17
18#[derive(Clone, Debug)]
19enum BlockType {
20    Convolutional {
21        out_channels: usize,
22        kernel: usize,
23        stride: usize,
24    },
25    UniversalBottleneck {
26        out_channels: usize,
27        start_kernel: usize,
28        mid_kernel: usize,
29        stride: usize,
30        expand: usize,
31    },
32    EdgeResidual {
33        out_channels: usize,
34        kernel: usize,
35        stride: usize,
36        expand: usize,
37    },
38    Attention {
39        out_channels: usize,
40        heads: usize,
41        kernel: usize,
42        stride: usize,
43        kv_dim: usize,
44        kv_stride: usize,
45    },
46}
47
48#[derive(Clone, Debug)]
49pub struct Config {
50    stem_dim: usize,
51    activation: Activation,
52    stages: [Vec<BlockType>; 5],
53}
54
55#[rustfmt::skip]
56impl Config {
57    pub fn small() -> Self {
58        Self {
59            stem_dim: 32,
60            activation: Activation::Relu,
61            stages: [
62                vec![
63                    BlockType::Convolutional { out_channels: 32, kernel: 3, stride: 2},
64                    BlockType::Convolutional { out_channels: 32, kernel: 1, stride: 1},
65                ],
66                vec![
67                    BlockType::Convolutional { out_channels: 96, kernel: 3, stride: 2},
68                    BlockType::Convolutional { out_channels: 64, kernel: 1, stride: 1},
69                ],
70                vec![
71                    BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 5, mid_kernel: 5, stride: 2, expand: 3},
72                    BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 0, mid_kernel: 3, stride: 1, expand: 2},
73                    BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 0, mid_kernel: 3, stride: 1, expand: 2},
74                    BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 0, mid_kernel: 3, stride: 1, expand: 2},
75                    BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 0, mid_kernel: 3, stride: 1, expand: 2},
76                    BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
77                ],
78                vec![
79                    BlockType::UniversalBottleneck { out_channels: 128, start_kernel: 3, mid_kernel: 3, stride: 2, expand: 6},
80                    BlockType::UniversalBottleneck { out_channels: 128, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
81                    BlockType::UniversalBottleneck { out_channels: 128, start_kernel: 0, mid_kernel: 5, stride: 1, expand: 4},
82                    BlockType::UniversalBottleneck { out_channels: 128, start_kernel: 0, mid_kernel: 5, stride: 1, expand: 3},
83                    BlockType::UniversalBottleneck { out_channels: 128, start_kernel: 0, mid_kernel: 3, stride: 1, expand: 4},
84                    BlockType::UniversalBottleneck { out_channels: 128, start_kernel: 0, mid_kernel: 3, stride: 1, expand: 4},
85                ],
86                vec![
87                    BlockType::Convolutional { out_channels: 960, kernel: 1, stride: 1},
88                ],
89            ],
90        }
91    }
92
93    pub fn medium() -> Self {
94        Self {
95            stem_dim: 32,
96            activation: Activation::Relu,
97            stages: [
98                 vec![
99                    BlockType::EdgeResidual { out_channels: 48, kernel: 3, stride: 2, expand: 4},
100                ],
101                vec![
102                    BlockType::UniversalBottleneck { out_channels: 80, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 4},
103                    BlockType::UniversalBottleneck { out_channels: 80, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 2},
104                ],
105                vec![
106                    BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 6},
107                    BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
108                    BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
109                    BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4},
110                    BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
111                    BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
112                    BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 2},
113                    BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
114                ],
115                vec![
116                    BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 5, stride: 2, expand: 6},
117                    BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
118                    BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4},
119                    BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4},
120                    BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 4},
121                    BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
122                    BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 2},
123                    BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
124                    BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 4},
125                    BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 4},
126                    BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 2},
127
128               ],
129                vec![
130                    BlockType::Convolutional { out_channels: 960, kernel: 1, stride: 1},
131                ],
132            ],
133        }
134    }
135
136    pub fn hybrid_medium() -> Self {
137        Self {
138            stem_dim: 32,
139            activation: Activation::Relu,
140            stages: [
141                 vec![
142                    BlockType::EdgeResidual { out_channels: 48, kernel: 3, stride: 2, expand: 4},
143                ],
144                vec![
145                    BlockType::UniversalBottleneck { out_channels: 80, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 4},
146                    BlockType::UniversalBottleneck { out_channels: 80, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 2},
147                ],
148                vec![
149                    BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 6},
150                    BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 2},
151                    BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
152                    BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4},
153                    BlockType::Attention { out_channels: 160, heads: 4, kernel: 3, stride: 1, kv_stride:2, kv_dim: 64},
154                    BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
155                    BlockType::Attention { out_channels: 160, heads: 4, kernel: 3, stride: 1, kv_stride:2, kv_dim: 64},
156                    BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
157                    BlockType::Attention { out_channels: 160, heads: 4, kernel: 3, stride: 1, kv_stride:2, kv_dim: 64},
158                    BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
159                    BlockType::Attention { out_channels: 160, heads: 4, kernel: 3, stride: 1, kv_stride:2, kv_dim: 64},
160                    BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
161                ],
162
163               vec![
164                    BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 5, stride: 2, expand: 6},
165                    BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
166                    BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4},
167                    BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4},
168                    BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 2},
169                    BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 2},
170                    BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 2},
171                    BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 4},
172                    BlockType::Attention { out_channels: 256, heads: 4, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64},
173                    BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
174                    BlockType::Attention { out_channels: 256, heads: 4, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64},
175                    BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
176                    BlockType::Attention { out_channels: 256, heads: 4, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64},
177                    BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
178                    BlockType::Attention { out_channels: 256, heads: 4, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64},
179                    BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
180               ],
181                vec![
182                    BlockType::Convolutional { out_channels: 960, kernel: 1, stride: 1},
183                ],
184            ],
185        }
186    }
187
188    pub fn large() -> Self {
189        Self {
190            stem_dim: 24,
191            activation: Activation::Relu,
192            stages: [
193                vec![
194                    BlockType::EdgeResidual { out_channels: 48, kernel: 3, stride: 2, expand: 4},
195                ],
196                vec![
197                    BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 4},
198                    BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
199                ],
200                vec![
201                    BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 4},
202                    BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
203                    BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
204                    BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
205                    BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4},
206                    BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
207                    BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
208                    BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
209                    BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
210                    BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
211                    BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
212                ],
213                vec![
214                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 2, expand: 4},
215                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
216                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
217                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
218                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
219                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
220                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
221                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
222                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
223                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
224                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
225                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
226                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
227                ],
228                vec![
229                    BlockType::Convolutional { out_channels: 960, kernel: 1, stride: 1},
230                ],
231            ],
232        }
233    }
234
235    pub fn hybrid_large() -> Self {
236        Self {
237            stem_dim: 24,
238            activation: Activation::Gelu,
239            stages: [
240                vec![
241                    BlockType::EdgeResidual { out_channels: 48, kernel: 3, stride: 2, expand: 4},
242                ],
243                vec![
244                    BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 4},
245                    BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
246                ],
247                vec![
248                    BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 4},
249                    BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
250                    BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
251                    BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
252                    BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4},
253                    BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
254                    BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
255                    BlockType::Attention { out_channels: 192, heads: 8, kernel: 3, stride: 1, kv_stride:2, kv_dim: 48},
256                    BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
257                    BlockType::Attention { out_channels: 192, heads: 8, kernel: 3, stride: 1, kv_stride:2, kv_dim: 48},
258                    BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
259                    BlockType::Attention { out_channels: 192, heads: 8, kernel: 3, stride: 1, kv_stride:2, kv_dim: 48},
260                    BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
261                    BlockType::Attention { out_channels: 192, heads: 8, kernel: 3, stride: 1, kv_stride:2, kv_dim: 48},
262                    BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
263                ],
264
265                vec![
266                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 2, expand: 4},
267                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
268                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
269                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
270                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
271                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
272                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
273                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
274                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
275                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
276                    BlockType::Attention { out_channels: 512, heads: 8, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64},
277                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
278                    BlockType::Attention { out_channels: 512, heads: 8, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64},
279                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
280                    BlockType::Attention { out_channels: 512, heads: 8, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64},
281                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
282                    BlockType::Attention { out_channels: 512, heads: 8, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64},
283                    BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
284                ],
285                vec![
286                    BlockType::Convolutional { out_channels: 960, kernel: 1, stride: 1},
287                ],
288            ],
289          }
290    }
291}
292
293fn depthwise_conv(
294    channels: usize,
295    kernel: usize,
296    stride: usize,
297    padding: usize,
298    vb: VarBuilder,
299) -> Result<Func<'static>> {
300    let conv2d_cfg = Conv2dConfig {
301        stride,
302        padding,
303        groups: channels,
304        ..Default::default()
305    };
306
307    let bn = batch_norm(channels, 1e-5, vb.pp("bn"))?;
308    let conv = conv2d_no_bias(channels, channels, kernel, conv2d_cfg, vb.pp("conv"))?;
309
310    Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false)))
311}
312
313fn pointwise_conv(
314    in_channels: usize,
315    out_channels: usize,
316    vb: VarBuilder,
317) -> Result<Func<'static>> {
318    let conv2d_cfg = Conv2dConfig {
319        ..Default::default()
320    };
321
322    let bn = batch_norm(out_channels, 1e-5, vb.pp("bn"))?;
323    let conv = conv2d_no_bias(in_channels, out_channels, 1, conv2d_cfg, vb.pp("conv"))?;
324
325    Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false)))
326}
327
328//Universal block that uses two pointwise convolutions and all combinations of two depthwise convolutions.
329#[allow(clippy::too_many_arguments)]
330fn universal_inverted_bottleneck_block(
331    cfg: &Config,
332    in_channels: usize,
333    out_channels: usize,
334    expand: usize,
335    start_kernel: usize,
336    mid_kernel: usize,
337    stride: usize,
338    vb: VarBuilder,
339) -> Result<Func<'static>> {
340    let act = cfg.activation;
341    let skip_connection = (in_channels == out_channels) && (stride == 1);
342
343    let dw_start_stride = if mid_kernel > 0 { 1 } else { stride };
344    let dw_start = depthwise_conv(
345        in_channels,
346        start_kernel,
347        dw_start_stride,
348        start_kernel / 2,
349        vb.pp("dw_start"),
350    );
351    let pw_exp = pointwise_conv(in_channels, in_channels * expand, vb.pp("pw_exp"))?;
352    let dw_mid = depthwise_conv(
353        in_channels * expand,
354        mid_kernel,
355        stride,
356        mid_kernel / 2,
357        vb.pp("dw_mid"),
358    );
359    let pw_proj = pointwise_conv(in_channels * expand, out_channels, vb.pp("pw_proj"))?;
360
361    let gamma = vb.get(out_channels, "layer_scale.gamma");
362
363    Ok(Func::new(move |xs| {
364        let residual = xs.clone();
365
366        let mut xs = xs.clone();
367
368        if let Ok(f) = &dw_start {
369            xs = xs.apply(f)?;
370        }
371
372        xs = xs.apply(&pw_exp)?.apply(&act)?;
373
374        if let Ok(f) = &dw_mid {
375            xs = xs.apply(f)?.apply(&act)?;
376        }
377
378        xs = xs.apply(&pw_proj)?;
379
380        if let Ok(g) = &gamma {
381            xs = xs.broadcast_mul(&g.reshape((1, (), 1, 1))?)?;
382        };
383
384        if skip_connection {
385            xs = (xs + residual)?;
386        }
387
388        Ok(xs)
389    }))
390}
391
392// Convolutional block including norm and activation.
393fn conv_block(
394    cfg: &Config,
395    in_channels: usize,
396    out_channels: usize,
397    kernel: usize,
398    stride: usize,
399    vb: VarBuilder,
400) -> Result<Func<'static>> {
401    let conv2d_cfg = Conv2dConfig {
402        stride,
403        padding: kernel / 2,
404        ..Default::default()
405    };
406
407    let act = cfg.activation;
408    let bn = batch_norm(out_channels, 1e-5, vb.pp("bn1"))?;
409    let conv = conv2d_no_bias(in_channels, out_channels, kernel, conv2d_cfg, vb.pp("conv"))?;
410
411    Ok(Func::new(move |xs| {
412        xs.apply(&conv)?.apply_t(&bn, false)?.apply(&act)
413    }))
414}
415
416fn edge_residual_block(
417    cfg: &Config,
418    in_channels: usize,
419    out_channels: usize,
420    kernel: usize,
421    stride: usize,
422    expand: usize,
423    vb: VarBuilder,
424) -> Result<Func<'static>> {
425    let conv_exp_cfg = Conv2dConfig {
426        stride,
427        padding: kernel / 2,
428        ..Default::default()
429    };
430
431    let conv_pwl_cfg = Conv2dConfig {
432        ..Default::default()
433    };
434
435    let act = cfg.activation;
436    let mid_channels = in_channels * expand;
437    let conv_exp = conv2d_no_bias(
438        in_channels,
439        mid_channels,
440        kernel,
441        conv_exp_cfg,
442        vb.pp("conv_exp"),
443    )?;
444    let bn1 = batch_norm(mid_channels, 1e-5, vb.pp("bn1"))?;
445
446    let conv_pwl = conv2d_no_bias(
447        mid_channels,
448        out_channels,
449        1,
450        conv_pwl_cfg,
451        vb.pp("conv_pwl"),
452    )?;
453    let bn2 = batch_norm(out_channels, 1e-5, vb.pp("bn2"))?;
454
455    Ok(Func::new(move |xs| {
456        let xs = xs
457            .apply(&conv_exp)?
458            .apply_t(&bn1, false)?
459            .apply(&act)?
460            .apply(&conv_pwl)?
461            .apply_t(&bn2, false)?;
462
463        Ok(xs)
464    }))
465}
466
467fn reshape_kv(t: &Tensor) -> Result<Tensor> {
468    let d = t.dims4()?;
469    let t = t
470        .reshape((d.0, d.1, ()))?
471        .transpose(1, 2)?
472        .unsqueeze(1)?
473        .contiguous()?;
474    Ok(t)
475}
476
477fn reshape_query(t: &Tensor, heads: usize, kv_dim: usize) -> Result<Tensor> {
478    let d = t.dims4()?;
479
480    let t = t
481        .reshape((d.0, heads, kv_dim, ()))?
482        .transpose(D::Minus1, D::Minus2)?
483        .contiguous()?;
484    Ok(t)
485}
486
487fn reshape_output(t: &Tensor, heads: usize, h: usize, w: usize) -> Result<Tensor> {
488    let d = t.dims4()?;
489    let t = t.transpose(1, 2)?;
490    let t = t
491        .reshape((d.0, h, w, d.3 * heads))?
492        .permute((0, 3, 1, 2))?
493        .contiguous()?;
494    Ok(t)
495}
496
497// Mobile multi-query attention
498#[allow(clippy::too_many_arguments)]
499fn mqa_block(
500    in_channels: usize,
501    out_channels: usize,
502    heads: usize,
503    kernel: usize,
504    stride: usize,
505    kv_dim: usize,
506    kv_stride: usize,
507    vb: VarBuilder,
508) -> Result<Func<'static>> {
509    let down_conv2d_cfg = Conv2dConfig {
510        stride: kv_stride,
511        padding: kernel / 2,
512        groups: in_channels,
513        ..Default::default()
514    };
515
516    let proj_conv2d_cfg = Conv2dConfig {
517        stride,
518        ..Default::default()
519    };
520
521    let skip_connection = (in_channels == out_channels) && (stride == 1);
522    let gamma = vb.get(out_channels, "layer_scale.gamma");
523    let norm = batch_norm(out_channels, 1e-5, vb.pp("norm"))?;
524    let scale = (kv_dim as f64).powf(-0.5);
525
526    let vb = vb.pp("attn");
527
528    let query_proj = conv2d_no_bias(
529        out_channels,
530        kv_dim * heads,
531        1,
532        proj_conv2d_cfg,
533        vb.pp("query.proj"),
534    )?;
535
536    let key_down_conv = conv2d_no_bias(
537        in_channels,
538        out_channels,
539        kernel,
540        down_conv2d_cfg,
541        vb.pp("key.down_conv"),
542    );
543    let key_norm = batch_norm(out_channels, 1e-5, vb.pp("key.norm"));
544
545    let key_proj = conv2d_no_bias(out_channels, kv_dim, 1, proj_conv2d_cfg, vb.pp("key.proj"))?;
546
547    let value_down_conv = conv2d_no_bias(
548        in_channels,
549        out_channels,
550        kernel,
551        down_conv2d_cfg,
552        vb.pp("value.down_conv"),
553    );
554
555    let value_norm = batch_norm(out_channels, 1e-5, vb.pp("value.norm"));
556    let value_proj = conv2d_no_bias(
557        out_channels,
558        kv_dim,
559        1,
560        proj_conv2d_cfg,
561        vb.pp("value.proj"),
562    )?;
563
564    let output_proj = conv2d_no_bias(
565        kv_dim * heads,
566        out_channels,
567        1,
568        proj_conv2d_cfg,
569        vb.pp("output.proj"),
570    )?;
571
572    Ok(Func::new(move |xs| {
573        let (_, _, h, w) = xs.dims4()?;
574
575        let residual = xs.clone();
576
577        let xs = xs.apply_t(&norm, false)?;
578
579        // Query
580        let q = xs.apply(&query_proj)?;
581
582        let q = reshape_query(&q, heads, kv_dim)?;
583        let q = (q * scale)?;
584
585        // Keys
586        let mut k = xs.clone();
587
588        if let (Ok(kd), Ok(n)) = (&key_down_conv, &key_norm) {
589            k = k.apply(kd)?.apply_t(n, false)?;
590        }
591
592        let k = k.apply(&key_proj)?;
593
594        let k = reshape_kv(&k)?;
595
596        // Value
597        let mut v = xs.clone();
598
599        if let (Ok(vd), Ok(n)) = (&value_down_conv, &value_norm) {
600            v = v.apply(vd)?;
601            v = v.apply_t(n, false)?;
602        }
603
604        let v = v.apply(&value_proj)?;
605        let v = reshape_kv(&v)?;
606
607        let attn = q.broadcast_matmul(&(k.transpose(D::Minus2, D::Minus1)?))?;
608        let attn = softmax(&attn, D::Minus1)?;
609        let o = attn.broadcast_matmul(&v)?;
610
611        let o = reshape_output(&o, heads, h, w)?;
612
613        let mut xs = o.apply(&output_proj)?;
614
615        // Layer scale
616
617        if let Ok(g) = &gamma {
618            xs = xs.broadcast_mul(&g.reshape((1, (), 1, 1))?)?;
619        };
620
621        if skip_connection {
622            xs = (xs + residual)?;
623        }
624        Ok(xs)
625    }))
626}
627
628// Stem.
629fn mobilenetv4_stem(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
630    let conv2d_cfg = Conv2dConfig {
631        stride: 2,
632        padding: 1,
633        ..Default::default()
634    };
635
636    let act = cfg.activation;
637    let out_channels = cfg.stem_dim;
638    let bn = batch_norm(out_channels, 1e-5, vb.pp("bn1"))?;
639    let conv = conv2d_no_bias(3, out_channels, 3, conv2d_cfg, vb.pp("conv_stem"))?;
640
641    Ok(Func::new(move |xs| {
642        let xs = xs.apply(&conv)?.apply_t(&bn, false)?.apply(&act)?;
643        Ok(xs)
644    }))
645}
646
647// The blocks in all the 5 stages of the model.
648fn mobilenetv4_blocks(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
649    let mut in_channels = cfg.stem_dim;
650    let mut blocks = Vec::new();
651
652    for stage in 0..5 {
653        let nblocks = cfg.stages[stage].len();
654
655        for block in 0..nblocks {
656            match cfg.stages[stage][block] {
657                BlockType::Convolutional {
658                    out_channels,
659                    kernel,
660                    stride,
661                } => {
662                    blocks.push(conv_block(
663                        cfg,
664                        in_channels,
665                        out_channels,
666                        kernel,
667                        stride,
668                        vb.pp(format!("{stage}.{block}")),
669                    )?);
670                    in_channels = out_channels;
671                }
672
673                BlockType::EdgeResidual {
674                    out_channels,
675                    kernel,
676                    stride,
677                    expand,
678                } => {
679                    blocks.push(edge_residual_block(
680                        cfg,
681                        in_channels,
682                        out_channels,
683                        kernel,
684                        stride,
685                        expand,
686                        vb.pp(format!("{stage}.{block}")),
687                    )?);
688                    in_channels = out_channels;
689                }
690
691                BlockType::UniversalBottleneck {
692                    out_channels,
693                    start_kernel,
694                    mid_kernel,
695                    stride,
696                    expand,
697                } => {
698                    blocks.push(universal_inverted_bottleneck_block(
699                        cfg,
700                        in_channels,
701                        out_channels,
702                        expand,
703                        start_kernel,
704                        mid_kernel,
705                        stride,
706                        vb.pp(format!("{stage}.{block}")),
707                    )?);
708                    in_channels = out_channels;
709                }
710
711                BlockType::Attention {
712                    out_channels,
713                    heads,
714                    kernel,
715                    stride,
716                    kv_dim,
717                    kv_stride,
718                } => {
719                    blocks.push(mqa_block(
720                        in_channels,
721                        out_channels,
722                        heads,
723                        kernel,
724                        stride,
725                        kv_dim,
726                        kv_stride,
727                        vb.pp(format!("{stage}.{block}")),
728                    )?);
729                    in_channels = out_channels;
730                }
731            }
732        }
733    }
734
735    Ok(Func::new(move |xs| {
736        let mut xs = xs.clone();
737        for block in blocks.iter() {
738            xs = xs.apply(block)?
739        }
740        Ok(xs)
741    }))
742}
743
744// Classification head.
745fn mobilenetv4_head(
746    cfg: &Config,
747    outputs: usize,
748    nclasses: usize,
749    vb: VarBuilder,
750) -> Result<Func<'static>> {
751    let conv2d_cfg = Conv2dConfig {
752        ..Default::default()
753    };
754
755    let act = cfg.activation;
756    let conv = conv2d_no_bias(960, outputs, 1, conv2d_cfg, vb.pp("conv_head"))?;
757    let norm = batch_norm(outputs, 1e-5, vb.pp("norm_head"))?;
758    let cls = linear(outputs, nclasses, vb.pp("classifier"))?;
759
760    Ok(Func::new(move |xs| {
761        let mut xs = xs.clone();
762        xs = xs.apply(&conv)?;
763        xs = xs.apply_t(&norm, false)?.apply(&act)?;
764        xs = xs.flatten_from(1)?;
765        xs = xs.apply(&cls)?;
766        Ok(xs)
767    }))
768}
769
770// Build a mobilenetv4 model for a given configuration.
771fn mobilenetv4_model(
772    cfg: &Config,
773    nclasses: Option<usize>,
774    vb: VarBuilder,
775) -> Result<Func<'static>> {
776    let cls = match nclasses {
777        None => None,
778        Some(nclasses) => {
779            let outputs = 1280;
780            let head = mobilenetv4_head(cfg, outputs, nclasses, vb.clone())?;
781            Some(head)
782        }
783    };
784
785    let stem = mobilenetv4_stem(cfg, vb.clone())?;
786
787    let blocks = mobilenetv4_blocks(cfg, vb.pp("blocks"))?;
788
789    Ok(Func::new(move |xs| {
790        let xs = xs.apply(&stem)?.apply(&blocks)?;
791        let xs = xs.mean_keepdim(D::Minus1)?.mean_keepdim(D::Minus2)?;
792        match &cls {
793            None => Ok(xs),
794            Some(cls) => xs.apply(cls),
795        }
796    }))
797}
798
799pub fn mobilenetv4(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
800    mobilenetv4_model(cfg, Some(nclasses), vb)
801}
802
803pub fn mobilenetv4_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
804    mobilenetv4_model(cfg, None, vb)
805}