1use std::sync::Arc;
8
9use candle::D::Minus1;
10use candle::{Module, Result, Tensor};
11use candle_nn::ops::Identity;
12use candle_nn::{
13 batch_norm, conv2d, conv2d_no_bias, conv_transpose2d, linear, seq, Activation, BatchNorm,
14 BatchNormConfig, Conv2d, Conv2dConfig, ConvTranspose2dConfig, Sequential, VarBuilder,
15};
16
17use crate::models::dinov2::DinoVisionTransformer;
18
19pub struct DepthAnythingV2Config {
20 out_channel_sizes: [usize; 4],
21 in_channel_size: usize, num_features: usize,
23 use_batch_norm: bool,
24 use_class_token: bool,
25 layer_ids_vits: Vec<usize>,
26 input_image_size: usize,
27 target_patch_size: usize,
28}
29
30impl DepthAnythingV2Config {
31 #[allow(clippy::too_many_arguments)]
32 pub fn new(
33 out_channel_sizes: [usize; 4],
34 in_channel_size: usize,
35 num_features: usize,
36 use_batch_norm: bool,
37 use_class_token: bool,
38 layer_ids_vits: Vec<usize>,
39 input_image_size: usize,
40 target_patch_size: usize,
41 ) -> Self {
42 Self {
43 out_channel_sizes,
44 in_channel_size,
45 num_features,
46 use_batch_norm,
47 use_class_token,
48 layer_ids_vits,
49 input_image_size,
50 target_patch_size,
51 }
52 }
53
54 pub fn vit_small() -> Self {
55 Self {
56 out_channel_sizes: [48, 96, 192, 384],
57 in_channel_size: 384,
58 num_features: 64,
59 use_batch_norm: false,
60 use_class_token: false,
61 layer_ids_vits: vec![2, 5, 8, 11],
62 input_image_size: 518,
63 target_patch_size: 518 / 14,
64 }
65 }
66
67 pub fn vit_base() -> Self {
68 Self {
69 out_channel_sizes: [96, 192, 384, 768],
70 in_channel_size: 768,
71 num_features: 128,
72 use_batch_norm: false,
73 use_class_token: false,
74 layer_ids_vits: vec![2, 5, 8, 11],
75 input_image_size: 518,
76 target_patch_size: 518 / 14,
77 }
78 }
79
80 pub fn vit_large() -> Self {
81 Self {
82 out_channel_sizes: [256, 512, 1024, 1024],
83 in_channel_size: 1024,
84 num_features: 256,
85 use_batch_norm: false,
86 use_class_token: false,
87 layer_ids_vits: vec![4, 11, 17, 23],
88 input_image_size: 518,
89 target_patch_size: 518 / 14,
90 }
91 }
92
93 pub fn vit_giant() -> Self {
94 Self {
95 out_channel_sizes: [1536, 1536, 1536, 1536],
96 in_channel_size: 1536,
97 num_features: 384,
98 use_batch_norm: false,
99 use_class_token: false,
100 layer_ids_vits: vec![9, 19, 29, 39],
101 input_image_size: 518,
102 target_patch_size: 518 / 14,
103 }
104 }
105}
106
107pub struct ResidualConvUnit {
108 activation: Activation,
109 conv1: Conv2d,
110 conv2: Conv2d,
111 batch_norm1: Option<BatchNorm>,
112 batch_norm2: Option<BatchNorm>,
113}
114
115impl ResidualConvUnit {
116 pub fn new(
117 conf: &DepthAnythingV2Config,
118 activation: Activation,
119 vb: VarBuilder,
120 ) -> Result<Self> {
121 const KERNEL_SIZE: usize = 3;
122 let conv_cfg = Conv2dConfig {
123 padding: 1,
124 stride: 1,
125 dilation: 1,
126 groups: 1,
127 };
128 let conv1 = conv2d(
129 conf.num_features,
130 conf.num_features,
131 KERNEL_SIZE,
132 conv_cfg,
133 vb.pp("conv1"),
134 )?;
135 let conv2 = conv2d(
136 conf.num_features,
137 conf.num_features,
138 KERNEL_SIZE,
139 conv_cfg,
140 vb.pp("conv2"),
141 )?;
142
143 let (batch_norm1, batch_norm2) = match conf.use_batch_norm {
144 true => {
145 let batch_norm_cfg = BatchNormConfig {
146 eps: 1e-05,
147 remove_mean: false,
148 affine: true,
149 momentum: 0.1,
150 };
151 (
152 Some(batch_norm(conf.num_features, batch_norm_cfg, vb.pp("bn1"))?),
153 Some(batch_norm(conf.num_features, batch_norm_cfg, vb.pp("bn2"))?),
154 )
155 }
156 false => (None, None),
157 };
158
159 Ok(Self {
160 activation,
161 conv1,
162 conv2,
163 batch_norm1,
164 batch_norm2,
165 })
166 }
167}
168
169impl Module for ResidualConvUnit {
170 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
171 let out = self.activation.forward(xs)?;
172 let out = self.conv1.forward(&out)?;
173 let out = if let Some(batch_norm1) = &self.batch_norm1 {
174 batch_norm1.forward_train(&out)?
175 } else {
176 out
177 };
178
179 let out = self.activation.forward(&out)?;
180 let out = self.conv2.forward(&out)?;
181 let out = if let Some(batch_norm2) = &self.batch_norm2 {
182 batch_norm2.forward_train(&out)?
183 } else {
184 out
185 };
186
187 out + xs
188 }
189}
190
191pub struct FeatureFusionBlock {
192 res_conv_unit1: ResidualConvUnit,
193 res_conv_unit2: ResidualConvUnit,
194 output_conv: Conv2d,
195 target_patch_size: usize,
196}
197
198impl FeatureFusionBlock {
199 pub fn new(
200 conf: &DepthAnythingV2Config,
201 target_patch_size: usize,
202 activation: Activation,
203 vb: VarBuilder,
204 ) -> Result<Self> {
205 const KERNEL_SIZE: usize = 1;
206 let conv_cfg = Conv2dConfig {
207 padding: 0,
208 stride: 1,
209 dilation: 1,
210 groups: 1,
211 };
212 let output_conv = conv2d(
213 conf.num_features,
214 conf.num_features,
215 KERNEL_SIZE,
216 conv_cfg,
217 vb.pp("out_conv"),
218 )?;
219 let res_conv_unit1 = ResidualConvUnit::new(conf, activation, vb.pp("resConfUnit1"))?;
220 let res_conv_unit2 = ResidualConvUnit::new(conf, activation, vb.pp("resConfUnit2"))?;
221
222 Ok(Self {
223 res_conv_unit1,
224 res_conv_unit2,
225 output_conv,
226 target_patch_size,
227 })
228 }
229}
230
231impl Module for FeatureFusionBlock {
232 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
233 let out = self.res_conv_unit2.forward(xs)?;
234 let out = out.interpolate2d(self.target_patch_size, self.target_patch_size)?;
235
236 self.output_conv.forward(&out)
237 }
238}
239
240pub struct Scratch {
241 layer1_rn: Conv2d,
242 layer2_rn: Conv2d,
243 layer3_rn: Conv2d,
244 layer4_rn: Conv2d,
245 refine_net1: FeatureFusionBlock,
246 refine_net2: FeatureFusionBlock,
247 refine_net3: FeatureFusionBlock,
248 refine_net4: FeatureFusionBlock,
249 output_conv1: Conv2d,
250 output_conv2: Sequential,
251}
252
253impl Scratch {
254 pub fn new(conf: &DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> {
255 const KERNEL_SIZE: usize = 3;
256 let conv_cfg = Conv2dConfig {
257 padding: 1,
258 stride: 1,
259 dilation: 1,
260 groups: 1,
261 };
262
263 let layer1_rn = conv2d_no_bias(
264 conf.out_channel_sizes[0],
265 conf.num_features,
266 KERNEL_SIZE,
267 conv_cfg,
268 vb.pp("layer1_rn"),
269 )?;
270 let layer2_rn = conv2d_no_bias(
271 conf.out_channel_sizes[1],
272 conf.num_features,
273 KERNEL_SIZE,
274 conv_cfg,
275 vb.pp("layer2_rn"),
276 )?;
277 let layer3_rn = conv2d_no_bias(
278 conf.out_channel_sizes[2],
279 conf.num_features,
280 KERNEL_SIZE,
281 conv_cfg,
282 vb.pp("layer3_rn"),
283 )?;
284 let layer4_rn = conv2d_no_bias(
285 conf.out_channel_sizes[3],
286 conf.num_features,
287 KERNEL_SIZE,
288 conv_cfg,
289 vb.pp("layer4_rn"),
290 )?;
291
292 let refine_net1 = FeatureFusionBlock::new(
293 conf,
294 conf.target_patch_size * 8,
295 Activation::Relu,
296 vb.pp("refinenet1"),
297 )?;
298 let refine_net2 = FeatureFusionBlock::new(
299 conf,
300 conf.target_patch_size * 4,
301 Activation::Relu,
302 vb.pp("refinenet2"),
303 )?;
304 let refine_net3 = FeatureFusionBlock::new(
305 conf,
306 conf.target_patch_size * 2,
307 Activation::Relu,
308 vb.pp("refinenet3"),
309 )?;
310 let refine_net4 = FeatureFusionBlock::new(
311 conf,
312 conf.target_patch_size,
313 Activation::Relu,
314 vb.pp("refinenet4"),
315 )?;
316
317 let conv_cfg = Conv2dConfig {
318 padding: 1,
319 stride: 1,
320 dilation: 1,
321 groups: 1,
322 };
323 let output_conv1 = conv2d(
324 conf.num_features,
325 conf.num_features / 2,
326 KERNEL_SIZE,
327 conv_cfg,
328 vb.pp("output_conv1"),
329 )?;
330
331 let output_conv2 = seq();
332 const HEAD_FEATURES_2: usize = 32;
333 const OUT_CHANNELS_2: usize = 1;
334 const KERNEL_SIZE_2: usize = 1;
335 let output_conv2 = output_conv2.add(conv2d(
336 conf.num_features / 2,
337 HEAD_FEATURES_2,
338 KERNEL_SIZE,
339 conv_cfg,
340 vb.pp("output_conv2").pp("0"),
341 )?);
342 let output_conv2 = output_conv2
343 .add(Activation::Relu)
344 .add(conv2d(
345 HEAD_FEATURES_2,
346 OUT_CHANNELS_2,
347 KERNEL_SIZE_2,
348 conv_cfg,
349 vb.pp("output_conv2").pp("2"),
350 )?)
351 .add(Activation::Relu);
352
353 Ok(Self {
354 layer1_rn,
355 layer2_rn,
356 layer3_rn,
357 layer4_rn,
358 refine_net1,
359 refine_net2,
360 refine_net3,
361 refine_net4,
362 output_conv1,
363 output_conv2,
364 })
365 }
366}
367
368const NUM_CHANNELS: usize = 4;
369
370pub struct DPTHead {
371 projections: Vec<Conv2d>,
372 resize_layers: Vec<Box<dyn Module>>,
373 readout_projections: Vec<Sequential>,
374 scratch: Scratch,
375 use_class_token: bool,
376 input_image_size: usize,
377 target_patch_size: usize,
378}
379
380impl DPTHead {
381 pub fn new(conf: &DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> {
382 let mut projections: Vec<Conv2d> = Vec::with_capacity(conf.out_channel_sizes.len());
383 for (conv_index, out_channel_size) in conf.out_channel_sizes.iter().enumerate() {
384 projections.push(conv2d(
385 conf.in_channel_size,
386 *out_channel_size,
387 1,
388 Default::default(),
389 vb.pp("projects").pp(conv_index.to_string()),
390 )?);
391 }
392
393 let resize_layers: Vec<Box<dyn Module>> = vec![
394 Box::new(conv_transpose2d(
395 conf.out_channel_sizes[0],
396 conf.out_channel_sizes[0],
397 4,
398 ConvTranspose2dConfig {
399 padding: 0,
400 stride: 4,
401 dilation: 1,
402 output_padding: 0,
403 },
404 vb.pp("resize_layers").pp("0"),
405 )?),
406 Box::new(conv_transpose2d(
407 conf.out_channel_sizes[1],
408 conf.out_channel_sizes[1],
409 2,
410 ConvTranspose2dConfig {
411 padding: 0,
412 stride: 2,
413 dilation: 1,
414 output_padding: 0,
415 },
416 vb.pp("resize_layers").pp("1"),
417 )?),
418 Box::new(Identity::new()),
419 Box::new(conv2d(
420 conf.out_channel_sizes[3],
421 conf.out_channel_sizes[3],
422 3,
423 Conv2dConfig {
424 padding: 1,
425 stride: 2,
426 dilation: 1,
427 groups: 1,
428 },
429 vb.pp("resize_layers").pp("3"),
430 )?),
431 ];
432
433 let readout_projections = if conf.use_class_token {
434 let rop = Vec::with_capacity(NUM_CHANNELS);
435 for rop_index in 0..NUM_CHANNELS {
436 seq()
437 .add(linear(
438 2 * conf.in_channel_size,
439 conf.in_channel_size,
440 vb.pp("readout_projects").pp(rop_index.to_string()),
441 )?)
442 .add(Activation::Gelu);
443 }
444 rop
445 } else {
446 vec![]
447 };
448
449 let scratch = Scratch::new(conf, vb.pp("scratch"))?;
450
451 Ok(Self {
452 projections,
453 resize_layers,
454 readout_projections,
455 scratch,
456 use_class_token: conf.use_class_token,
457 input_image_size: conf.input_image_size,
458 target_patch_size: conf.target_patch_size,
459 })
460 }
461}
462
463impl Module for DPTHead {
464 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
465 let mut out: Vec<Tensor> = Vec::with_capacity(NUM_CHANNELS);
466 for i in 0..NUM_CHANNELS {
467 let x = if self.use_class_token {
468 let x = xs.get(i)?.get(0)?;
469 let class_token = xs.get(i)?.get(1)?;
470 let readout = class_token.unsqueeze(1)?.expand(x.shape())?;
471 let to_cat = [x, readout];
472 let cat = Tensor::cat(&to_cat, Minus1)?;
473 self.readout_projections[i].forward(&cat)?
474 } else {
475 xs.get(i)?
476 };
477 let x_dims = x.dims();
478
479 let x = x.permute((0, 2, 1))?.reshape((
480 x_dims[0],
481 x_dims[x_dims.len() - 1],
482 self.target_patch_size,
483 self.target_patch_size,
484 ))?;
485 let x = self.projections[i].forward(&x)?;
486
487 let x = self.resize_layers[i].forward(&x)?;
488 out.push(x);
489 }
490
491 let layer_1_rn = self.scratch.layer1_rn.forward(&out[0])?;
492 let layer_2_rn = self.scratch.layer2_rn.forward(&out[1])?;
493 let layer_3_rn = self.scratch.layer3_rn.forward(&out[2])?;
494 let layer_4_rn = self.scratch.layer4_rn.forward(&out[3])?;
495
496 let path4 = self.scratch.refine_net4.forward(&layer_4_rn)?;
497
498 let res3_out = self
499 .scratch
500 .refine_net3
501 .res_conv_unit1
502 .forward(&layer_3_rn)?;
503 let res3_out = path4.add(&res3_out)?;
504 let path3 = self.scratch.refine_net3.forward(&res3_out)?;
505
506 let res2_out = self
507 .scratch
508 .refine_net2
509 .res_conv_unit1
510 .forward(&layer_2_rn)?;
511 let res2_out = path3.add(&res2_out)?;
512 let path2 = self.scratch.refine_net2.forward(&res2_out)?;
513
514 let res1_out = self
515 .scratch
516 .refine_net1
517 .res_conv_unit1
518 .forward(&layer_1_rn)?;
519 let res1_out = path2.add(&res1_out)?;
520 let path1 = self.scratch.refine_net1.forward(&res1_out)?;
521
522 let out = self.scratch.output_conv1.forward(&path1)?;
523
524 let out = out.interpolate2d(self.input_image_size, self.input_image_size)?;
525
526 self.scratch.output_conv2.forward(&out)
527 }
528}
529
530pub struct DepthAnythingV2 {
531 pretrained: Arc<DinoVisionTransformer>,
532 depth_head: DPTHead,
533 conf: DepthAnythingV2Config,
534}
535
536impl DepthAnythingV2 {
537 pub fn new(
538 pretrained: Arc<DinoVisionTransformer>,
539 conf: DepthAnythingV2Config,
540 vb: VarBuilder,
541 ) -> Result<Self> {
542 let depth_head = DPTHead::new(&conf, vb.pp("depth_head"))?;
543
544 Ok(Self {
545 pretrained,
546 depth_head,
547 conf,
548 })
549 }
550}
551
552impl Module for DepthAnythingV2 {
553 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
554 let features = self.pretrained.get_intermediate_layers(
555 xs,
556 &self.conf.layer_ids_vits,
557 false,
558 false,
559 true,
560 )?;
561 let depth = self.depth_head.forward(&features)?;
562
563 depth.relu()
564 }
565}