candle_transformers/models/segment_anything/
mask_decoder.rs1use candle::{IndexOp, Result, Tensor};
2use candle_nn::{Module, VarBuilder};
3
4use super::transformer::TwoWayTransformer;
5
6#[derive(Debug)]
7struct MlpMaskDecoder {
8 layers: Vec<super::Linear>,
9 sigmoid_output: bool,
10 span: tracing::Span,
11}
12
13impl MlpMaskDecoder {
14 fn new(
15 input_dim: usize,
16 hidden_dim: usize,
17 output_dim: usize,
18 num_layers: usize,
19 sigmoid_output: bool,
20 vb: VarBuilder,
21 ) -> Result<Self> {
22 let mut layers = Vec::with_capacity(num_layers);
23 let vb = vb.pp("layers");
24 for i in 0..num_layers {
25 let in_dim = if i == 0 { input_dim } else { hidden_dim };
26 let out_dim = if i + 1 == num_layers {
27 output_dim
28 } else {
29 hidden_dim
30 };
31 let layer = super::linear(vb.pp(i), in_dim, out_dim, true)?;
32 layers.push(layer)
33 }
34 let span = tracing::span!(tracing::Level::TRACE, "mlp-mask-decoder");
35 Ok(Self {
36 layers,
37 sigmoid_output,
38 span,
39 })
40 }
41}
42
43impl Module for MlpMaskDecoder {
44 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
45 let _enter = self.span.enter();
46 let mut xs = xs.clone();
47 for (i, layer) in self.layers.iter().enumerate() {
48 xs = layer.forward(&xs)?;
49 if i + 1 < self.layers.len() {
50 xs = xs.relu()?
51 }
52 }
53 if self.sigmoid_output {
54 candle_nn::ops::sigmoid(&xs)
55 } else {
56 Ok(xs)
57 }
58 }
59}
60
61#[derive(Debug)]
62pub struct MaskDecoder {
63 iou_token: candle_nn::Embedding,
64 mask_tokens: candle_nn::Embedding,
65 iou_prediction_head: MlpMaskDecoder,
66 output_upscaling_conv1: candle_nn::ConvTranspose2d,
67 output_upscaling_ln: super::LayerNorm2d,
68 output_upscaling_conv2: candle_nn::ConvTranspose2d,
69 num_mask_tokens: usize,
70 output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
71 transformer: TwoWayTransformer,
72 span: tracing::Span,
73}
74
75impl MaskDecoder {
76 pub fn new(
77 transformer_dim: usize,
78 num_multimask_outputs: usize,
79 iou_head_depth: usize,
80 iou_head_hidden_dim: usize,
81 vb: VarBuilder,
82 ) -> Result<Self> {
83 let num_mask_tokens = num_multimask_outputs + 1;
84 let iou_prediction_head = MlpMaskDecoder::new(
85 transformer_dim,
86 iou_head_hidden_dim,
87 num_mask_tokens,
88 iou_head_depth,
89 false,
90 vb.pp("iou_prediction_head"),
91 )?;
92 let iou_token = candle_nn::embedding(1, transformer_dim, vb.pp("iou_token"))?;
93 let mask_tokens =
94 candle_nn::embedding(num_mask_tokens, transformer_dim, vb.pp("mask_tokens"))?;
95 let cfg = candle_nn::ConvTranspose2dConfig {
96 stride: 2,
97 ..Default::default()
98 };
99 let output_upscaling_conv1 = candle_nn::conv_transpose2d(
100 transformer_dim,
101 transformer_dim / 4,
102 2,
103 cfg,
104 vb.pp("output_upscaling.0"),
105 )?;
106 let output_upscaling_ln =
107 super::LayerNorm2d::new(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?;
108 let output_upscaling_conv2 = candle_nn::conv_transpose2d(
109 transformer_dim / 4,
110 transformer_dim / 8,
111 2,
112 cfg,
113 vb.pp("output_upscaling.3"),
114 )?;
115 let mut output_hypernetworks_mlps = Vec::with_capacity(num_mask_tokens);
116 let vb_o = vb.pp("output_hypernetworks_mlps");
117 for i in 0..num_mask_tokens {
118 let mlp = MlpMaskDecoder::new(
119 transformer_dim,
120 transformer_dim,
121 transformer_dim / 8,
122 3,
123 false,
124 vb_o.pp(i),
125 )?;
126 output_hypernetworks_mlps.push(mlp)
127 }
128 let transformer = TwoWayTransformer::new(
129 2,
130 transformer_dim,
131 8,
132 2048,
133 vb.pp("transformer"),
134 )?;
135 let span = tracing::span!(tracing::Level::TRACE, "mask-decoder");
136 Ok(Self {
137 iou_token,
138 mask_tokens,
139 iou_prediction_head,
140 output_upscaling_conv1,
141 output_upscaling_ln,
142 output_upscaling_conv2,
143 num_mask_tokens,
144 output_hypernetworks_mlps,
145 transformer,
146 span,
147 })
148 }
149
150 pub fn forward(
151 &self,
152 image_embeddings: &Tensor,
153 image_pe: &Tensor,
154 sparse_prompt_embeddings: &Tensor,
155 dense_prompt_embeddings: &Tensor,
156 multimask_output: bool,
157 ) -> Result<(Tensor, Tensor)> {
158 let _enter = self.span.enter();
159 let (masks, iou_pred) = self.predict_masks(
160 image_embeddings,
161 image_pe,
162 sparse_prompt_embeddings,
163 dense_prompt_embeddings,
164 )?;
165 let masks = if multimask_output {
166 masks.i((.., 1..))?
167 } else {
168 masks.i((.., 0..1))?
169 };
170 let iou_pred = if multimask_output {
171 iou_pred.i((.., 1..))?
172 } else {
173 iou_pred.i((.., 0..1))?
174 };
175 Ok((masks, iou_pred))
176 }
177
178 fn predict_masks(
179 &self,
180 image_embeddings: &Tensor,
181 image_pe: &Tensor,
182 sparse_prompt_embeddings: &Tensor,
183 dense_prompt_embeddings: &Tensor,
184 ) -> Result<(Tensor, Tensor)> {
185 let output_tokens = Tensor::cat(
187 &[self.iou_token.embeddings(), self.mask_tokens.embeddings()],
188 0,
189 )?;
190 let (d1, d2) = output_tokens.dims2()?;
191 let output_tokens =
192 output_tokens
193 .unsqueeze(0)?
194 .expand((sparse_prompt_embeddings.dim(0)?, d1, d2))?;
195 let tokens = Tensor::cat(&[&output_tokens, sparse_prompt_embeddings], 1)?;
196
197 let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?;
199 let src = src.broadcast_add(dense_prompt_embeddings)?;
200 let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?;
201 let (b, c, h, w) = src.dims4()?;
202
203 let (hs, src) = self.transformer.forward(&src, &pos_src, &tokens)?;
205 let iou_token_out = hs.i((.., 0))?;
206 let mask_tokens_out = hs.i((.., 1..1 + self.num_mask_tokens))?;
207
208 let src = src.transpose(1, 2)?.reshape((b, c, h, w))?;
210 let upscaled_embedding = self
211 .output_upscaling_conv1
212 .forward(&src)?
213 .apply(&self.output_upscaling_ln)?
214 .gelu()?
215 .apply(&self.output_upscaling_conv2)?
216 .gelu()?;
217 let mut hyper_in_list = Vec::with_capacity(self.num_mask_tokens);
218 for (i, mlp) in self.output_hypernetworks_mlps.iter().enumerate() {
219 let h = mlp.forward(&mask_tokens_out.i((.., i))?)?;
220 hyper_in_list.push(h)
221 }
222 let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?.contiguous()?;
223 let (b, c, h, w) = upscaled_embedding.dims4()?;
224 let masks = hyper_in.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?;
225 let masks = masks.reshape((b, (), h, w))?;
226
227 let iou_pred = self.iou_prediction_head.forward(&iou_token_out)?;
229 Ok((masks, iou_pred))
230 }
231}
232
233fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> {
235 let img = img.unsqueeze(dim + 1)?;
236 let mut dims = img.dims().to_vec();
237 dims[dim + 1] = repeats;
238 img.broadcast_as(dims)?.flatten(dim, dim + 1)
239}