candle_transformers/models/
segformer.rs

1//! Segformer model implementation for semantic segmentation and image classification.
2//!
3//! Segformer is a transformer-based model designed for vision tasks. It uses a hierarchical
4//! structure that progressively generates features at different scales.
5//!
6//! Key characteristics:
7//! - Efficient self-attention with sequence reduction
8//! - Hierarchical feature generation
9//! - Mix-FFN for local and global feature interaction
10//! - Lightweight all-MLP decode head
11//!
12//! References:
13//! - [SegFormer Paper](https://arxiv.org/abs/2105.15203)
14//! - [Model Card](https://huggingface.co/nvidia/mit-b0)
15//!
16
17use crate::models::with_tracing::{conv2d, linear, Conv2d, Linear};
18use candle::{Context, Module, ModuleT, Result, Tensor, D};
19use candle_nn::{conv2d_no_bias, layer_norm, Activation, Conv2dConfig, VarBuilder};
20use serde::Deserialize;
21use std::collections::HashMap;
22
23// https://github.com/huggingface/transformers/blob/main/src/transformers/models/segformer/configuration_segformer.py
24#[derive(Debug, Clone, PartialEq, Deserialize)]
25pub struct Config {
26    #[serde(default)]
27    pub id2label: HashMap<String, String>,
28    pub num_channels: usize,
29    pub num_encoder_blocks: usize,
30    pub depths: Vec<usize>,
31    pub sr_ratios: Vec<usize>,
32    pub hidden_sizes: Vec<usize>,
33    pub patch_sizes: Vec<usize>,
34    pub strides: Vec<usize>,
35    pub num_attention_heads: Vec<usize>,
36    pub mlp_ratios: Vec<usize>,
37    pub hidden_act: candle_nn::Activation,
38    pub layer_norm_eps: f64,
39    pub decoder_hidden_size: usize,
40}
41
42#[derive(Debug, Clone)]
43struct SegformerOverlapPatchEmbeddings {
44    projection: Conv2d,
45    layer_norm: candle_nn::LayerNorm,
46}
47
48impl SegformerOverlapPatchEmbeddings {
49    fn new(
50        config: &Config,
51        patch_size: usize,
52        stride: usize,
53        num_channels: usize,
54        hidden_size: usize,
55        vb: VarBuilder,
56    ) -> Result<Self> {
57        let projection = conv2d(
58            num_channels,
59            hidden_size,
60            patch_size,
61            Conv2dConfig {
62                stride,
63                padding: patch_size / 2,
64                ..Default::default()
65            },
66            vb.pp("proj"),
67        )?;
68        let layer_norm =
69            candle_nn::layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm"))?;
70        Ok(Self {
71            projection,
72            layer_norm,
73        })
74    }
75}
76
77impl Module for SegformerOverlapPatchEmbeddings {
78    fn forward(&self, x: &Tensor) -> Result<Tensor> {
79        let embeddings = self.projection.forward(x)?;
80        let shape = embeddings.shape();
81        // [B, C, H, W] -> [B, H * W, C]
82        let embeddings = embeddings.flatten_from(2)?.transpose(1, 2)?;
83        let embeddings = self.layer_norm.forward(&embeddings)?;
84        // [B, H * W, C] -> [B, C, H, W]
85        let embeddings = embeddings.transpose(1, 2)?.reshape(shape)?;
86        Ok(embeddings)
87    }
88}
89
90#[derive(Debug, Clone)]
91struct SegformerEfficientSelfAttention {
92    num_attention_heads: usize,
93    attention_head_size: usize,
94    query: Linear,
95    key: Linear,
96    value: Linear,
97    sr: Option<Conv2d>,
98    layer_norm: Option<layer_norm::LayerNorm>,
99}
100
101impl SegformerEfficientSelfAttention {
102    fn new(
103        config: &Config,
104        hidden_size: usize,
105        num_attention_heads: usize,
106        sequence_reduction_ratio: usize,
107        vb: VarBuilder,
108    ) -> Result<Self> {
109        if hidden_size % num_attention_heads != 0 {
110            candle::bail!(
111                "The hidden size {} is not a multiple of the number of attention heads {}",
112                hidden_size,
113                num_attention_heads
114            )
115        }
116        let attention_head_size = hidden_size / num_attention_heads;
117        let all_head_size = num_attention_heads * attention_head_size;
118        let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
119        let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
120        let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
121        let (sr, layer_norm) = if sequence_reduction_ratio > 1 {
122            (
123                Some(conv2d(
124                    hidden_size,
125                    hidden_size,
126                    sequence_reduction_ratio,
127                    Conv2dConfig {
128                        stride: sequence_reduction_ratio,
129                        ..Default::default()
130                    },
131                    vb.pp("sr"),
132                )?),
133                Some(candle_nn::layer_norm(
134                    hidden_size,
135                    config.layer_norm_eps,
136                    vb.pp("layer_norm"),
137                )?),
138            )
139        } else {
140            (None, None)
141        };
142        Ok(Self {
143            num_attention_heads,
144            attention_head_size,
145            query,
146            key,
147            value,
148            sr,
149            layer_norm,
150        })
151    }
152
153    fn transpose_for_scores(&self, hidden_states: Tensor) -> Result<Tensor> {
154        let (batch, seq_length, _) = hidden_states.shape().dims3()?;
155        let new_shape = &[
156            batch,
157            seq_length,
158            self.num_attention_heads,
159            self.attention_head_size,
160        ];
161        let hidden_states = hidden_states.reshape(new_shape)?;
162        let hidden_states = hidden_states.permute((0, 2, 1, 3))?;
163        Ok(hidden_states)
164    }
165}
166
167impl Module for SegformerEfficientSelfAttention {
168    fn forward(&self, x: &Tensor) -> Result<Tensor> {
169        // [B, C, H, W] -> [B, H * W, C]
170        let hidden_states = x.flatten_from(2)?.permute((0, 2, 1))?;
171        let query = self
172            .transpose_for_scores(self.query.forward(&hidden_states)?)?
173            .contiguous()?;
174        let hidden_states = if let (Some(sr), Some(layer_norm)) = (&self.sr, &self.layer_norm) {
175            let hidden_states = sr.forward(x)?;
176            // [B, C, H, W] -> [B, H * W, C]
177            let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?;
178            layer_norm.forward(&hidden_states)?
179        } else {
180            // already [B, H * W, C]
181            hidden_states
182        };
183        // standard self-attention
184        let key = self
185            .transpose_for_scores(self.key.forward(&hidden_states)?)?
186            .contiguous()?;
187        let value = self
188            .transpose_for_scores(self.value.forward(&hidden_states)?)?
189            .contiguous()?;
190        let attention_scores =
191            (query.matmul(&key.t()?)? / f64::sqrt(self.attention_head_size as f64))?;
192        let attention_scores = candle_nn::ops::softmax_last_dim(&attention_scores)?;
193        let result = attention_scores.matmul(&value)?;
194        let result = result.permute((0, 2, 1, 3))?.contiguous()?;
195        result.flatten_from(D::Minus2)
196    }
197}
198
199#[derive(Debug, Clone)]
200struct SegformerSelfOutput {
201    dense: Linear,
202}
203
204impl SegformerSelfOutput {
205    fn new(hidden_size: usize, vb: VarBuilder) -> Result<Self> {
206        let dense = linear(hidden_size, hidden_size, vb.pp("dense"))?;
207        Ok(Self { dense })
208    }
209}
210
211impl Module for SegformerSelfOutput {
212    fn forward(&self, x: &Tensor) -> Result<Tensor> {
213        self.dense.forward(x)
214    }
215}
216
217#[derive(Debug, Clone)]
218struct SegformerAttention {
219    attention: SegformerEfficientSelfAttention,
220    output: SegformerSelfOutput,
221}
222
223impl SegformerAttention {
224    fn new(
225        config: &Config,
226        hidden_size: usize,
227        num_attention_heads: usize,
228        sequence_reduction_ratio: usize,
229        vb: VarBuilder,
230    ) -> Result<Self> {
231        let attention = SegformerEfficientSelfAttention::new(
232            config,
233            hidden_size,
234            num_attention_heads,
235            sequence_reduction_ratio,
236            vb.pp("self"),
237        )?;
238        let output = SegformerSelfOutput::new(hidden_size, vb.pp("output"))?;
239        Ok(Self { attention, output })
240    }
241}
242
243impl Module for SegformerAttention {
244    fn forward(&self, x: &Tensor) -> Result<Tensor> {
245        let attention_output = self.attention.forward(x)?;
246        self.output.forward(&attention_output)
247    }
248}
249
250#[derive(Debug, Clone)]
251struct SegformerDWConv {
252    dw_conv: Conv2d,
253}
254
255impl SegformerDWConv {
256    fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
257        let dw_conv = conv2d(
258            dim,
259            dim,
260            3,
261            Conv2dConfig {
262                stride: 1,
263                padding: 1,
264                groups: dim,
265                ..Default::default()
266            },
267            vb.pp("dwconv"),
268        )?;
269        Ok(Self { dw_conv })
270    }
271}
272
273impl Module for SegformerDWConv {
274    fn forward(&self, x: &Tensor) -> Result<Tensor> {
275        self.dw_conv.forward(x)
276    }
277}
278
279#[derive(Debug, Clone)]
280struct SegformerMixFFN {
281    dense1: Linear,
282    dw_conv: SegformerDWConv,
283    act: Activation,
284    dense2: Linear,
285}
286
287impl SegformerMixFFN {
288    fn new(
289        config: &Config,
290        in_features: usize,
291        hidden_features: usize,
292        out_features: usize,
293        vb: VarBuilder,
294    ) -> Result<Self> {
295        let dense1 = linear(in_features, hidden_features, vb.pp("dense1"))?;
296        let dw_conv = SegformerDWConv::new(hidden_features, vb.pp("dwconv"))?;
297        let act = config.hidden_act;
298        let dense2 = linear(hidden_features, out_features, vb.pp("dense2"))?;
299        Ok(Self {
300            dense1,
301            dw_conv,
302            act,
303            dense2,
304        })
305    }
306}
307
308impl Module for SegformerMixFFN {
309    fn forward(&self, x: &Tensor) -> Result<Tensor> {
310        let (batch, _, height, width) = x.shape().dims4()?;
311        let hidden_states = self
312            .dense1
313            .forward(&x.flatten_from(2)?.permute((0, 2, 1))?)?;
314        let channels = hidden_states.dim(2)?;
315        let hidden_states = self.dw_conv.forward(
316            &hidden_states
317                .permute((0, 2, 1))?
318                .reshape((batch, channels, height, width))?,
319        )?;
320        let hidden_states = self.act.forward(&hidden_states)?;
321        let hidden_states = self
322            .dense2
323            .forward(&hidden_states.flatten_from(2)?.permute((0, 2, 1))?)?;
324        let channels = hidden_states.dim(2)?;
325        hidden_states
326            .permute((0, 2, 1))?
327            .reshape((batch, channels, height, width))
328    }
329}
330
331#[derive(Debug, Clone)]
332struct SegformerLayer {
333    layer_norm_1: candle_nn::LayerNorm,
334    attention: SegformerAttention,
335    layer_norm_2: candle_nn::LayerNorm,
336    mlp: SegformerMixFFN,
337}
338
339impl SegformerLayer {
340    fn new(
341        config: &Config,
342        hidden_size: usize,
343        num_attention_heads: usize,
344        sequence_reduction_ratio: usize,
345        mlp_ratio: usize,
346        vb: VarBuilder,
347    ) -> Result<Self> {
348        let layer_norm_1 = layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm_1"))?;
349        let attention = SegformerAttention::new(
350            config,
351            hidden_size,
352            num_attention_heads,
353            sequence_reduction_ratio,
354            vb.pp("attention"),
355        )?;
356        let layer_norm_2 = layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm_2"))?;
357        let mlp = SegformerMixFFN::new(
358            config,
359            hidden_size,
360            hidden_size * mlp_ratio,
361            hidden_size,
362            vb.pp("mlp"),
363        )?;
364        Ok(Self {
365            layer_norm_1,
366            attention,
367            layer_norm_2,
368            mlp,
369        })
370    }
371}
372
373impl Module for SegformerLayer {
374    fn forward(&self, x: &Tensor) -> Result<Tensor> {
375        let shape = x.shape().dims4()?;
376        // [B, C, H, W] -> [B, H * W, C]
377        let hidden_states = x.flatten_from(2)?.permute((0, 2, 1))?;
378        let layer_norm_output = self.layer_norm_1.forward(&hidden_states)?;
379        let layer_norm_output = layer_norm_output.permute((0, 2, 1))?.reshape(shape)?;
380        // attention takes in [B, C, H, W] in order to properly do conv2d (and output [B, H * W, C])
381        let attention_output = self.attention.forward(&layer_norm_output)?;
382        let hidden_states = (attention_output + hidden_states)?;
383        let layer_norm_output = self.layer_norm_2.forward(&hidden_states)?;
384        let mlp_output = self
385            .mlp
386            .forward(&layer_norm_output.permute((0, 2, 1))?.reshape(shape)?)?;
387        hidden_states.permute((0, 2, 1))?.reshape(shape)? + mlp_output
388    }
389}
390
391#[derive(Debug, Clone)]
392struct SegformerEncoder {
393    /// config file
394    config: Config,
395    /// a list of embeddings
396    patch_embeddings: Vec<SegformerOverlapPatchEmbeddings>,
397    /// a list of attention blocks, each consisting of layers
398    blocks: Vec<Vec<SegformerLayer>>,
399    /// a final list of layer norms
400    layer_norms: Vec<candle_nn::LayerNorm>,
401}
402
403impl SegformerEncoder {
404    fn new(config: Config, vb: VarBuilder) -> Result<Self> {
405        let mut patch_embeddings = Vec::with_capacity(config.num_encoder_blocks);
406        let mut blocks = Vec::with_capacity(config.num_encoder_blocks);
407        let mut layer_norms = Vec::with_capacity(config.num_encoder_blocks);
408        for i in 0..config.num_encoder_blocks {
409            let patch_size = config.patch_sizes[i];
410            let stride = config.strides[i];
411            let hidden_size = config.hidden_sizes[i];
412            let num_channels = if i == 0 {
413                config.num_channels
414            } else {
415                config.hidden_sizes[i - 1]
416            };
417            patch_embeddings.push(SegformerOverlapPatchEmbeddings::new(
418                &config,
419                patch_size,
420                stride,
421                num_channels,
422                hidden_size,
423                vb.pp(format!("patch_embeddings.{}", i)),
424            )?);
425            let mut layers = Vec::with_capacity(config.depths[i]);
426            for j in 0..config.depths[i] {
427                let sequence_reduction_ratio = config.sr_ratios[i];
428                let num_attention_heads = config.num_attention_heads[i];
429                let mlp_ratio = config.mlp_ratios[i];
430                layers.push(SegformerLayer::new(
431                    &config,
432                    hidden_size,
433                    num_attention_heads,
434                    sequence_reduction_ratio,
435                    mlp_ratio,
436                    vb.pp(format!("block.{}.{}", i, j)),
437                )?);
438            }
439            blocks.push(layers);
440            layer_norms.push(layer_norm(
441                hidden_size,
442                config.layer_norm_eps,
443                vb.pp(format!("layer_norm.{}", i)),
444            )?);
445        }
446        Ok(Self {
447            config,
448            patch_embeddings,
449            blocks,
450            layer_norms,
451        })
452    }
453}
454
455impl ModuleWithHiddenStates for SegformerEncoder {
456    fn forward(&self, x: &Tensor) -> Result<Vec<Tensor>> {
457        let mut all_hidden_states = Vec::with_capacity(self.config.num_encoder_blocks);
458        let mut hidden_states = x.clone();
459        for i in 0..self.config.num_encoder_blocks {
460            hidden_states = self.patch_embeddings[i].forward(&hidden_states)?;
461            for layer in &self.blocks[i] {
462                hidden_states = layer.forward(&hidden_states)?;
463            }
464            let shape = hidden_states.shape().dims4()?;
465            hidden_states =
466                self.layer_norms[i].forward(&hidden_states.flatten_from(2)?.permute((0, 2, 1))?)?;
467            hidden_states = hidden_states.permute((0, 2, 1))?.reshape(shape)?;
468            all_hidden_states.push(hidden_states.clone());
469        }
470        Ok(all_hidden_states)
471    }
472}
473
474#[derive(Debug, Clone)]
475struct SegformerModel {
476    encoder: SegformerEncoder,
477}
478
479impl SegformerModel {
480    fn new(config: &Config, vb: VarBuilder) -> Result<Self> {
481        let encoder = SegformerEncoder::new(config.clone(), vb.pp("encoder"))?;
482        Ok(Self { encoder })
483    }
484}
485
486impl ModuleWithHiddenStates for SegformerModel {
487    fn forward(&self, x: &Tensor) -> Result<Vec<Tensor>> {
488        self.encoder.forward(x)
489    }
490}
491
492#[derive(Debug, Clone)]
493struct SegformerMLP {
494    proj: Linear,
495}
496
497impl SegformerMLP {
498    fn new(config: &Config, input_dim: usize, vb: VarBuilder) -> Result<Self> {
499        let proj = linear(input_dim, config.decoder_hidden_size, vb.pp("proj"))?;
500        Ok(Self { proj })
501    }
502}
503
504impl Module for SegformerMLP {
505    fn forward(&self, x: &Tensor) -> Result<Tensor> {
506        self.proj.forward(x)
507    }
508}
509
510#[derive(Debug, Clone)]
511struct SegformerDecodeHead {
512    linear_c: Vec<SegformerMLP>,
513    linear_fuse: candle_nn::Conv2d,
514    batch_norm: candle_nn::BatchNorm,
515    classifier: candle_nn::Conv2d,
516}
517
518impl SegformerDecodeHead {
519    fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result<Self> {
520        let mut linear_c = Vec::with_capacity(config.num_encoder_blocks);
521        for i in 0..config.num_encoder_blocks {
522            let hidden_size = config.hidden_sizes[i];
523            linear_c.push(SegformerMLP::new(
524                config,
525                hidden_size,
526                vb.pp(format!("linear_c.{}", i)),
527            )?);
528        }
529        let linear_fuse = conv2d_no_bias(
530            config.decoder_hidden_size * config.num_encoder_blocks,
531            config.decoder_hidden_size,
532            1,
533            Conv2dConfig::default(),
534            vb.pp("linear_fuse"),
535        )?;
536        let batch_norm = candle_nn::batch_norm(
537            config.decoder_hidden_size,
538            config.layer_norm_eps,
539            vb.pp("batch_norm"),
540        )?;
541        let classifier = conv2d_no_bias(
542            config.decoder_hidden_size,
543            num_labels,
544            1,
545            Conv2dConfig::default(),
546            vb.pp("classifier"),
547        )?;
548        Ok(Self {
549            linear_c,
550            linear_fuse,
551            batch_norm,
552            classifier,
553        })
554    }
555
556    fn forward(&self, encoder_hidden_states: &[Tensor]) -> Result<Tensor> {
557        if encoder_hidden_states.len() != self.linear_c.len() {
558            candle::bail!(
559                "The number of encoder hidden states {} is not equal to the number of linear layers {}",
560                encoder_hidden_states.len(),
561                self.linear_c.len()
562            )
563        }
564        // most fine layer
565        let (_, _, upsample_height, upsample_width) = encoder_hidden_states[0].shape().dims4()?;
566        let mut hidden_states = Vec::with_capacity(self.linear_c.len());
567        for (hidden_state, mlp) in encoder_hidden_states.iter().zip(&self.linear_c) {
568            let (batch, _, height, width) = hidden_state.shape().dims4()?;
569            let hidden_state = mlp.forward(&hidden_state.flatten_from(2)?.permute((0, 2, 1))?)?;
570            let hidden_state = hidden_state.permute((0, 2, 1))?.reshape((
571                batch,
572                hidden_state.dim(2)?,
573                height,
574                width,
575            ))?;
576            let hidden_state = hidden_state.upsample_nearest2d(upsample_height, upsample_width)?;
577            hidden_states.push(hidden_state);
578        }
579        hidden_states.reverse();
580        let hidden_states = Tensor::cat(&hidden_states, 1)?;
581        let hidden_states = self.linear_fuse.forward(&hidden_states)?;
582        let hidden_states = self.batch_norm.forward_t(&hidden_states, false)?;
583        let hidden_states = hidden_states.relu()?;
584        self.classifier.forward(&hidden_states)
585    }
586}
587
588trait ModuleWithHiddenStates {
589    fn forward(&self, xs: &Tensor) -> Result<Vec<Tensor>>;
590}
591
592#[derive(Debug, Clone)]
593pub struct SemanticSegmentationModel {
594    segformer: SegformerModel,
595    decode_head: SegformerDecodeHead,
596}
597
598impl SemanticSegmentationModel {
599    pub fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result<Self> {
600        let segformer = SegformerModel::new(config, vb.pp("segformer"))?;
601        let decode_head = SegformerDecodeHead::new(config, num_labels, vb.pp("decode_head"))?;
602        Ok(Self {
603            segformer,
604            decode_head,
605        })
606    }
607}
608
609impl Module for SemanticSegmentationModel {
610    fn forward(&self, x: &Tensor) -> Result<Tensor> {
611        let hidden_states = self.segformer.forward(x)?;
612        self.decode_head.forward(&hidden_states)
613    }
614}
615
616#[derive(Debug, Clone)]
617pub struct ImageClassificationModel {
618    segformer: SegformerModel,
619    classifier: Linear,
620}
621
622impl ImageClassificationModel {
623    pub fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result<Self> {
624        let segformer = SegformerModel::new(config, vb.pp("segformer"))?;
625        let classifier = linear(config.decoder_hidden_size, num_labels, vb.pp("classifier"))?;
626        Ok(Self {
627            segformer,
628            classifier,
629        })
630    }
631}
632
633impl Module for ImageClassificationModel {
634    fn forward(&self, x: &Tensor) -> Result<Tensor> {
635        let all_hidden_states = self.segformer.forward(x)?;
636        let hidden_states = all_hidden_states.last().context("no last")?;
637        let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?;
638        let mean = hidden_states.mean(1)?;
639        self.classifier.forward(&mean)
640    }
641}
642
643#[cfg(test)]
644mod tests {
645
646    use super::*;
647
648    #[test]
649    fn test_config_json_load() {
650        let raw_json = r#"{
651            "architectures": [
652              "SegformerForImageClassification"
653            ],
654            "attention_probs_dropout_prob": 0.0,
655            "classifier_dropout_prob": 0.1,
656            "decoder_hidden_size": 256,
657            "depths": [
658              2,
659              2,
660              2,
661              2
662            ],
663            "downsampling_rates": [
664              1,
665              4,
666              8,
667              16
668            ],
669            "drop_path_rate": 0.1,
670            "hidden_act": "gelu",
671            "hidden_dropout_prob": 0.0,
672            "hidden_sizes": [
673              32,
674              64,
675              160,
676              256
677            ],
678            "image_size": 224,
679            "initializer_range": 0.02,
680            "layer_norm_eps": 1e-06,
681            "mlp_ratios": [
682              4,
683              4,
684              4,
685              4
686            ],
687            "model_type": "segformer",
688            "num_attention_heads": [
689              1,
690              2,
691              5,
692              8
693            ],
694            "num_channels": 3,
695            "num_encoder_blocks": 4,
696            "patch_sizes": [
697              7,
698              3,
699              3,
700              3
701            ],
702            "sr_ratios": [
703              8,
704              4,
705              2,
706              1
707            ],
708            "strides": [
709              4,
710              2,
711              2,
712              2
713            ],
714            "torch_dtype": "float32",
715            "transformers_version": "4.12.0.dev0"
716          }"#;
717        let config: Config = serde_json::from_str(raw_json).unwrap();
718        assert_eq!(vec![4, 2, 2, 2], config.strides);
719        assert_eq!(1e-6, config.layer_norm_eps);
720    }
721}