1use candle::shape::ShapeWithOneHole;
16use candle::{Result, D};
17use candle_nn::{conv2d, layer_norm, linear, Conv2dConfig, Func, VarBuilder};
18
19#[derive(Clone)]
20pub struct Config {
21 blocks: [usize; 4],
22 channels: [usize; 4],
23 use_conv_mlp: bool,
24}
25
26impl Config {
27 pub fn atto() -> Self {
28 Self {
29 blocks: [2, 2, 6, 2],
30 channels: [40, 80, 160, 320],
31 use_conv_mlp: true,
32 }
33 }
34
35 pub fn femto() -> Self {
36 Self {
37 blocks: [2, 2, 6, 2],
38 channels: [48, 96, 192, 384],
39 use_conv_mlp: true,
40 }
41 }
42
43 pub fn pico() -> Self {
44 Self {
45 blocks: [2, 2, 6, 2],
46 channels: [64, 128, 256, 512],
47 use_conv_mlp: true,
48 }
49 }
50
51 pub fn nano() -> Self {
52 Self {
53 blocks: [2, 2, 8, 2],
54 channels: [80, 160, 320, 640],
55 use_conv_mlp: true,
56 }
57 }
58
59 pub fn tiny() -> Self {
60 Self {
61 blocks: [3, 3, 9, 3],
62 channels: [96, 192, 384, 768],
63 use_conv_mlp: false,
64 }
65 }
66
67 pub fn small() -> Self {
68 Self {
69 blocks: [3, 3, 27, 3],
70 channels: [96, 192, 384, 768],
71 use_conv_mlp: false,
72 }
73 }
74
75 pub fn base() -> Self {
76 Self {
77 blocks: [3, 3, 27, 3],
78 channels: [128, 256, 512, 1024],
79 use_conv_mlp: false,
80 }
81 }
82
83 pub fn large() -> Self {
84 Self {
85 blocks: [3, 3, 27, 3],
86 channels: [192, 384, 768, 1536],
87 use_conv_mlp: false,
88 }
89 }
90
91 pub fn xlarge() -> Self {
92 Self {
93 blocks: [3, 3, 27, 3],
94 channels: [256, 512, 1024, 2048],
95 use_conv_mlp: false,
96 }
97 }
98
99 pub fn huge() -> Self {
100 Self {
101 blocks: [3, 3, 27, 3],
102 channels: [352, 704, 1408, 2816],
103 use_conv_mlp: false,
104 }
105 }
106}
107
108fn layer_norm_cl(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
110 let norm = layer_norm(dim, 1e-6, vb)?;
111
112 Ok(Func::new(move |xs| xs.apply(&norm)))
113}
114
115fn layer_norm_cf(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
117 let norm = layer_norm(dim, 1e-6, vb)?;
118
119 Ok(Func::new(move |xs| {
120 let xs = xs
121 .permute((0, 2, 3, 1))?
122 .apply(&norm)?
123 .permute((0, 3, 1, 2))?;
124 Ok(xs)
125 }))
126}
127
128fn convnext2_grn(dim: usize, channels_last: bool, vb: VarBuilder) -> Result<Func<'static>> {
131 let (shape, spatial_dim, channel_dim) = if channels_last {
132 ((1, 1, 1, ()).into_shape(dim)?, [1, 2], 3)
133 } else {
134 ((1, (), 1, 1).into_shape(dim)?, [2, 3], 1)
135 };
136
137 let gamma = vb.get(dim, "weight")?.reshape(&shape)?;
138 let beta = vb.get(dim, "bias")?.reshape(&shape)?;
139
140 Ok(Func::new(move |xs| {
141 let residual = xs;
142 let gx = xs
143 .sqr()?
144 .sum_keepdim(spatial_dim)?
145 .mean_keepdim(spatial_dim)?
146 .sqrt()?;
147
148 let gxmean = gx.mean_keepdim(channel_dim)?;
149 let nx = gx.broadcast_div(&(gxmean + 1e-6)?)?;
150 let xs = xs
151 .broadcast_mul(&nx)?
152 .broadcast_mul(&gamma)?
153 .broadcast_add(&beta)?;
154
155 xs + residual
156 }))
157}
158
159fn convnext_stem(out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
161 let conv2d_cfg = Conv2dConfig {
162 stride: 4,
163 ..Default::default()
164 };
165 let patchify = conv2d(3, out_channels, 4, conv2d_cfg, vb.pp(0))?;
166 let norm = layer_norm_cf(out_channels, vb.pp(1))?;
167
168 Ok(Func::new(move |xs| xs.apply(&patchify)?.apply(&norm)))
169}
170
171fn convnext_downsample(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
173 let conv2d_cfg = Conv2dConfig {
174 stride: 2,
175 ..Default::default()
176 };
177 let norm = layer_norm_cf(dim / 2, vb.pp(0))?;
178 let conv = conv2d(dim / 2, dim, 2, conv2d_cfg, vb.pp(1))?;
179
180 Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&conv)))
181}
182
183fn convnext_mlp(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
185 let fc1 = linear(dim, 4 * dim, vb.pp("fc1"))?;
186 let fc2 = linear(4 * dim, dim, vb.pp("fc2"))?;
187 let grn = convnext2_grn(4 * dim, true, vb.pp("grn"));
188
189 Ok(Func::new(move |xs| {
190 let mut xs = xs.apply(&fc1)?.gelu_erf()?;
191 if let Ok(g) = &grn {
192 xs = xs.apply(g)?;
193 }
194 xs = xs.apply(&fc2)?;
195 Ok(xs)
196 }))
197}
198
199fn convnext_conv_mlp(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
201 let conv2d_cfg = Conv2dConfig {
202 ..Default::default()
203 };
204 let fc1 = conv2d(dim, 4 * dim, 1, conv2d_cfg, vb.pp("fc1"))?;
205 let fc2 = conv2d(4 * dim, dim, 1, conv2d_cfg, vb.pp("fc2"))?;
206
207 let grn = convnext2_grn(4 * dim, false, vb.pp("grn"));
208 Ok(Func::new(move |xs| {
209 let mut xs = xs.apply(&fc1)?.gelu_erf()?;
210 if let Ok(g) = &grn {
211 xs = xs.apply(g)?;
212 }
213 xs = xs.apply(&fc2)?;
214 Ok(xs)
215 }))
216}
217
218fn convnext_block(dim: usize, use_conv_mlp: bool, vb: VarBuilder) -> Result<Func<'static>> {
220 let conv2d_cfg = Conv2dConfig {
221 groups: dim,
222 padding: 3,
223 ..Default::default()
224 };
225
226 let conv_dw = conv2d(dim, dim, 7, conv2d_cfg, vb.pp("conv_dw"))?;
227 let gamma = vb.get(dim, "gamma");
228
229 let (mlp, norm) = if use_conv_mlp {
230 (
231 convnext_conv_mlp(dim, vb.pp("mlp"))?,
232 layer_norm_cf(dim, vb.pp("norm"))?,
233 )
234 } else {
235 (
236 convnext_mlp(dim, vb.pp("mlp"))?,
237 layer_norm_cl(dim, vb.pp("norm"))?,
238 )
239 };
240
241 Ok(Func::new(move |xs| {
242 let residual = xs;
243 let mut xs = xs.apply(&conv_dw)?;
244
245 xs = if use_conv_mlp {
246 xs.apply(&norm)?.apply(&mlp)?
247 } else {
248 xs.permute((0, 2, 3, 1))?
249 .apply(&norm)?
250 .apply(&mlp)?
251 .permute((0, 3, 1, 2))?
252 };
253
254 if let Ok(g) = &gamma {
255 xs = xs.broadcast_mul(&g.reshape((1, (), 1, 1))?)?;
256 };
257
258 xs + residual
259 }))
260}
261
262fn convnext_stage(cfg: &Config, stage_idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
264 let nblocks = cfg.blocks[stage_idx];
265 let mut blocks = Vec::with_capacity(nblocks);
266
267 let dim = cfg.channels[stage_idx];
268
269 if stage_idx > 0 {
270 blocks.push(convnext_downsample(dim, vb.pp("downsample"))?);
271 }
272
273 for block_idx in 0..nblocks {
274 blocks.push(convnext_block(
275 dim,
276 cfg.use_conv_mlp,
277 vb.pp(format!("blocks.{block_idx}")),
278 )?);
279 }
280
281 Ok(Func::new(move |xs| {
282 let mut xs = xs.clone();
283 for block in blocks.iter() {
284 xs = xs.apply(block)?
285 }
286 Ok(xs)
287 }))
288}
289
290fn convnext_head(outputs: usize, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
292 let norm = layer_norm_cl(outputs, vb.pp("norm"))?;
293 let linear = linear(outputs, nclasses, vb.pp("fc"))?;
294 Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&linear)))
295}
296
297fn convnext_model(
299 config: &Config,
300 nclasses: Option<usize>,
301 vb: VarBuilder,
302) -> Result<Func<'static>> {
303 let head = match nclasses {
304 None => None,
305 Some(nclasses) => {
306 let head = convnext_head(config.channels[3], nclasses, vb.pp("head"))?;
307 Some(head)
308 }
309 };
310
311 let stem = convnext_stem(config.channels[0], vb.pp("stem"))?;
312 let vb = vb.pp("stages");
313 let stage1 = convnext_stage(config, 0, vb.pp(0))?;
314 let stage2 = convnext_stage(config, 1, vb.pp(1))?;
315 let stage3 = convnext_stage(config, 2, vb.pp(2))?;
316 let stage4 = convnext_stage(config, 3, vb.pp(3))?;
317
318 Ok(Func::new(move |xs| {
319 let xs = xs
320 .apply(&stem)?
321 .apply(&stage1)?
322 .apply(&stage2)?
323 .apply(&stage3)?
324 .apply(&stage4)?
325 .mean(D::Minus2)?
326 .mean(D::Minus1)?;
327 match &head {
328 None => Ok(xs),
329 Some(head) => xs.apply(head),
330 }
331 }))
332}
333
334pub fn convnext(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
335 convnext_model(cfg, Some(nclasses), vb)
336}
337
338pub fn convnext_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
339 convnext_model(cfg, None, vb)
340}