candle_transformers/models/
fastvit.rs

1//! # FastViT inference implementation based on timm
2//!
3//! ## Description
4//! See ["FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization"](https://arxiv.org/pdf/2303.14189)
5//!
6//! Implementation based on [timm model](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/fastvit.py)
7
8use candle::{Context, DType, Result, Tensor, D};
9use candle_nn::{
10    batch_norm, conv2d, conv2d_no_bias, linear, linear_no_bias, ops::sigmoid, ops::softmax,
11    BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
12};
13
14#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
15pub struct Config {
16    pub exp_ratio: usize,
17    pub in_channels: usize,
18    pub blocks: [usize; 4],
19    pub attn: bool,
20    pub lkc_use_act: bool,
21}
22
23impl Config {
24    pub fn t8() -> Self {
25        Self {
26            exp_ratio: 3,
27            in_channels: 48,
28            blocks: [2, 2, 4, 2],
29            attn: false,
30            lkc_use_act: false,
31        }
32    }
33
34    pub fn t12() -> Self {
35        Self {
36            exp_ratio: 3,
37            in_channels: 64,
38            blocks: [2, 2, 6, 2],
39            attn: false,
40            lkc_use_act: false,
41        }
42    }
43    pub fn s12() -> Self {
44        Self {
45            exp_ratio: 4,
46            in_channels: 64,
47            blocks: [2, 2, 6, 2],
48            attn: false,
49            lkc_use_act: false,
50        }
51    }
52    pub fn sa12() -> Self {
53        Self {
54            exp_ratio: 4,
55            in_channels: 64,
56            blocks: [2, 2, 6, 2],
57            attn: true,
58            lkc_use_act: false,
59        }
60    }
61    pub fn sa24() -> Self {
62        Self {
63            exp_ratio: 4,
64            in_channels: 64,
65            blocks: [4, 4, 12, 4],
66            attn: true,
67            lkc_use_act: false,
68        }
69    }
70    pub fn sa36() -> Self {
71        Self {
72            exp_ratio: 4,
73            in_channels: 64,
74            blocks: [6, 6, 18, 6],
75            attn: true,
76            lkc_use_act: false,
77        }
78    }
79    pub fn ma36() -> Self {
80        Self {
81            exp_ratio: 4,
82            in_channels: 76,
83            blocks: [6, 6, 18, 6],
84            attn: true,
85            lkc_use_act: false,
86        }
87    }
88
89    // configs used by MobileCLIP's image encoder
90    pub fn mci0() -> Self {
91        Self {
92            exp_ratio: 3,
93            in_channels: 64,
94            blocks: [2, 6, 10, 2],
95            attn: true,
96            lkc_use_act: true,
97        }
98    }
99    pub fn mci1() -> Self {
100        Self {
101            exp_ratio: 3,
102            in_channels: 64,
103            blocks: [4, 12, 20, 4],
104            attn: true,
105            lkc_use_act: true,
106        }
107    }
108    pub fn mci2() -> Self {
109        Self {
110            exp_ratio: 3,
111            in_channels: 80,
112            blocks: [4, 12, 24, 4],
113            attn: true,
114            lkc_use_act: true,
115        }
116    }
117}
118
119fn conv_norm(
120    in_channels: usize,
121    out_channels: usize,
122    kernel: usize,
123    stride: usize,
124    vb: VarBuilder,
125) -> Result<Func<'static>> {
126    let conv2d_cfg = Conv2dConfig {
127        stride,
128        padding: kernel / 2,
129        groups: in_channels,
130        ..Default::default()
131    };
132
133    let bn = batch_norm(out_channels, 1e-5, vb.pp("bn"))?;
134    let conv = conv2d_no_bias(in_channels, out_channels, kernel, conv2d_cfg, vb.pp("conv"))?;
135    let conv = conv.absorb_bn(&bn)?;
136    Ok(Func::new(move |xs| {
137        let xs = xs.apply(&conv)?;
138        Ok(xs)
139    }))
140}
141
142fn conv_mlp(dim: usize, exp_ratio: usize, vb: VarBuilder) -> Result<Func<'static>> {
143    let conv2d_cfg = Conv2dConfig {
144        ..Default::default()
145    };
146
147    let conv = conv_norm(dim, dim, 7, 1, vb.pp("conv"))?;
148    let fc1 = conv2d(dim, dim * exp_ratio, 1, conv2d_cfg, vb.pp("fc1"))?;
149    let fc2 = conv2d(dim * exp_ratio, dim, 1, conv2d_cfg, vb.pp("fc2"))?;
150
151    Ok(Func::new(move |xs| {
152        let xs = xs.apply(&conv)?.apply(&fc1)?.gelu_erf()?.apply(&fc2)?;
153        Ok(xs)
154    }))
155}
156
157fn squeeze_and_excitation(
158    in_channels: usize,
159    squeeze_channels: usize,
160    vb: VarBuilder,
161) -> Result<Func<'static>> {
162    let conv2d_cfg = Conv2dConfig {
163        ..Default::default()
164    };
165    let fc1 = conv2d(in_channels, squeeze_channels, 1, conv2d_cfg, vb.pp("fc1"))?;
166    let fc2 = conv2d(squeeze_channels, in_channels, 1, conv2d_cfg, vb.pp("fc2"))?;
167
168    Ok(Func::new(move |xs| {
169        let residual = xs;
170        let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
171        let xs = sigmoid(&xs.apply(&fc1)?.relu()?.apply(&fc2)?)?;
172
173        residual.broadcast_mul(&xs)
174    }))
175}
176
177// fuses a convolutional kernel and a batchnorm layer into a convolutional layer
178// based on the _fuse_bn_tensor method in timm
179// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
180fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
181    let (gamma, beta) = bn.weight_and_bias().context("no weight-bias")?;
182    let mu = bn.running_mean();
183    let sigma = (bn.running_var() + bn.eps())?.sqrt();
184    let gps = (gamma / sigma)?;
185    let bias = (beta - mu * &gps)?;
186    let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?;
187
188    Ok((weights, bias))
189}
190
191fn mobileone_block(
192    in_channels: usize,
193    out_channels: usize,
194    kernel: usize,
195    stride: usize,
196    group_size: usize,
197    use_act: bool,
198    vb: VarBuilder,
199) -> Result<Func<'static>> {
200    let groups = if group_size == 0 {
201        1
202    } else {
203        in_channels / group_size
204    };
205
206    let padding = kernel / 2;
207    let conv2d_cfg = Conv2dConfig {
208        stride,
209        groups,
210        padding,
211        ..Default::default()
212    };
213
214    let mut w = Tensor::zeros(
215        (out_channels, in_channels / groups, kernel, kernel),
216        DType::F32,
217        vb.device(),
218    )?;
219    let dim = out_channels;
220
221    let mut b = Tensor::zeros(dim, DType::F32, vb.device())?;
222
223    let conv_kxk_bn = batch_norm(dim, 1e-5, vb.pp("conv_kxk.0.bn"));
224    let conv_kxk = conv2d_no_bias(
225        in_channels,
226        out_channels,
227        kernel,
228        conv2d_cfg,
229        vb.pp("conv_kxk.0.conv"),
230    );
231
232    if let (Ok(conv), Ok(bn)) = (conv_kxk, conv_kxk_bn) {
233        let (wk, bk) = fuse_conv_bn(conv.weight(), bn)?;
234        w = (w + wk)?;
235        b = (b + bk)?;
236    };
237
238    let conv_scale_bn = batch_norm(dim, 1e-5, vb.pp("conv_scale.bn"));
239    let conv_scale = conv2d_no_bias(
240        in_channels,
241        out_channels,
242        1,
243        conv2d_cfg,
244        vb.pp("conv_scale.conv"),
245    );
246
247    if let (Ok(conv), Ok(bn)) = (conv_scale, conv_scale_bn) {
248        let (ws, bs) = fuse_conv_bn(conv.weight(), bn)?;
249        // pad to 3x3
250        let ws = ws
251            .pad_with_zeros(D::Minus1, 1, 1)?
252            .pad_with_zeros(D::Minus2, 1, 1)?;
253
254        w = (w + ws)?;
255        b = (b + bs)?;
256    };
257
258    let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("se"));
259
260    // read and reparameterize the identity bn into wi and bi
261    let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"));
262
263    if let Ok(id_bn) = identity_bn {
264        let mut weights: Vec<f32> = vec![0.0; w.elem_count()];
265        let id = in_channels / groups;
266        // See https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L809
267        for i in 0..in_channels {
268            if kernel > 1 {
269                weights[i * kernel * kernel + 4] = 1.0;
270            } else {
271                weights[i * (id + 1)] = 1.0;
272            }
273        }
274
275        let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
276        let (wi, bi) = fuse_conv_bn(weights, id_bn)?;
277
278        w = (w + wi)?;
279        b = (b + bi)?;
280    };
281    let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);
282
283    Ok(Func::new(move |xs| {
284        let mut xs = xs.apply(&reparam_conv)?;
285        if let Ok(f) = &se {
286            xs = xs.apply(f)?;
287        }
288        if use_act {
289            xs = xs.gelu_erf()?;
290        };
291        Ok(xs)
292    }))
293}
294
295fn repmixer(dim: usize, kernel: usize, vb: VarBuilder) -> Result<Func<'static>> {
296    let gamma = vb.get((dim, 1, 1), "layer_scale.gamma")?;
297    let norm = mobileone_block(dim, dim, kernel, 1, 1, false, vb.pp("norm"))?;
298    let mixer = mobileone_block(dim, dim, kernel, 1, 1, false, vb.pp("mixer"))?;
299
300    Ok(Func::new(move |xs| {
301        let residual = xs.clone();
302        let xs = (xs.apply(&mixer)? - xs.apply(&norm)?)?;
303        let xs = xs.broadcast_mul(&gamma.reshape((1, (), 1, 1))?)?;
304        let xs = (xs + residual)?;
305        Ok(xs)
306    }))
307}
308
309fn repmixer_block(dim: usize, exp_ratio: usize, vb: VarBuilder) -> Result<Func<'static>> {
310    let gamma = vb.get((dim, 1, 1), "layer_scale.gamma")?;
311    let token_mixer = repmixer(dim, 3, vb.pp("token_mixer"))?;
312    let mlp = conv_mlp(dim, exp_ratio, vb.pp("mlp"))?;
313
314    Ok(Func::new(move |xs| {
315        let residual = xs.apply(&token_mixer)?;
316        let mut xs = residual.apply(&mlp)?;
317        xs = xs.broadcast_mul(&gamma.reshape((1, (), 1, 1))?)?;
318        let xs = (xs + residual)?;
319        Ok(xs)
320    }))
321}
322
323fn positional_encoding(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
324    let conv2d_cfg = Conv2dConfig {
325        stride: 1,
326        padding: 3,
327        groups: dim,
328        ..Default::default()
329    };
330
331    let conv = conv2d(dim, dim, 7, conv2d_cfg, vb.pp("pos_enc"))?;
332
333    Ok(Func::new(move |xs| {
334        let xs = (xs + xs.apply(&conv)?)?;
335        Ok(xs)
336    }))
337}
338
339fn attention(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
340    let qkv = linear_no_bias(dim, dim * 3, vb.pp("qkv"))?;
341    let proj = linear(dim, dim, vb.pp("proj"))?;
342    let head_dim = 32;
343    let num_heads = dim / head_dim;
344    let scale = (head_dim as f64).powf(-0.5);
345
346    Ok(Func::new(move |xs| {
347        let xs = xs.clone();
348        let (b, c, h, w) = xs.dims4()?;
349        let n = h * w;
350        let xs = xs.flatten_from(2)?.transpose(D::Minus1, D::Minus2)?;
351        let qkv = xs
352            .apply(&qkv)?
353            .reshape((b, n, 3, num_heads, head_dim))?
354            .permute((2, 0, 3, 1, 4))?;
355
356        let q = qkv.get(0)?;
357        let k = qkv.get(1)?;
358        let v = qkv.get(2)?;
359
360        let q = (q * scale)?;
361
362        let att = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?;
363        let att = softmax(&att, D::Minus1)?;
364        let xs = att.matmul(&v)?;
365
366        let xs = xs.transpose(1, 2)?.reshape((b, n, c))?;
367        let xs = xs.apply(&proj)?;
368        let xs = xs.transpose(D::Minus1, D::Minus2)?.reshape((b, c, h, w))?;
369
370        Ok(xs)
371    }))
372}
373
374fn attention_block(dim: usize, exp_ratio: usize, vb: VarBuilder) -> Result<Func<'static>> {
375    let gamma1 = vb.get((dim, 1, 1), "layer_scale_1.gamma")?;
376    let gamma2 = vb.get((dim, 1, 1), "layer_scale_2.gamma")?;
377    let norm = batch_norm(dim, 1e-5, vb.pp("norm"))?;
378    let token_mixer = attention(dim, vb.pp("token_mixer"))?;
379    let mlp = conv_mlp(dim, exp_ratio, vb.pp("mlp"))?;
380
381    Ok(Func::new(move |xs| {
382        let xs = xs.clone();
383        let xs = (&xs
384            + &xs
385                .apply_t(&norm, false)?
386                .apply(&token_mixer)?
387                .broadcast_mul(&gamma1.reshape((1, (), 1, 1))?)?)?;
388
389        let xs = (&xs
390            + &xs
391                .apply(&mlp)?
392                .broadcast_mul(&gamma2.reshape((1, (), 1, 1))?)?)?;
393
394        Ok(xs)
395    }))
396}
397
398fn fastvit_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
399    let nblocks = cfg.blocks[idx];
400    let mut blocks = Vec::with_capacity(nblocks);
401
402    let dim = cfg.in_channels << idx;
403    let downsample = fastvit_patch_embed(dim / 2, dim, cfg.lkc_use_act, vb.pp("downsample"));
404    for block_idx in 0..nblocks {
405        let block = if cfg.attn && idx == 3 {
406            attention_block(dim, cfg.exp_ratio, vb.pp(format!("blocks.{block_idx}")))?
407        } else {
408            repmixer_block(dim, cfg.exp_ratio, vb.pp(format!("blocks.{block_idx}")))?
409        };
410        blocks.push(block);
411    }
412    let pos_emb = positional_encoding(dim, vb.pp("pos_emb"));
413
414    Ok(Func::new(move |xs| {
415        let mut xs = xs.clone();
416        if let Ok(ds) = &downsample {
417            xs = xs.apply(ds)?;
418        }
419        if let Ok(pos) = &pos_emb {
420            xs = xs.apply(pos)?;
421        }
422        for block in blocks.iter() {
423            xs = xs.apply(block)?;
424        }
425        Ok(xs)
426    }))
427}
428
429fn fastvit_patch_embed(
430    in_channels: usize,
431    out_channels: usize,
432    use_act: bool,
433    vb: VarBuilder,
434) -> Result<Func<'static>> {
435    let lk = conv_norm(in_channels, out_channels, 7, 2, vb.pp("proj.0.large_conv"))?;
436    let sk = conv_norm(in_channels, out_channels, 3, 2, vb.pp("proj.0.small_conv"))?;
437    let se = squeeze_and_excitation(out_channels, out_channels / 4, vb.pp("proj.0.se"));
438    let mb = mobileone_block(out_channels, out_channels, 1, 1, 0, true, vb.pp("proj.1"))?;
439
440    Ok(Func::new(move |xs| {
441        let mut xs = (xs.apply(&lk)? + xs.apply(&sk)?)?;
442        if let Ok(f) = &se {
443            xs = xs.apply(f)?;
444        }
445        if use_act {
446            xs = xs.gelu_erf()?;
447        };
448        let xs = xs.apply(&mb)?;
449        Ok(xs)
450    }))
451}
452
453fn fastvit_stem(in_channels: usize, out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
454    let mb0 = mobileone_block(in_channels, out_channels, 3, 2, 0, true, vb.pp(0))?;
455    let mb1 = mobileone_block(out_channels, out_channels, 3, 2, 1, true, vb.pp(1))?;
456    let mb2 = mobileone_block(out_channels, out_channels, 1, 1, 0, true, vb.pp(2))?;
457    Ok(Func::new(move |xs| {
458        let xs = xs.apply(&mb0)?.apply(&mb1)?.apply(&mb2)?;
459        Ok(xs)
460    }))
461}
462
463// Build a fastvit model for a given configuration.
464fn fastvit_model(cfg: &Config, nclasses: Option<usize>, vb: VarBuilder) -> Result<Func<'static>> {
465    let cls = match nclasses {
466        None => None,
467        Some(nclasses) => {
468            let linear = linear(cfg.in_channels * 16, nclasses, vb.pp("head.fc"))?;
469            Some(linear)
470        }
471    };
472
473    let stem = fastvit_stem(3, cfg.in_channels, vb.pp("stem"))?;
474    let final_conv = mobileone_block(
475        cfg.in_channels * 8,
476        cfg.in_channels * 16,
477        3,
478        1,
479        1,
480        true,
481        vb.pp("final_conv"),
482    )?;
483
484    let vb = vb.pp("stages");
485    let stage1 = fastvit_stage(cfg, 0, vb.pp(0))?;
486    let stage2 = fastvit_stage(cfg, 1, vb.pp(1))?;
487    let stage3 = fastvit_stage(cfg, 2, vb.pp(2))?;
488    let stage4 = fastvit_stage(cfg, 3, vb.pp(3))?;
489
490    Ok(Func::new(move |xs| {
491        let xs = xs
492            .apply(&stem)?
493            .apply(&stage1)?
494            .apply(&stage2)?
495            .apply(&stage3)?
496            .apply(&stage4)?
497            .apply(&final_conv)?;
498        match &cls {
499            None => Ok(xs),
500            Some(cls) => xs.mean(D::Minus2)?.mean(D::Minus1)?.apply(cls),
501        }
502    }))
503}
504
505pub fn fastvit(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
506    fastvit_model(cfg, Some(nclasses), vb)
507}
508
509pub fn fastvit_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
510    fastvit_model(cfg, None, vb)
511}