1use candle::{Context, DType, Result, Tensor, D};
9use candle_nn::{
10 batch_norm, conv2d, conv2d_no_bias, linear, linear_no_bias, ops::sigmoid, ops::softmax,
11 BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
12};
13
14#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
15pub struct Config {
16 pub exp_ratio: usize,
17 pub in_channels: usize,
18 pub blocks: [usize; 4],
19 pub attn: bool,
20 pub lkc_use_act: bool,
21}
22
23impl Config {
24 pub fn t8() -> Self {
25 Self {
26 exp_ratio: 3,
27 in_channels: 48,
28 blocks: [2, 2, 4, 2],
29 attn: false,
30 lkc_use_act: false,
31 }
32 }
33
34 pub fn t12() -> Self {
35 Self {
36 exp_ratio: 3,
37 in_channels: 64,
38 blocks: [2, 2, 6, 2],
39 attn: false,
40 lkc_use_act: false,
41 }
42 }
43 pub fn s12() -> Self {
44 Self {
45 exp_ratio: 4,
46 in_channels: 64,
47 blocks: [2, 2, 6, 2],
48 attn: false,
49 lkc_use_act: false,
50 }
51 }
52 pub fn sa12() -> Self {
53 Self {
54 exp_ratio: 4,
55 in_channels: 64,
56 blocks: [2, 2, 6, 2],
57 attn: true,
58 lkc_use_act: false,
59 }
60 }
61 pub fn sa24() -> Self {
62 Self {
63 exp_ratio: 4,
64 in_channels: 64,
65 blocks: [4, 4, 12, 4],
66 attn: true,
67 lkc_use_act: false,
68 }
69 }
70 pub fn sa36() -> Self {
71 Self {
72 exp_ratio: 4,
73 in_channels: 64,
74 blocks: [6, 6, 18, 6],
75 attn: true,
76 lkc_use_act: false,
77 }
78 }
79 pub fn ma36() -> Self {
80 Self {
81 exp_ratio: 4,
82 in_channels: 76,
83 blocks: [6, 6, 18, 6],
84 attn: true,
85 lkc_use_act: false,
86 }
87 }
88
89 pub fn mci0() -> Self {
91 Self {
92 exp_ratio: 3,
93 in_channels: 64,
94 blocks: [2, 6, 10, 2],
95 attn: true,
96 lkc_use_act: true,
97 }
98 }
99 pub fn mci1() -> Self {
100 Self {
101 exp_ratio: 3,
102 in_channels: 64,
103 blocks: [4, 12, 20, 4],
104 attn: true,
105 lkc_use_act: true,
106 }
107 }
108 pub fn mci2() -> Self {
109 Self {
110 exp_ratio: 3,
111 in_channels: 80,
112 blocks: [4, 12, 24, 4],
113 attn: true,
114 lkc_use_act: true,
115 }
116 }
117}
118
119fn conv_norm(
120 in_channels: usize,
121 out_channels: usize,
122 kernel: usize,
123 stride: usize,
124 vb: VarBuilder,
125) -> Result<Func<'static>> {
126 let conv2d_cfg = Conv2dConfig {
127 stride,
128 padding: kernel / 2,
129 groups: in_channels,
130 ..Default::default()
131 };
132
133 let bn = batch_norm(out_channels, 1e-5, vb.pp("bn"))?;
134 let conv = conv2d_no_bias(in_channels, out_channels, kernel, conv2d_cfg, vb.pp("conv"))?;
135 let conv = conv.absorb_bn(&bn)?;
136 Ok(Func::new(move |xs| {
137 let xs = xs.apply(&conv)?;
138 Ok(xs)
139 }))
140}
141
142fn conv_mlp(dim: usize, exp_ratio: usize, vb: VarBuilder) -> Result<Func<'static>> {
143 let conv2d_cfg = Conv2dConfig {
144 ..Default::default()
145 };
146
147 let conv = conv_norm(dim, dim, 7, 1, vb.pp("conv"))?;
148 let fc1 = conv2d(dim, dim * exp_ratio, 1, conv2d_cfg, vb.pp("fc1"))?;
149 let fc2 = conv2d(dim * exp_ratio, dim, 1, conv2d_cfg, vb.pp("fc2"))?;
150
151 Ok(Func::new(move |xs| {
152 let xs = xs.apply(&conv)?.apply(&fc1)?.gelu_erf()?.apply(&fc2)?;
153 Ok(xs)
154 }))
155}
156
157fn squeeze_and_excitation(
158 in_channels: usize,
159 squeeze_channels: usize,
160 vb: VarBuilder,
161) -> Result<Func<'static>> {
162 let conv2d_cfg = Conv2dConfig {
163 ..Default::default()
164 };
165 let fc1 = conv2d(in_channels, squeeze_channels, 1, conv2d_cfg, vb.pp("fc1"))?;
166 let fc2 = conv2d(squeeze_channels, in_channels, 1, conv2d_cfg, vb.pp("fc2"))?;
167
168 Ok(Func::new(move |xs| {
169 let residual = xs;
170 let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
171 let xs = sigmoid(&xs.apply(&fc1)?.relu()?.apply(&fc2)?)?;
172
173 residual.broadcast_mul(&xs)
174 }))
175}
176
177fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
181 let (gamma, beta) = bn.weight_and_bias().context("no weight-bias")?;
182 let mu = bn.running_mean();
183 let sigma = (bn.running_var() + bn.eps())?.sqrt();
184 let gps = (gamma / sigma)?;
185 let bias = (beta - mu * &gps)?;
186 let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?;
187
188 Ok((weights, bias))
189}
190
191fn mobileone_block(
192 in_channels: usize,
193 out_channels: usize,
194 kernel: usize,
195 stride: usize,
196 group_size: usize,
197 use_act: bool,
198 vb: VarBuilder,
199) -> Result<Func<'static>> {
200 let groups = if group_size == 0 {
201 1
202 } else {
203 in_channels / group_size
204 };
205
206 let padding = kernel / 2;
207 let conv2d_cfg = Conv2dConfig {
208 stride,
209 groups,
210 padding,
211 ..Default::default()
212 };
213
214 let mut w = Tensor::zeros(
215 (out_channels, in_channels / groups, kernel, kernel),
216 DType::F32,
217 vb.device(),
218 )?;
219 let dim = out_channels;
220
221 let mut b = Tensor::zeros(dim, DType::F32, vb.device())?;
222
223 let conv_kxk_bn = batch_norm(dim, 1e-5, vb.pp("conv_kxk.0.bn"));
224 let conv_kxk = conv2d_no_bias(
225 in_channels,
226 out_channels,
227 kernel,
228 conv2d_cfg,
229 vb.pp("conv_kxk.0.conv"),
230 );
231
232 if let (Ok(conv), Ok(bn)) = (conv_kxk, conv_kxk_bn) {
233 let (wk, bk) = fuse_conv_bn(conv.weight(), bn)?;
234 w = (w + wk)?;
235 b = (b + bk)?;
236 };
237
238 let conv_scale_bn = batch_norm(dim, 1e-5, vb.pp("conv_scale.bn"));
239 let conv_scale = conv2d_no_bias(
240 in_channels,
241 out_channels,
242 1,
243 conv2d_cfg,
244 vb.pp("conv_scale.conv"),
245 );
246
247 if let (Ok(conv), Ok(bn)) = (conv_scale, conv_scale_bn) {
248 let (ws, bs) = fuse_conv_bn(conv.weight(), bn)?;
249 let ws = ws
251 .pad_with_zeros(D::Minus1, 1, 1)?
252 .pad_with_zeros(D::Minus2, 1, 1)?;
253
254 w = (w + ws)?;
255 b = (b + bs)?;
256 };
257
258 let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("se"));
259
260 let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"));
262
263 if let Ok(id_bn) = identity_bn {
264 let mut weights: Vec<f32> = vec![0.0; w.elem_count()];
265 let id = in_channels / groups;
266 for i in 0..in_channels {
268 if kernel > 1 {
269 weights[i * kernel * kernel + 4] = 1.0;
270 } else {
271 weights[i * (id + 1)] = 1.0;
272 }
273 }
274
275 let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
276 let (wi, bi) = fuse_conv_bn(weights, id_bn)?;
277
278 w = (w + wi)?;
279 b = (b + bi)?;
280 };
281 let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);
282
283 Ok(Func::new(move |xs| {
284 let mut xs = xs.apply(&reparam_conv)?;
285 if let Ok(f) = &se {
286 xs = xs.apply(f)?;
287 }
288 if use_act {
289 xs = xs.gelu_erf()?;
290 };
291 Ok(xs)
292 }))
293}
294
295fn repmixer(dim: usize, kernel: usize, vb: VarBuilder) -> Result<Func<'static>> {
296 let gamma = vb.get((dim, 1, 1), "layer_scale.gamma")?;
297 let norm = mobileone_block(dim, dim, kernel, 1, 1, false, vb.pp("norm"))?;
298 let mixer = mobileone_block(dim, dim, kernel, 1, 1, false, vb.pp("mixer"))?;
299
300 Ok(Func::new(move |xs| {
301 let residual = xs.clone();
302 let xs = (xs.apply(&mixer)? - xs.apply(&norm)?)?;
303 let xs = xs.broadcast_mul(&gamma.reshape((1, (), 1, 1))?)?;
304 let xs = (xs + residual)?;
305 Ok(xs)
306 }))
307}
308
309fn repmixer_block(dim: usize, exp_ratio: usize, vb: VarBuilder) -> Result<Func<'static>> {
310 let gamma = vb.get((dim, 1, 1), "layer_scale.gamma")?;
311 let token_mixer = repmixer(dim, 3, vb.pp("token_mixer"))?;
312 let mlp = conv_mlp(dim, exp_ratio, vb.pp("mlp"))?;
313
314 Ok(Func::new(move |xs| {
315 let residual = xs.apply(&token_mixer)?;
316 let mut xs = residual.apply(&mlp)?;
317 xs = xs.broadcast_mul(&gamma.reshape((1, (), 1, 1))?)?;
318 let xs = (xs + residual)?;
319 Ok(xs)
320 }))
321}
322
323fn positional_encoding(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
324 let conv2d_cfg = Conv2dConfig {
325 stride: 1,
326 padding: 3,
327 groups: dim,
328 ..Default::default()
329 };
330
331 let conv = conv2d(dim, dim, 7, conv2d_cfg, vb.pp("pos_enc"))?;
332
333 Ok(Func::new(move |xs| {
334 let xs = (xs + xs.apply(&conv)?)?;
335 Ok(xs)
336 }))
337}
338
339fn attention(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
340 let qkv = linear_no_bias(dim, dim * 3, vb.pp("qkv"))?;
341 let proj = linear(dim, dim, vb.pp("proj"))?;
342 let head_dim = 32;
343 let num_heads = dim / head_dim;
344 let scale = (head_dim as f64).powf(-0.5);
345
346 Ok(Func::new(move |xs| {
347 let xs = xs.clone();
348 let (b, c, h, w) = xs.dims4()?;
349 let n = h * w;
350 let xs = xs.flatten_from(2)?.transpose(D::Minus1, D::Minus2)?;
351 let qkv = xs
352 .apply(&qkv)?
353 .reshape((b, n, 3, num_heads, head_dim))?
354 .permute((2, 0, 3, 1, 4))?;
355
356 let q = qkv.get(0)?;
357 let k = qkv.get(1)?;
358 let v = qkv.get(2)?;
359
360 let q = (q * scale)?;
361
362 let att = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?;
363 let att = softmax(&att, D::Minus1)?;
364 let xs = att.matmul(&v)?;
365
366 let xs = xs.transpose(1, 2)?.reshape((b, n, c))?;
367 let xs = xs.apply(&proj)?;
368 let xs = xs.transpose(D::Minus1, D::Minus2)?.reshape((b, c, h, w))?;
369
370 Ok(xs)
371 }))
372}
373
374fn attention_block(dim: usize, exp_ratio: usize, vb: VarBuilder) -> Result<Func<'static>> {
375 let gamma1 = vb.get((dim, 1, 1), "layer_scale_1.gamma")?;
376 let gamma2 = vb.get((dim, 1, 1), "layer_scale_2.gamma")?;
377 let norm = batch_norm(dim, 1e-5, vb.pp("norm"))?;
378 let token_mixer = attention(dim, vb.pp("token_mixer"))?;
379 let mlp = conv_mlp(dim, exp_ratio, vb.pp("mlp"))?;
380
381 Ok(Func::new(move |xs| {
382 let xs = xs.clone();
383 let xs = (&xs
384 + &xs
385 .apply_t(&norm, false)?
386 .apply(&token_mixer)?
387 .broadcast_mul(&gamma1.reshape((1, (), 1, 1))?)?)?;
388
389 let xs = (&xs
390 + &xs
391 .apply(&mlp)?
392 .broadcast_mul(&gamma2.reshape((1, (), 1, 1))?)?)?;
393
394 Ok(xs)
395 }))
396}
397
398fn fastvit_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
399 let nblocks = cfg.blocks[idx];
400 let mut blocks = Vec::with_capacity(nblocks);
401
402 let dim = cfg.in_channels << idx;
403 let downsample = fastvit_patch_embed(dim / 2, dim, cfg.lkc_use_act, vb.pp("downsample"));
404 for block_idx in 0..nblocks {
405 let block = if cfg.attn && idx == 3 {
406 attention_block(dim, cfg.exp_ratio, vb.pp(format!("blocks.{block_idx}")))?
407 } else {
408 repmixer_block(dim, cfg.exp_ratio, vb.pp(format!("blocks.{block_idx}")))?
409 };
410 blocks.push(block);
411 }
412 let pos_emb = positional_encoding(dim, vb.pp("pos_emb"));
413
414 Ok(Func::new(move |xs| {
415 let mut xs = xs.clone();
416 if let Ok(ds) = &downsample {
417 xs = xs.apply(ds)?;
418 }
419 if let Ok(pos) = &pos_emb {
420 xs = xs.apply(pos)?;
421 }
422 for block in blocks.iter() {
423 xs = xs.apply(block)?;
424 }
425 Ok(xs)
426 }))
427}
428
429fn fastvit_patch_embed(
430 in_channels: usize,
431 out_channels: usize,
432 use_act: bool,
433 vb: VarBuilder,
434) -> Result<Func<'static>> {
435 let lk = conv_norm(in_channels, out_channels, 7, 2, vb.pp("proj.0.large_conv"))?;
436 let sk = conv_norm(in_channels, out_channels, 3, 2, vb.pp("proj.0.small_conv"))?;
437 let se = squeeze_and_excitation(out_channels, out_channels / 4, vb.pp("proj.0.se"));
438 let mb = mobileone_block(out_channels, out_channels, 1, 1, 0, true, vb.pp("proj.1"))?;
439
440 Ok(Func::new(move |xs| {
441 let mut xs = (xs.apply(&lk)? + xs.apply(&sk)?)?;
442 if let Ok(f) = &se {
443 xs = xs.apply(f)?;
444 }
445 if use_act {
446 xs = xs.gelu_erf()?;
447 };
448 let xs = xs.apply(&mb)?;
449 Ok(xs)
450 }))
451}
452
453fn fastvit_stem(in_channels: usize, out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
454 let mb0 = mobileone_block(in_channels, out_channels, 3, 2, 0, true, vb.pp(0))?;
455 let mb1 = mobileone_block(out_channels, out_channels, 3, 2, 1, true, vb.pp(1))?;
456 let mb2 = mobileone_block(out_channels, out_channels, 1, 1, 0, true, vb.pp(2))?;
457 Ok(Func::new(move |xs| {
458 let xs = xs.apply(&mb0)?.apply(&mb1)?.apply(&mb2)?;
459 Ok(xs)
460 }))
461}
462
463fn fastvit_model(cfg: &Config, nclasses: Option<usize>, vb: VarBuilder) -> Result<Func<'static>> {
465 let cls = match nclasses {
466 None => None,
467 Some(nclasses) => {
468 let linear = linear(cfg.in_channels * 16, nclasses, vb.pp("head.fc"))?;
469 Some(linear)
470 }
471 };
472
473 let stem = fastvit_stem(3, cfg.in_channels, vb.pp("stem"))?;
474 let final_conv = mobileone_block(
475 cfg.in_channels * 8,
476 cfg.in_channels * 16,
477 3,
478 1,
479 1,
480 true,
481 vb.pp("final_conv"),
482 )?;
483
484 let vb = vb.pp("stages");
485 let stage1 = fastvit_stage(cfg, 0, vb.pp(0))?;
486 let stage2 = fastvit_stage(cfg, 1, vb.pp(1))?;
487 let stage3 = fastvit_stage(cfg, 2, vb.pp(2))?;
488 let stage4 = fastvit_stage(cfg, 3, vb.pp(3))?;
489
490 Ok(Func::new(move |xs| {
491 let xs = xs
492 .apply(&stem)?
493 .apply(&stage1)?
494 .apply(&stage2)?
495 .apply(&stage3)?
496 .apply(&stage4)?
497 .apply(&final_conv)?;
498 match &cls {
499 None => Ok(xs),
500 Some(cls) => xs.mean(D::Minus2)?.mean(D::Minus1)?.apply(cls),
501 }
502 }))
503}
504
505pub fn fastvit(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
506 fastvit_model(cfg, Some(nclasses), vb)
507}
508
509pub fn fastvit_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
510 fastvit_model(cfg, None, vb)
511}