1use candle::{Result, D};
8use candle_nn::{conv2d, layer_norm, linear, ops::softmax, Conv2dConfig, Func, VarBuilder};
9
10#[derive(Debug, Clone, serde::Deserialize)]
11pub struct Config {
12 channels: usize,
13 heads: usize,
14 stages: [usize; 4],
15}
16
17impl Config {
18 pub fn tiny() -> Self {
19 Self {
20 channels: 96,
21 heads: 1,
22 stages: [1, 2, 7, 2],
23 }
24 }
25 pub fn small() -> Self {
26 Self {
27 channels: 96,
28 heads: 1,
29 stages: [1, 2, 11, 2],
30 }
31 }
32 pub fn base() -> Self {
33 Self {
34 channels: 96,
35 heads: 1,
36 stages: [2, 3, 16, 3],
37 }
38 }
39 pub fn base_plus() -> Self {
40 Self {
41 channels: 112,
42 heads: 2,
43 stages: [2, 3, 16, 3],
44 }
45 }
46 pub fn large() -> Self {
47 Self {
48 channels: 144,
49 heads: 2,
50 stages: [2, 6, 36, 4],
51 }
52 }
53 pub fn huge() -> Self {
54 Self {
55 channels: 256,
56 heads: 4,
57 stages: [2, 6, 36, 4],
58 }
59 }
60}
61
62const NUM_TOKENS: usize = 56 * 56;
63
64fn hiera_embeddings(channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
65 let conv_cfg = Conv2dConfig {
66 stride: 4,
67 padding: 3,
68 ..Default::default()
69 };
70 let proj = conv2d(3, channels, 7, conv_cfg, vb.pp("patch_embed.proj"))?;
71
72 let pos_embed = vb.get((1, NUM_TOKENS, channels), "pos_embed")?;
73
74 Ok(Func::new(move |xs| {
75 let xs = xs.apply(&proj)?;
76 let (b, c, _, _) = xs.dims4()?;
77 let xs = xs.reshape((b, c, ()))?.transpose(1, 2)?;
78 let xs = xs.broadcast_add(&pos_embed)?;
79 Ok(xs)
80 }))
81}
82
83fn hiera_unroll() -> Result<Func<'static>> {
84 Ok(Func::new(move |xs| {
85 let mut xs = xs.clone();
86 let (mut b, _, c) = xs.dims3()?;
87 let mut size = 56;
88
89 xs = xs.reshape((b, size, size, c))?;
90 for _ in 0..3 {
91 size /= 2;
92 let new_shape = &[b, size, 2, size, 2, c];
93 xs = xs.reshape(new_shape)?;
94 xs = xs.permute((0, 2, 4, 1, 3, 5))?;
95 xs = xs.flatten(0, 2)?;
96 b *= 4;
97 }
98 xs = xs.reshape(((), NUM_TOKENS, c))?;
99
100 Ok(xs)
101 }))
102}
103
104fn hiera_mlp(in_channels: usize, out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
105 let fc1 = linear(in_channels, out_channels, vb.pp("fc1"))?;
106 let fc2 = linear(out_channels, in_channels, vb.pp("fc2"))?;
107
108 Ok(Func::new(move |xs| {
109 let xs = xs.apply(&fc1)?.gelu()?.apply(&fc2)?;
110 Ok(xs)
111 }))
112}
113
114fn hiera_attention(
115 in_channels: usize,
116 out_channels: usize,
117 heads: usize,
118 q_stride: usize,
119 window_size: usize,
120 use_mask_attention: bool,
121 vb: VarBuilder,
122) -> Result<Func<'static>> {
123 let head_dim = out_channels / heads;
124
125 let scale = (head_dim as f64).powf(-0.5);
126
127 let proj = linear(out_channels, out_channels, vb.pp("proj"))?;
128 let qkv = linear(in_channels, out_channels * 3, vb.pp("qkv"))?;
129
130 Ok(Func::new(move |xs| {
131 let (b, n, _) = xs.dims3()?;
132
133 let num_windows = if use_mask_attention {
134 n / (q_stride * window_size)
135 } else {
136 1
137 };
138 let qkv = xs.apply(&qkv)?;
139
140 let ec = qkv.elem_count();
141 let s = ec / (b * num_windows * 3 * heads * head_dim);
142 let qkv = qkv
143 .reshape((b, s, num_windows, 3, heads, head_dim))?
144 .permute((3, 0, 4, 2, 1, 5))?;
145
146 let mut q = qkv.get(0)?;
147 let k = qkv.get(1)?;
148 let v = qkv.get(2)?;
149
150 if q_stride > 1 {
151 let ec = q.elem_count();
152 let s = ec / (b * num_windows * q_stride * heads * head_dim);
153 q = q
154 .reshape((b, heads, num_windows, q_stride, s, head_dim))?
155 .max(3)?;
156 }
157
158 let q = (q * scale)?;
159
160 let att = q
163 .squeeze(0)?
164 .matmul(&k.squeeze(0)?.transpose(D::Minus2, D::Minus1)?)?;
165 let att = softmax(&att, D::Minus1)?;
166 let xs = att.matmul(&v.squeeze(0)?)?.unsqueeze(0)?;
167
168 let xs = xs.transpose(1, 3)?.reshape((b, (), out_channels))?;
169 let xs = xs.apply(&proj)?;
170
171 Ok(xs)
172 }))
173}
174
175fn hiera_block(
176 heads: usize,
177 in_channels: usize,
178 out_channels: usize,
179 q_stride: usize,
180 window_size: usize,
181 use_mask_attention: bool,
182 vb: VarBuilder,
183) -> Result<Func<'static>> {
184 let norm1 = layer_norm(in_channels, 1e-6, vb.pp("norm1"))?;
185 let norm2 = layer_norm(out_channels, 1e-6, vb.pp("norm2"))?;
186 let proj = linear(in_channels, out_channels, vb.pp("proj"));
187 let stride = 4;
188 let mlp = hiera_mlp(out_channels, out_channels * 4, vb.pp("mlp"))?;
189 let attn = hiera_attention(
190 in_channels,
191 out_channels,
192 heads,
193 q_stride,
194 window_size,
195 use_mask_attention,
196 vb.pp("attn"),
197 )?;
198
199 Ok(Func::new(move |xs| {
200 let mut xs = xs.clone();
201 let xs_norm = xs.apply_t(&norm1, false)?;
202 if let Ok(p) = &proj {
203 xs = xs_norm.apply(p)?;
204 let (a, _, d) = xs.dims3()?;
205 xs = xs.reshape((a, stride, (), d))?.max(1)?;
206 }
207 let xs = (xs + &xs_norm.apply(&attn)?)?;
208
209 let xs = (&xs + &xs.apply_t(&norm2, false)?.apply(&mlp)?)?;
210
211 Ok(xs)
212 }))
213}
214
215fn hiera_blocks(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
216 let nblocks = cfg.stages.iter().sum();
217 let mut blocks = Vec::with_capacity(nblocks);
218
219 let mut out_channels = cfg.channels;
220 let mut in_channels = out_channels;
221 let mut heads = cfg.heads;
222 let mut b = 0;
223
224 let mut q_stride = 1;
225 let mut window_size = 64;
226
227 for s in 0..4 {
228 let use_mask_attention = s < 2;
229
230 for _ in 0..cfg.stages[s] {
231 blocks.push(hiera_block(
232 heads,
233 in_channels,
234 out_channels,
235 q_stride,
236 window_size,
237 use_mask_attention,
238 vb.pp(b),
239 )?);
240 b += 1;
241 in_channels = out_channels;
242 q_stride = 1;
243 }
244 q_stride = 4;
245 out_channels *= 2;
246 heads *= 2;
247 window_size /= 4;
248 }
249
250 Ok(Func::new(move |xs| {
251 let mut xs = xs.clone();
252 for block in blocks.iter() {
253 xs = xs.apply(block)?
254 }
255 Ok(xs)
256 }))
257}
258
259fn hiera_head(outputs: usize, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
260 let norm = layer_norm(outputs, 1e-6, vb.pp("norm"))?;
261 let linear = linear(outputs, nclasses, vb.pp("fc"))?;
262 Ok(Func::new(move |xs| {
263 xs.apply_t(&norm, false)?.apply(&linear)
264 }))
265}
266
267fn hiera_model(cfg: &Config, nclasses: Option<usize>, vb: VarBuilder) -> Result<Func<'static>> {
269 let cls = match nclasses {
270 None => None,
271 Some(nclasses) => {
272 let outputs = cfg.channels * 8;
273 let head = hiera_head(outputs, nclasses, vb.pp("head"))?;
274 Some(head)
275 }
276 };
277
278 let embeddings = hiera_embeddings(cfg.channels, vb.clone())?;
279 let unroll = hiera_unroll()?;
280 let blocks = hiera_blocks(cfg, vb.pp("blocks"))?;
281
282 Ok(Func::new(move |xs| {
283 let xs = xs
284 .apply(&embeddings)?
285 .apply(&unroll)?
286 .apply(&blocks)?
287 .mean(1)?;
288 match &cls {
289 None => Ok(xs),
290 Some(cls) => xs.apply(cls),
291 }
292 }))
293}
294
295pub fn hiera(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
296 hiera_model(cfg, Some(nclasses), vb)
297}
298
299pub fn hiera_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
300 hiera_model(cfg, None, vb)
301}