candle_transformers/models/
hiera.rs

1//! Hiera inference implementation based on timm.
2//!
3//!
4//! - 💻 [Hiera](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/hiera.py)
5//! - 📝 [Paper](https://arxiv.org/abs/2306.00989). Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles
6
7use candle::{Result, D};
8use candle_nn::{conv2d, layer_norm, linear, ops::softmax, Conv2dConfig, Func, VarBuilder};
9
10#[derive(Debug, Clone, serde::Deserialize)]
11pub struct Config {
12    channels: usize,
13    heads: usize,
14    stages: [usize; 4],
15}
16
17impl Config {
18    pub fn tiny() -> Self {
19        Self {
20            channels: 96,
21            heads: 1,
22            stages: [1, 2, 7, 2],
23        }
24    }
25    pub fn small() -> Self {
26        Self {
27            channels: 96,
28            heads: 1,
29            stages: [1, 2, 11, 2],
30        }
31    }
32    pub fn base() -> Self {
33        Self {
34            channels: 96,
35            heads: 1,
36            stages: [2, 3, 16, 3],
37        }
38    }
39    pub fn base_plus() -> Self {
40        Self {
41            channels: 112,
42            heads: 2,
43            stages: [2, 3, 16, 3],
44        }
45    }
46    pub fn large() -> Self {
47        Self {
48            channels: 144,
49            heads: 2,
50            stages: [2, 6, 36, 4],
51        }
52    }
53    pub fn huge() -> Self {
54        Self {
55            channels: 256,
56            heads: 4,
57            stages: [2, 6, 36, 4],
58        }
59    }
60}
61
62const NUM_TOKENS: usize = 56 * 56;
63
64fn hiera_embeddings(channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
65    let conv_cfg = Conv2dConfig {
66        stride: 4,
67        padding: 3,
68        ..Default::default()
69    };
70    let proj = conv2d(3, channels, 7, conv_cfg, vb.pp("patch_embed.proj"))?;
71
72    let pos_embed = vb.get((1, NUM_TOKENS, channels), "pos_embed")?;
73
74    Ok(Func::new(move |xs| {
75        let xs = xs.apply(&proj)?;
76        let (b, c, _, _) = xs.dims4()?;
77        let xs = xs.reshape((b, c, ()))?.transpose(1, 2)?;
78        let xs = xs.broadcast_add(&pos_embed)?;
79        Ok(xs)
80    }))
81}
82
83fn hiera_unroll() -> Result<Func<'static>> {
84    Ok(Func::new(move |xs| {
85        let mut xs = xs.clone();
86        let (mut b, _, c) = xs.dims3()?;
87        let mut size = 56;
88
89        xs = xs.reshape((b, size, size, c))?;
90        for _ in 0..3 {
91            size /= 2;
92            let new_shape = &[b, size, 2, size, 2, c];
93            xs = xs.reshape(new_shape)?;
94            xs = xs.permute((0, 2, 4, 1, 3, 5))?;
95            xs = xs.flatten(0, 2)?;
96            b *= 4;
97        }
98        xs = xs.reshape(((), NUM_TOKENS, c))?;
99
100        Ok(xs)
101    }))
102}
103
104fn hiera_mlp(in_channels: usize, out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
105    let fc1 = linear(in_channels, out_channels, vb.pp("fc1"))?;
106    let fc2 = linear(out_channels, in_channels, vb.pp("fc2"))?;
107
108    Ok(Func::new(move |xs| {
109        let xs = xs.apply(&fc1)?.gelu()?.apply(&fc2)?;
110        Ok(xs)
111    }))
112}
113
114fn hiera_attention(
115    in_channels: usize,
116    out_channels: usize,
117    heads: usize,
118    q_stride: usize,
119    window_size: usize,
120    use_mask_attention: bool,
121    vb: VarBuilder,
122) -> Result<Func<'static>> {
123    let head_dim = out_channels / heads;
124
125    let scale = (head_dim as f64).powf(-0.5);
126
127    let proj = linear(out_channels, out_channels, vb.pp("proj"))?;
128    let qkv = linear(in_channels, out_channels * 3, vb.pp("qkv"))?;
129
130    Ok(Func::new(move |xs| {
131        let (b, n, _) = xs.dims3()?;
132
133        let num_windows = if use_mask_attention {
134            n / (q_stride * window_size)
135        } else {
136            1
137        };
138        let qkv = xs.apply(&qkv)?;
139
140        let ec = qkv.elem_count();
141        let s = ec / (b * num_windows * 3 * heads * head_dim);
142        let qkv = qkv
143            .reshape((b, s, num_windows, 3, heads, head_dim))?
144            .permute((3, 0, 4, 2, 1, 5))?;
145
146        let mut q = qkv.get(0)?;
147        let k = qkv.get(1)?;
148        let v = qkv.get(2)?;
149
150        if q_stride > 1 {
151            let ec = q.elem_count();
152            let s = ec / (b * num_windows * q_stride * heads * head_dim);
153            q = q
154                .reshape((b, heads, num_windows, q_stride, s, head_dim))?
155                .max(3)?;
156        }
157
158        let q = (q * scale)?;
159
160        // Q, K and V are 6 dimensional with the first dimension being 1.
161        // Squeeze them for the attention calculation since 6 dimensional matmuls are not supported.
162        let att = q
163            .squeeze(0)?
164            .matmul(&k.squeeze(0)?.transpose(D::Minus2, D::Minus1)?)?;
165        let att = softmax(&att, D::Minus1)?;
166        let xs = att.matmul(&v.squeeze(0)?)?.unsqueeze(0)?;
167
168        let xs = xs.transpose(1, 3)?.reshape((b, (), out_channels))?;
169        let xs = xs.apply(&proj)?;
170
171        Ok(xs)
172    }))
173}
174
175fn hiera_block(
176    heads: usize,
177    in_channels: usize,
178    out_channels: usize,
179    q_stride: usize,
180    window_size: usize,
181    use_mask_attention: bool,
182    vb: VarBuilder,
183) -> Result<Func<'static>> {
184    let norm1 = layer_norm(in_channels, 1e-6, vb.pp("norm1"))?;
185    let norm2 = layer_norm(out_channels, 1e-6, vb.pp("norm2"))?;
186    let proj = linear(in_channels, out_channels, vb.pp("proj"));
187    let stride = 4;
188    let mlp = hiera_mlp(out_channels, out_channels * 4, vb.pp("mlp"))?;
189    let attn = hiera_attention(
190        in_channels,
191        out_channels,
192        heads,
193        q_stride,
194        window_size,
195        use_mask_attention,
196        vb.pp("attn"),
197    )?;
198
199    Ok(Func::new(move |xs| {
200        let mut xs = xs.clone();
201        let xs_norm = xs.apply_t(&norm1, false)?;
202        if let Ok(p) = &proj {
203            xs = xs_norm.apply(p)?;
204            let (a, _, d) = xs.dims3()?;
205            xs = xs.reshape((a, stride, (), d))?.max(1)?;
206        }
207        let xs = (xs + &xs_norm.apply(&attn)?)?;
208
209        let xs = (&xs + &xs.apply_t(&norm2, false)?.apply(&mlp)?)?;
210
211        Ok(xs)
212    }))
213}
214
215fn hiera_blocks(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
216    let nblocks = cfg.stages.iter().sum();
217    let mut blocks = Vec::with_capacity(nblocks);
218
219    let mut out_channels = cfg.channels;
220    let mut in_channels = out_channels;
221    let mut heads = cfg.heads;
222    let mut b = 0;
223
224    let mut q_stride = 1;
225    let mut window_size = 64;
226
227    for s in 0..4 {
228        let use_mask_attention = s < 2;
229
230        for _ in 0..cfg.stages[s] {
231            blocks.push(hiera_block(
232                heads,
233                in_channels,
234                out_channels,
235                q_stride,
236                window_size,
237                use_mask_attention,
238                vb.pp(b),
239            )?);
240            b += 1;
241            in_channels = out_channels;
242            q_stride = 1;
243        }
244        q_stride = 4;
245        out_channels *= 2;
246        heads *= 2;
247        window_size /= 4;
248    }
249
250    Ok(Func::new(move |xs| {
251        let mut xs = xs.clone();
252        for block in blocks.iter() {
253            xs = xs.apply(block)?
254        }
255        Ok(xs)
256    }))
257}
258
259fn hiera_head(outputs: usize, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
260    let norm = layer_norm(outputs, 1e-6, vb.pp("norm"))?;
261    let linear = linear(outputs, nclasses, vb.pp("fc"))?;
262    Ok(Func::new(move |xs| {
263        xs.apply_t(&norm, false)?.apply(&linear)
264    }))
265}
266
267// Build a hiera model for a given configuration.
268fn hiera_model(cfg: &Config, nclasses: Option<usize>, vb: VarBuilder) -> Result<Func<'static>> {
269    let cls = match nclasses {
270        None => None,
271        Some(nclasses) => {
272            let outputs = cfg.channels * 8;
273            let head = hiera_head(outputs, nclasses, vb.pp("head"))?;
274            Some(head)
275        }
276    };
277
278    let embeddings = hiera_embeddings(cfg.channels, vb.clone())?;
279    let unroll = hiera_unroll()?;
280    let blocks = hiera_blocks(cfg, vb.pp("blocks"))?;
281
282    Ok(Func::new(move |xs| {
283        let xs = xs
284            .apply(&embeddings)?
285            .apply(&unroll)?
286            .apply(&blocks)?
287            .mean(1)?;
288        match &cls {
289            None => Ok(xs),
290            Some(cls) => xs.apply(cls),
291        }
292    }))
293}
294
295pub fn hiera(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
296    hiera_model(cfg, Some(nclasses), vb)
297}
298
299pub fn hiera_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
300    hiera_model(cfg, None, vb)
301}