candle_transformers/models/segment_anything/
mod.rs1pub use crate::models::with_tracing::Linear;
33use candle::{Result, Tensor};
34use candle_nn::{Module, VarBuilder};
35
36pub mod image_encoder;
37pub mod mask_decoder;
38pub mod prompt_encoder;
39pub mod sam;
40pub mod tiny_vit;
41pub mod transformer;
42
43pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
44 if bias {
45 crate::models::with_tracing::linear(in_dim, out_dim, vb)
46 } else {
47 crate::models::with_tracing::linear_no_bias(in_dim, out_dim, vb)
48 }
49}
50
51#[derive(Debug)]
52pub struct LayerNorm2d {
53 weight: Tensor,
54 bias: Tensor,
55 num_channels: usize,
56 eps: f64,
57}
58
59impl LayerNorm2d {
60 pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
61 let weight = vb.get(num_channels, "weight")?;
62 let bias = vb.get(num_channels, "bias")?;
63 Ok(Self {
64 weight,
65 bias,
66 num_channels,
67 eps,
68 })
69 }
70}
71
72impl Module for LayerNorm2d {
73 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
74 let u = xs.mean_keepdim(1)?;
75 let xs = xs.broadcast_sub(&u)?;
76 let s = xs.sqr()?.mean_keepdim(1)?;
77 let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?;
78 xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)?
79 .broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?)
80 }
81}
82
83#[derive(Debug)]
84pub struct MlpBlock {
85 lin1: Linear,
86 lin2: Linear,
87 activation: candle_nn::Activation,
88 span: tracing::Span,
89}
90
91impl MlpBlock {
92 pub fn new(
93 embedding_dim: usize,
94 mlp_dim: usize,
95 activation: candle_nn::Activation,
96 vb: VarBuilder,
97 ) -> Result<Self> {
98 let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?;
99 let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?;
100 let span = tracing::span!(tracing::Level::TRACE, "mlp-block");
101 Ok(Self {
102 lin1,
103 lin2,
104 activation,
105 span,
106 })
107 }
108}
109
110impl Module for MlpBlock {
111 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
112 let _enter = self.span.enter();
113 xs.apply(&self.lin1)?
114 .apply(&self.activation)?
115 .apply(&self.lin2)
116 }
117}