1use candle::{Result, Tensor, D};
15use candle_nn::{
16 batch_norm, conv2d_no_bias, linear, BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
17};
18
19const CHANNELS_PER_STAGE: [usize; 5] = [64, 64, 128, 256, 512];
20
21#[derive(Clone)]
22pub struct Config {
23 a: f32,
24 b: f32,
25 groups: usize,
26 stages: [usize; 4],
27}
28
29impl Config {
30 pub fn a0() -> Self {
31 Self {
32 a: 0.75,
33 b: 2.5,
34 groups: 1,
35 stages: [2, 4, 14, 1],
36 }
37 }
38
39 pub fn a1() -> Self {
40 Self {
41 a: 1.0,
42 b: 2.5,
43 groups: 1,
44 stages: [2, 4, 14, 1],
45 }
46 }
47
48 pub fn a2() -> Self {
49 Self {
50 a: 1.5,
51 b: 2.75,
52 groups: 1,
53 stages: [2, 4, 14, 1],
54 }
55 }
56
57 pub fn b0() -> Self {
58 Self {
59 a: 1.0,
60 b: 2.5,
61 groups: 1,
62 stages: [4, 6, 16, 1],
63 }
64 }
65
66 pub fn b1() -> Self {
67 Self {
68 a: 2.0,
69 b: 4.0,
70 groups: 1,
71 stages: [4, 6, 16, 1],
72 }
73 }
74
75 pub fn b2() -> Self {
76 Self {
77 a: 2.5,
78 b: 5.0,
79 groups: 1,
80 stages: [4, 6, 16, 1],
81 }
82 }
83
84 pub fn b3() -> Self {
85 Self {
86 a: 3.0,
87 b: 5.0,
88 groups: 1,
89 stages: [4, 6, 16, 1],
90 }
91 }
92
93 pub fn b1g4() -> Self {
94 Self {
95 a: 2.0,
96 b: 4.0,
97 groups: 4,
98 stages: [4, 6, 16, 1],
99 }
100 }
101
102 pub fn b2g4() -> Self {
103 Self {
104 a: 2.5,
105 b: 5.0,
106 groups: 4,
107 stages: [4, 6, 16, 1],
108 }
109 }
110
111 pub fn b3g4() -> Self {
112 Self {
113 a: 3.0,
114 b: 5.0,
115 groups: 4,
116 stages: [4, 6, 16, 1],
117 }
118 }
119}
120
121fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
125 let (gamma, beta) = bn.weight_and_bias().unwrap();
126 let mu = bn.running_mean();
127 let sigma = (bn.running_var() + bn.eps())?.sqrt();
128 let gps = (gamma / sigma)?;
129 let bias = (beta - mu * &gps)?;
130 let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?;
131
132 Ok((weights, bias))
133}
134
135fn repvgg_layer(
140 has_identity: bool,
141 dim: usize,
142 stride: usize,
143 in_channels: usize,
144 out_channels: usize,
145 groups: usize,
146 vb: VarBuilder,
147) -> Result<Func<'static>> {
148 let conv2d_cfg = Conv2dConfig {
149 stride,
150 groups,
151 padding: 1,
152 ..Default::default()
153 };
154
155 let conv1x1_bn = batch_norm(dim, 1e-5, vb.pp("conv_1x1.bn"))?;
159 let conv1x1 = conv2d_no_bias(
160 in_channels,
161 out_channels,
162 1,
163 conv2d_cfg,
164 vb.pp("conv_1x1.conv"),
165 )?;
166
167 let (mut w1, b1) = fuse_conv_bn(conv1x1.weight(), conv1x1_bn)?;
168
169 w1 = w1.pad_with_zeros(D::Minus1, 1, 1)?;
171 w1 = w1.pad_with_zeros(D::Minus2, 1, 1)?;
172
173 let convkxk_bn = batch_norm(dim, 1e-5, vb.pp("conv_kxk.bn"))?;
175 let conv3x3 = conv2d_no_bias(
176 in_channels,
177 out_channels,
178 3,
179 conv2d_cfg,
180 vb.pp("conv_kxk.conv"),
181 )?;
182
183 let (w3, b3) = fuse_conv_bn(conv3x3.weight(), convkxk_bn)?;
184
185 let mut w = (w1 + w3)?;
186 let mut b = (b1 + b3)?;
187
188 if has_identity {
190 let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"))?;
191
192 let mut weights: Vec<f32> = vec![0.0; conv3x3.weight().elem_count()];
194
195 let in_dim = in_channels / groups;
197 for i in 0..in_channels {
198 weights[i * in_dim * 3 * 3 + (i % in_dim) * 3 * 3 + 4] = 1.0;
199 }
200
201 let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
202 let (wi, bi) = fuse_conv_bn(weights, identity_bn)?;
203
204 w = (w + wi)?;
205 b = (b + bi)?;
206 }
207
208 let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);
210
211 Ok(Func::new(move |xs| {
212 let xs = xs.apply(&reparam_conv)?.relu()?;
213 Ok(xs)
214 }))
215}
216
217fn output_channels_per_stage(a: f32, b: f32, stage: usize) -> usize {
219 let channels = CHANNELS_PER_STAGE[stage] as f32;
220
221 match stage {
222 0 => std::cmp::min(64, (channels * a) as usize),
223 4 => (channels * b) as usize,
224 _ => (channels * a) as usize,
225 }
226}
227
228fn repvgg_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
234 let nlayers = cfg.stages[idx - 1];
235 let mut layers = Vec::with_capacity(nlayers);
236 let prev_layers: usize = cfg.stages[..idx - 1].iter().sum();
237 let out_channels_prev = output_channels_per_stage(cfg.a, cfg.b, idx - 1);
238 let out_channels = output_channels_per_stage(cfg.a, cfg.b, idx);
239
240 for layer_idx in 0..nlayers {
241 let (has_identity, stride, in_channels) = if layer_idx == 0 {
242 (false, 2, out_channels_prev)
243 } else {
244 (true, 1, out_channels)
245 };
246
247 let groups = if (prev_layers + layer_idx) % 2 == 1 {
248 cfg.groups
249 } else {
250 1
251 };
252
253 layers.push(repvgg_layer(
254 has_identity,
255 out_channels,
256 stride,
257 in_channels,
258 out_channels,
259 groups,
260 vb.pp(layer_idx),
261 )?)
262 }
263
264 Ok(Func::new(move |xs| {
265 let mut xs = xs.clone();
266 for layer in layers.iter() {
267 xs = xs.apply(layer)?
268 }
269 Ok(xs)
270 }))
271}
272
273fn repvgg_model(config: &Config, nclasses: Option<usize>, vb: VarBuilder) -> Result<Func<'static>> {
275 let cls = match nclasses {
276 None => None,
277 Some(nclasses) => {
278 let outputs = output_channels_per_stage(config.a, config.b, 4);
279 let linear = linear(outputs, nclasses, vb.pp("head.fc"))?;
280 Some(linear)
281 }
282 };
283
284 let stem_dim = output_channels_per_stage(config.a, config.b, 0);
285 let stem = repvgg_layer(false, stem_dim, 2, 3, stem_dim, 1, vb.pp("stem"))?;
286 let vb = vb.pp("stages");
287 let stage1 = repvgg_stage(config, 1, vb.pp(0))?;
288 let stage2 = repvgg_stage(config, 2, vb.pp(1))?;
289 let stage3 = repvgg_stage(config, 3, vb.pp(2))?;
290 let stage4 = repvgg_stage(config, 4, vb.pp(3))?;
291
292 Ok(Func::new(move |xs| {
293 let xs = xs
294 .apply(&stem)?
295 .apply(&stage1)?
296 .apply(&stage2)?
297 .apply(&stage3)?
298 .apply(&stage4)?
299 .mean(D::Minus1)?
300 .mean(D::Minus1)?;
301 match &cls {
302 None => Ok(xs),
303 Some(cls) => xs.apply(cls),
304 }
305 }))
306}
307
308pub fn repvgg(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
309 repvgg_model(cfg, Some(nclasses), vb)
310}
311
312pub fn repvgg_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
313 repvgg_model(cfg, None, vb)
314}