candle_transformers/models/
repvgg.rs

1//! RepVGG inference implementation
2//!
3//! Key characteristics:
4//! - Efficient inference architecture through structural reparameterization
5//! - Single 3x3 conv layer after fusing 3x3 branch, 1x1 branch and identity branch
6//! - Different configurations including a0-a2, b0-b3 and variants with group convolutions
7//! - High accuracy with VGG-like plain architecture and training
8//!
9//! References:
10//! - [RepVGG Paper](https://arxiv.org/abs/2101.03697). RepVGG: Making VGG-style ConvNets Great Again
11//! - [Official Implementation](https://github.com/DingXiaoH/RepVGG)
12//!
13
14use candle::{Result, Tensor, D};
15use candle_nn::{
16    batch_norm, conv2d_no_bias, linear, BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
17};
18
19const CHANNELS_PER_STAGE: [usize; 5] = [64, 64, 128, 256, 512];
20
21#[derive(Clone)]
22pub struct Config {
23    a: f32,
24    b: f32,
25    groups: usize,
26    stages: [usize; 4],
27}
28
29impl Config {
30    pub fn a0() -> Self {
31        Self {
32            a: 0.75,
33            b: 2.5,
34            groups: 1,
35            stages: [2, 4, 14, 1],
36        }
37    }
38
39    pub fn a1() -> Self {
40        Self {
41            a: 1.0,
42            b: 2.5,
43            groups: 1,
44            stages: [2, 4, 14, 1],
45        }
46    }
47
48    pub fn a2() -> Self {
49        Self {
50            a: 1.5,
51            b: 2.75,
52            groups: 1,
53            stages: [2, 4, 14, 1],
54        }
55    }
56
57    pub fn b0() -> Self {
58        Self {
59            a: 1.0,
60            b: 2.5,
61            groups: 1,
62            stages: [4, 6, 16, 1],
63        }
64    }
65
66    pub fn b1() -> Self {
67        Self {
68            a: 2.0,
69            b: 4.0,
70            groups: 1,
71            stages: [4, 6, 16, 1],
72        }
73    }
74
75    pub fn b2() -> Self {
76        Self {
77            a: 2.5,
78            b: 5.0,
79            groups: 1,
80            stages: [4, 6, 16, 1],
81        }
82    }
83
84    pub fn b3() -> Self {
85        Self {
86            a: 3.0,
87            b: 5.0,
88            groups: 1,
89            stages: [4, 6, 16, 1],
90        }
91    }
92
93    pub fn b1g4() -> Self {
94        Self {
95            a: 2.0,
96            b: 4.0,
97            groups: 4,
98            stages: [4, 6, 16, 1],
99        }
100    }
101
102    pub fn b2g4() -> Self {
103        Self {
104            a: 2.5,
105            b: 5.0,
106            groups: 4,
107            stages: [4, 6, 16, 1],
108        }
109    }
110
111    pub fn b3g4() -> Self {
112        Self {
113            a: 3.0,
114            b: 5.0,
115            groups: 4,
116            stages: [4, 6, 16, 1],
117        }
118    }
119}
120
121// fuses a convolutional kernel and a batchnorm layer into a convolutional layer
122// based on the _fuse_bn_tensor method in timm
123// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
124fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
125    let (gamma, beta) = bn.weight_and_bias().unwrap();
126    let mu = bn.running_mean();
127    let sigma = (bn.running_var() + bn.eps())?.sqrt();
128    let gps = (gamma / sigma)?;
129    let bias = (beta - mu * &gps)?;
130    let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?;
131
132    Ok((weights, bias))
133}
134
135// A RepVGG layer has a different training time and inference time architecture.
136// The latter is a simple and efficient equivalent transformation of the former
137// realized by a structural reparameterization technique, where 3x3 and 1x1 convolutions
138// along with identity branches and batchnorm layers are fused into a single 3x3 convolution.
139fn repvgg_layer(
140    has_identity: bool,
141    dim: usize,
142    stride: usize,
143    in_channels: usize,
144    out_channels: usize,
145    groups: usize,
146    vb: VarBuilder,
147) -> Result<Func<'static>> {
148    let conv2d_cfg = Conv2dConfig {
149        stride,
150        groups,
151        padding: 1,
152        ..Default::default()
153    };
154
155    // read and reparameterize the 1x1 conv and bn into w1 and b1
156    // based on https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L543
157
158    let conv1x1_bn = batch_norm(dim, 1e-5, vb.pp("conv_1x1.bn"))?;
159    let conv1x1 = conv2d_no_bias(
160        in_channels,
161        out_channels,
162        1,
163        conv2d_cfg,
164        vb.pp("conv_1x1.conv"),
165    )?;
166
167    let (mut w1, b1) = fuse_conv_bn(conv1x1.weight(), conv1x1_bn)?;
168
169    // resize to 3x3
170    w1 = w1.pad_with_zeros(D::Minus1, 1, 1)?;
171    w1 = w1.pad_with_zeros(D::Minus2, 1, 1)?;
172
173    // read and reparameterize the 3x3 conv and bn into w3 and b3
174    let convkxk_bn = batch_norm(dim, 1e-5, vb.pp("conv_kxk.bn"))?;
175    let conv3x3 = conv2d_no_bias(
176        in_channels,
177        out_channels,
178        3,
179        conv2d_cfg,
180        vb.pp("conv_kxk.conv"),
181    )?;
182
183    let (w3, b3) = fuse_conv_bn(conv3x3.weight(), convkxk_bn)?;
184
185    let mut w = (w1 + w3)?;
186    let mut b = (b1 + b3)?;
187
188    // read and reparameterize the identity bn into wi and bi
189    if has_identity {
190        let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"))?;
191
192        // create a 3x3 convolution equivalent to the identity branch
193        let mut weights: Vec<f32> = vec![0.0; conv3x3.weight().elem_count()];
194
195        // https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L620
196        let in_dim = in_channels / groups;
197        for i in 0..in_channels {
198            weights[i * in_dim * 3 * 3 + (i % in_dim) * 3 * 3 + 4] = 1.0;
199        }
200
201        let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
202        let (wi, bi) = fuse_conv_bn(weights, identity_bn)?;
203
204        w = (w + wi)?;
205        b = (b + bi)?;
206    }
207
208    // create the 3x3 conv equivalent to the sum of 3x3, 1x1 and identity branches
209    let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);
210
211    Ok(Func::new(move |xs| {
212        let xs = xs.apply(&reparam_conv)?.relu()?;
213        Ok(xs)
214    }))
215}
216
217// Get the number of output channels per stage taking into account the multipliers
218fn output_channels_per_stage(a: f32, b: f32, stage: usize) -> usize {
219    let channels = CHANNELS_PER_STAGE[stage] as f32;
220
221    match stage {
222        0 => std::cmp::min(64, (channels * a) as usize),
223        4 => (channels * b) as usize,
224        _ => (channels * a) as usize,
225    }
226}
227
228// Each stage is made of layers. The first layer always downsamples with stride 2.
229// All but the first layer have a residual connection.
230// The G4 variants have a groupwise convolution instead of a dense one on odd layers
231// counted across stage boundaries, so we keep track of which layer we are in the
232// full model.
233fn repvgg_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
234    let nlayers = cfg.stages[idx - 1];
235    let mut layers = Vec::with_capacity(nlayers);
236    let prev_layers: usize = cfg.stages[..idx - 1].iter().sum();
237    let out_channels_prev = output_channels_per_stage(cfg.a, cfg.b, idx - 1);
238    let out_channels = output_channels_per_stage(cfg.a, cfg.b, idx);
239
240    for layer_idx in 0..nlayers {
241        let (has_identity, stride, in_channels) = if layer_idx == 0 {
242            (false, 2, out_channels_prev)
243        } else {
244            (true, 1, out_channels)
245        };
246
247        let groups = if (prev_layers + layer_idx) % 2 == 1 {
248            cfg.groups
249        } else {
250            1
251        };
252
253        layers.push(repvgg_layer(
254            has_identity,
255            out_channels,
256            stride,
257            in_channels,
258            out_channels,
259            groups,
260            vb.pp(layer_idx),
261        )?)
262    }
263
264    Ok(Func::new(move |xs| {
265        let mut xs = xs.clone();
266        for layer in layers.iter() {
267            xs = xs.apply(layer)?
268        }
269        Ok(xs)
270    }))
271}
272
273// Build a RepVGG model for a given configuration.
274fn repvgg_model(config: &Config, nclasses: Option<usize>, vb: VarBuilder) -> Result<Func<'static>> {
275    let cls = match nclasses {
276        None => None,
277        Some(nclasses) => {
278            let outputs = output_channels_per_stage(config.a, config.b, 4);
279            let linear = linear(outputs, nclasses, vb.pp("head.fc"))?;
280            Some(linear)
281        }
282    };
283
284    let stem_dim = output_channels_per_stage(config.a, config.b, 0);
285    let stem = repvgg_layer(false, stem_dim, 2, 3, stem_dim, 1, vb.pp("stem"))?;
286    let vb = vb.pp("stages");
287    let stage1 = repvgg_stage(config, 1, vb.pp(0))?;
288    let stage2 = repvgg_stage(config, 2, vb.pp(1))?;
289    let stage3 = repvgg_stage(config, 3, vb.pp(2))?;
290    let stage4 = repvgg_stage(config, 4, vb.pp(3))?;
291
292    Ok(Func::new(move |xs| {
293        let xs = xs
294            .apply(&stem)?
295            .apply(&stage1)?
296            .apply(&stage2)?
297            .apply(&stage3)?
298            .apply(&stage4)?
299            .mean(D::Minus1)?
300            .mean(D::Minus1)?;
301        match &cls {
302            None => Ok(xs),
303            Some(cls) => xs.apply(cls),
304        }
305    }))
306}
307
308pub fn repvgg(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
309    repvgg_model(cfg, Some(nclasses), vb)
310}
311
312pub fn repvgg_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
313    repvgg_model(cfg, None, vb)
314}