candle_transformers/models/wuerstchen/
prior.rs1use super::common::{AttnBlock, ResBlock, TimestepBlock};
2use candle::{DType, Result, Tensor, D};
3use candle_nn::VarBuilder;
4
5#[derive(Debug)]
6struct Block {
7 res_block: ResBlock,
8 ts_block: TimestepBlock,
9 attn_block: AttnBlock,
10}
11
12#[derive(Debug)]
13pub struct WPrior {
14 projection: candle_nn::Conv2d,
15 cond_mapper_lin1: candle_nn::Linear,
16 cond_mapper_lin2: candle_nn::Linear,
17 blocks: Vec<Block>,
18 out_ln: super::common::WLayerNorm,
19 out_conv: candle_nn::Conv2d,
20 c_r: usize,
21}
22
23impl WPrior {
24 #[allow(clippy::too_many_arguments)]
25 pub fn new(
26 c_in: usize,
27 c: usize,
28 c_cond: usize,
29 c_r: usize,
30 depth: usize,
31 nhead: usize,
32 use_flash_attn: bool,
33 vb: VarBuilder,
34 ) -> Result<Self> {
35 let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?;
36 let cond_mapper_lin1 = candle_nn::linear(c_cond, c, vb.pp("cond_mapper.0"))?;
37 let cond_mapper_lin2 = candle_nn::linear(c, c, vb.pp("cond_mapper.2"))?;
38 let out_ln = super::common::WLayerNorm::new(c)?;
39 let out_conv = candle_nn::conv2d(c, c_in * 2, 1, Default::default(), vb.pp("out.1"))?;
40 let mut blocks = Vec::with_capacity(depth);
41 for index in 0..depth {
42 let res_block = ResBlock::new(c, 0, 3, vb.pp(format!("blocks.{}", 3 * index)))?;
43 let ts_block = TimestepBlock::new(c, c_r, vb.pp(format!("blocks.{}", 3 * index + 1)))?;
44 let attn_block = AttnBlock::new(
45 c,
46 c,
47 nhead,
48 true,
49 use_flash_attn,
50 vb.pp(format!("blocks.{}", 3 * index + 2)),
51 )?;
52 blocks.push(Block {
53 res_block,
54 ts_block,
55 attn_block,
56 })
57 }
58 Ok(Self {
59 projection,
60 cond_mapper_lin1,
61 cond_mapper_lin2,
62 blocks,
63 out_ln,
64 out_conv,
65 c_r,
66 })
67 }
68
69 pub fn gen_r_embedding(&self, r: &Tensor) -> Result<Tensor> {
70 const MAX_POSITIONS: usize = 10000;
71 let r = (r * MAX_POSITIONS as f64)?;
72 let half_dim = self.c_r / 2;
73 let emb = (MAX_POSITIONS as f64).ln() / (half_dim - 1) as f64;
74 let emb = (Tensor::arange(0u32, half_dim as u32, r.device())?.to_dtype(DType::F32)?
75 * -emb)?
76 .exp()?;
77 let emb = r.unsqueeze(1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
78 let emb = Tensor::cat(&[emb.sin()?, emb.cos()?], 1)?;
79 let emb = if self.c_r % 2 == 1 {
80 emb.pad_with_zeros(D::Minus1, 0, 1)?
81 } else {
82 emb
83 };
84 emb.to_dtype(r.dtype())
85 }
86
87 pub fn forward(&self, xs: &Tensor, r: &Tensor, c: &Tensor) -> Result<Tensor> {
88 let x_in = xs;
89 let mut xs = xs.apply(&self.projection)?;
90 let c_embed = c
91 .apply(&self.cond_mapper_lin1)?
92 .apply(&|xs: &_| candle_nn::ops::leaky_relu(xs, 0.2))?
93 .apply(&self.cond_mapper_lin2)?;
94 let r_embed = self.gen_r_embedding(r)?;
95 for block in self.blocks.iter() {
96 xs = block.res_block.forward(&xs, None)?;
97 xs = block.ts_block.forward(&xs, &r_embed)?;
98 xs = block.attn_block.forward(&xs, &c_embed)?;
99 }
100 let ab = xs.apply(&self.out_ln)?.apply(&self.out_conv)?.chunk(2, 1)?;
101 (x_in - &ab[0])? / ((&ab[1] - 1.)?.abs()? + 1e-5)
102 }
103}