candle_transformers/models/
depth_anything_v2.rs

1//! Implementation of the Depth Anything model from FAIR.
2//!
3//! See:
4//! - ["Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data"](https://github.com/LiheYoung/Depth-Anything)
5//!
6
7use std::sync::Arc;
8
9use candle::D::Minus1;
10use candle::{Module, Result, Tensor};
11use candle_nn::ops::Identity;
12use candle_nn::{
13    batch_norm, conv2d, conv2d_no_bias, conv_transpose2d, linear, seq, Activation, BatchNorm,
14    BatchNormConfig, Conv2d, Conv2dConfig, ConvTranspose2dConfig, Sequential, VarBuilder,
15};
16
17use crate::models::dinov2::DinoVisionTransformer;
18
19pub struct DepthAnythingV2Config {
20    out_channel_sizes: [usize; 4],
21    in_channel_size: usize, // embed_dim in the Dino model
22    num_features: usize,
23    use_batch_norm: bool,
24    use_class_token: bool,
25    layer_ids_vits: Vec<usize>,
26    input_image_size: usize,
27    target_patch_size: usize,
28}
29
30impl DepthAnythingV2Config {
31    #[allow(clippy::too_many_arguments)]
32    pub fn new(
33        out_channel_sizes: [usize; 4],
34        in_channel_size: usize,
35        num_features: usize,
36        use_batch_norm: bool,
37        use_class_token: bool,
38        layer_ids_vits: Vec<usize>,
39        input_image_size: usize,
40        target_patch_size: usize,
41    ) -> Self {
42        Self {
43            out_channel_sizes,
44            in_channel_size,
45            num_features,
46            use_batch_norm,
47            use_class_token,
48            layer_ids_vits,
49            input_image_size,
50            target_patch_size,
51        }
52    }
53
54    pub fn vit_small() -> Self {
55        Self {
56            out_channel_sizes: [48, 96, 192, 384],
57            in_channel_size: 384,
58            num_features: 64,
59            use_batch_norm: false,
60            use_class_token: false,
61            layer_ids_vits: vec![2, 5, 8, 11],
62            input_image_size: 518,
63            target_patch_size: 518 / 14,
64        }
65    }
66
67    pub fn vit_base() -> Self {
68        Self {
69            out_channel_sizes: [96, 192, 384, 768],
70            in_channel_size: 768,
71            num_features: 128,
72            use_batch_norm: false,
73            use_class_token: false,
74            layer_ids_vits: vec![2, 5, 8, 11],
75            input_image_size: 518,
76            target_patch_size: 518 / 14,
77        }
78    }
79
80    pub fn vit_large() -> Self {
81        Self {
82            out_channel_sizes: [256, 512, 1024, 1024],
83            in_channel_size: 1024,
84            num_features: 256,
85            use_batch_norm: false,
86            use_class_token: false,
87            layer_ids_vits: vec![4, 11, 17, 23],
88            input_image_size: 518,
89            target_patch_size: 518 / 14,
90        }
91    }
92
93    pub fn vit_giant() -> Self {
94        Self {
95            out_channel_sizes: [1536, 1536, 1536, 1536],
96            in_channel_size: 1536,
97            num_features: 384,
98            use_batch_norm: false,
99            use_class_token: false,
100            layer_ids_vits: vec![9, 19, 29, 39],
101            input_image_size: 518,
102            target_patch_size: 518 / 14,
103        }
104    }
105}
106
107pub struct ResidualConvUnit {
108    activation: Activation,
109    conv1: Conv2d,
110    conv2: Conv2d,
111    batch_norm1: Option<BatchNorm>,
112    batch_norm2: Option<BatchNorm>,
113}
114
115impl ResidualConvUnit {
116    pub fn new(
117        conf: &DepthAnythingV2Config,
118        activation: Activation,
119        vb: VarBuilder,
120    ) -> Result<Self> {
121        const KERNEL_SIZE: usize = 3;
122        let conv_cfg = Conv2dConfig {
123            padding: 1,
124            stride: 1,
125            dilation: 1,
126            groups: 1,
127        };
128        let conv1 = conv2d(
129            conf.num_features,
130            conf.num_features,
131            KERNEL_SIZE,
132            conv_cfg,
133            vb.pp("conv1"),
134        )?;
135        let conv2 = conv2d(
136            conf.num_features,
137            conf.num_features,
138            KERNEL_SIZE,
139            conv_cfg,
140            vb.pp("conv2"),
141        )?;
142
143        let (batch_norm1, batch_norm2) = match conf.use_batch_norm {
144            true => {
145                let batch_norm_cfg = BatchNormConfig {
146                    eps: 1e-05,
147                    remove_mean: false,
148                    affine: true,
149                    momentum: 0.1,
150                };
151                (
152                    Some(batch_norm(conf.num_features, batch_norm_cfg, vb.pp("bn1"))?),
153                    Some(batch_norm(conf.num_features, batch_norm_cfg, vb.pp("bn2"))?),
154                )
155            }
156            false => (None, None),
157        };
158
159        Ok(Self {
160            activation,
161            conv1,
162            conv2,
163            batch_norm1,
164            batch_norm2,
165        })
166    }
167}
168
169impl Module for ResidualConvUnit {
170    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
171        let out = self.activation.forward(xs)?;
172        let out = self.conv1.forward(&out)?;
173        let out = if let Some(batch_norm1) = &self.batch_norm1 {
174            batch_norm1.forward_train(&out)?
175        } else {
176            out
177        };
178
179        let out = self.activation.forward(&out)?;
180        let out = self.conv2.forward(&out)?;
181        let out = if let Some(batch_norm2) = &self.batch_norm2 {
182            batch_norm2.forward_train(&out)?
183        } else {
184            out
185        };
186
187        out + xs
188    }
189}
190
191pub struct FeatureFusionBlock {
192    res_conv_unit1: ResidualConvUnit,
193    res_conv_unit2: ResidualConvUnit,
194    output_conv: Conv2d,
195    target_patch_size: usize,
196}
197
198impl FeatureFusionBlock {
199    pub fn new(
200        conf: &DepthAnythingV2Config,
201        target_patch_size: usize,
202        activation: Activation,
203        vb: VarBuilder,
204    ) -> Result<Self> {
205        const KERNEL_SIZE: usize = 1;
206        let conv_cfg = Conv2dConfig {
207            padding: 0,
208            stride: 1,
209            dilation: 1,
210            groups: 1,
211        };
212        let output_conv = conv2d(
213            conf.num_features,
214            conf.num_features,
215            KERNEL_SIZE,
216            conv_cfg,
217            vb.pp("out_conv"),
218        )?;
219        let res_conv_unit1 = ResidualConvUnit::new(conf, activation, vb.pp("resConfUnit1"))?;
220        let res_conv_unit2 = ResidualConvUnit::new(conf, activation, vb.pp("resConfUnit2"))?;
221
222        Ok(Self {
223            res_conv_unit1,
224            res_conv_unit2,
225            output_conv,
226            target_patch_size,
227        })
228    }
229}
230
231impl Module for FeatureFusionBlock {
232    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
233        let out = self.res_conv_unit2.forward(xs)?;
234        let out = out.interpolate2d(self.target_patch_size, self.target_patch_size)?;
235
236        self.output_conv.forward(&out)
237    }
238}
239
240pub struct Scratch {
241    layer1_rn: Conv2d,
242    layer2_rn: Conv2d,
243    layer3_rn: Conv2d,
244    layer4_rn: Conv2d,
245    refine_net1: FeatureFusionBlock,
246    refine_net2: FeatureFusionBlock,
247    refine_net3: FeatureFusionBlock,
248    refine_net4: FeatureFusionBlock,
249    output_conv1: Conv2d,
250    output_conv2: Sequential,
251}
252
253impl Scratch {
254    pub fn new(conf: &DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> {
255        const KERNEL_SIZE: usize = 3;
256        let conv_cfg = Conv2dConfig {
257            padding: 1,
258            stride: 1,
259            dilation: 1,
260            groups: 1,
261        };
262
263        let layer1_rn = conv2d_no_bias(
264            conf.out_channel_sizes[0],
265            conf.num_features,
266            KERNEL_SIZE,
267            conv_cfg,
268            vb.pp("layer1_rn"),
269        )?;
270        let layer2_rn = conv2d_no_bias(
271            conf.out_channel_sizes[1],
272            conf.num_features,
273            KERNEL_SIZE,
274            conv_cfg,
275            vb.pp("layer2_rn"),
276        )?;
277        let layer3_rn = conv2d_no_bias(
278            conf.out_channel_sizes[2],
279            conf.num_features,
280            KERNEL_SIZE,
281            conv_cfg,
282            vb.pp("layer3_rn"),
283        )?;
284        let layer4_rn = conv2d_no_bias(
285            conf.out_channel_sizes[3],
286            conf.num_features,
287            KERNEL_SIZE,
288            conv_cfg,
289            vb.pp("layer4_rn"),
290        )?;
291
292        let refine_net1 = FeatureFusionBlock::new(
293            conf,
294            conf.target_patch_size * 8,
295            Activation::Relu,
296            vb.pp("refinenet1"),
297        )?;
298        let refine_net2 = FeatureFusionBlock::new(
299            conf,
300            conf.target_patch_size * 4,
301            Activation::Relu,
302            vb.pp("refinenet2"),
303        )?;
304        let refine_net3 = FeatureFusionBlock::new(
305            conf,
306            conf.target_patch_size * 2,
307            Activation::Relu,
308            vb.pp("refinenet3"),
309        )?;
310        let refine_net4 = FeatureFusionBlock::new(
311            conf,
312            conf.target_patch_size,
313            Activation::Relu,
314            vb.pp("refinenet4"),
315        )?;
316
317        let conv_cfg = Conv2dConfig {
318            padding: 1,
319            stride: 1,
320            dilation: 1,
321            groups: 1,
322        };
323        let output_conv1 = conv2d(
324            conf.num_features,
325            conf.num_features / 2,
326            KERNEL_SIZE,
327            conv_cfg,
328            vb.pp("output_conv1"),
329        )?;
330
331        let output_conv2 = seq();
332        const HEAD_FEATURES_2: usize = 32;
333        const OUT_CHANNELS_2: usize = 1;
334        const KERNEL_SIZE_2: usize = 1;
335        let output_conv2 = output_conv2.add(conv2d(
336            conf.num_features / 2,
337            HEAD_FEATURES_2,
338            KERNEL_SIZE,
339            conv_cfg,
340            vb.pp("output_conv2").pp("0"),
341        )?);
342        let output_conv2 = output_conv2
343            .add(Activation::Relu)
344            .add(conv2d(
345                HEAD_FEATURES_2,
346                OUT_CHANNELS_2,
347                KERNEL_SIZE_2,
348                conv_cfg,
349                vb.pp("output_conv2").pp("2"),
350            )?)
351            .add(Activation::Relu);
352
353        Ok(Self {
354            layer1_rn,
355            layer2_rn,
356            layer3_rn,
357            layer4_rn,
358            refine_net1,
359            refine_net2,
360            refine_net3,
361            refine_net4,
362            output_conv1,
363            output_conv2,
364        })
365    }
366}
367
368const NUM_CHANNELS: usize = 4;
369
370pub struct DPTHead {
371    projections: Vec<Conv2d>,
372    resize_layers: Vec<Box<dyn Module>>,
373    readout_projections: Vec<Sequential>,
374    scratch: Scratch,
375    use_class_token: bool,
376    input_image_size: usize,
377    target_patch_size: usize,
378}
379
380impl DPTHead {
381    pub fn new(conf: &DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> {
382        let mut projections: Vec<Conv2d> = Vec::with_capacity(conf.out_channel_sizes.len());
383        for (conv_index, out_channel_size) in conf.out_channel_sizes.iter().enumerate() {
384            projections.push(conv2d(
385                conf.in_channel_size,
386                *out_channel_size,
387                1,
388                Default::default(),
389                vb.pp("projects").pp(conv_index.to_string()),
390            )?);
391        }
392
393        let resize_layers: Vec<Box<dyn Module>> = vec![
394            Box::new(conv_transpose2d(
395                conf.out_channel_sizes[0],
396                conf.out_channel_sizes[0],
397                4,
398                ConvTranspose2dConfig {
399                    padding: 0,
400                    stride: 4,
401                    dilation: 1,
402                    output_padding: 0,
403                },
404                vb.pp("resize_layers").pp("0"),
405            )?),
406            Box::new(conv_transpose2d(
407                conf.out_channel_sizes[1],
408                conf.out_channel_sizes[1],
409                2,
410                ConvTranspose2dConfig {
411                    padding: 0,
412                    stride: 2,
413                    dilation: 1,
414                    output_padding: 0,
415                },
416                vb.pp("resize_layers").pp("1"),
417            )?),
418            Box::new(Identity::new()),
419            Box::new(conv2d(
420                conf.out_channel_sizes[3],
421                conf.out_channel_sizes[3],
422                3,
423                Conv2dConfig {
424                    padding: 1,
425                    stride: 2,
426                    dilation: 1,
427                    groups: 1,
428                },
429                vb.pp("resize_layers").pp("3"),
430            )?),
431        ];
432
433        let readout_projections = if conf.use_class_token {
434            let rop = Vec::with_capacity(NUM_CHANNELS);
435            for rop_index in 0..NUM_CHANNELS {
436                seq()
437                    .add(linear(
438                        2 * conf.in_channel_size,
439                        conf.in_channel_size,
440                        vb.pp("readout_projects").pp(rop_index.to_string()),
441                    )?)
442                    .add(Activation::Gelu);
443            }
444            rop
445        } else {
446            vec![]
447        };
448
449        let scratch = Scratch::new(conf, vb.pp("scratch"))?;
450
451        Ok(Self {
452            projections,
453            resize_layers,
454            readout_projections,
455            scratch,
456            use_class_token: conf.use_class_token,
457            input_image_size: conf.input_image_size,
458            target_patch_size: conf.target_patch_size,
459        })
460    }
461}
462
463impl Module for DPTHead {
464    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
465        let mut out: Vec<Tensor> = Vec::with_capacity(NUM_CHANNELS);
466        for i in 0..NUM_CHANNELS {
467            let x = if self.use_class_token {
468                let x = xs.get(i)?.get(0)?;
469                let class_token = xs.get(i)?.get(1)?;
470                let readout = class_token.unsqueeze(1)?.expand(x.shape())?;
471                let to_cat = [x, readout];
472                let cat = Tensor::cat(&to_cat, Minus1)?;
473                self.readout_projections[i].forward(&cat)?
474            } else {
475                xs.get(i)?
476            };
477            let x_dims = x.dims();
478
479            let x = x.permute((0, 2, 1))?.reshape((
480                x_dims[0],
481                x_dims[x_dims.len() - 1],
482                self.target_patch_size,
483                self.target_patch_size,
484            ))?;
485            let x = self.projections[i].forward(&x)?;
486
487            let x = self.resize_layers[i].forward(&x)?;
488            out.push(x);
489        }
490
491        let layer_1_rn = self.scratch.layer1_rn.forward(&out[0])?;
492        let layer_2_rn = self.scratch.layer2_rn.forward(&out[1])?;
493        let layer_3_rn = self.scratch.layer3_rn.forward(&out[2])?;
494        let layer_4_rn = self.scratch.layer4_rn.forward(&out[3])?;
495
496        let path4 = self.scratch.refine_net4.forward(&layer_4_rn)?;
497
498        let res3_out = self
499            .scratch
500            .refine_net3
501            .res_conv_unit1
502            .forward(&layer_3_rn)?;
503        let res3_out = path4.add(&res3_out)?;
504        let path3 = self.scratch.refine_net3.forward(&res3_out)?;
505
506        let res2_out = self
507            .scratch
508            .refine_net2
509            .res_conv_unit1
510            .forward(&layer_2_rn)?;
511        let res2_out = path3.add(&res2_out)?;
512        let path2 = self.scratch.refine_net2.forward(&res2_out)?;
513
514        let res1_out = self
515            .scratch
516            .refine_net1
517            .res_conv_unit1
518            .forward(&layer_1_rn)?;
519        let res1_out = path2.add(&res1_out)?;
520        let path1 = self.scratch.refine_net1.forward(&res1_out)?;
521
522        let out = self.scratch.output_conv1.forward(&path1)?;
523
524        let out = out.interpolate2d(self.input_image_size, self.input_image_size)?;
525
526        self.scratch.output_conv2.forward(&out)
527    }
528}
529
530pub struct DepthAnythingV2 {
531    pretrained: Arc<DinoVisionTransformer>,
532    depth_head: DPTHead,
533    conf: DepthAnythingV2Config,
534}
535
536impl DepthAnythingV2 {
537    pub fn new(
538        pretrained: Arc<DinoVisionTransformer>,
539        conf: DepthAnythingV2Config,
540        vb: VarBuilder,
541    ) -> Result<Self> {
542        let depth_head = DPTHead::new(&conf, vb.pp("depth_head"))?;
543
544        Ok(Self {
545            pretrained,
546            depth_head,
547            conf,
548        })
549    }
550}
551
552impl Module for DepthAnythingV2 {
553    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
554        let features = self.pretrained.get_intermediate_layers(
555            xs,
556            &self.conf.layer_ids_vits,
557            false,
558            false,
559            true,
560        )?;
561        let depth = self.depth_head.forward(&features)?;
562
563        depth.relu()
564    }
565}