candle_transformers/models/wuerstchen/
paella_vq.rs

1use super::common::LayerNormNoWeights;
2use candle::{Module, Result, Tensor};
3use candle_nn::VarBuilder;
4
5#[derive(Debug)]
6pub struct MixingResidualBlock {
7    norm1: LayerNormNoWeights,
8    depthwise_conv: candle_nn::Conv2d,
9    norm2: LayerNormNoWeights,
10    channelwise_lin1: candle_nn::Linear,
11    channelwise_lin2: candle_nn::Linear,
12    gammas: Vec<f32>,
13}
14
15impl MixingResidualBlock {
16    pub fn new(inp: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
17        let norm1 = LayerNormNoWeights::new(inp)?;
18        let norm2 = LayerNormNoWeights::new(inp)?;
19        let cfg = candle_nn::Conv2dConfig {
20            groups: inp,
21            ..Default::default()
22        };
23        let depthwise_conv = candle_nn::conv2d(inp, inp, 3, cfg, vb.pp("depthwise.1"))?;
24        let channelwise_lin1 = candle_nn::linear(inp, embed_dim, vb.pp("channelwise.0"))?;
25        let channelwise_lin2 = candle_nn::linear(embed_dim, inp, vb.pp("channelwise.2"))?;
26        let gammas = vb.get(6, "gammas")?.to_vec1::<f32>()?;
27        Ok(Self {
28            norm1,
29            depthwise_conv,
30            norm2,
31            channelwise_lin1,
32            channelwise_lin2,
33            gammas,
34        })
35    }
36}
37
38impl Module for MixingResidualBlock {
39    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
40        let mods = &self.gammas;
41        let x_temp = xs
42            .permute((0, 2, 3, 1))?
43            .apply(&self.norm1)?
44            .permute((0, 3, 1, 2))?
45            .affine(1. + mods[0] as f64, mods[1] as f64)?;
46        let x_temp = candle_nn::ops::replication_pad2d(&x_temp, 1)?;
47        let xs = (xs + x_temp.apply(&self.depthwise_conv)? * mods[2] as f64)?;
48        let x_temp = xs
49            .permute((0, 2, 3, 1))?
50            .apply(&self.norm2)?
51            .permute((0, 3, 1, 2))?
52            .affine(1. + mods[3] as f64, mods[4] as f64)?;
53        let x_temp = x_temp
54            .permute((0, 2, 3, 1))?
55            .contiguous()?
56            .apply(&self.channelwise_lin1)?
57            .gelu()?
58            .apply(&self.channelwise_lin2)?
59            .permute((0, 3, 1, 2))?;
60        xs + x_temp * mods[5] as f64
61    }
62}
63
64#[derive(Debug)]
65pub struct PaellaVQ {
66    in_block_conv: candle_nn::Conv2d,
67    out_block_conv: candle_nn::Conv2d,
68    down_blocks: Vec<(Option<candle_nn::Conv2d>, MixingResidualBlock)>,
69    down_blocks_conv: candle_nn::Conv2d,
70    down_blocks_bn: candle_nn::BatchNorm,
71    up_blocks_conv: candle_nn::Conv2d,
72    up_blocks: Vec<(Vec<MixingResidualBlock>, Option<candle_nn::ConvTranspose2d>)>,
73}
74
75impl PaellaVQ {
76    pub fn new(vb: VarBuilder) -> Result<Self> {
77        const IN_CHANNELS: usize = 3;
78        const OUT_CHANNELS: usize = 3;
79        const LATENT_CHANNELS: usize = 4;
80        const EMBED_DIM: usize = 384;
81        const BOTTLENECK_BLOCKS: usize = 12;
82        const C_LEVELS: [usize; 2] = [EMBED_DIM / 2, EMBED_DIM];
83
84        let in_block_conv = candle_nn::conv2d(
85            IN_CHANNELS * 4,
86            C_LEVELS[0],
87            1,
88            Default::default(),
89            vb.pp("in_block.1"),
90        )?;
91        let out_block_conv = candle_nn::conv2d(
92            C_LEVELS[0],
93            OUT_CHANNELS * 4,
94            1,
95            Default::default(),
96            vb.pp("out_block.0"),
97        )?;
98
99        let mut down_blocks = Vec::new();
100        let vb_d = vb.pp("down_blocks");
101        let mut d_idx = 0;
102        for (i, &c_level) in C_LEVELS.iter().enumerate() {
103            let conv_block = if i > 0 {
104                let cfg = candle_nn::Conv2dConfig {
105                    padding: 1,
106                    stride: 2,
107                    ..Default::default()
108                };
109                let block = candle_nn::conv2d(C_LEVELS[i - 1], c_level, 4, cfg, vb_d.pp(d_idx))?;
110                d_idx += 1;
111                Some(block)
112            } else {
113                None
114            };
115            let res_block = MixingResidualBlock::new(c_level, c_level * 4, vb_d.pp(d_idx))?;
116            d_idx += 1;
117            down_blocks.push((conv_block, res_block))
118        }
119        let vb_d = vb_d.pp(d_idx);
120        let down_blocks_conv = candle_nn::conv2d_no_bias(
121            C_LEVELS[1],
122            LATENT_CHANNELS,
123            1,
124            Default::default(),
125            vb_d.pp(0),
126        )?;
127        let down_blocks_bn = candle_nn::batch_norm(LATENT_CHANNELS, 1e-5, vb_d.pp(1))?;
128
129        let mut up_blocks = Vec::new();
130        let vb_u = vb.pp("up_blocks");
131        let mut u_idx = 0;
132        let up_blocks_conv = candle_nn::conv2d(
133            LATENT_CHANNELS,
134            C_LEVELS[1],
135            1,
136            Default::default(),
137            vb_u.pp(u_idx).pp(0),
138        )?;
139        u_idx += 1;
140        for (i, &c_level) in C_LEVELS.iter().rev().enumerate() {
141            let mut res_blocks = Vec::new();
142            let n_bottleneck_blocks = if i == 0 { BOTTLENECK_BLOCKS } else { 1 };
143            for _j in 0..n_bottleneck_blocks {
144                let res_block = MixingResidualBlock::new(c_level, c_level * 4, vb_u.pp(u_idx))?;
145                u_idx += 1;
146                res_blocks.push(res_block)
147            }
148            let conv_block = if i < C_LEVELS.len() - 1 {
149                let cfg = candle_nn::ConvTranspose2dConfig {
150                    padding: 1,
151                    stride: 2,
152                    ..Default::default()
153                };
154                let block = candle_nn::conv_transpose2d(
155                    c_level,
156                    C_LEVELS[C_LEVELS.len() - i - 2],
157                    4,
158                    cfg,
159                    vb_u.pp(u_idx),
160                )?;
161                u_idx += 1;
162                Some(block)
163            } else {
164                None
165            };
166            up_blocks.push((res_blocks, conv_block))
167        }
168        Ok(Self {
169            in_block_conv,
170            down_blocks,
171            down_blocks_conv,
172            down_blocks_bn,
173            up_blocks,
174            up_blocks_conv,
175            out_block_conv,
176        })
177    }
178
179    pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
180        let mut xs = candle_nn::ops::pixel_unshuffle(xs, 2)?.apply(&self.in_block_conv)?;
181        for down_block in self.down_blocks.iter() {
182            if let Some(conv) = &down_block.0 {
183                xs = xs.apply(conv)?
184            }
185            xs = xs.apply(&down_block.1)?
186        }
187        xs.apply(&self.down_blocks_conv)?
188            .apply_t(&self.down_blocks_bn, false)
189    }
190
191    pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
192        // TODO: quantizer if we want to support `force_not_quantize=False`.
193        let mut xs = xs.apply(&self.up_blocks_conv)?;
194        for up_block in self.up_blocks.iter() {
195            for b in up_block.0.iter() {
196                xs = xs.apply(b)?;
197            }
198            if let Some(conv) = &up_block.1 {
199                xs = xs.apply(conv)?
200            }
201        }
202        xs.apply(&self.out_block_conv)?
203            .apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, 2))
204    }
205}
206
207impl Module for PaellaVQ {
208    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
209        self.decode(&self.encode(xs)?)
210    }
211}