1use candle::{DType, Result, Tensor, D};
8use candle_nn::{
9 batch_norm, conv2d, conv2d_no_bias, linear, ops::sigmoid, BatchNorm, Conv2d, Conv2dConfig,
10 Func, VarBuilder,
11};
12
13struct StageConfig {
14 blocks: usize,
15 channels: usize,
16}
17
18const STAGES: [StageConfig; 5] = [
21 StageConfig {
22 blocks: 1,
23 channels: 64,
24 },
25 StageConfig {
26 blocks: 2,
27 channels: 64,
28 },
29 StageConfig {
30 blocks: 8,
31 channels: 128,
32 },
33 StageConfig {
34 blocks: 10,
35 channels: 256,
36 },
37 StageConfig {
38 blocks: 1,
39 channels: 512,
40 },
41];
42
43#[derive(Clone)]
44pub struct Config {
45 k: usize,
47 alphas: [f32; 5],
49}
50
51impl Config {
52 pub fn s0() -> Self {
53 Self {
54 k: 4,
55 alphas: [0.75, 0.75, 1.0, 1.0, 2.0],
56 }
57 }
58 pub fn s1() -> Self {
59 Self {
60 k: 1,
61 alphas: [1.5, 1.5, 1.5, 2.0, 2.5],
62 }
63 }
64 pub fn s2() -> Self {
65 Self {
66 k: 1,
67 alphas: [1.5, 1.5, 2.0, 2.5, 4.0],
68 }
69 }
70 pub fn s3() -> Self {
71 Self {
72 k: 1,
73 alphas: [2.0, 2.0, 2.5, 3.0, 4.0],
74 }
75 }
76 pub fn s4() -> Self {
77 Self {
78 k: 1,
79 alphas: [3.0, 3.0, 3.5, 3.5, 4.0],
80 }
81 }
82}
83
84fn squeeze_and_excitation(
86 in_channels: usize,
87 squeeze_channels: usize,
88 vb: VarBuilder,
89) -> Result<Func<'static>> {
90 let conv2d_cfg = Conv2dConfig {
91 ..Default::default()
92 };
93 let fc1 = conv2d(in_channels, squeeze_channels, 1, conv2d_cfg, vb.pp("fc1"))?;
94 let fc2 = conv2d(squeeze_channels, in_channels, 1, conv2d_cfg, vb.pp("fc2"))?;
95
96 Ok(Func::new(move |xs| {
97 let residual = xs;
98 let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
99 let xs = sigmoid(&xs.apply(&fc1)?.relu()?.apply(&fc2)?)?;
100
101 residual.broadcast_mul(&xs)
102 }))
103}
104
105fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
109 let (gamma, beta) = bn.weight_and_bias().unwrap();
110 let mu = bn.running_mean();
111 let sigma = (bn.running_var() + bn.eps())?.sqrt();
112 let gps = (gamma / sigma)?;
113 let bias = (beta - mu * &gps)?;
114 let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?;
115
116 Ok((weights, bias))
117}
118
119#[allow(clippy::too_many_arguments)]
124fn mobileone_block(
125 has_identity: bool,
126 k: usize,
127 dim: usize,
128 stride: usize,
129 padding: usize,
130 groups: usize,
131 kernel: usize,
132 in_channels: usize,
133 out_channels: usize,
134 vb: VarBuilder,
135) -> Result<Func<'static>> {
136 let conv2d_cfg = Conv2dConfig {
137 stride,
138 padding,
139 groups,
140 ..Default::default()
141 };
142
143 let mut w = Tensor::zeros(
144 (out_channels, in_channels / groups, kernel, kernel),
145 DType::F32,
146 vb.device(),
147 )?;
148 let mut b = Tensor::zeros(dim, DType::F32, vb.device())?;
149
150 for i in 0..k {
152 let conv_kxk_bn = batch_norm(dim, 1e-5, vb.pp(format!("conv_kxk.{i}.bn")))?;
153 let conv_kxk = conv2d_no_bias(
154 in_channels,
155 out_channels,
156 kernel,
157 conv2d_cfg,
158 vb.pp(format!("conv_kxk.{i}.conv")),
159 )?;
160 let (wk, bk) = fuse_conv_bn(conv_kxk.weight(), conv_kxk_bn)?;
161 w = (w + wk)?;
162 b = (b + bk)?;
163 }
164
165 if kernel > 1 {
166 let conv_scale_bn = batch_norm(dim, 1e-5, vb.pp("conv_scale.bn"))?;
167 let conv_scale = conv2d_no_bias(
168 in_channels,
169 out_channels,
170 1,
171 conv2d_cfg,
172 vb.pp("conv_scale.conv"),
173 )?;
174
175 let (mut ws, bs) = fuse_conv_bn(conv_scale.weight(), conv_scale_bn)?;
176 ws = ws.pad_with_zeros(D::Minus1, 1, 1)?;
178 ws = ws.pad_with_zeros(D::Minus2, 1, 1)?;
179
180 w = (w + ws)?;
181 b = (b + bs)?;
182 }
183
184 let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("attn"));
186
187 if has_identity {
189 let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"))?;
190
191 let mut weights: Vec<f32> = vec![0.0; w.elem_count()];
192
193 let id = in_channels / groups;
194 for i in 0..in_channels {
196 if kernel > 1 {
197 weights[i * kernel * kernel + 4] = 1.0;
198 } else {
199 weights[i * (id + 1)] = 1.0;
200 }
201 }
202
203 let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
204 let (wi, bi) = fuse_conv_bn(weights, identity_bn)?;
205
206 w = (w + wi)?;
207 b = (b + bi)?;
208 }
209
210 let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);
211
212 Ok(Func::new(move |xs| {
213 let mut xs = xs.apply(&reparam_conv)?;
214 if let Ok(f) = &se {
215 xs = xs.apply(f)?;
216 }
217 xs = xs.relu()?;
218 Ok(xs)
219 }))
220}
221
222fn output_channels_per_stage(cfg: &Config, stage: usize) -> usize {
224 let channels = STAGES[stage].channels as f32;
225 let alpha = cfg.alphas[stage];
226
227 match stage {
228 0 => std::cmp::min(64, (channels * alpha) as usize),
229 _ => (channels * alpha) as usize,
230 }
231}
232
233fn mobileone_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
236 let nblocks = STAGES[idx].blocks;
237 let mut blocks = Vec::with_capacity(nblocks);
238
239 let mut in_channels = output_channels_per_stage(cfg, idx - 1);
240
241 for block_idx in 0..nblocks {
242 let out_channels = output_channels_per_stage(cfg, idx);
243 let (has_identity, stride) = if block_idx == 0 {
244 (false, 2)
245 } else {
246 (true, 1)
247 };
248
249 blocks.push(mobileone_block(
251 has_identity,
252 cfg.k,
253 in_channels,
254 stride,
255 1,
256 in_channels,
257 3,
258 in_channels,
259 in_channels,
260 vb.pp(block_idx * 2),
261 )?);
262
263 blocks.push(mobileone_block(
265 has_identity,
266 cfg.k,
267 out_channels,
268 1, 0, 1, 1, in_channels,
273 out_channels,
274 vb.pp(block_idx * 2 + 1),
275 )?);
276
277 in_channels = out_channels;
278 }
279
280 Ok(Func::new(move |xs| {
281 let mut xs = xs.clone();
282 for block in blocks.iter() {
283 xs = xs.apply(block)?
284 }
285 Ok(xs)
286 }))
287}
288
289fn mobileone_model(
291 config: &Config,
292 nclasses: Option<usize>,
293 vb: VarBuilder,
294) -> Result<Func<'static>> {
295 let cls = match nclasses {
296 None => None,
297 Some(nclasses) => {
298 let outputs = output_channels_per_stage(config, 4);
299 let linear = linear(outputs, nclasses, vb.pp("head.fc"))?;
300 Some(linear)
301 }
302 };
303
304 let stem_dim = output_channels_per_stage(config, 0);
305 let stem = mobileone_block(false, 1, stem_dim, 2, 1, 1, 3, 3, stem_dim, vb.pp("stem"))?;
306 let vb = vb.pp("stages");
307 let stage1 = mobileone_stage(config, 1, vb.pp(0))?;
308 let stage2 = mobileone_stage(config, 2, vb.pp(1))?;
309 let stage3 = mobileone_stage(config, 3, vb.pp(2))?;
310 let stage4 = mobileone_stage(config, 4, vb.pp(3))?;
311
312 Ok(Func::new(move |xs| {
313 let xs = xs
314 .apply(&stem)?
315 .apply(&stage1)?
316 .apply(&stage2)?
317 .apply(&stage3)?
318 .apply(&stage4)?
319 .mean(D::Minus2)?
320 .mean(D::Minus1)?;
321 match &cls {
322 None => Ok(xs),
323 Some(cls) => xs.apply(cls),
324 }
325 }))
326}
327
328pub fn mobileone(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
329 mobileone_model(cfg, Some(nclasses), vb)
330}
331
332pub fn mobileone_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
333 mobileone_model(cfg, None, vb)
334}