1use crate::models::with_tracing::{conv2d, linear, Conv2d, Linear};
18use candle::{Context, Module, ModuleT, Result, Tensor, D};
19use candle_nn::{conv2d_no_bias, layer_norm, Activation, Conv2dConfig, VarBuilder};
20use serde::Deserialize;
21use std::collections::HashMap;
22
23#[derive(Debug, Clone, PartialEq, Deserialize)]
25pub struct Config {
26 #[serde(default)]
27 pub id2label: HashMap<String, String>,
28 pub num_channels: usize,
29 pub num_encoder_blocks: usize,
30 pub depths: Vec<usize>,
31 pub sr_ratios: Vec<usize>,
32 pub hidden_sizes: Vec<usize>,
33 pub patch_sizes: Vec<usize>,
34 pub strides: Vec<usize>,
35 pub num_attention_heads: Vec<usize>,
36 pub mlp_ratios: Vec<usize>,
37 pub hidden_act: candle_nn::Activation,
38 pub layer_norm_eps: f64,
39 pub decoder_hidden_size: usize,
40}
41
42#[derive(Debug, Clone)]
43struct SegformerOverlapPatchEmbeddings {
44 projection: Conv2d,
45 layer_norm: candle_nn::LayerNorm,
46}
47
48impl SegformerOverlapPatchEmbeddings {
49 fn new(
50 config: &Config,
51 patch_size: usize,
52 stride: usize,
53 num_channels: usize,
54 hidden_size: usize,
55 vb: VarBuilder,
56 ) -> Result<Self> {
57 let projection = conv2d(
58 num_channels,
59 hidden_size,
60 patch_size,
61 Conv2dConfig {
62 stride,
63 padding: patch_size / 2,
64 ..Default::default()
65 },
66 vb.pp("proj"),
67 )?;
68 let layer_norm =
69 candle_nn::layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm"))?;
70 Ok(Self {
71 projection,
72 layer_norm,
73 })
74 }
75}
76
77impl Module for SegformerOverlapPatchEmbeddings {
78 fn forward(&self, x: &Tensor) -> Result<Tensor> {
79 let embeddings = self.projection.forward(x)?;
80 let shape = embeddings.shape();
81 let embeddings = embeddings.flatten_from(2)?.transpose(1, 2)?;
83 let embeddings = self.layer_norm.forward(&embeddings)?;
84 let embeddings = embeddings.transpose(1, 2)?.reshape(shape)?;
86 Ok(embeddings)
87 }
88}
89
90#[derive(Debug, Clone)]
91struct SegformerEfficientSelfAttention {
92 num_attention_heads: usize,
93 attention_head_size: usize,
94 query: Linear,
95 key: Linear,
96 value: Linear,
97 sr: Option<Conv2d>,
98 layer_norm: Option<layer_norm::LayerNorm>,
99}
100
101impl SegformerEfficientSelfAttention {
102 fn new(
103 config: &Config,
104 hidden_size: usize,
105 num_attention_heads: usize,
106 sequence_reduction_ratio: usize,
107 vb: VarBuilder,
108 ) -> Result<Self> {
109 if hidden_size % num_attention_heads != 0 {
110 candle::bail!(
111 "The hidden size {} is not a multiple of the number of attention heads {}",
112 hidden_size,
113 num_attention_heads
114 )
115 }
116 let attention_head_size = hidden_size / num_attention_heads;
117 let all_head_size = num_attention_heads * attention_head_size;
118 let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
119 let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
120 let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
121 let (sr, layer_norm) = if sequence_reduction_ratio > 1 {
122 (
123 Some(conv2d(
124 hidden_size,
125 hidden_size,
126 sequence_reduction_ratio,
127 Conv2dConfig {
128 stride: sequence_reduction_ratio,
129 ..Default::default()
130 },
131 vb.pp("sr"),
132 )?),
133 Some(candle_nn::layer_norm(
134 hidden_size,
135 config.layer_norm_eps,
136 vb.pp("layer_norm"),
137 )?),
138 )
139 } else {
140 (None, None)
141 };
142 Ok(Self {
143 num_attention_heads,
144 attention_head_size,
145 query,
146 key,
147 value,
148 sr,
149 layer_norm,
150 })
151 }
152
153 fn transpose_for_scores(&self, hidden_states: Tensor) -> Result<Tensor> {
154 let (batch, seq_length, _) = hidden_states.shape().dims3()?;
155 let new_shape = &[
156 batch,
157 seq_length,
158 self.num_attention_heads,
159 self.attention_head_size,
160 ];
161 let hidden_states = hidden_states.reshape(new_shape)?;
162 let hidden_states = hidden_states.permute((0, 2, 1, 3))?;
163 Ok(hidden_states)
164 }
165}
166
167impl Module for SegformerEfficientSelfAttention {
168 fn forward(&self, x: &Tensor) -> Result<Tensor> {
169 let hidden_states = x.flatten_from(2)?.permute((0, 2, 1))?;
171 let query = self
172 .transpose_for_scores(self.query.forward(&hidden_states)?)?
173 .contiguous()?;
174 let hidden_states = if let (Some(sr), Some(layer_norm)) = (&self.sr, &self.layer_norm) {
175 let hidden_states = sr.forward(x)?;
176 let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?;
178 layer_norm.forward(&hidden_states)?
179 } else {
180 hidden_states
182 };
183 let key = self
185 .transpose_for_scores(self.key.forward(&hidden_states)?)?
186 .contiguous()?;
187 let value = self
188 .transpose_for_scores(self.value.forward(&hidden_states)?)?
189 .contiguous()?;
190 let attention_scores =
191 (query.matmul(&key.t()?)? / f64::sqrt(self.attention_head_size as f64))?;
192 let attention_scores = candle_nn::ops::softmax_last_dim(&attention_scores)?;
193 let result = attention_scores.matmul(&value)?;
194 let result = result.permute((0, 2, 1, 3))?.contiguous()?;
195 result.flatten_from(D::Minus2)
196 }
197}
198
199#[derive(Debug, Clone)]
200struct SegformerSelfOutput {
201 dense: Linear,
202}
203
204impl SegformerSelfOutput {
205 fn new(hidden_size: usize, vb: VarBuilder) -> Result<Self> {
206 let dense = linear(hidden_size, hidden_size, vb.pp("dense"))?;
207 Ok(Self { dense })
208 }
209}
210
211impl Module for SegformerSelfOutput {
212 fn forward(&self, x: &Tensor) -> Result<Tensor> {
213 self.dense.forward(x)
214 }
215}
216
217#[derive(Debug, Clone)]
218struct SegformerAttention {
219 attention: SegformerEfficientSelfAttention,
220 output: SegformerSelfOutput,
221}
222
223impl SegformerAttention {
224 fn new(
225 config: &Config,
226 hidden_size: usize,
227 num_attention_heads: usize,
228 sequence_reduction_ratio: usize,
229 vb: VarBuilder,
230 ) -> Result<Self> {
231 let attention = SegformerEfficientSelfAttention::new(
232 config,
233 hidden_size,
234 num_attention_heads,
235 sequence_reduction_ratio,
236 vb.pp("self"),
237 )?;
238 let output = SegformerSelfOutput::new(hidden_size, vb.pp("output"))?;
239 Ok(Self { attention, output })
240 }
241}
242
243impl Module for SegformerAttention {
244 fn forward(&self, x: &Tensor) -> Result<Tensor> {
245 let attention_output = self.attention.forward(x)?;
246 self.output.forward(&attention_output)
247 }
248}
249
250#[derive(Debug, Clone)]
251struct SegformerDWConv {
252 dw_conv: Conv2d,
253}
254
255impl SegformerDWConv {
256 fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
257 let dw_conv = conv2d(
258 dim,
259 dim,
260 3,
261 Conv2dConfig {
262 stride: 1,
263 padding: 1,
264 groups: dim,
265 ..Default::default()
266 },
267 vb.pp("dwconv"),
268 )?;
269 Ok(Self { dw_conv })
270 }
271}
272
273impl Module for SegformerDWConv {
274 fn forward(&self, x: &Tensor) -> Result<Tensor> {
275 self.dw_conv.forward(x)
276 }
277}
278
279#[derive(Debug, Clone)]
280struct SegformerMixFFN {
281 dense1: Linear,
282 dw_conv: SegformerDWConv,
283 act: Activation,
284 dense2: Linear,
285}
286
287impl SegformerMixFFN {
288 fn new(
289 config: &Config,
290 in_features: usize,
291 hidden_features: usize,
292 out_features: usize,
293 vb: VarBuilder,
294 ) -> Result<Self> {
295 let dense1 = linear(in_features, hidden_features, vb.pp("dense1"))?;
296 let dw_conv = SegformerDWConv::new(hidden_features, vb.pp("dwconv"))?;
297 let act = config.hidden_act;
298 let dense2 = linear(hidden_features, out_features, vb.pp("dense2"))?;
299 Ok(Self {
300 dense1,
301 dw_conv,
302 act,
303 dense2,
304 })
305 }
306}
307
308impl Module for SegformerMixFFN {
309 fn forward(&self, x: &Tensor) -> Result<Tensor> {
310 let (batch, _, height, width) = x.shape().dims4()?;
311 let hidden_states = self
312 .dense1
313 .forward(&x.flatten_from(2)?.permute((0, 2, 1))?)?;
314 let channels = hidden_states.dim(2)?;
315 let hidden_states = self.dw_conv.forward(
316 &hidden_states
317 .permute((0, 2, 1))?
318 .reshape((batch, channels, height, width))?,
319 )?;
320 let hidden_states = self.act.forward(&hidden_states)?;
321 let hidden_states = self
322 .dense2
323 .forward(&hidden_states.flatten_from(2)?.permute((0, 2, 1))?)?;
324 let channels = hidden_states.dim(2)?;
325 hidden_states
326 .permute((0, 2, 1))?
327 .reshape((batch, channels, height, width))
328 }
329}
330
331#[derive(Debug, Clone)]
332struct SegformerLayer {
333 layer_norm_1: candle_nn::LayerNorm,
334 attention: SegformerAttention,
335 layer_norm_2: candle_nn::LayerNorm,
336 mlp: SegformerMixFFN,
337}
338
339impl SegformerLayer {
340 fn new(
341 config: &Config,
342 hidden_size: usize,
343 num_attention_heads: usize,
344 sequence_reduction_ratio: usize,
345 mlp_ratio: usize,
346 vb: VarBuilder,
347 ) -> Result<Self> {
348 let layer_norm_1 = layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm_1"))?;
349 let attention = SegformerAttention::new(
350 config,
351 hidden_size,
352 num_attention_heads,
353 sequence_reduction_ratio,
354 vb.pp("attention"),
355 )?;
356 let layer_norm_2 = layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm_2"))?;
357 let mlp = SegformerMixFFN::new(
358 config,
359 hidden_size,
360 hidden_size * mlp_ratio,
361 hidden_size,
362 vb.pp("mlp"),
363 )?;
364 Ok(Self {
365 layer_norm_1,
366 attention,
367 layer_norm_2,
368 mlp,
369 })
370 }
371}
372
373impl Module for SegformerLayer {
374 fn forward(&self, x: &Tensor) -> Result<Tensor> {
375 let shape = x.shape().dims4()?;
376 let hidden_states = x.flatten_from(2)?.permute((0, 2, 1))?;
378 let layer_norm_output = self.layer_norm_1.forward(&hidden_states)?;
379 let layer_norm_output = layer_norm_output.permute((0, 2, 1))?.reshape(shape)?;
380 let attention_output = self.attention.forward(&layer_norm_output)?;
382 let hidden_states = (attention_output + hidden_states)?;
383 let layer_norm_output = self.layer_norm_2.forward(&hidden_states)?;
384 let mlp_output = self
385 .mlp
386 .forward(&layer_norm_output.permute((0, 2, 1))?.reshape(shape)?)?;
387 hidden_states.permute((0, 2, 1))?.reshape(shape)? + mlp_output
388 }
389}
390
391#[derive(Debug, Clone)]
392struct SegformerEncoder {
393 config: Config,
395 patch_embeddings: Vec<SegformerOverlapPatchEmbeddings>,
397 blocks: Vec<Vec<SegformerLayer>>,
399 layer_norms: Vec<candle_nn::LayerNorm>,
401}
402
403impl SegformerEncoder {
404 fn new(config: Config, vb: VarBuilder) -> Result<Self> {
405 let mut patch_embeddings = Vec::with_capacity(config.num_encoder_blocks);
406 let mut blocks = Vec::with_capacity(config.num_encoder_blocks);
407 let mut layer_norms = Vec::with_capacity(config.num_encoder_blocks);
408 for i in 0..config.num_encoder_blocks {
409 let patch_size = config.patch_sizes[i];
410 let stride = config.strides[i];
411 let hidden_size = config.hidden_sizes[i];
412 let num_channels = if i == 0 {
413 config.num_channels
414 } else {
415 config.hidden_sizes[i - 1]
416 };
417 patch_embeddings.push(SegformerOverlapPatchEmbeddings::new(
418 &config,
419 patch_size,
420 stride,
421 num_channels,
422 hidden_size,
423 vb.pp(format!("patch_embeddings.{}", i)),
424 )?);
425 let mut layers = Vec::with_capacity(config.depths[i]);
426 for j in 0..config.depths[i] {
427 let sequence_reduction_ratio = config.sr_ratios[i];
428 let num_attention_heads = config.num_attention_heads[i];
429 let mlp_ratio = config.mlp_ratios[i];
430 layers.push(SegformerLayer::new(
431 &config,
432 hidden_size,
433 num_attention_heads,
434 sequence_reduction_ratio,
435 mlp_ratio,
436 vb.pp(format!("block.{}.{}", i, j)),
437 )?);
438 }
439 blocks.push(layers);
440 layer_norms.push(layer_norm(
441 hidden_size,
442 config.layer_norm_eps,
443 vb.pp(format!("layer_norm.{}", i)),
444 )?);
445 }
446 Ok(Self {
447 config,
448 patch_embeddings,
449 blocks,
450 layer_norms,
451 })
452 }
453}
454
455impl ModuleWithHiddenStates for SegformerEncoder {
456 fn forward(&self, x: &Tensor) -> Result<Vec<Tensor>> {
457 let mut all_hidden_states = Vec::with_capacity(self.config.num_encoder_blocks);
458 let mut hidden_states = x.clone();
459 for i in 0..self.config.num_encoder_blocks {
460 hidden_states = self.patch_embeddings[i].forward(&hidden_states)?;
461 for layer in &self.blocks[i] {
462 hidden_states = layer.forward(&hidden_states)?;
463 }
464 let shape = hidden_states.shape().dims4()?;
465 hidden_states =
466 self.layer_norms[i].forward(&hidden_states.flatten_from(2)?.permute((0, 2, 1))?)?;
467 hidden_states = hidden_states.permute((0, 2, 1))?.reshape(shape)?;
468 all_hidden_states.push(hidden_states.clone());
469 }
470 Ok(all_hidden_states)
471 }
472}
473
474#[derive(Debug, Clone)]
475struct SegformerModel {
476 encoder: SegformerEncoder,
477}
478
479impl SegformerModel {
480 fn new(config: &Config, vb: VarBuilder) -> Result<Self> {
481 let encoder = SegformerEncoder::new(config.clone(), vb.pp("encoder"))?;
482 Ok(Self { encoder })
483 }
484}
485
486impl ModuleWithHiddenStates for SegformerModel {
487 fn forward(&self, x: &Tensor) -> Result<Vec<Tensor>> {
488 self.encoder.forward(x)
489 }
490}
491
492#[derive(Debug, Clone)]
493struct SegformerMLP {
494 proj: Linear,
495}
496
497impl SegformerMLP {
498 fn new(config: &Config, input_dim: usize, vb: VarBuilder) -> Result<Self> {
499 let proj = linear(input_dim, config.decoder_hidden_size, vb.pp("proj"))?;
500 Ok(Self { proj })
501 }
502}
503
504impl Module for SegformerMLP {
505 fn forward(&self, x: &Tensor) -> Result<Tensor> {
506 self.proj.forward(x)
507 }
508}
509
510#[derive(Debug, Clone)]
511struct SegformerDecodeHead {
512 linear_c: Vec<SegformerMLP>,
513 linear_fuse: candle_nn::Conv2d,
514 batch_norm: candle_nn::BatchNorm,
515 classifier: candle_nn::Conv2d,
516}
517
518impl SegformerDecodeHead {
519 fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result<Self> {
520 let mut linear_c = Vec::with_capacity(config.num_encoder_blocks);
521 for i in 0..config.num_encoder_blocks {
522 let hidden_size = config.hidden_sizes[i];
523 linear_c.push(SegformerMLP::new(
524 config,
525 hidden_size,
526 vb.pp(format!("linear_c.{}", i)),
527 )?);
528 }
529 let linear_fuse = conv2d_no_bias(
530 config.decoder_hidden_size * config.num_encoder_blocks,
531 config.decoder_hidden_size,
532 1,
533 Conv2dConfig::default(),
534 vb.pp("linear_fuse"),
535 )?;
536 let batch_norm = candle_nn::batch_norm(
537 config.decoder_hidden_size,
538 config.layer_norm_eps,
539 vb.pp("batch_norm"),
540 )?;
541 let classifier = conv2d_no_bias(
542 config.decoder_hidden_size,
543 num_labels,
544 1,
545 Conv2dConfig::default(),
546 vb.pp("classifier"),
547 )?;
548 Ok(Self {
549 linear_c,
550 linear_fuse,
551 batch_norm,
552 classifier,
553 })
554 }
555
556 fn forward(&self, encoder_hidden_states: &[Tensor]) -> Result<Tensor> {
557 if encoder_hidden_states.len() != self.linear_c.len() {
558 candle::bail!(
559 "The number of encoder hidden states {} is not equal to the number of linear layers {}",
560 encoder_hidden_states.len(),
561 self.linear_c.len()
562 )
563 }
564 let (_, _, upsample_height, upsample_width) = encoder_hidden_states[0].shape().dims4()?;
566 let mut hidden_states = Vec::with_capacity(self.linear_c.len());
567 for (hidden_state, mlp) in encoder_hidden_states.iter().zip(&self.linear_c) {
568 let (batch, _, height, width) = hidden_state.shape().dims4()?;
569 let hidden_state = mlp.forward(&hidden_state.flatten_from(2)?.permute((0, 2, 1))?)?;
570 let hidden_state = hidden_state.permute((0, 2, 1))?.reshape((
571 batch,
572 hidden_state.dim(2)?,
573 height,
574 width,
575 ))?;
576 let hidden_state = hidden_state.upsample_nearest2d(upsample_height, upsample_width)?;
577 hidden_states.push(hidden_state);
578 }
579 hidden_states.reverse();
580 let hidden_states = Tensor::cat(&hidden_states, 1)?;
581 let hidden_states = self.linear_fuse.forward(&hidden_states)?;
582 let hidden_states = self.batch_norm.forward_t(&hidden_states, false)?;
583 let hidden_states = hidden_states.relu()?;
584 self.classifier.forward(&hidden_states)
585 }
586}
587
588trait ModuleWithHiddenStates {
589 fn forward(&self, xs: &Tensor) -> Result<Vec<Tensor>>;
590}
591
592#[derive(Debug, Clone)]
593pub struct SemanticSegmentationModel {
594 segformer: SegformerModel,
595 decode_head: SegformerDecodeHead,
596}
597
598impl SemanticSegmentationModel {
599 pub fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result<Self> {
600 let segformer = SegformerModel::new(config, vb.pp("segformer"))?;
601 let decode_head = SegformerDecodeHead::new(config, num_labels, vb.pp("decode_head"))?;
602 Ok(Self {
603 segformer,
604 decode_head,
605 })
606 }
607}
608
609impl Module for SemanticSegmentationModel {
610 fn forward(&self, x: &Tensor) -> Result<Tensor> {
611 let hidden_states = self.segformer.forward(x)?;
612 self.decode_head.forward(&hidden_states)
613 }
614}
615
616#[derive(Debug, Clone)]
617pub struct ImageClassificationModel {
618 segformer: SegformerModel,
619 classifier: Linear,
620}
621
622impl ImageClassificationModel {
623 pub fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result<Self> {
624 let segformer = SegformerModel::new(config, vb.pp("segformer"))?;
625 let classifier = linear(config.decoder_hidden_size, num_labels, vb.pp("classifier"))?;
626 Ok(Self {
627 segformer,
628 classifier,
629 })
630 }
631}
632
633impl Module for ImageClassificationModel {
634 fn forward(&self, x: &Tensor) -> Result<Tensor> {
635 let all_hidden_states = self.segformer.forward(x)?;
636 let hidden_states = all_hidden_states.last().context("no last")?;
637 let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?;
638 let mean = hidden_states.mean(1)?;
639 self.classifier.forward(&mean)
640 }
641}
642
643#[cfg(test)]
644mod tests {
645
646 use super::*;
647
648 #[test]
649 fn test_config_json_load() {
650 let raw_json = r#"{
651 "architectures": [
652 "SegformerForImageClassification"
653 ],
654 "attention_probs_dropout_prob": 0.0,
655 "classifier_dropout_prob": 0.1,
656 "decoder_hidden_size": 256,
657 "depths": [
658 2,
659 2,
660 2,
661 2
662 ],
663 "downsampling_rates": [
664 1,
665 4,
666 8,
667 16
668 ],
669 "drop_path_rate": 0.1,
670 "hidden_act": "gelu",
671 "hidden_dropout_prob": 0.0,
672 "hidden_sizes": [
673 32,
674 64,
675 160,
676 256
677 ],
678 "image_size": 224,
679 "initializer_range": 0.02,
680 "layer_norm_eps": 1e-06,
681 "mlp_ratios": [
682 4,
683 4,
684 4,
685 4
686 ],
687 "model_type": "segformer",
688 "num_attention_heads": [
689 1,
690 2,
691 5,
692 8
693 ],
694 "num_channels": 3,
695 "num_encoder_blocks": 4,
696 "patch_sizes": [
697 7,
698 3,
699 3,
700 3
701 ],
702 "sr_ratios": [
703 8,
704 4,
705 2,
706 1
707 ],
708 "strides": [
709 4,
710 2,
711 2,
712 2
713 ],
714 "torch_dtype": "float32",
715 "transformers_version": "4.12.0.dev0"
716 }"#;
717 let config: Config = serde_json::from_str(raw_json).unwrap();
718 assert_eq!(vec![4, 2, 2, 2], config.strides);
719 assert_eq!(1e-6, config.layer_norm_eps);
720 }
721}