candle_transformers/models/wuerstchen/
diffnext.rs

1use super::common::{AttnBlock, GlobalResponseNorm, LayerNormNoWeights, TimestepBlock, WLayerNorm};
2use candle::{DType, Module, Result, Tensor, D};
3use candle_nn::VarBuilder;
4
5#[derive(Debug)]
6pub struct ResBlockStageB {
7    depthwise: candle_nn::Conv2d,
8    norm: WLayerNorm,
9    channelwise_lin1: candle_nn::Linear,
10    channelwise_grn: GlobalResponseNorm,
11    channelwise_lin2: candle_nn::Linear,
12}
13
14impl ResBlockStageB {
15    pub fn new(c: usize, c_skip: usize, ksize: usize, vb: VarBuilder) -> Result<Self> {
16        let cfg = candle_nn::Conv2dConfig {
17            groups: c,
18            padding: ksize / 2,
19            ..Default::default()
20        };
21        let depthwise = candle_nn::conv2d(c, c, ksize, cfg, vb.pp("depthwise"))?;
22        let norm = WLayerNorm::new(c)?;
23        let channelwise_lin1 = candle_nn::linear(c + c_skip, c * 4, vb.pp("channelwise.0"))?;
24        let channelwise_grn = GlobalResponseNorm::new(4 * c, vb.pp("channelwise.2"))?;
25        let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?;
26        Ok(Self {
27            depthwise,
28            norm,
29            channelwise_lin1,
30            channelwise_grn,
31            channelwise_lin2,
32        })
33    }
34
35    pub fn forward(&self, xs: &Tensor, x_skip: Option<&Tensor>) -> Result<Tensor> {
36        let x_res = xs;
37        let xs = xs.apply(&self.depthwise)?.apply(&self.norm)?;
38        let xs = match x_skip {
39            None => xs.clone(),
40            Some(x_skip) => Tensor::cat(&[&xs, x_skip], 1)?,
41        };
42        let xs = xs
43            .permute((0, 2, 3, 1))?
44            .contiguous()?
45            .apply(&self.channelwise_lin1)?
46            .gelu()?
47            .apply(&self.channelwise_grn)?
48            .apply(&self.channelwise_lin2)?
49            .permute((0, 3, 1, 2))?;
50        xs + x_res
51    }
52}
53
54#[derive(Debug)]
55struct SubBlock {
56    res_block: ResBlockStageB,
57    ts_block: TimestepBlock,
58    attn_block: Option<AttnBlock>,
59}
60
61#[derive(Debug)]
62struct DownBlock {
63    layer_norm: Option<WLayerNorm>,
64    conv: Option<candle_nn::Conv2d>,
65    sub_blocks: Vec<SubBlock>,
66}
67
68#[derive(Debug)]
69struct UpBlock {
70    sub_blocks: Vec<SubBlock>,
71    layer_norm: Option<WLayerNorm>,
72    conv: Option<candle_nn::ConvTranspose2d>,
73}
74
75#[derive(Debug)]
76pub struct WDiffNeXt {
77    clip_mapper: candle_nn::Linear,
78    effnet_mappers: Vec<Option<candle_nn::Conv2d>>,
79    seq_norm: LayerNormNoWeights,
80    embedding_conv: candle_nn::Conv2d,
81    embedding_ln: WLayerNorm,
82    down_blocks: Vec<DownBlock>,
83    up_blocks: Vec<UpBlock>,
84    clf_ln: WLayerNorm,
85    clf_conv: candle_nn::Conv2d,
86    c_r: usize,
87    patch_size: usize,
88}
89
90impl WDiffNeXt {
91    #[allow(clippy::too_many_arguments)]
92    pub fn new(
93        c_in: usize,
94        c_out: usize,
95        c_r: usize,
96        c_cond: usize,
97        clip_embd: usize,
98        patch_size: usize,
99        use_flash_attn: bool,
100        vb: VarBuilder,
101    ) -> Result<Self> {
102        const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280];
103        const BLOCKS: [usize; 4] = [4, 4, 14, 4];
104        const NHEAD: [usize; 4] = [1, 10, 20, 20];
105        const INJECT_EFFNET: [bool; 4] = [false, true, true, true];
106        const EFFNET_EMBD: usize = 16;
107
108        let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?;
109        let mut effnet_mappers = Vec::with_capacity(2 * INJECT_EFFNET.len());
110        let vb_e = vb.pp("effnet_mappers");
111        for (i, &inject) in INJECT_EFFNET.iter().enumerate() {
112            let c = if inject {
113                Some(candle_nn::conv2d(
114                    EFFNET_EMBD,
115                    c_cond,
116                    1,
117                    Default::default(),
118                    vb_e.pp(i),
119                )?)
120            } else {
121                None
122            };
123            effnet_mappers.push(c)
124        }
125        for (i, &inject) in INJECT_EFFNET.iter().rev().enumerate() {
126            let c = if inject {
127                Some(candle_nn::conv2d(
128                    EFFNET_EMBD,
129                    c_cond,
130                    1,
131                    Default::default(),
132                    vb_e.pp(i + INJECT_EFFNET.len()),
133                )?)
134            } else {
135                None
136            };
137            effnet_mappers.push(c)
138        }
139        let seq_norm = LayerNormNoWeights::new(c_cond)?;
140        let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?;
141        let embedding_conv = candle_nn::conv2d(
142            c_in * patch_size * patch_size,
143            C_HIDDEN[0],
144            1,
145            Default::default(),
146            vb.pp("embedding.1"),
147        )?;
148
149        let mut down_blocks = Vec::with_capacity(C_HIDDEN.len());
150        for (i, &c_hidden) in C_HIDDEN.iter().enumerate() {
151            let vb = vb.pp("down_blocks").pp(i);
152            let (layer_norm, conv, start_layer_i) = if i > 0 {
153                let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
154                let cfg = candle_nn::Conv2dConfig {
155                    stride: 2,
156                    ..Default::default()
157                };
158                let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp("0.1"))?;
159                (Some(layer_norm), Some(conv), 1)
160            } else {
161                (None, None, 0)
162            };
163            let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
164            let mut layer_i = start_layer_i;
165            for _j in 0..BLOCKS[i] {
166                let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 };
167                let res_block = ResBlockStageB::new(c_hidden, c_skip, 3, vb.pp(layer_i))?;
168                layer_i += 1;
169                let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
170                layer_i += 1;
171                let attn_block = if i == 0 {
172                    None
173                } else {
174                    let attn_block = AttnBlock::new(
175                        c_hidden,
176                        c_cond,
177                        NHEAD[i],
178                        true,
179                        use_flash_attn,
180                        vb.pp(layer_i),
181                    )?;
182                    layer_i += 1;
183                    Some(attn_block)
184                };
185                let sub_block = SubBlock {
186                    res_block,
187                    ts_block,
188                    attn_block,
189                };
190                sub_blocks.push(sub_block)
191            }
192            let down_block = DownBlock {
193                layer_norm,
194                conv,
195                sub_blocks,
196            };
197            down_blocks.push(down_block)
198        }
199
200        let mut up_blocks = Vec::with_capacity(C_HIDDEN.len());
201        for (i, &c_hidden) in C_HIDDEN.iter().enumerate().rev() {
202            let vb = vb.pp("up_blocks").pp(C_HIDDEN.len() - 1 - i);
203            let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
204            let mut layer_i = 0;
205            for j in 0..BLOCKS[i] {
206                let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 };
207                let c_skip_res = if i < BLOCKS.len() - 1 && j == 0 {
208                    c_hidden + c_skip
209                } else {
210                    c_skip
211                };
212                let res_block = ResBlockStageB::new(c_hidden, c_skip_res, 3, vb.pp(layer_i))?;
213                layer_i += 1;
214                let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
215                layer_i += 1;
216                let attn_block = if i == 0 {
217                    None
218                } else {
219                    let attn_block = AttnBlock::new(
220                        c_hidden,
221                        c_cond,
222                        NHEAD[i],
223                        true,
224                        use_flash_attn,
225                        vb.pp(layer_i),
226                    )?;
227                    layer_i += 1;
228                    Some(attn_block)
229                };
230                let sub_block = SubBlock {
231                    res_block,
232                    ts_block,
233                    attn_block,
234                };
235                sub_blocks.push(sub_block)
236            }
237            let (layer_norm, conv) = if i > 0 {
238                let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
239                let cfg = candle_nn::ConvTranspose2dConfig {
240                    stride: 2,
241                    ..Default::default()
242                };
243                let conv = candle_nn::conv_transpose2d(
244                    c_hidden,
245                    C_HIDDEN[i - 1],
246                    2,
247                    cfg,
248                    vb.pp(layer_i).pp(1),
249                )?;
250                (Some(layer_norm), Some(conv))
251            } else {
252                (None, None)
253            };
254            let up_block = UpBlock {
255                layer_norm,
256                conv,
257                sub_blocks,
258            };
259            up_blocks.push(up_block)
260        }
261
262        let clf_ln = WLayerNorm::new(C_HIDDEN[0])?;
263        let clf_conv = candle_nn::conv2d(
264            C_HIDDEN[0],
265            2 * c_out * patch_size * patch_size,
266            1,
267            Default::default(),
268            vb.pp("clf.1"),
269        )?;
270        Ok(Self {
271            clip_mapper,
272            effnet_mappers,
273            seq_norm,
274            embedding_conv,
275            embedding_ln,
276            down_blocks,
277            up_blocks,
278            clf_ln,
279            clf_conv,
280            c_r,
281            patch_size,
282        })
283    }
284
285    fn gen_r_embedding(&self, r: &Tensor) -> Result<Tensor> {
286        const MAX_POSITIONS: usize = 10000;
287        let r = (r * MAX_POSITIONS as f64)?;
288        let half_dim = self.c_r / 2;
289        let emb = (MAX_POSITIONS as f64).ln() / (half_dim - 1) as f64;
290        let emb = (Tensor::arange(0u32, half_dim as u32, r.device())?.to_dtype(DType::F32)?
291            * -emb)?
292            .exp()?;
293        let emb = r.unsqueeze(1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
294        let emb = Tensor::cat(&[emb.sin()?, emb.cos()?], 1)?;
295        let emb = if self.c_r % 2 == 1 {
296            emb.pad_with_zeros(D::Minus1, 0, 1)?
297        } else {
298            emb
299        };
300        emb.to_dtype(r.dtype())
301    }
302
303    fn gen_c_embeddings(&self, clip: &Tensor) -> Result<Tensor> {
304        clip.apply(&self.clip_mapper)?.apply(&self.seq_norm)
305    }
306
307    pub fn forward(
308        &self,
309        xs: &Tensor,
310        r: &Tensor,
311        effnet: &Tensor,
312        clip: Option<&Tensor>,
313    ) -> Result<Tensor> {
314        const EPS: f64 = 1e-3;
315
316        let r_embed = self.gen_r_embedding(r)?;
317        let clip = match clip {
318            None => None,
319            Some(clip) => Some(self.gen_c_embeddings(clip)?),
320        };
321        let x_in = xs;
322
323        let mut xs = xs
324            .apply(&|xs: &_| candle_nn::ops::pixel_unshuffle(xs, self.patch_size))?
325            .apply(&self.embedding_conv)?
326            .apply(&self.embedding_ln)?;
327
328        let mut level_outputs = Vec::new();
329        for (i, down_block) in self.down_blocks.iter().enumerate() {
330            if let Some(ln) = &down_block.layer_norm {
331                xs = xs.apply(ln)?
332            }
333            if let Some(conv) = &down_block.conv {
334                xs = xs.apply(conv)?
335            }
336            let skip = match &self.effnet_mappers[i] {
337                None => None,
338                Some(m) => {
339                    let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?;
340                    Some(m.forward(&effnet)?)
341                }
342            };
343            for block in down_block.sub_blocks.iter() {
344                xs = block.res_block.forward(&xs, skip.as_ref())?;
345                xs = block.ts_block.forward(&xs, &r_embed)?;
346                if let Some(attn_block) = &block.attn_block {
347                    xs = attn_block.forward(&xs, clip.as_ref().unwrap())?;
348                }
349            }
350            level_outputs.push(xs.clone())
351        }
352        level_outputs.reverse();
353        let mut xs = level_outputs[0].clone();
354
355        for (i, up_block) in self.up_blocks.iter().enumerate() {
356            let effnet_c = match &self.effnet_mappers[self.down_blocks.len() + i] {
357                None => None,
358                Some(m) => {
359                    let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?;
360                    Some(m.forward(&effnet)?)
361                }
362            };
363            for (j, block) in up_block.sub_blocks.iter().enumerate() {
364                let skip = if j == 0 && i > 0 {
365                    Some(&level_outputs[i])
366                } else {
367                    None
368                };
369                let skip = match (skip, effnet_c.as_ref()) {
370                    (Some(skip), Some(effnet_c)) => Some(Tensor::cat(&[skip, effnet_c], 1)?),
371                    (None, Some(skip)) | (Some(skip), None) => Some(skip.clone()),
372                    (None, None) => None,
373                };
374                xs = block.res_block.forward(&xs, skip.as_ref())?;
375                xs = block.ts_block.forward(&xs, &r_embed)?;
376                if let Some(attn_block) = &block.attn_block {
377                    xs = attn_block.forward(&xs, clip.as_ref().unwrap())?;
378                }
379            }
380            if let Some(ln) = &up_block.layer_norm {
381                xs = xs.apply(ln)?
382            }
383            if let Some(conv) = &up_block.conv {
384                xs = xs.apply(conv)?
385            }
386        }
387
388        let ab = xs
389            .apply(&self.clf_ln)?
390            .apply(&self.clf_conv)?
391            .apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, self.patch_size))?
392            .chunk(2, 1)?;
393        let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?;
394        (x_in - &ab[0])? / b
395    }
396}