1use candle::{Result, D};
14use candle_nn::{batch_norm, Conv2d, Func, VarBuilder};
15
16fn conv2d(
17 c_in: usize,
18 c_out: usize,
19 ksize: usize,
20 padding: usize,
21 stride: usize,
22 vb: VarBuilder,
23) -> Result<Conv2d> {
24 let conv2d_cfg = candle_nn::Conv2dConfig {
25 stride,
26 padding,
27 ..Default::default()
28 };
29 candle_nn::conv2d_no_bias(c_in, c_out, ksize, conv2d_cfg, vb)
30}
31
32fn downsample(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Result<Func> {
33 if stride != 1 || c_in != c_out {
34 let conv = conv2d(c_in, c_out, 1, 0, stride, vb.pp(0))?;
35 let bn = batch_norm(c_out, 1e-5, vb.pp(1))?;
36 Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false)))
37 } else {
38 Ok(Func::new(|xs| Ok(xs.clone())))
39 }
40}
41
42fn basic_block(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Result<Func> {
43 let conv1 = conv2d(c_in, c_out, 3, 1, stride, vb.pp("conv1"))?;
44 let bn1 = batch_norm(c_out, 1e-5, vb.pp("bn1"))?;
45 let conv2 = conv2d(c_out, c_out, 3, 1, 1, vb.pp("conv2"))?;
46 let bn2 = batch_norm(c_out, 1e-5, vb.pp("bn2"))?;
47 let downsample = downsample(c_in, c_out, stride, vb.pp("downsample"))?;
48 Ok(Func::new(move |xs| {
49 let ys = xs
50 .apply(&conv1)?
51 .apply_t(&bn1, false)?
52 .relu()?
53 .apply(&conv2)?
54 .apply_t(&bn2, false)?;
55 (xs.apply(&downsample)? + ys)?.relu()
56 }))
57}
58
59fn basic_layer(
60 c_in: usize,
61 c_out: usize,
62 stride: usize,
63 cnt: usize,
64 vb: VarBuilder,
65) -> Result<Func> {
66 let mut layers = Vec::with_capacity(cnt);
67 for index in 0..cnt {
68 let l_in = if index == 0 { c_in } else { c_out };
69 let stride = if index == 0 { stride } else { 1 };
70 layers.push(basic_block(l_in, c_out, stride, vb.pp(index))?)
71 }
72 Ok(Func::new(move |xs| {
73 let mut xs = xs.clone();
74 for layer in layers.iter() {
75 xs = xs.apply(layer)?
76 }
77 Ok(xs)
78 }))
79}
80
81fn resnet(
82 nclasses: Option<usize>,
83 c1: usize,
84 c2: usize,
85 c3: usize,
86 c4: usize,
87 vb: VarBuilder,
88) -> Result<Func> {
89 let conv1 = conv2d(3, 64, 7, 3, 2, vb.pp("conv1"))?;
90 let bn1 = batch_norm(64, 1e-5, vb.pp("bn1"))?;
91 let layer1 = basic_layer(64, 64, 1, c1, vb.pp("layer1"))?;
92 let layer2 = basic_layer(64, 128, 2, c2, vb.pp("layer2"))?;
93 let layer3 = basic_layer(128, 256, 2, c3, vb.pp("layer3"))?;
94 let layer4 = basic_layer(256, 512, 2, c4, vb.pp("layer4"))?;
95 let fc = match nclasses {
96 None => None,
97 Some(nclasses) => {
98 let linear = candle_nn::linear(512, nclasses, vb.pp("fc"))?;
99 Some(linear)
100 }
101 };
102 Ok(Func::new(move |xs| {
103 let xs = xs
104 .apply(&conv1)?
105 .apply_t(&bn1, false)?
106 .relu()?
107 .pad_with_same(D::Minus1, 1, 1)?
108 .pad_with_same(D::Minus2, 1, 1)?
109 .max_pool2d_with_stride(3, 2)?
110 .apply(&layer1)?
111 .apply(&layer2)?
112 .apply(&layer3)?
113 .apply(&layer4)?
114 .mean(D::Minus1)?
115 .mean(D::Minus1)?;
116 match &fc {
117 None => Ok(xs),
118 Some(fc) => xs.apply(fc),
119 }
120 }))
121}
122
123pub fn resnet18(num_classes: usize, vb: VarBuilder) -> Result<Func> {
125 resnet(Some(num_classes), 2, 2, 2, 2, vb)
126}
127
128pub fn resnet18_no_final_layer(vb: VarBuilder) -> Result<Func> {
129 resnet(None, 2, 2, 2, 2, vb)
130}
131
132pub fn resnet34(num_classes: usize, vb: VarBuilder) -> Result<Func> {
134 resnet(Some(num_classes), 3, 4, 6, 3, vb)
135}
136
137pub fn resnet34_no_final_layer(vb: VarBuilder) -> Result<Func> {
138 resnet(None, 3, 4, 6, 3, vb)
139}
140
141fn bottleneck_block(
143 c_in: usize,
144 c_out: usize,
145 stride: usize,
146 e: usize,
147 vb: VarBuilder,
148) -> Result<Func> {
149 let e_dim = e * c_out;
150 let conv1 = conv2d(c_in, c_out, 1, 0, 1, vb.pp("conv1"))?;
151 let bn1 = batch_norm(c_out, 1e-5, vb.pp("bn1"))?;
152 let conv2 = conv2d(c_out, c_out, 3, 1, stride, vb.pp("conv2"))?;
153 let bn2 = batch_norm(c_out, 1e-5, vb.pp("bn2"))?;
154 let conv3 = conv2d(c_out, e_dim, 1, 0, 1, vb.pp("conv3"))?;
155 let bn3 = batch_norm(e_dim, 1e-5, vb.pp("bn3"))?;
156 let downsample = downsample(c_in, e_dim, stride, vb.pp("downsample"))?;
157 Ok(Func::new(move |xs| {
158 let ys = xs
159 .apply(&conv1)?
160 .apply_t(&bn1, false)?
161 .relu()?
162 .apply(&conv2)?
163 .apply_t(&bn2, false)?
164 .relu()?
165 .apply(&conv3)?
166 .apply_t(&bn3, false)?;
167 (xs.apply(&downsample)? + ys)?.relu()
168 }))
169}
170
171fn bottleneck_layer(
172 c_in: usize,
173 c_out: usize,
174 stride: usize,
175 cnt: usize,
176 vb: VarBuilder,
177) -> Result<Func> {
178 let mut layers = Vec::with_capacity(cnt);
179 for index in 0..cnt {
180 let l_in = if index == 0 { c_in } else { 4 * c_out };
181 let stride = if index == 0 { stride } else { 1 };
182 layers.push(bottleneck_block(l_in, c_out, stride, 4, vb.pp(index))?)
183 }
184 Ok(Func::new(move |xs| {
185 let mut xs = xs.clone();
186 for layer in layers.iter() {
187 xs = xs.apply(layer)?
188 }
189 Ok(xs)
190 }))
191}
192
193fn bottleneck_resnet(
194 nclasses: Option<usize>,
195 c1: usize,
196 c2: usize,
197 c3: usize,
198 c4: usize,
199 vb: VarBuilder,
200) -> Result<Func> {
201 let conv1 = conv2d(3, 64, 7, 3, 2, vb.pp("conv1"))?;
202 let bn1 = batch_norm(64, 1e-5, vb.pp("bn1"))?;
203 let layer1 = bottleneck_layer(64, 64, 1, c1, vb.pp("layer1"))?;
204 let layer2 = bottleneck_layer(4 * 64, 128, 2, c2, vb.pp("layer2"))?;
205 let layer3 = bottleneck_layer(4 * 128, 256, 2, c3, vb.pp("layer3"))?;
206 let layer4 = bottleneck_layer(4 * 256, 512, 2, c4, vb.pp("layer4"))?;
207 let fc = match nclasses {
208 None => None,
209 Some(nclasses) => {
210 let linear = candle_nn::linear(4 * 512, nclasses, vb.pp("fc"))?;
211 Some(linear)
212 }
213 };
214 Ok(Func::new(move |xs| {
215 let xs = xs
216 .apply(&conv1)?
217 .apply_t(&bn1, false)?
218 .relu()?
219 .pad_with_same(D::Minus1, 1, 1)?
220 .pad_with_same(D::Minus2, 1, 1)?
221 .max_pool2d_with_stride(3, 2)?
222 .apply(&layer1)?
223 .apply(&layer2)?
224 .apply(&layer3)?
225 .apply(&layer4)?
226 .mean(D::Minus1)?
227 .mean(D::Minus1)?;
228 match &fc {
229 None => Ok(xs),
230 Some(fc) => xs.apply(fc),
231 }
232 }))
233}
234
235pub fn resnet50(num_classes: usize, vb: VarBuilder) -> Result<Func> {
236 bottleneck_resnet(Some(num_classes), 3, 4, 6, 3, vb)
237}
238
239pub fn resnet50_no_final_layer(vb: VarBuilder) -> Result<Func> {
240 bottleneck_resnet(None, 3, 4, 6, 3, vb)
241}
242
243pub fn resnet101(num_classes: usize, vb: VarBuilder) -> Result<Func> {
244 bottleneck_resnet(Some(num_classes), 3, 4, 23, 3, vb)
245}
246
247pub fn resnet101_no_final_layer(vb: VarBuilder) -> Result<Func> {
248 bottleneck_resnet(None, 3, 4, 23, 3, vb)
249}
250
251pub fn resnet152(num_classes: usize, vb: VarBuilder) -> Result<Func> {
252 bottleneck_resnet(Some(num_classes), 3, 8, 36, 3, vb)
253}
254
255pub fn resnet152_no_final_layer(vb: VarBuilder) -> Result<Func> {
256 bottleneck_resnet(None, 3, 8, 36, 3, vb)
257}