candle_transformers/models/wuerstchen/
paella_vq.rs1use super::common::LayerNormNoWeights;
2use candle::{Module, Result, Tensor};
3use candle_nn::VarBuilder;
4
5#[derive(Debug)]
6pub struct MixingResidualBlock {
7 norm1: LayerNormNoWeights,
8 depthwise_conv: candle_nn::Conv2d,
9 norm2: LayerNormNoWeights,
10 channelwise_lin1: candle_nn::Linear,
11 channelwise_lin2: candle_nn::Linear,
12 gammas: Vec<f32>,
13}
14
15impl MixingResidualBlock {
16 pub fn new(inp: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
17 let norm1 = LayerNormNoWeights::new(inp)?;
18 let norm2 = LayerNormNoWeights::new(inp)?;
19 let cfg = candle_nn::Conv2dConfig {
20 groups: inp,
21 ..Default::default()
22 };
23 let depthwise_conv = candle_nn::conv2d(inp, inp, 3, cfg, vb.pp("depthwise.1"))?;
24 let channelwise_lin1 = candle_nn::linear(inp, embed_dim, vb.pp("channelwise.0"))?;
25 let channelwise_lin2 = candle_nn::linear(embed_dim, inp, vb.pp("channelwise.2"))?;
26 let gammas = vb.get(6, "gammas")?.to_vec1::<f32>()?;
27 Ok(Self {
28 norm1,
29 depthwise_conv,
30 norm2,
31 channelwise_lin1,
32 channelwise_lin2,
33 gammas,
34 })
35 }
36}
37
38impl Module for MixingResidualBlock {
39 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
40 let mods = &self.gammas;
41 let x_temp = xs
42 .permute((0, 2, 3, 1))?
43 .apply(&self.norm1)?
44 .permute((0, 3, 1, 2))?
45 .affine(1. + mods[0] as f64, mods[1] as f64)?;
46 let x_temp = candle_nn::ops::replication_pad2d(&x_temp, 1)?;
47 let xs = (xs + x_temp.apply(&self.depthwise_conv)? * mods[2] as f64)?;
48 let x_temp = xs
49 .permute((0, 2, 3, 1))?
50 .apply(&self.norm2)?
51 .permute((0, 3, 1, 2))?
52 .affine(1. + mods[3] as f64, mods[4] as f64)?;
53 let x_temp = x_temp
54 .permute((0, 2, 3, 1))?
55 .contiguous()?
56 .apply(&self.channelwise_lin1)?
57 .gelu()?
58 .apply(&self.channelwise_lin2)?
59 .permute((0, 3, 1, 2))?;
60 xs + x_temp * mods[5] as f64
61 }
62}
63
64#[derive(Debug)]
65pub struct PaellaVQ {
66 in_block_conv: candle_nn::Conv2d,
67 out_block_conv: candle_nn::Conv2d,
68 down_blocks: Vec<(Option<candle_nn::Conv2d>, MixingResidualBlock)>,
69 down_blocks_conv: candle_nn::Conv2d,
70 down_blocks_bn: candle_nn::BatchNorm,
71 up_blocks_conv: candle_nn::Conv2d,
72 up_blocks: Vec<(Vec<MixingResidualBlock>, Option<candle_nn::ConvTranspose2d>)>,
73}
74
75impl PaellaVQ {
76 pub fn new(vb: VarBuilder) -> Result<Self> {
77 const IN_CHANNELS: usize = 3;
78 const OUT_CHANNELS: usize = 3;
79 const LATENT_CHANNELS: usize = 4;
80 const EMBED_DIM: usize = 384;
81 const BOTTLENECK_BLOCKS: usize = 12;
82 const C_LEVELS: [usize; 2] = [EMBED_DIM / 2, EMBED_DIM];
83
84 let in_block_conv = candle_nn::conv2d(
85 IN_CHANNELS * 4,
86 C_LEVELS[0],
87 1,
88 Default::default(),
89 vb.pp("in_block.1"),
90 )?;
91 let out_block_conv = candle_nn::conv2d(
92 C_LEVELS[0],
93 OUT_CHANNELS * 4,
94 1,
95 Default::default(),
96 vb.pp("out_block.0"),
97 )?;
98
99 let mut down_blocks = Vec::new();
100 let vb_d = vb.pp("down_blocks");
101 let mut d_idx = 0;
102 for (i, &c_level) in C_LEVELS.iter().enumerate() {
103 let conv_block = if i > 0 {
104 let cfg = candle_nn::Conv2dConfig {
105 padding: 1,
106 stride: 2,
107 ..Default::default()
108 };
109 let block = candle_nn::conv2d(C_LEVELS[i - 1], c_level, 4, cfg, vb_d.pp(d_idx))?;
110 d_idx += 1;
111 Some(block)
112 } else {
113 None
114 };
115 let res_block = MixingResidualBlock::new(c_level, c_level * 4, vb_d.pp(d_idx))?;
116 d_idx += 1;
117 down_blocks.push((conv_block, res_block))
118 }
119 let vb_d = vb_d.pp(d_idx);
120 let down_blocks_conv = candle_nn::conv2d_no_bias(
121 C_LEVELS[1],
122 LATENT_CHANNELS,
123 1,
124 Default::default(),
125 vb_d.pp(0),
126 )?;
127 let down_blocks_bn = candle_nn::batch_norm(LATENT_CHANNELS, 1e-5, vb_d.pp(1))?;
128
129 let mut up_blocks = Vec::new();
130 let vb_u = vb.pp("up_blocks");
131 let mut u_idx = 0;
132 let up_blocks_conv = candle_nn::conv2d(
133 LATENT_CHANNELS,
134 C_LEVELS[1],
135 1,
136 Default::default(),
137 vb_u.pp(u_idx).pp(0),
138 )?;
139 u_idx += 1;
140 for (i, &c_level) in C_LEVELS.iter().rev().enumerate() {
141 let mut res_blocks = Vec::new();
142 let n_bottleneck_blocks = if i == 0 { BOTTLENECK_BLOCKS } else { 1 };
143 for _j in 0..n_bottleneck_blocks {
144 let res_block = MixingResidualBlock::new(c_level, c_level * 4, vb_u.pp(u_idx))?;
145 u_idx += 1;
146 res_blocks.push(res_block)
147 }
148 let conv_block = if i < C_LEVELS.len() - 1 {
149 let cfg = candle_nn::ConvTranspose2dConfig {
150 padding: 1,
151 stride: 2,
152 ..Default::default()
153 };
154 let block = candle_nn::conv_transpose2d(
155 c_level,
156 C_LEVELS[C_LEVELS.len() - i - 2],
157 4,
158 cfg,
159 vb_u.pp(u_idx),
160 )?;
161 u_idx += 1;
162 Some(block)
163 } else {
164 None
165 };
166 up_blocks.push((res_blocks, conv_block))
167 }
168 Ok(Self {
169 in_block_conv,
170 down_blocks,
171 down_blocks_conv,
172 down_blocks_bn,
173 up_blocks,
174 up_blocks_conv,
175 out_block_conv,
176 })
177 }
178
179 pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
180 let mut xs = candle_nn::ops::pixel_unshuffle(xs, 2)?.apply(&self.in_block_conv)?;
181 for down_block in self.down_blocks.iter() {
182 if let Some(conv) = &down_block.0 {
183 xs = xs.apply(conv)?
184 }
185 xs = xs.apply(&down_block.1)?
186 }
187 xs.apply(&self.down_blocks_conv)?
188 .apply_t(&self.down_blocks_bn, false)
189 }
190
191 pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
192 let mut xs = xs.apply(&self.up_blocks_conv)?;
194 for up_block in self.up_blocks.iter() {
195 for b in up_block.0.iter() {
196 xs = xs.apply(b)?;
197 }
198 if let Some(conv) = &up_block.1 {
199 xs = xs.apply(conv)?
200 }
201 }
202 xs.apply(&self.out_block_conv)?
203 .apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, 2))
204 }
205}
206
207impl Module for PaellaVQ {
208 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
209 self.decode(&self.encode(xs)?)
210 }
211}