candle_transformers/models/wuerstchen/
common.rs

1use candle::{DType, Module, Result, Tensor, D};
2use candle_nn::VarBuilder;
3
4// https://github.com/huggingface/diffusers/blob/19edca82f1ff194c07317369a92b470dbae97f34/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py#L22
5#[derive(Debug)]
6pub struct WLayerNorm {
7    eps: f64,
8}
9
10impl WLayerNorm {
11    pub fn new(_size: usize) -> Result<Self> {
12        Ok(Self { eps: 1e-6 })
13    }
14}
15
16impl Module for WLayerNorm {
17    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
18        let xs = xs.permute((0, 2, 3, 1))?;
19
20        let x_dtype = xs.dtype();
21        let internal_dtype = match x_dtype {
22            DType::F16 | DType::BF16 => DType::F32,
23            d => d,
24        };
25
26        let hidden_size = xs.dim(D::Minus1)?;
27        let xs = xs.to_dtype(internal_dtype)?;
28        let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
29        let xs = xs.broadcast_sub(&mean_x)?;
30        let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
31        xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
32            .to_dtype(x_dtype)?
33            .permute((0, 3, 1, 2))
34    }
35}
36
37#[derive(Debug)]
38pub struct LayerNormNoWeights {
39    eps: f64,
40}
41
42impl LayerNormNoWeights {
43    pub fn new(_size: usize) -> Result<Self> {
44        Ok(Self { eps: 1e-6 })
45    }
46}
47
48impl Module for LayerNormNoWeights {
49    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
50        let x_dtype = xs.dtype();
51        let internal_dtype = match x_dtype {
52            DType::F16 | DType::BF16 => DType::F32,
53            d => d,
54        };
55        let hidden_size = xs.dim(D::Minus1)?;
56        let xs = xs.to_dtype(internal_dtype)?;
57        let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
58        let xs = xs.broadcast_sub(&mean_x)?;
59        let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
60        xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
61            .to_dtype(x_dtype)
62    }
63}
64
65#[derive(Debug)]
66pub struct TimestepBlock {
67    mapper: candle_nn::Linear,
68}
69
70impl TimestepBlock {
71    pub fn new(c: usize, c_timestep: usize, vb: VarBuilder) -> Result<Self> {
72        let mapper = candle_nn::linear(c_timestep, c * 2, vb.pp("mapper"))?;
73        Ok(Self { mapper })
74    }
75
76    pub fn forward(&self, xs: &Tensor, t: &Tensor) -> Result<Tensor> {
77        let ab = self
78            .mapper
79            .forward(t)?
80            .unsqueeze(2)?
81            .unsqueeze(3)?
82            .chunk(2, 1)?;
83        xs.broadcast_mul(&(&ab[0] + 1.)?)?.broadcast_add(&ab[1])
84    }
85}
86
87#[derive(Debug)]
88pub struct GlobalResponseNorm {
89    gamma: Tensor,
90    beta: Tensor,
91}
92
93impl GlobalResponseNorm {
94    pub fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
95        let gamma = vb.get((1, 1, 1, dim), "gamma")?;
96        let beta = vb.get((1, 1, 1, dim), "beta")?;
97        Ok(Self { gamma, beta })
98    }
99}
100
101impl Module for GlobalResponseNorm {
102    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
103        let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
104        let stand_div_norm =
105            agg_norm.broadcast_div(&(agg_norm.mean_keepdim(D::Minus1)? + 1e-6)?)?;
106        xs.broadcast_mul(&stand_div_norm)?
107            .broadcast_mul(&self.gamma)?
108            .broadcast_add(&self.beta)?
109            + xs
110    }
111}
112
113#[derive(Debug)]
114pub struct ResBlock {
115    depthwise: candle_nn::Conv2d,
116    norm: WLayerNorm,
117    channelwise_lin1: candle_nn::Linear,
118    channelwise_grn: GlobalResponseNorm,
119    channelwise_lin2: candle_nn::Linear,
120}
121
122impl ResBlock {
123    pub fn new(c: usize, c_skip: usize, ksize: usize, vb: VarBuilder) -> Result<Self> {
124        let cfg = candle_nn::Conv2dConfig {
125            padding: ksize / 2,
126            groups: c,
127            ..Default::default()
128        };
129        let depthwise = candle_nn::conv2d(c + c_skip, c, ksize, cfg, vb.pp("depthwise"))?;
130        let norm = WLayerNorm::new(c)?;
131        let channelwise_lin1 = candle_nn::linear(c, c * 4, vb.pp("channelwise.0"))?;
132        let channelwise_grn = GlobalResponseNorm::new(c * 4, vb.pp("channelwise.2"))?;
133        let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?;
134        Ok(Self {
135            depthwise,
136            norm,
137            channelwise_lin1,
138            channelwise_grn,
139            channelwise_lin2,
140        })
141    }
142
143    pub fn forward(&self, xs: &Tensor, x_skip: Option<&Tensor>) -> Result<Tensor> {
144        let x_res = xs;
145        let xs = match x_skip {
146            None => xs.clone(),
147            Some(x_skip) => Tensor::cat(&[xs, x_skip], 1)?,
148        };
149        let xs = xs
150            .apply(&self.depthwise)?
151            .apply(&self.norm)?
152            .permute((0, 2, 3, 1))?;
153        let xs = xs
154            .apply(&self.channelwise_lin1)?
155            .gelu_erf()?
156            .apply(&self.channelwise_grn)?
157            .apply(&self.channelwise_lin2)?
158            .permute((0, 3, 1, 2))?;
159        xs + x_res
160    }
161}
162use super::attention_processor::Attention;
163#[derive(Debug)]
164pub struct AttnBlock {
165    self_attn: bool,
166    norm: WLayerNorm,
167    attention: Attention,
168    kv_mapper_lin: candle_nn::Linear,
169}
170
171impl AttnBlock {
172    pub fn new(
173        c: usize,
174        c_cond: usize,
175        nhead: usize,
176        self_attn: bool,
177        use_flash_attn: bool,
178        vb: VarBuilder,
179    ) -> Result<Self> {
180        let norm = WLayerNorm::new(c)?;
181        let attention = Attention::new(c, nhead, c / nhead, use_flash_attn, vb.pp("attention"))?;
182        let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?;
183        Ok(Self {
184            self_attn,
185            norm,
186            attention,
187            kv_mapper_lin,
188        })
189    }
190
191    pub fn forward(&self, xs: &Tensor, kv: &Tensor) -> Result<Tensor> {
192        let kv = candle_nn::ops::silu(kv)?.apply(&self.kv_mapper_lin)?;
193        let norm_xs = self.norm.forward(xs)?;
194        let kv = if self.self_attn {
195            let (b_size, channel, _, _) = xs.dims4()?;
196            let norm_xs = norm_xs.reshape((b_size, channel, ()))?.transpose(1, 2)?;
197            Tensor::cat(&[&norm_xs, &kv], 1)?.contiguous()?
198        } else {
199            kv
200        };
201        xs + self.attention.forward(&norm_xs, &kv)
202    }
203}