1use candle::{Result, Tensor, D};
14use candle_nn::{
15 batch_norm, conv2d_no_bias, linear, ops::softmax, Activation, Conv2dConfig, Func, VarBuilder,
16};
17
18#[derive(Clone, Debug)]
19enum BlockType {
20 Convolutional {
21 out_channels: usize,
22 kernel: usize,
23 stride: usize,
24 },
25 UniversalBottleneck {
26 out_channels: usize,
27 start_kernel: usize,
28 mid_kernel: usize,
29 stride: usize,
30 expand: usize,
31 },
32 EdgeResidual {
33 out_channels: usize,
34 kernel: usize,
35 stride: usize,
36 expand: usize,
37 },
38 Attention {
39 out_channels: usize,
40 heads: usize,
41 kernel: usize,
42 stride: usize,
43 kv_dim: usize,
44 kv_stride: usize,
45 },
46}
47
48#[derive(Clone, Debug)]
49pub struct Config {
50 stem_dim: usize,
51 activation: Activation,
52 stages: [Vec<BlockType>; 5],
53}
54
55#[rustfmt::skip]
56impl Config {
57 pub fn small() -> Self {
58 Self {
59 stem_dim: 32,
60 activation: Activation::Relu,
61 stages: [
62 vec![
63 BlockType::Convolutional { out_channels: 32, kernel: 3, stride: 2},
64 BlockType::Convolutional { out_channels: 32, kernel: 1, stride: 1},
65 ],
66 vec![
67 BlockType::Convolutional { out_channels: 96, kernel: 3, stride: 2},
68 BlockType::Convolutional { out_channels: 64, kernel: 1, stride: 1},
69 ],
70 vec![
71 BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 5, mid_kernel: 5, stride: 2, expand: 3},
72 BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 0, mid_kernel: 3, stride: 1, expand: 2},
73 BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 0, mid_kernel: 3, stride: 1, expand: 2},
74 BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 0, mid_kernel: 3, stride: 1, expand: 2},
75 BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 0, mid_kernel: 3, stride: 1, expand: 2},
76 BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
77 ],
78 vec![
79 BlockType::UniversalBottleneck { out_channels: 128, start_kernel: 3, mid_kernel: 3, stride: 2, expand: 6},
80 BlockType::UniversalBottleneck { out_channels: 128, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
81 BlockType::UniversalBottleneck { out_channels: 128, start_kernel: 0, mid_kernel: 5, stride: 1, expand: 4},
82 BlockType::UniversalBottleneck { out_channels: 128, start_kernel: 0, mid_kernel: 5, stride: 1, expand: 3},
83 BlockType::UniversalBottleneck { out_channels: 128, start_kernel: 0, mid_kernel: 3, stride: 1, expand: 4},
84 BlockType::UniversalBottleneck { out_channels: 128, start_kernel: 0, mid_kernel: 3, stride: 1, expand: 4},
85 ],
86 vec![
87 BlockType::Convolutional { out_channels: 960, kernel: 1, stride: 1},
88 ],
89 ],
90 }
91 }
92
93 pub fn medium() -> Self {
94 Self {
95 stem_dim: 32,
96 activation: Activation::Relu,
97 stages: [
98 vec![
99 BlockType::EdgeResidual { out_channels: 48, kernel: 3, stride: 2, expand: 4},
100 ],
101 vec![
102 BlockType::UniversalBottleneck { out_channels: 80, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 4},
103 BlockType::UniversalBottleneck { out_channels: 80, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 2},
104 ],
105 vec![
106 BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 6},
107 BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
108 BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
109 BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4},
110 BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
111 BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
112 BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 2},
113 BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
114 ],
115 vec![
116 BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 5, stride: 2, expand: 6},
117 BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
118 BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4},
119 BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4},
120 BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 4},
121 BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
122 BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 2},
123 BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
124 BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 4},
125 BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 4},
126 BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 2},
127
128 ],
129 vec![
130 BlockType::Convolutional { out_channels: 960, kernel: 1, stride: 1},
131 ],
132 ],
133 }
134 }
135
136 pub fn hybrid_medium() -> Self {
137 Self {
138 stem_dim: 32,
139 activation: Activation::Relu,
140 stages: [
141 vec![
142 BlockType::EdgeResidual { out_channels: 48, kernel: 3, stride: 2, expand: 4},
143 ],
144 vec![
145 BlockType::UniversalBottleneck { out_channels: 80, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 4},
146 BlockType::UniversalBottleneck { out_channels: 80, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 2},
147 ],
148 vec![
149 BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 6},
150 BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 2},
151 BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
152 BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4},
153 BlockType::Attention { out_channels: 160, heads: 4, kernel: 3, stride: 1, kv_stride:2, kv_dim: 64},
154 BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
155 BlockType::Attention { out_channels: 160, heads: 4, kernel: 3, stride: 1, kv_stride:2, kv_dim: 64},
156 BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
157 BlockType::Attention { out_channels: 160, heads: 4, kernel: 3, stride: 1, kv_stride:2, kv_dim: 64},
158 BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
159 BlockType::Attention { out_channels: 160, heads: 4, kernel: 3, stride: 1, kv_stride:2, kv_dim: 64},
160 BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
161 ],
162
163 vec![
164 BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 5, stride: 2, expand: 6},
165 BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
166 BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4},
167 BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4},
168 BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 2},
169 BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 2},
170 BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 2},
171 BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 4},
172 BlockType::Attention { out_channels: 256, heads: 4, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64},
173 BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
174 BlockType::Attention { out_channels: 256, heads: 4, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64},
175 BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
176 BlockType::Attention { out_channels: 256, heads: 4, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64},
177 BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
178 BlockType::Attention { out_channels: 256, heads: 4, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64},
179 BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
180 ],
181 vec![
182 BlockType::Convolutional { out_channels: 960, kernel: 1, stride: 1},
183 ],
184 ],
185 }
186 }
187
188 pub fn large() -> Self {
189 Self {
190 stem_dim: 24,
191 activation: Activation::Relu,
192 stages: [
193 vec![
194 BlockType::EdgeResidual { out_channels: 48, kernel: 3, stride: 2, expand: 4},
195 ],
196 vec![
197 BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 4},
198 BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
199 ],
200 vec![
201 BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 4},
202 BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
203 BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
204 BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
205 BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4},
206 BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
207 BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
208 BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
209 BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
210 BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
211 BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
212 ],
213 vec![
214 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 2, expand: 4},
215 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
216 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
217 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
218 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
219 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
220 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
221 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
222 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
223 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
224 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
225 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
226 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
227 ],
228 vec![
229 BlockType::Convolutional { out_channels: 960, kernel: 1, stride: 1},
230 ],
231 ],
232 }
233 }
234
235 pub fn hybrid_large() -> Self {
236 Self {
237 stem_dim: 24,
238 activation: Activation::Gelu,
239 stages: [
240 vec![
241 BlockType::EdgeResidual { out_channels: 48, kernel: 3, stride: 2, expand: 4},
242 ],
243 vec![
244 BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 4},
245 BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
246 ],
247 vec![
248 BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 4},
249 BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
250 BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
251 BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
252 BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4},
253 BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
254 BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
255 BlockType::Attention { out_channels: 192, heads: 8, kernel: 3, stride: 1, kv_stride:2, kv_dim: 48},
256 BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
257 BlockType::Attention { out_channels: 192, heads: 8, kernel: 3, stride: 1, kv_stride:2, kv_dim: 48},
258 BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
259 BlockType::Attention { out_channels: 192, heads: 8, kernel: 3, stride: 1, kv_stride:2, kv_dim: 48},
260 BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
261 BlockType::Attention { out_channels: 192, heads: 8, kernel: 3, stride: 1, kv_stride:2, kv_dim: 48},
262 BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
263 ],
264
265 vec![
266 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 2, expand: 4},
267 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
268 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
269 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
270 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
271 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
272 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
273 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
274 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
275 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
276 BlockType::Attention { out_channels: 512, heads: 8, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64},
277 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
278 BlockType::Attention { out_channels: 512, heads: 8, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64},
279 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
280 BlockType::Attention { out_channels: 512, heads: 8, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64},
281 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
282 BlockType::Attention { out_channels: 512, heads: 8, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64},
283 BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
284 ],
285 vec![
286 BlockType::Convolutional { out_channels: 960, kernel: 1, stride: 1},
287 ],
288 ],
289 }
290 }
291}
292
293fn depthwise_conv(
294 channels: usize,
295 kernel: usize,
296 stride: usize,
297 padding: usize,
298 vb: VarBuilder,
299) -> Result<Func<'static>> {
300 let conv2d_cfg = Conv2dConfig {
301 stride,
302 padding,
303 groups: channels,
304 ..Default::default()
305 };
306
307 let bn = batch_norm(channels, 1e-5, vb.pp("bn"))?;
308 let conv = conv2d_no_bias(channels, channels, kernel, conv2d_cfg, vb.pp("conv"))?;
309
310 Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false)))
311}
312
313fn pointwise_conv(
314 in_channels: usize,
315 out_channels: usize,
316 vb: VarBuilder,
317) -> Result<Func<'static>> {
318 let conv2d_cfg = Conv2dConfig {
319 ..Default::default()
320 };
321
322 let bn = batch_norm(out_channels, 1e-5, vb.pp("bn"))?;
323 let conv = conv2d_no_bias(in_channels, out_channels, 1, conv2d_cfg, vb.pp("conv"))?;
324
325 Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false)))
326}
327
328#[allow(clippy::too_many_arguments)]
330fn universal_inverted_bottleneck_block(
331 cfg: &Config,
332 in_channels: usize,
333 out_channels: usize,
334 expand: usize,
335 start_kernel: usize,
336 mid_kernel: usize,
337 stride: usize,
338 vb: VarBuilder,
339) -> Result<Func<'static>> {
340 let act = cfg.activation;
341 let skip_connection = (in_channels == out_channels) && (stride == 1);
342
343 let dw_start_stride = if mid_kernel > 0 { 1 } else { stride };
344 let dw_start = depthwise_conv(
345 in_channels,
346 start_kernel,
347 dw_start_stride,
348 start_kernel / 2,
349 vb.pp("dw_start"),
350 );
351 let pw_exp = pointwise_conv(in_channels, in_channels * expand, vb.pp("pw_exp"))?;
352 let dw_mid = depthwise_conv(
353 in_channels * expand,
354 mid_kernel,
355 stride,
356 mid_kernel / 2,
357 vb.pp("dw_mid"),
358 );
359 let pw_proj = pointwise_conv(in_channels * expand, out_channels, vb.pp("pw_proj"))?;
360
361 let gamma = vb.get(out_channels, "layer_scale.gamma");
362
363 Ok(Func::new(move |xs| {
364 let residual = xs.clone();
365
366 let mut xs = xs.clone();
367
368 if let Ok(f) = &dw_start {
369 xs = xs.apply(f)?;
370 }
371
372 xs = xs.apply(&pw_exp)?.apply(&act)?;
373
374 if let Ok(f) = &dw_mid {
375 xs = xs.apply(f)?.apply(&act)?;
376 }
377
378 xs = xs.apply(&pw_proj)?;
379
380 if let Ok(g) = &gamma {
381 xs = xs.broadcast_mul(&g.reshape((1, (), 1, 1))?)?;
382 };
383
384 if skip_connection {
385 xs = (xs + residual)?;
386 }
387
388 Ok(xs)
389 }))
390}
391
392fn conv_block(
394 cfg: &Config,
395 in_channels: usize,
396 out_channels: usize,
397 kernel: usize,
398 stride: usize,
399 vb: VarBuilder,
400) -> Result<Func<'static>> {
401 let conv2d_cfg = Conv2dConfig {
402 stride,
403 padding: kernel / 2,
404 ..Default::default()
405 };
406
407 let act = cfg.activation;
408 let bn = batch_norm(out_channels, 1e-5, vb.pp("bn1"))?;
409 let conv = conv2d_no_bias(in_channels, out_channels, kernel, conv2d_cfg, vb.pp("conv"))?;
410
411 Ok(Func::new(move |xs| {
412 xs.apply(&conv)?.apply_t(&bn, false)?.apply(&act)
413 }))
414}
415
416fn edge_residual_block(
417 cfg: &Config,
418 in_channels: usize,
419 out_channels: usize,
420 kernel: usize,
421 stride: usize,
422 expand: usize,
423 vb: VarBuilder,
424) -> Result<Func<'static>> {
425 let conv_exp_cfg = Conv2dConfig {
426 stride,
427 padding: kernel / 2,
428 ..Default::default()
429 };
430
431 let conv_pwl_cfg = Conv2dConfig {
432 ..Default::default()
433 };
434
435 let act = cfg.activation;
436 let mid_channels = in_channels * expand;
437 let conv_exp = conv2d_no_bias(
438 in_channels,
439 mid_channels,
440 kernel,
441 conv_exp_cfg,
442 vb.pp("conv_exp"),
443 )?;
444 let bn1 = batch_norm(mid_channels, 1e-5, vb.pp("bn1"))?;
445
446 let conv_pwl = conv2d_no_bias(
447 mid_channels,
448 out_channels,
449 1,
450 conv_pwl_cfg,
451 vb.pp("conv_pwl"),
452 )?;
453 let bn2 = batch_norm(out_channels, 1e-5, vb.pp("bn2"))?;
454
455 Ok(Func::new(move |xs| {
456 let xs = xs
457 .apply(&conv_exp)?
458 .apply_t(&bn1, false)?
459 .apply(&act)?
460 .apply(&conv_pwl)?
461 .apply_t(&bn2, false)?;
462
463 Ok(xs)
464 }))
465}
466
467fn reshape_kv(t: &Tensor) -> Result<Tensor> {
468 let d = t.dims4()?;
469 let t = t
470 .reshape((d.0, d.1, ()))?
471 .transpose(1, 2)?
472 .unsqueeze(1)?
473 .contiguous()?;
474 Ok(t)
475}
476
477fn reshape_query(t: &Tensor, heads: usize, kv_dim: usize) -> Result<Tensor> {
478 let d = t.dims4()?;
479
480 let t = t
481 .reshape((d.0, heads, kv_dim, ()))?
482 .transpose(D::Minus1, D::Minus2)?
483 .contiguous()?;
484 Ok(t)
485}
486
487fn reshape_output(t: &Tensor, heads: usize, h: usize, w: usize) -> Result<Tensor> {
488 let d = t.dims4()?;
489 let t = t.transpose(1, 2)?;
490 let t = t
491 .reshape((d.0, h, w, d.3 * heads))?
492 .permute((0, 3, 1, 2))?
493 .contiguous()?;
494 Ok(t)
495}
496
497#[allow(clippy::too_many_arguments)]
499fn mqa_block(
500 in_channels: usize,
501 out_channels: usize,
502 heads: usize,
503 kernel: usize,
504 stride: usize,
505 kv_dim: usize,
506 kv_stride: usize,
507 vb: VarBuilder,
508) -> Result<Func<'static>> {
509 let down_conv2d_cfg = Conv2dConfig {
510 stride: kv_stride,
511 padding: kernel / 2,
512 groups: in_channels,
513 ..Default::default()
514 };
515
516 let proj_conv2d_cfg = Conv2dConfig {
517 stride,
518 ..Default::default()
519 };
520
521 let skip_connection = (in_channels == out_channels) && (stride == 1);
522 let gamma = vb.get(out_channels, "layer_scale.gamma");
523 let norm = batch_norm(out_channels, 1e-5, vb.pp("norm"))?;
524 let scale = (kv_dim as f64).powf(-0.5);
525
526 let vb = vb.pp("attn");
527
528 let query_proj = conv2d_no_bias(
529 out_channels,
530 kv_dim * heads,
531 1,
532 proj_conv2d_cfg,
533 vb.pp("query.proj"),
534 )?;
535
536 let key_down_conv = conv2d_no_bias(
537 in_channels,
538 out_channels,
539 kernel,
540 down_conv2d_cfg,
541 vb.pp("key.down_conv"),
542 );
543 let key_norm = batch_norm(out_channels, 1e-5, vb.pp("key.norm"));
544
545 let key_proj = conv2d_no_bias(out_channels, kv_dim, 1, proj_conv2d_cfg, vb.pp("key.proj"))?;
546
547 let value_down_conv = conv2d_no_bias(
548 in_channels,
549 out_channels,
550 kernel,
551 down_conv2d_cfg,
552 vb.pp("value.down_conv"),
553 );
554
555 let value_norm = batch_norm(out_channels, 1e-5, vb.pp("value.norm"));
556 let value_proj = conv2d_no_bias(
557 out_channels,
558 kv_dim,
559 1,
560 proj_conv2d_cfg,
561 vb.pp("value.proj"),
562 )?;
563
564 let output_proj = conv2d_no_bias(
565 kv_dim * heads,
566 out_channels,
567 1,
568 proj_conv2d_cfg,
569 vb.pp("output.proj"),
570 )?;
571
572 Ok(Func::new(move |xs| {
573 let (_, _, h, w) = xs.dims4()?;
574
575 let residual = xs.clone();
576
577 let xs = xs.apply_t(&norm, false)?;
578
579 let q = xs.apply(&query_proj)?;
581
582 let q = reshape_query(&q, heads, kv_dim)?;
583 let q = (q * scale)?;
584
585 let mut k = xs.clone();
587
588 if let (Ok(kd), Ok(n)) = (&key_down_conv, &key_norm) {
589 k = k.apply(kd)?.apply_t(n, false)?;
590 }
591
592 let k = k.apply(&key_proj)?;
593
594 let k = reshape_kv(&k)?;
595
596 let mut v = xs.clone();
598
599 if let (Ok(vd), Ok(n)) = (&value_down_conv, &value_norm) {
600 v = v.apply(vd)?;
601 v = v.apply_t(n, false)?;
602 }
603
604 let v = v.apply(&value_proj)?;
605 let v = reshape_kv(&v)?;
606
607 let attn = q.broadcast_matmul(&(k.transpose(D::Minus2, D::Minus1)?))?;
608 let attn = softmax(&attn, D::Minus1)?;
609 let o = attn.broadcast_matmul(&v)?;
610
611 let o = reshape_output(&o, heads, h, w)?;
612
613 let mut xs = o.apply(&output_proj)?;
614
615 if let Ok(g) = &gamma {
618 xs = xs.broadcast_mul(&g.reshape((1, (), 1, 1))?)?;
619 };
620
621 if skip_connection {
622 xs = (xs + residual)?;
623 }
624 Ok(xs)
625 }))
626}
627
628fn mobilenetv4_stem(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
630 let conv2d_cfg = Conv2dConfig {
631 stride: 2,
632 padding: 1,
633 ..Default::default()
634 };
635
636 let act = cfg.activation;
637 let out_channels = cfg.stem_dim;
638 let bn = batch_norm(out_channels, 1e-5, vb.pp("bn1"))?;
639 let conv = conv2d_no_bias(3, out_channels, 3, conv2d_cfg, vb.pp("conv_stem"))?;
640
641 Ok(Func::new(move |xs| {
642 let xs = xs.apply(&conv)?.apply_t(&bn, false)?.apply(&act)?;
643 Ok(xs)
644 }))
645}
646
647fn mobilenetv4_blocks(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
649 let mut in_channels = cfg.stem_dim;
650 let mut blocks = Vec::new();
651
652 for stage in 0..5 {
653 let nblocks = cfg.stages[stage].len();
654
655 for block in 0..nblocks {
656 match cfg.stages[stage][block] {
657 BlockType::Convolutional {
658 out_channels,
659 kernel,
660 stride,
661 } => {
662 blocks.push(conv_block(
663 cfg,
664 in_channels,
665 out_channels,
666 kernel,
667 stride,
668 vb.pp(format!("{stage}.{block}")),
669 )?);
670 in_channels = out_channels;
671 }
672
673 BlockType::EdgeResidual {
674 out_channels,
675 kernel,
676 stride,
677 expand,
678 } => {
679 blocks.push(edge_residual_block(
680 cfg,
681 in_channels,
682 out_channels,
683 kernel,
684 stride,
685 expand,
686 vb.pp(format!("{stage}.{block}")),
687 )?);
688 in_channels = out_channels;
689 }
690
691 BlockType::UniversalBottleneck {
692 out_channels,
693 start_kernel,
694 mid_kernel,
695 stride,
696 expand,
697 } => {
698 blocks.push(universal_inverted_bottleneck_block(
699 cfg,
700 in_channels,
701 out_channels,
702 expand,
703 start_kernel,
704 mid_kernel,
705 stride,
706 vb.pp(format!("{stage}.{block}")),
707 )?);
708 in_channels = out_channels;
709 }
710
711 BlockType::Attention {
712 out_channels,
713 heads,
714 kernel,
715 stride,
716 kv_dim,
717 kv_stride,
718 } => {
719 blocks.push(mqa_block(
720 in_channels,
721 out_channels,
722 heads,
723 kernel,
724 stride,
725 kv_dim,
726 kv_stride,
727 vb.pp(format!("{stage}.{block}")),
728 )?);
729 in_channels = out_channels;
730 }
731 }
732 }
733 }
734
735 Ok(Func::new(move |xs| {
736 let mut xs = xs.clone();
737 for block in blocks.iter() {
738 xs = xs.apply(block)?
739 }
740 Ok(xs)
741 }))
742}
743
744fn mobilenetv4_head(
746 cfg: &Config,
747 outputs: usize,
748 nclasses: usize,
749 vb: VarBuilder,
750) -> Result<Func<'static>> {
751 let conv2d_cfg = Conv2dConfig {
752 ..Default::default()
753 };
754
755 let act = cfg.activation;
756 let conv = conv2d_no_bias(960, outputs, 1, conv2d_cfg, vb.pp("conv_head"))?;
757 let norm = batch_norm(outputs, 1e-5, vb.pp("norm_head"))?;
758 let cls = linear(outputs, nclasses, vb.pp("classifier"))?;
759
760 Ok(Func::new(move |xs| {
761 let mut xs = xs.clone();
762 xs = xs.apply(&conv)?;
763 xs = xs.apply_t(&norm, false)?.apply(&act)?;
764 xs = xs.flatten_from(1)?;
765 xs = xs.apply(&cls)?;
766 Ok(xs)
767 }))
768}
769
770fn mobilenetv4_model(
772 cfg: &Config,
773 nclasses: Option<usize>,
774 vb: VarBuilder,
775) -> Result<Func<'static>> {
776 let cls = match nclasses {
777 None => None,
778 Some(nclasses) => {
779 let outputs = 1280;
780 let head = mobilenetv4_head(cfg, outputs, nclasses, vb.clone())?;
781 Some(head)
782 }
783 };
784
785 let stem = mobilenetv4_stem(cfg, vb.clone())?;
786
787 let blocks = mobilenetv4_blocks(cfg, vb.pp("blocks"))?;
788
789 Ok(Func::new(move |xs| {
790 let xs = xs.apply(&stem)?.apply(&blocks)?;
791 let xs = xs.mean_keepdim(D::Minus1)?.mean_keepdim(D::Minus2)?;
792 match &cls {
793 None => Ok(xs),
794 Some(cls) => xs.apply(cls),
795 }
796 }))
797}
798
799pub fn mobilenetv4(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
800 mobilenetv4_model(cfg, Some(nclasses), vb)
801}
802
803pub fn mobilenetv4_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
804 mobilenetv4_model(cfg, None, vb)
805}