candle_transformers/models/
resnet.rs

1//! # ResNet Implementation
2//!
3//! Implementation of ResNet architectures as described in the paper:
4//!
5//! ## Reference
6//!
7//! [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)
8//! He et al. (2015)
9//!
10//! This paper introduced ResNet, a deep neural network architecture that utilizes
11//! skip connections ("residual connections") to enable training of very deep networks.
12
13use candle::{Result, D};
14use candle_nn::{batch_norm, Conv2d, Func, VarBuilder};
15
16fn conv2d(
17    c_in: usize,
18    c_out: usize,
19    ksize: usize,
20    padding: usize,
21    stride: usize,
22    vb: VarBuilder,
23) -> Result<Conv2d> {
24    let conv2d_cfg = candle_nn::Conv2dConfig {
25        stride,
26        padding,
27        ..Default::default()
28    };
29    candle_nn::conv2d_no_bias(c_in, c_out, ksize, conv2d_cfg, vb)
30}
31
32fn downsample(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Result<Func> {
33    if stride != 1 || c_in != c_out {
34        let conv = conv2d(c_in, c_out, 1, 0, stride, vb.pp(0))?;
35        let bn = batch_norm(c_out, 1e-5, vb.pp(1))?;
36        Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false)))
37    } else {
38        Ok(Func::new(|xs| Ok(xs.clone())))
39    }
40}
41
42fn basic_block(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Result<Func> {
43    let conv1 = conv2d(c_in, c_out, 3, 1, stride, vb.pp("conv1"))?;
44    let bn1 = batch_norm(c_out, 1e-5, vb.pp("bn1"))?;
45    let conv2 = conv2d(c_out, c_out, 3, 1, 1, vb.pp("conv2"))?;
46    let bn2 = batch_norm(c_out, 1e-5, vb.pp("bn2"))?;
47    let downsample = downsample(c_in, c_out, stride, vb.pp("downsample"))?;
48    Ok(Func::new(move |xs| {
49        let ys = xs
50            .apply(&conv1)?
51            .apply_t(&bn1, false)?
52            .relu()?
53            .apply(&conv2)?
54            .apply_t(&bn2, false)?;
55        (xs.apply(&downsample)? + ys)?.relu()
56    }))
57}
58
59fn basic_layer(
60    c_in: usize,
61    c_out: usize,
62    stride: usize,
63    cnt: usize,
64    vb: VarBuilder,
65) -> Result<Func> {
66    let mut layers = Vec::with_capacity(cnt);
67    for index in 0..cnt {
68        let l_in = if index == 0 { c_in } else { c_out };
69        let stride = if index == 0 { stride } else { 1 };
70        layers.push(basic_block(l_in, c_out, stride, vb.pp(index))?)
71    }
72    Ok(Func::new(move |xs| {
73        let mut xs = xs.clone();
74        for layer in layers.iter() {
75            xs = xs.apply(layer)?
76        }
77        Ok(xs)
78    }))
79}
80
81fn resnet(
82    nclasses: Option<usize>,
83    c1: usize,
84    c2: usize,
85    c3: usize,
86    c4: usize,
87    vb: VarBuilder,
88) -> Result<Func> {
89    let conv1 = conv2d(3, 64, 7, 3, 2, vb.pp("conv1"))?;
90    let bn1 = batch_norm(64, 1e-5, vb.pp("bn1"))?;
91    let layer1 = basic_layer(64, 64, 1, c1, vb.pp("layer1"))?;
92    let layer2 = basic_layer(64, 128, 2, c2, vb.pp("layer2"))?;
93    let layer3 = basic_layer(128, 256, 2, c3, vb.pp("layer3"))?;
94    let layer4 = basic_layer(256, 512, 2, c4, vb.pp("layer4"))?;
95    let fc = match nclasses {
96        None => None,
97        Some(nclasses) => {
98            let linear = candle_nn::linear(512, nclasses, vb.pp("fc"))?;
99            Some(linear)
100        }
101    };
102    Ok(Func::new(move |xs| {
103        let xs = xs
104            .apply(&conv1)?
105            .apply_t(&bn1, false)?
106            .relu()?
107            .pad_with_same(D::Minus1, 1, 1)?
108            .pad_with_same(D::Minus2, 1, 1)?
109            .max_pool2d_with_stride(3, 2)?
110            .apply(&layer1)?
111            .apply(&layer2)?
112            .apply(&layer3)?
113            .apply(&layer4)?
114            .mean(D::Minus1)?
115            .mean(D::Minus1)?;
116        match &fc {
117            None => Ok(xs),
118            Some(fc) => xs.apply(fc),
119        }
120    }))
121}
122
123/// Creates a ResNet-18 model.
124pub fn resnet18(num_classes: usize, vb: VarBuilder) -> Result<Func> {
125    resnet(Some(num_classes), 2, 2, 2, 2, vb)
126}
127
128pub fn resnet18_no_final_layer(vb: VarBuilder) -> Result<Func> {
129    resnet(None, 2, 2, 2, 2, vb)
130}
131
132/// Creates a ResNet-34 model.
133pub fn resnet34(num_classes: usize, vb: VarBuilder) -> Result<Func> {
134    resnet(Some(num_classes), 3, 4, 6, 3, vb)
135}
136
137pub fn resnet34_no_final_layer(vb: VarBuilder) -> Result<Func> {
138    resnet(None, 3, 4, 6, 3, vb)
139}
140
141// Bottleneck versions for ResNet 50, 101, and 152.
142fn bottleneck_block(
143    c_in: usize,
144    c_out: usize,
145    stride: usize,
146    e: usize,
147    vb: VarBuilder,
148) -> Result<Func> {
149    let e_dim = e * c_out;
150    let conv1 = conv2d(c_in, c_out, 1, 0, 1, vb.pp("conv1"))?;
151    let bn1 = batch_norm(c_out, 1e-5, vb.pp("bn1"))?;
152    let conv2 = conv2d(c_out, c_out, 3, 1, stride, vb.pp("conv2"))?;
153    let bn2 = batch_norm(c_out, 1e-5, vb.pp("bn2"))?;
154    let conv3 = conv2d(c_out, e_dim, 1, 0, 1, vb.pp("conv3"))?;
155    let bn3 = batch_norm(e_dim, 1e-5, vb.pp("bn3"))?;
156    let downsample = downsample(c_in, e_dim, stride, vb.pp("downsample"))?;
157    Ok(Func::new(move |xs| {
158        let ys = xs
159            .apply(&conv1)?
160            .apply_t(&bn1, false)?
161            .relu()?
162            .apply(&conv2)?
163            .apply_t(&bn2, false)?
164            .relu()?
165            .apply(&conv3)?
166            .apply_t(&bn3, false)?;
167        (xs.apply(&downsample)? + ys)?.relu()
168    }))
169}
170
171fn bottleneck_layer(
172    c_in: usize,
173    c_out: usize,
174    stride: usize,
175    cnt: usize,
176    vb: VarBuilder,
177) -> Result<Func> {
178    let mut layers = Vec::with_capacity(cnt);
179    for index in 0..cnt {
180        let l_in = if index == 0 { c_in } else { 4 * c_out };
181        let stride = if index == 0 { stride } else { 1 };
182        layers.push(bottleneck_block(l_in, c_out, stride, 4, vb.pp(index))?)
183    }
184    Ok(Func::new(move |xs| {
185        let mut xs = xs.clone();
186        for layer in layers.iter() {
187            xs = xs.apply(layer)?
188        }
189        Ok(xs)
190    }))
191}
192
193fn bottleneck_resnet(
194    nclasses: Option<usize>,
195    c1: usize,
196    c2: usize,
197    c3: usize,
198    c4: usize,
199    vb: VarBuilder,
200) -> Result<Func> {
201    let conv1 = conv2d(3, 64, 7, 3, 2, vb.pp("conv1"))?;
202    let bn1 = batch_norm(64, 1e-5, vb.pp("bn1"))?;
203    let layer1 = bottleneck_layer(64, 64, 1, c1, vb.pp("layer1"))?;
204    let layer2 = bottleneck_layer(4 * 64, 128, 2, c2, vb.pp("layer2"))?;
205    let layer3 = bottleneck_layer(4 * 128, 256, 2, c3, vb.pp("layer3"))?;
206    let layer4 = bottleneck_layer(4 * 256, 512, 2, c4, vb.pp("layer4"))?;
207    let fc = match nclasses {
208        None => None,
209        Some(nclasses) => {
210            let linear = candle_nn::linear(4 * 512, nclasses, vb.pp("fc"))?;
211            Some(linear)
212        }
213    };
214    Ok(Func::new(move |xs| {
215        let xs = xs
216            .apply(&conv1)?
217            .apply_t(&bn1, false)?
218            .relu()?
219            .pad_with_same(D::Minus1, 1, 1)?
220            .pad_with_same(D::Minus2, 1, 1)?
221            .max_pool2d_with_stride(3, 2)?
222            .apply(&layer1)?
223            .apply(&layer2)?
224            .apply(&layer3)?
225            .apply(&layer4)?
226            .mean(D::Minus1)?
227            .mean(D::Minus1)?;
228        match &fc {
229            None => Ok(xs),
230            Some(fc) => xs.apply(fc),
231        }
232    }))
233}
234
235pub fn resnet50(num_classes: usize, vb: VarBuilder) -> Result<Func> {
236    bottleneck_resnet(Some(num_classes), 3, 4, 6, 3, vb)
237}
238
239pub fn resnet50_no_final_layer(vb: VarBuilder) -> Result<Func> {
240    bottleneck_resnet(None, 3, 4, 6, 3, vb)
241}
242
243pub fn resnet101(num_classes: usize, vb: VarBuilder) -> Result<Func> {
244    bottleneck_resnet(Some(num_classes), 3, 4, 23, 3, vb)
245}
246
247pub fn resnet101_no_final_layer(vb: VarBuilder) -> Result<Func> {
248    bottleneck_resnet(None, 3, 4, 23, 3, vb)
249}
250
251pub fn resnet152(num_classes: usize, vb: VarBuilder) -> Result<Func> {
252    bottleneck_resnet(Some(num_classes), 3, 8, 36, 3, vb)
253}
254
255pub fn resnet152_no_final_layer(vb: VarBuilder) -> Result<Func> {
256    bottleneck_resnet(None, 3, 8, 36, 3, vb)
257}