1use super::common::{AttnBlock, GlobalResponseNorm, LayerNormNoWeights, TimestepBlock, WLayerNorm};
2use candle::{DType, Module, Result, Tensor, D};
3use candle_nn::VarBuilder;
4
5#[derive(Debug)]
6pub struct ResBlockStageB {
7 depthwise: candle_nn::Conv2d,
8 norm: WLayerNorm,
9 channelwise_lin1: candle_nn::Linear,
10 channelwise_grn: GlobalResponseNorm,
11 channelwise_lin2: candle_nn::Linear,
12}
13
14impl ResBlockStageB {
15 pub fn new(c: usize, c_skip: usize, ksize: usize, vb: VarBuilder) -> Result<Self> {
16 let cfg = candle_nn::Conv2dConfig {
17 groups: c,
18 padding: ksize / 2,
19 ..Default::default()
20 };
21 let depthwise = candle_nn::conv2d(c, c, ksize, cfg, vb.pp("depthwise"))?;
22 let norm = WLayerNorm::new(c)?;
23 let channelwise_lin1 = candle_nn::linear(c + c_skip, c * 4, vb.pp("channelwise.0"))?;
24 let channelwise_grn = GlobalResponseNorm::new(4 * c, vb.pp("channelwise.2"))?;
25 let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?;
26 Ok(Self {
27 depthwise,
28 norm,
29 channelwise_lin1,
30 channelwise_grn,
31 channelwise_lin2,
32 })
33 }
34
35 pub fn forward(&self, xs: &Tensor, x_skip: Option<&Tensor>) -> Result<Tensor> {
36 let x_res = xs;
37 let xs = xs.apply(&self.depthwise)?.apply(&self.norm)?;
38 let xs = match x_skip {
39 None => xs.clone(),
40 Some(x_skip) => Tensor::cat(&[&xs, x_skip], 1)?,
41 };
42 let xs = xs
43 .permute((0, 2, 3, 1))?
44 .contiguous()?
45 .apply(&self.channelwise_lin1)?
46 .gelu()?
47 .apply(&self.channelwise_grn)?
48 .apply(&self.channelwise_lin2)?
49 .permute((0, 3, 1, 2))?;
50 xs + x_res
51 }
52}
53
54#[derive(Debug)]
55struct SubBlock {
56 res_block: ResBlockStageB,
57 ts_block: TimestepBlock,
58 attn_block: Option<AttnBlock>,
59}
60
61#[derive(Debug)]
62struct DownBlock {
63 layer_norm: Option<WLayerNorm>,
64 conv: Option<candle_nn::Conv2d>,
65 sub_blocks: Vec<SubBlock>,
66}
67
68#[derive(Debug)]
69struct UpBlock {
70 sub_blocks: Vec<SubBlock>,
71 layer_norm: Option<WLayerNorm>,
72 conv: Option<candle_nn::ConvTranspose2d>,
73}
74
75#[derive(Debug)]
76pub struct WDiffNeXt {
77 clip_mapper: candle_nn::Linear,
78 effnet_mappers: Vec<Option<candle_nn::Conv2d>>,
79 seq_norm: LayerNormNoWeights,
80 embedding_conv: candle_nn::Conv2d,
81 embedding_ln: WLayerNorm,
82 down_blocks: Vec<DownBlock>,
83 up_blocks: Vec<UpBlock>,
84 clf_ln: WLayerNorm,
85 clf_conv: candle_nn::Conv2d,
86 c_r: usize,
87 patch_size: usize,
88}
89
90impl WDiffNeXt {
91 #[allow(clippy::too_many_arguments)]
92 pub fn new(
93 c_in: usize,
94 c_out: usize,
95 c_r: usize,
96 c_cond: usize,
97 clip_embd: usize,
98 patch_size: usize,
99 use_flash_attn: bool,
100 vb: VarBuilder,
101 ) -> Result<Self> {
102 const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280];
103 const BLOCKS: [usize; 4] = [4, 4, 14, 4];
104 const NHEAD: [usize; 4] = [1, 10, 20, 20];
105 const INJECT_EFFNET: [bool; 4] = [false, true, true, true];
106 const EFFNET_EMBD: usize = 16;
107
108 let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?;
109 let mut effnet_mappers = Vec::with_capacity(2 * INJECT_EFFNET.len());
110 let vb_e = vb.pp("effnet_mappers");
111 for (i, &inject) in INJECT_EFFNET.iter().enumerate() {
112 let c = if inject {
113 Some(candle_nn::conv2d(
114 EFFNET_EMBD,
115 c_cond,
116 1,
117 Default::default(),
118 vb_e.pp(i),
119 )?)
120 } else {
121 None
122 };
123 effnet_mappers.push(c)
124 }
125 for (i, &inject) in INJECT_EFFNET.iter().rev().enumerate() {
126 let c = if inject {
127 Some(candle_nn::conv2d(
128 EFFNET_EMBD,
129 c_cond,
130 1,
131 Default::default(),
132 vb_e.pp(i + INJECT_EFFNET.len()),
133 )?)
134 } else {
135 None
136 };
137 effnet_mappers.push(c)
138 }
139 let seq_norm = LayerNormNoWeights::new(c_cond)?;
140 let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?;
141 let embedding_conv = candle_nn::conv2d(
142 c_in * patch_size * patch_size,
143 C_HIDDEN[0],
144 1,
145 Default::default(),
146 vb.pp("embedding.1"),
147 )?;
148
149 let mut down_blocks = Vec::with_capacity(C_HIDDEN.len());
150 for (i, &c_hidden) in C_HIDDEN.iter().enumerate() {
151 let vb = vb.pp("down_blocks").pp(i);
152 let (layer_norm, conv, start_layer_i) = if i > 0 {
153 let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
154 let cfg = candle_nn::Conv2dConfig {
155 stride: 2,
156 ..Default::default()
157 };
158 let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp("0.1"))?;
159 (Some(layer_norm), Some(conv), 1)
160 } else {
161 (None, None, 0)
162 };
163 let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
164 let mut layer_i = start_layer_i;
165 for _j in 0..BLOCKS[i] {
166 let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 };
167 let res_block = ResBlockStageB::new(c_hidden, c_skip, 3, vb.pp(layer_i))?;
168 layer_i += 1;
169 let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
170 layer_i += 1;
171 let attn_block = if i == 0 {
172 None
173 } else {
174 let attn_block = AttnBlock::new(
175 c_hidden,
176 c_cond,
177 NHEAD[i],
178 true,
179 use_flash_attn,
180 vb.pp(layer_i),
181 )?;
182 layer_i += 1;
183 Some(attn_block)
184 };
185 let sub_block = SubBlock {
186 res_block,
187 ts_block,
188 attn_block,
189 };
190 sub_blocks.push(sub_block)
191 }
192 let down_block = DownBlock {
193 layer_norm,
194 conv,
195 sub_blocks,
196 };
197 down_blocks.push(down_block)
198 }
199
200 let mut up_blocks = Vec::with_capacity(C_HIDDEN.len());
201 for (i, &c_hidden) in C_HIDDEN.iter().enumerate().rev() {
202 let vb = vb.pp("up_blocks").pp(C_HIDDEN.len() - 1 - i);
203 let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
204 let mut layer_i = 0;
205 for j in 0..BLOCKS[i] {
206 let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 };
207 let c_skip_res = if i < BLOCKS.len() - 1 && j == 0 {
208 c_hidden + c_skip
209 } else {
210 c_skip
211 };
212 let res_block = ResBlockStageB::new(c_hidden, c_skip_res, 3, vb.pp(layer_i))?;
213 layer_i += 1;
214 let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
215 layer_i += 1;
216 let attn_block = if i == 0 {
217 None
218 } else {
219 let attn_block = AttnBlock::new(
220 c_hidden,
221 c_cond,
222 NHEAD[i],
223 true,
224 use_flash_attn,
225 vb.pp(layer_i),
226 )?;
227 layer_i += 1;
228 Some(attn_block)
229 };
230 let sub_block = SubBlock {
231 res_block,
232 ts_block,
233 attn_block,
234 };
235 sub_blocks.push(sub_block)
236 }
237 let (layer_norm, conv) = if i > 0 {
238 let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
239 let cfg = candle_nn::ConvTranspose2dConfig {
240 stride: 2,
241 ..Default::default()
242 };
243 let conv = candle_nn::conv_transpose2d(
244 c_hidden,
245 C_HIDDEN[i - 1],
246 2,
247 cfg,
248 vb.pp(layer_i).pp(1),
249 )?;
250 (Some(layer_norm), Some(conv))
251 } else {
252 (None, None)
253 };
254 let up_block = UpBlock {
255 layer_norm,
256 conv,
257 sub_blocks,
258 };
259 up_blocks.push(up_block)
260 }
261
262 let clf_ln = WLayerNorm::new(C_HIDDEN[0])?;
263 let clf_conv = candle_nn::conv2d(
264 C_HIDDEN[0],
265 2 * c_out * patch_size * patch_size,
266 1,
267 Default::default(),
268 vb.pp("clf.1"),
269 )?;
270 Ok(Self {
271 clip_mapper,
272 effnet_mappers,
273 seq_norm,
274 embedding_conv,
275 embedding_ln,
276 down_blocks,
277 up_blocks,
278 clf_ln,
279 clf_conv,
280 c_r,
281 patch_size,
282 })
283 }
284
285 fn gen_r_embedding(&self, r: &Tensor) -> Result<Tensor> {
286 const MAX_POSITIONS: usize = 10000;
287 let r = (r * MAX_POSITIONS as f64)?;
288 let half_dim = self.c_r / 2;
289 let emb = (MAX_POSITIONS as f64).ln() / (half_dim - 1) as f64;
290 let emb = (Tensor::arange(0u32, half_dim as u32, r.device())?.to_dtype(DType::F32)?
291 * -emb)?
292 .exp()?;
293 let emb = r.unsqueeze(1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
294 let emb = Tensor::cat(&[emb.sin()?, emb.cos()?], 1)?;
295 let emb = if self.c_r % 2 == 1 {
296 emb.pad_with_zeros(D::Minus1, 0, 1)?
297 } else {
298 emb
299 };
300 emb.to_dtype(r.dtype())
301 }
302
303 fn gen_c_embeddings(&self, clip: &Tensor) -> Result<Tensor> {
304 clip.apply(&self.clip_mapper)?.apply(&self.seq_norm)
305 }
306
307 pub fn forward(
308 &self,
309 xs: &Tensor,
310 r: &Tensor,
311 effnet: &Tensor,
312 clip: Option<&Tensor>,
313 ) -> Result<Tensor> {
314 const EPS: f64 = 1e-3;
315
316 let r_embed = self.gen_r_embedding(r)?;
317 let clip = match clip {
318 None => None,
319 Some(clip) => Some(self.gen_c_embeddings(clip)?),
320 };
321 let x_in = xs;
322
323 let mut xs = xs
324 .apply(&|xs: &_| candle_nn::ops::pixel_unshuffle(xs, self.patch_size))?
325 .apply(&self.embedding_conv)?
326 .apply(&self.embedding_ln)?;
327
328 let mut level_outputs = Vec::new();
329 for (i, down_block) in self.down_blocks.iter().enumerate() {
330 if let Some(ln) = &down_block.layer_norm {
331 xs = xs.apply(ln)?
332 }
333 if let Some(conv) = &down_block.conv {
334 xs = xs.apply(conv)?
335 }
336 let skip = match &self.effnet_mappers[i] {
337 None => None,
338 Some(m) => {
339 let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?;
340 Some(m.forward(&effnet)?)
341 }
342 };
343 for block in down_block.sub_blocks.iter() {
344 xs = block.res_block.forward(&xs, skip.as_ref())?;
345 xs = block.ts_block.forward(&xs, &r_embed)?;
346 if let Some(attn_block) = &block.attn_block {
347 xs = attn_block.forward(&xs, clip.as_ref().unwrap())?;
348 }
349 }
350 level_outputs.push(xs.clone())
351 }
352 level_outputs.reverse();
353 let mut xs = level_outputs[0].clone();
354
355 for (i, up_block) in self.up_blocks.iter().enumerate() {
356 let effnet_c = match &self.effnet_mappers[self.down_blocks.len() + i] {
357 None => None,
358 Some(m) => {
359 let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?;
360 Some(m.forward(&effnet)?)
361 }
362 };
363 for (j, block) in up_block.sub_blocks.iter().enumerate() {
364 let skip = if j == 0 && i > 0 {
365 Some(&level_outputs[i])
366 } else {
367 None
368 };
369 let skip = match (skip, effnet_c.as_ref()) {
370 (Some(skip), Some(effnet_c)) => Some(Tensor::cat(&[skip, effnet_c], 1)?),
371 (None, Some(skip)) | (Some(skip), None) => Some(skip.clone()),
372 (None, None) => None,
373 };
374 xs = block.res_block.forward(&xs, skip.as_ref())?;
375 xs = block.ts_block.forward(&xs, &r_embed)?;
376 if let Some(attn_block) = &block.attn_block {
377 xs = attn_block.forward(&xs, clip.as_ref().unwrap())?;
378 }
379 }
380 if let Some(ln) = &up_block.layer_norm {
381 xs = xs.apply(ln)?
382 }
383 if let Some(conv) = &up_block.conv {
384 xs = xs.apply(conv)?
385 }
386 }
387
388 let ab = xs
389 .apply(&self.clf_ln)?
390 .apply(&self.clf_conv)?
391 .apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, self.patch_size))?
392 .chunk(2, 1)?;
393 let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?;
394 (x_in - &ab[0])? / b
395 }
396}