candle_transformers/models/
convnext.rs

1//! ConvNeXt implementation.
2//!
3//! This candle implementation uses a pre-trained ConvNeXt network for inference. The
4//! classification head has been trained on the ImageNet dataset and returns the
5//! probabilities for the top-5 classes.
6//!
7//! Original code:
8//! - 💻 [ConvNeXt](https://github.com/facebookresearch/ConvNeXt/)
9//! - 💻 [ConvNeXt-V2](https://github.com/facebookresearch/ConvNeXt-V2/)
10//! - 💻 [timm](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py)
11//! - 📝 [Paper](https://arxiv.org/abs/2201.03545) A ConvNet for the 2020s
12//! - 📝 [Paper](https://arxiv.org/abs/2301.00808) ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders
13//!
14
15use candle::shape::ShapeWithOneHole;
16use candle::{Result, D};
17use candle_nn::{conv2d, layer_norm, linear, Conv2dConfig, Func, VarBuilder};
18
19#[derive(Clone)]
20pub struct Config {
21    blocks: [usize; 4],
22    channels: [usize; 4],
23    use_conv_mlp: bool,
24}
25
26impl Config {
27    pub fn atto() -> Self {
28        Self {
29            blocks: [2, 2, 6, 2],
30            channels: [40, 80, 160, 320],
31            use_conv_mlp: true,
32        }
33    }
34
35    pub fn femto() -> Self {
36        Self {
37            blocks: [2, 2, 6, 2],
38            channels: [48, 96, 192, 384],
39            use_conv_mlp: true,
40        }
41    }
42
43    pub fn pico() -> Self {
44        Self {
45            blocks: [2, 2, 6, 2],
46            channels: [64, 128, 256, 512],
47            use_conv_mlp: true,
48        }
49    }
50
51    pub fn nano() -> Self {
52        Self {
53            blocks: [2, 2, 8, 2],
54            channels: [80, 160, 320, 640],
55            use_conv_mlp: true,
56        }
57    }
58
59    pub fn tiny() -> Self {
60        Self {
61            blocks: [3, 3, 9, 3],
62            channels: [96, 192, 384, 768],
63            use_conv_mlp: false,
64        }
65    }
66
67    pub fn small() -> Self {
68        Self {
69            blocks: [3, 3, 27, 3],
70            channels: [96, 192, 384, 768],
71            use_conv_mlp: false,
72        }
73    }
74
75    pub fn base() -> Self {
76        Self {
77            blocks: [3, 3, 27, 3],
78            channels: [128, 256, 512, 1024],
79            use_conv_mlp: false,
80        }
81    }
82
83    pub fn large() -> Self {
84        Self {
85            blocks: [3, 3, 27, 3],
86            channels: [192, 384, 768, 1536],
87            use_conv_mlp: false,
88        }
89    }
90
91    pub fn xlarge() -> Self {
92        Self {
93            blocks: [3, 3, 27, 3],
94            channels: [256, 512, 1024, 2048],
95            use_conv_mlp: false,
96        }
97    }
98
99    pub fn huge() -> Self {
100        Self {
101            blocks: [3, 3, 27, 3],
102            channels: [352, 704, 1408, 2816],
103            use_conv_mlp: false,
104        }
105    }
106}
107
108// Layer norm for data in channels-last format.
109fn layer_norm_cl(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
110    let norm = layer_norm(dim, 1e-6, vb)?;
111
112    Ok(Func::new(move |xs| xs.apply(&norm)))
113}
114
115// Layer norm for data in channels-first format.
116fn layer_norm_cf(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
117    let norm = layer_norm(dim, 1e-6, vb)?;
118
119    Ok(Func::new(move |xs| {
120        let xs = xs
121            .permute((0, 2, 3, 1))?
122            .apply(&norm)?
123            .permute((0, 3, 1, 2))?;
124        Ok(xs)
125    }))
126}
127
128// Global response normalization layer
129// Based on https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/grn.py
130fn convnext2_grn(dim: usize, channels_last: bool, vb: VarBuilder) -> Result<Func<'static>> {
131    let (shape, spatial_dim, channel_dim) = if channels_last {
132        ((1, 1, 1, ()).into_shape(dim)?, [1, 2], 3)
133    } else {
134        ((1, (), 1, 1).into_shape(dim)?, [2, 3], 1)
135    };
136
137    let gamma = vb.get(dim, "weight")?.reshape(&shape)?;
138    let beta = vb.get(dim, "bias")?.reshape(&shape)?;
139
140    Ok(Func::new(move |xs| {
141        let residual = xs;
142        let gx = xs
143            .sqr()?
144            .sum_keepdim(spatial_dim)?
145            .mean_keepdim(spatial_dim)?
146            .sqrt()?;
147
148        let gxmean = gx.mean_keepdim(channel_dim)?;
149        let nx = gx.broadcast_div(&(gxmean + 1e-6)?)?;
150        let xs = xs
151            .broadcast_mul(&nx)?
152            .broadcast_mul(&gamma)?
153            .broadcast_add(&beta)?;
154
155        xs + residual
156    }))
157}
158
159// Initial downsampling via a patchify layer.
160fn convnext_stem(out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
161    let conv2d_cfg = Conv2dConfig {
162        stride: 4,
163        ..Default::default()
164    };
165    let patchify = conv2d(3, out_channels, 4, conv2d_cfg, vb.pp(0))?;
166    let norm = layer_norm_cf(out_channels, vb.pp(1))?;
167
168    Ok(Func::new(move |xs| xs.apply(&patchify)?.apply(&norm)))
169}
170
171// Downsampling applied after the stages.
172fn convnext_downsample(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
173    let conv2d_cfg = Conv2dConfig {
174        stride: 2,
175        ..Default::default()
176    };
177    let norm = layer_norm_cf(dim / 2, vb.pp(0))?;
178    let conv = conv2d(dim / 2, dim, 2, conv2d_cfg, vb.pp(1))?;
179
180    Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&conv)))
181}
182
183// MLP block from the original paper with optional GRN layer (v2 models).
184fn convnext_mlp(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
185    let fc1 = linear(dim, 4 * dim, vb.pp("fc1"))?;
186    let fc2 = linear(4 * dim, dim, vb.pp("fc2"))?;
187    let grn = convnext2_grn(4 * dim, true, vb.pp("grn"));
188
189    Ok(Func::new(move |xs| {
190        let mut xs = xs.apply(&fc1)?.gelu_erf()?;
191        if let Ok(g) = &grn {
192            xs = xs.apply(g)?;
193        }
194        xs = xs.apply(&fc2)?;
195        Ok(xs)
196    }))
197}
198
199// MLP block using pointwise convolutions, with optional GRN layer (v2 models).
200fn convnext_conv_mlp(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
201    let conv2d_cfg = Conv2dConfig {
202        ..Default::default()
203    };
204    let fc1 = conv2d(dim, 4 * dim, 1, conv2d_cfg, vb.pp("fc1"))?;
205    let fc2 = conv2d(4 * dim, dim, 1, conv2d_cfg, vb.pp("fc2"))?;
206
207    let grn = convnext2_grn(4 * dim, false, vb.pp("grn"));
208    Ok(Func::new(move |xs| {
209        let mut xs = xs.apply(&fc1)?.gelu_erf()?;
210        if let Ok(g) = &grn {
211            xs = xs.apply(g)?;
212        }
213        xs = xs.apply(&fc2)?;
214        Ok(xs)
215    }))
216}
217
218// A block consisting of a depthwise convolution, a MLP and layer scaling (v1 models only).
219fn convnext_block(dim: usize, use_conv_mlp: bool, vb: VarBuilder) -> Result<Func<'static>> {
220    let conv2d_cfg = Conv2dConfig {
221        groups: dim,
222        padding: 3,
223        ..Default::default()
224    };
225
226    let conv_dw = conv2d(dim, dim, 7, conv2d_cfg, vb.pp("conv_dw"))?;
227    let gamma = vb.get(dim, "gamma");
228
229    let (mlp, norm) = if use_conv_mlp {
230        (
231            convnext_conv_mlp(dim, vb.pp("mlp"))?,
232            layer_norm_cf(dim, vb.pp("norm"))?,
233        )
234    } else {
235        (
236            convnext_mlp(dim, vb.pp("mlp"))?,
237            layer_norm_cl(dim, vb.pp("norm"))?,
238        )
239    };
240
241    Ok(Func::new(move |xs| {
242        let residual = xs;
243        let mut xs = xs.apply(&conv_dw)?;
244
245        xs = if use_conv_mlp {
246            xs.apply(&norm)?.apply(&mlp)?
247        } else {
248            xs.permute((0, 2, 3, 1))?
249                .apply(&norm)?
250                .apply(&mlp)?
251                .permute((0, 3, 1, 2))?
252        };
253
254        if let Ok(g) = &gamma {
255            xs = xs.broadcast_mul(&g.reshape((1, (), 1, 1))?)?;
256        };
257
258        xs + residual
259    }))
260}
261
262// Each stage contains blocks and a downsampling layer for the previous stage.
263fn convnext_stage(cfg: &Config, stage_idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
264    let nblocks = cfg.blocks[stage_idx];
265    let mut blocks = Vec::with_capacity(nblocks);
266
267    let dim = cfg.channels[stage_idx];
268
269    if stage_idx > 0 {
270        blocks.push(convnext_downsample(dim, vb.pp("downsample"))?);
271    }
272
273    for block_idx in 0..nblocks {
274        blocks.push(convnext_block(
275            dim,
276            cfg.use_conv_mlp,
277            vb.pp(format!("blocks.{block_idx}")),
278        )?);
279    }
280
281    Ok(Func::new(move |xs| {
282        let mut xs = xs.clone();
283        for block in blocks.iter() {
284            xs = xs.apply(block)?
285        }
286        Ok(xs)
287    }))
288}
289
290// Classification head.
291fn convnext_head(outputs: usize, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
292    let norm = layer_norm_cl(outputs, vb.pp("norm"))?;
293    let linear = linear(outputs, nclasses, vb.pp("fc"))?;
294    Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&linear)))
295}
296
297// Build a convnext model for a given configuration.
298fn convnext_model(
299    config: &Config,
300    nclasses: Option<usize>,
301    vb: VarBuilder,
302) -> Result<Func<'static>> {
303    let head = match nclasses {
304        None => None,
305        Some(nclasses) => {
306            let head = convnext_head(config.channels[3], nclasses, vb.pp("head"))?;
307            Some(head)
308        }
309    };
310
311    let stem = convnext_stem(config.channels[0], vb.pp("stem"))?;
312    let vb = vb.pp("stages");
313    let stage1 = convnext_stage(config, 0, vb.pp(0))?;
314    let stage2 = convnext_stage(config, 1, vb.pp(1))?;
315    let stage3 = convnext_stage(config, 2, vb.pp(2))?;
316    let stage4 = convnext_stage(config, 3, vb.pp(3))?;
317
318    Ok(Func::new(move |xs| {
319        let xs = xs
320            .apply(&stem)?
321            .apply(&stage1)?
322            .apply(&stage2)?
323            .apply(&stage3)?
324            .apply(&stage4)?
325            .mean(D::Minus2)?
326            .mean(D::Minus1)?;
327        match &head {
328            None => Ok(xs),
329            Some(head) => xs.apply(head),
330        }
331    }))
332}
333
334pub fn convnext(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
335    convnext_model(cfg, Some(nclasses), vb)
336}
337
338pub fn convnext_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
339    convnext_model(cfg, None, vb)
340}