candle_transformers/models/segment_anything/
mod.rs

1//! Segment Anything Model (SAM)
2//!
3//! SAM is an architecture for image segmentation, capable of segmenting any object
4//! in an image based on prompts like points or boxes. //! This model provides a robust and fast image segmentation pipeline that can be tweaked via
5//! some prompting (requesting some points to be in the target mask, requesting some
6//! points to be part of the background so _not_ in the target mask, specifying some
7//! bounding box).
8//!
9//! - ⚡ [Interactive Wasm Example](https://huggingface.co/spaces/radames/candle-segment-anything-wasm)
10//! - 💻 [GH Link](https://github.com/facebookresearch/segment-anything)
11//! - 📝 [Paper](https://arxiv.org/abs/2304.02643)
12//! - 💡 The default backbone can be replaced by the smaller and faster TinyViT model based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
13//!
14//!
15//! ## Example
16//!
17//! ```bash
18//! cargo run --example segment-anything --release -- \
19//!     --image candle-examples/examples/yolo-v8/assets/bike.jpg
20//!     --use-tiny --point 0.6,0.6 --point 0.6,0.55
21//! ```
22//!
23//! <div align=center style="display: flex; justify-content: center; gap: 10px;">
24//!   <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.jpg" alt="" width="30%">
25//!   <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/single_pt_prompt.jpg" alt="" width="30%">
26//!   <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/two_pt_prompt.jpg" alt="" width="30%">
27//! </div>
28//!
29//!
30//! > Original; Prompt with `--point 0.6,0.55`; Prompt with `--point 0.6,0.6 --point 0.6,0.55`
31//!
32pub use crate::models::with_tracing::Linear;
33use candle::{Result, Tensor};
34use candle_nn::{Module, VarBuilder};
35
36pub mod image_encoder;
37pub mod mask_decoder;
38pub mod prompt_encoder;
39pub mod sam;
40pub mod tiny_vit;
41pub mod transformer;
42
43pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
44    if bias {
45        crate::models::with_tracing::linear(in_dim, out_dim, vb)
46    } else {
47        crate::models::with_tracing::linear_no_bias(in_dim, out_dim, vb)
48    }
49}
50
51#[derive(Debug)]
52pub struct LayerNorm2d {
53    weight: Tensor,
54    bias: Tensor,
55    num_channels: usize,
56    eps: f64,
57}
58
59impl LayerNorm2d {
60    pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
61        let weight = vb.get(num_channels, "weight")?;
62        let bias = vb.get(num_channels, "bias")?;
63        Ok(Self {
64            weight,
65            bias,
66            num_channels,
67            eps,
68        })
69    }
70}
71
72impl Module for LayerNorm2d {
73    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
74        let u = xs.mean_keepdim(1)?;
75        let xs = xs.broadcast_sub(&u)?;
76        let s = xs.sqr()?.mean_keepdim(1)?;
77        let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?;
78        xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)?
79            .broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?)
80    }
81}
82
83#[derive(Debug)]
84pub struct MlpBlock {
85    lin1: Linear,
86    lin2: Linear,
87    activation: candle_nn::Activation,
88    span: tracing::Span,
89}
90
91impl MlpBlock {
92    pub fn new(
93        embedding_dim: usize,
94        mlp_dim: usize,
95        activation: candle_nn::Activation,
96        vb: VarBuilder,
97    ) -> Result<Self> {
98        let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?;
99        let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?;
100        let span = tracing::span!(tracing::Level::TRACE, "mlp-block");
101        Ok(Self {
102            lin1,
103            lin2,
104            activation,
105            span,
106        })
107    }
108}
109
110impl Module for MlpBlock {
111    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
112        let _enter = self.span.enter();
113        xs.apply(&self.lin1)?
114            .apply(&self.activation)?
115            .apply(&self.lin2)
116    }
117}