candle_transformers/models/wuerstchen/
prior.rs

1use super::common::{AttnBlock, ResBlock, TimestepBlock};
2use candle::{DType, Result, Tensor, D};
3use candle_nn::VarBuilder;
4
5#[derive(Debug)]
6struct Block {
7    res_block: ResBlock,
8    ts_block: TimestepBlock,
9    attn_block: AttnBlock,
10}
11
12#[derive(Debug)]
13pub struct WPrior {
14    projection: candle_nn::Conv2d,
15    cond_mapper_lin1: candle_nn::Linear,
16    cond_mapper_lin2: candle_nn::Linear,
17    blocks: Vec<Block>,
18    out_ln: super::common::WLayerNorm,
19    out_conv: candle_nn::Conv2d,
20    c_r: usize,
21}
22
23impl WPrior {
24    #[allow(clippy::too_many_arguments)]
25    pub fn new(
26        c_in: usize,
27        c: usize,
28        c_cond: usize,
29        c_r: usize,
30        depth: usize,
31        nhead: usize,
32        use_flash_attn: bool,
33        vb: VarBuilder,
34    ) -> Result<Self> {
35        let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?;
36        let cond_mapper_lin1 = candle_nn::linear(c_cond, c, vb.pp("cond_mapper.0"))?;
37        let cond_mapper_lin2 = candle_nn::linear(c, c, vb.pp("cond_mapper.2"))?;
38        let out_ln = super::common::WLayerNorm::new(c)?;
39        let out_conv = candle_nn::conv2d(c, c_in * 2, 1, Default::default(), vb.pp("out.1"))?;
40        let mut blocks = Vec::with_capacity(depth);
41        for index in 0..depth {
42            let res_block = ResBlock::new(c, 0, 3, vb.pp(format!("blocks.{}", 3 * index)))?;
43            let ts_block = TimestepBlock::new(c, c_r, vb.pp(format!("blocks.{}", 3 * index + 1)))?;
44            let attn_block = AttnBlock::new(
45                c,
46                c,
47                nhead,
48                true,
49                use_flash_attn,
50                vb.pp(format!("blocks.{}", 3 * index + 2)),
51            )?;
52            blocks.push(Block {
53                res_block,
54                ts_block,
55                attn_block,
56            })
57        }
58        Ok(Self {
59            projection,
60            cond_mapper_lin1,
61            cond_mapper_lin2,
62            blocks,
63            out_ln,
64            out_conv,
65            c_r,
66        })
67    }
68
69    pub fn gen_r_embedding(&self, r: &Tensor) -> Result<Tensor> {
70        const MAX_POSITIONS: usize = 10000;
71        let r = (r * MAX_POSITIONS as f64)?;
72        let half_dim = self.c_r / 2;
73        let emb = (MAX_POSITIONS as f64).ln() / (half_dim - 1) as f64;
74        let emb = (Tensor::arange(0u32, half_dim as u32, r.device())?.to_dtype(DType::F32)?
75            * -emb)?
76            .exp()?;
77        let emb = r.unsqueeze(1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
78        let emb = Tensor::cat(&[emb.sin()?, emb.cos()?], 1)?;
79        let emb = if self.c_r % 2 == 1 {
80            emb.pad_with_zeros(D::Minus1, 0, 1)?
81        } else {
82            emb
83        };
84        emb.to_dtype(r.dtype())
85    }
86
87    pub fn forward(&self, xs: &Tensor, r: &Tensor, c: &Tensor) -> Result<Tensor> {
88        let x_in = xs;
89        let mut xs = xs.apply(&self.projection)?;
90        let c_embed = c
91            .apply(&self.cond_mapper_lin1)?
92            .apply(&|xs: &_| candle_nn::ops::leaky_relu(xs, 0.2))?
93            .apply(&self.cond_mapper_lin2)?;
94        let r_embed = self.gen_r_embedding(r)?;
95        for block in self.blocks.iter() {
96            xs = block.res_block.forward(&xs, None)?;
97            xs = block.ts_block.forward(&xs, &r_embed)?;
98            xs = block.attn_block.forward(&xs, &c_embed)?;
99        }
100        let ab = xs.apply(&self.out_ln)?.apply(&self.out_conv)?.chunk(2, 1)?;
101        (x_in - &ab[0])? / ((&ab[1] - 1.)?.abs()? + 1e-5)
102    }
103}