candle_transformers/models/wuerstchen/
common.rs1use candle::{DType, Module, Result, Tensor, D};
2use candle_nn::VarBuilder;
3
4#[derive(Debug)]
6pub struct WLayerNorm {
7 eps: f64,
8}
9
10impl WLayerNorm {
11 pub fn new(_size: usize) -> Result<Self> {
12 Ok(Self { eps: 1e-6 })
13 }
14}
15
16impl Module for WLayerNorm {
17 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
18 let xs = xs.permute((0, 2, 3, 1))?;
19
20 let x_dtype = xs.dtype();
21 let internal_dtype = match x_dtype {
22 DType::F16 | DType::BF16 => DType::F32,
23 d => d,
24 };
25
26 let hidden_size = xs.dim(D::Minus1)?;
27 let xs = xs.to_dtype(internal_dtype)?;
28 let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
29 let xs = xs.broadcast_sub(&mean_x)?;
30 let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
31 xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
32 .to_dtype(x_dtype)?
33 .permute((0, 3, 1, 2))
34 }
35}
36
37#[derive(Debug)]
38pub struct LayerNormNoWeights {
39 eps: f64,
40}
41
42impl LayerNormNoWeights {
43 pub fn new(_size: usize) -> Result<Self> {
44 Ok(Self { eps: 1e-6 })
45 }
46}
47
48impl Module for LayerNormNoWeights {
49 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
50 let x_dtype = xs.dtype();
51 let internal_dtype = match x_dtype {
52 DType::F16 | DType::BF16 => DType::F32,
53 d => d,
54 };
55 let hidden_size = xs.dim(D::Minus1)?;
56 let xs = xs.to_dtype(internal_dtype)?;
57 let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
58 let xs = xs.broadcast_sub(&mean_x)?;
59 let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
60 xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
61 .to_dtype(x_dtype)
62 }
63}
64
65#[derive(Debug)]
66pub struct TimestepBlock {
67 mapper: candle_nn::Linear,
68}
69
70impl TimestepBlock {
71 pub fn new(c: usize, c_timestep: usize, vb: VarBuilder) -> Result<Self> {
72 let mapper = candle_nn::linear(c_timestep, c * 2, vb.pp("mapper"))?;
73 Ok(Self { mapper })
74 }
75
76 pub fn forward(&self, xs: &Tensor, t: &Tensor) -> Result<Tensor> {
77 let ab = self
78 .mapper
79 .forward(t)?
80 .unsqueeze(2)?
81 .unsqueeze(3)?
82 .chunk(2, 1)?;
83 xs.broadcast_mul(&(&ab[0] + 1.)?)?.broadcast_add(&ab[1])
84 }
85}
86
87#[derive(Debug)]
88pub struct GlobalResponseNorm {
89 gamma: Tensor,
90 beta: Tensor,
91}
92
93impl GlobalResponseNorm {
94 pub fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
95 let gamma = vb.get((1, 1, 1, dim), "gamma")?;
96 let beta = vb.get((1, 1, 1, dim), "beta")?;
97 Ok(Self { gamma, beta })
98 }
99}
100
101impl Module for GlobalResponseNorm {
102 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
103 let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
104 let stand_div_norm =
105 agg_norm.broadcast_div(&(agg_norm.mean_keepdim(D::Minus1)? + 1e-6)?)?;
106 xs.broadcast_mul(&stand_div_norm)?
107 .broadcast_mul(&self.gamma)?
108 .broadcast_add(&self.beta)?
109 + xs
110 }
111}
112
113#[derive(Debug)]
114pub struct ResBlock {
115 depthwise: candle_nn::Conv2d,
116 norm: WLayerNorm,
117 channelwise_lin1: candle_nn::Linear,
118 channelwise_grn: GlobalResponseNorm,
119 channelwise_lin2: candle_nn::Linear,
120}
121
122impl ResBlock {
123 pub fn new(c: usize, c_skip: usize, ksize: usize, vb: VarBuilder) -> Result<Self> {
124 let cfg = candle_nn::Conv2dConfig {
125 padding: ksize / 2,
126 groups: c,
127 ..Default::default()
128 };
129 let depthwise = candle_nn::conv2d(c + c_skip, c, ksize, cfg, vb.pp("depthwise"))?;
130 let norm = WLayerNorm::new(c)?;
131 let channelwise_lin1 = candle_nn::linear(c, c * 4, vb.pp("channelwise.0"))?;
132 let channelwise_grn = GlobalResponseNorm::new(c * 4, vb.pp("channelwise.2"))?;
133 let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?;
134 Ok(Self {
135 depthwise,
136 norm,
137 channelwise_lin1,
138 channelwise_grn,
139 channelwise_lin2,
140 })
141 }
142
143 pub fn forward(&self, xs: &Tensor, x_skip: Option<&Tensor>) -> Result<Tensor> {
144 let x_res = xs;
145 let xs = match x_skip {
146 None => xs.clone(),
147 Some(x_skip) => Tensor::cat(&[xs, x_skip], 1)?,
148 };
149 let xs = xs
150 .apply(&self.depthwise)?
151 .apply(&self.norm)?
152 .permute((0, 2, 3, 1))?;
153 let xs = xs
154 .apply(&self.channelwise_lin1)?
155 .gelu_erf()?
156 .apply(&self.channelwise_grn)?
157 .apply(&self.channelwise_lin2)?
158 .permute((0, 3, 1, 2))?;
159 xs + x_res
160 }
161}
162use super::attention_processor::Attention;
163#[derive(Debug)]
164pub struct AttnBlock {
165 self_attn: bool,
166 norm: WLayerNorm,
167 attention: Attention,
168 kv_mapper_lin: candle_nn::Linear,
169}
170
171impl AttnBlock {
172 pub fn new(
173 c: usize,
174 c_cond: usize,
175 nhead: usize,
176 self_attn: bool,
177 use_flash_attn: bool,
178 vb: VarBuilder,
179 ) -> Result<Self> {
180 let norm = WLayerNorm::new(c)?;
181 let attention = Attention::new(c, nhead, c / nhead, use_flash_attn, vb.pp("attention"))?;
182 let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?;
183 Ok(Self {
184 self_attn,
185 norm,
186 attention,
187 kv_mapper_lin,
188 })
189 }
190
191 pub fn forward(&self, xs: &Tensor, kv: &Tensor) -> Result<Tensor> {
192 let kv = candle_nn::ops::silu(kv)?.apply(&self.kv_mapper_lin)?;
193 let norm_xs = self.norm.forward(xs)?;
194 let kv = if self.self_attn {
195 let (b_size, channel, _, _) = xs.dims4()?;
196 let norm_xs = norm_xs.reshape((b_size, channel, ()))?.transpose(1, 2)?;
197 Tensor::cat(&[&norm_xs, &kv], 1)?.contiguous()?
198 } else {
199 kv
200 };
201 xs + self.attention.forward(&norm_xs, &kv)
202 }
203}