candle_transformers/models/segment_anything/
mask_decoder.rs

1use candle::{IndexOp, Result, Tensor};
2use candle_nn::{Module, VarBuilder};
3
4use super::transformer::TwoWayTransformer;
5
6#[derive(Debug)]
7struct MlpMaskDecoder {
8    layers: Vec<super::Linear>,
9    sigmoid_output: bool,
10    span: tracing::Span,
11}
12
13impl MlpMaskDecoder {
14    fn new(
15        input_dim: usize,
16        hidden_dim: usize,
17        output_dim: usize,
18        num_layers: usize,
19        sigmoid_output: bool,
20        vb: VarBuilder,
21    ) -> Result<Self> {
22        let mut layers = Vec::with_capacity(num_layers);
23        let vb = vb.pp("layers");
24        for i in 0..num_layers {
25            let in_dim = if i == 0 { input_dim } else { hidden_dim };
26            let out_dim = if i + 1 == num_layers {
27                output_dim
28            } else {
29                hidden_dim
30            };
31            let layer = super::linear(vb.pp(i), in_dim, out_dim, true)?;
32            layers.push(layer)
33        }
34        let span = tracing::span!(tracing::Level::TRACE, "mlp-mask-decoder");
35        Ok(Self {
36            layers,
37            sigmoid_output,
38            span,
39        })
40    }
41}
42
43impl Module for MlpMaskDecoder {
44    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
45        let _enter = self.span.enter();
46        let mut xs = xs.clone();
47        for (i, layer) in self.layers.iter().enumerate() {
48            xs = layer.forward(&xs)?;
49            if i + 1 < self.layers.len() {
50                xs = xs.relu()?
51            }
52        }
53        if self.sigmoid_output {
54            candle_nn::ops::sigmoid(&xs)
55        } else {
56            Ok(xs)
57        }
58    }
59}
60
61#[derive(Debug)]
62pub struct MaskDecoder {
63    iou_token: candle_nn::Embedding,
64    mask_tokens: candle_nn::Embedding,
65    iou_prediction_head: MlpMaskDecoder,
66    output_upscaling_conv1: candle_nn::ConvTranspose2d,
67    output_upscaling_ln: super::LayerNorm2d,
68    output_upscaling_conv2: candle_nn::ConvTranspose2d,
69    num_mask_tokens: usize,
70    output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
71    transformer: TwoWayTransformer,
72    span: tracing::Span,
73}
74
75impl MaskDecoder {
76    pub fn new(
77        transformer_dim: usize,
78        num_multimask_outputs: usize,
79        iou_head_depth: usize,
80        iou_head_hidden_dim: usize,
81        vb: VarBuilder,
82    ) -> Result<Self> {
83        let num_mask_tokens = num_multimask_outputs + 1;
84        let iou_prediction_head = MlpMaskDecoder::new(
85            transformer_dim,
86            iou_head_hidden_dim,
87            num_mask_tokens,
88            iou_head_depth,
89            false,
90            vb.pp("iou_prediction_head"),
91        )?;
92        let iou_token = candle_nn::embedding(1, transformer_dim, vb.pp("iou_token"))?;
93        let mask_tokens =
94            candle_nn::embedding(num_mask_tokens, transformer_dim, vb.pp("mask_tokens"))?;
95        let cfg = candle_nn::ConvTranspose2dConfig {
96            stride: 2,
97            ..Default::default()
98        };
99        let output_upscaling_conv1 = candle_nn::conv_transpose2d(
100            transformer_dim,
101            transformer_dim / 4,
102            2,
103            cfg,
104            vb.pp("output_upscaling.0"),
105        )?;
106        let output_upscaling_ln =
107            super::LayerNorm2d::new(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?;
108        let output_upscaling_conv2 = candle_nn::conv_transpose2d(
109            transformer_dim / 4,
110            transformer_dim / 8,
111            2,
112            cfg,
113            vb.pp("output_upscaling.3"),
114        )?;
115        let mut output_hypernetworks_mlps = Vec::with_capacity(num_mask_tokens);
116        let vb_o = vb.pp("output_hypernetworks_mlps");
117        for i in 0..num_mask_tokens {
118            let mlp = MlpMaskDecoder::new(
119                transformer_dim,
120                transformer_dim,
121                transformer_dim / 8,
122                3,
123                false,
124                vb_o.pp(i),
125            )?;
126            output_hypernetworks_mlps.push(mlp)
127        }
128        let transformer = TwoWayTransformer::new(
129            /* depth */ 2,
130            /* embedding_dim */ transformer_dim,
131            /* num_heads */ 8,
132            /* mlp_dim */ 2048,
133            vb.pp("transformer"),
134        )?;
135        let span = tracing::span!(tracing::Level::TRACE, "mask-decoder");
136        Ok(Self {
137            iou_token,
138            mask_tokens,
139            iou_prediction_head,
140            output_upscaling_conv1,
141            output_upscaling_ln,
142            output_upscaling_conv2,
143            num_mask_tokens,
144            output_hypernetworks_mlps,
145            transformer,
146            span,
147        })
148    }
149
150    pub fn forward(
151        &self,
152        image_embeddings: &Tensor,
153        image_pe: &Tensor,
154        sparse_prompt_embeddings: &Tensor,
155        dense_prompt_embeddings: &Tensor,
156        multimask_output: bool,
157    ) -> Result<(Tensor, Tensor)> {
158        let _enter = self.span.enter();
159        let (masks, iou_pred) = self.predict_masks(
160            image_embeddings,
161            image_pe,
162            sparse_prompt_embeddings,
163            dense_prompt_embeddings,
164        )?;
165        let masks = if multimask_output {
166            masks.i((.., 1..))?
167        } else {
168            masks.i((.., 0..1))?
169        };
170        let iou_pred = if multimask_output {
171            iou_pred.i((.., 1..))?
172        } else {
173            iou_pred.i((.., 0..1))?
174        };
175        Ok((masks, iou_pred))
176    }
177
178    fn predict_masks(
179        &self,
180        image_embeddings: &Tensor,
181        image_pe: &Tensor,
182        sparse_prompt_embeddings: &Tensor,
183        dense_prompt_embeddings: &Tensor,
184    ) -> Result<(Tensor, Tensor)> {
185        // Concatenate output tokens.
186        let output_tokens = Tensor::cat(
187            &[self.iou_token.embeddings(), self.mask_tokens.embeddings()],
188            0,
189        )?;
190        let (d1, d2) = output_tokens.dims2()?;
191        let output_tokens =
192            output_tokens
193                .unsqueeze(0)?
194                .expand((sparse_prompt_embeddings.dim(0)?, d1, d2))?;
195        let tokens = Tensor::cat(&[&output_tokens, sparse_prompt_embeddings], 1)?;
196
197        // Expand per-image data in batch direction to be per mask
198        let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?;
199        let src = src.broadcast_add(dense_prompt_embeddings)?;
200        let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?;
201        let (b, c, h, w) = src.dims4()?;
202
203        // Run the transformer
204        let (hs, src) = self.transformer.forward(&src, &pos_src, &tokens)?;
205        let iou_token_out = hs.i((.., 0))?;
206        let mask_tokens_out = hs.i((.., 1..1 + self.num_mask_tokens))?;
207
208        // Upscale mask embeddings and predict masks using the masks tokens.
209        let src = src.transpose(1, 2)?.reshape((b, c, h, w))?;
210        let upscaled_embedding = self
211            .output_upscaling_conv1
212            .forward(&src)?
213            .apply(&self.output_upscaling_ln)?
214            .gelu()?
215            .apply(&self.output_upscaling_conv2)?
216            .gelu()?;
217        let mut hyper_in_list = Vec::with_capacity(self.num_mask_tokens);
218        for (i, mlp) in self.output_hypernetworks_mlps.iter().enumerate() {
219            let h = mlp.forward(&mask_tokens_out.i((.., i))?)?;
220            hyper_in_list.push(h)
221        }
222        let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?.contiguous()?;
223        let (b, c, h, w) = upscaled_embedding.dims4()?;
224        let masks = hyper_in.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?;
225        let masks = masks.reshape((b, (), h, w))?;
226
227        // Generate mask quality predictions.
228        let iou_pred = self.iou_prediction_head.forward(&iou_token_out)?;
229        Ok((masks, iou_pred))
230    }
231}
232
233// Equivalent to torch.repeat_interleave
234fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> {
235    let img = img.unsqueeze(dim + 1)?;
236    let mut dims = img.dims().to_vec();
237    dims[dim + 1] = repeats;
238    img.broadcast_as(dims)?.flatten(dim, dim + 1)
239}