candle_transformers/models/mmdit/
embedding.rs1use candle::{bail, DType, Module, Result, Tensor};
2use candle_nn as nn;
3
4pub struct PatchEmbedder {
5 proj: nn::Conv2d,
6}
7
8impl PatchEmbedder {
9 pub fn new(
10 patch_size: usize,
11 in_channels: usize,
12 embed_dim: usize,
13 vb: nn::VarBuilder,
14 ) -> Result<Self> {
15 let proj = nn::conv2d(
16 in_channels,
17 embed_dim,
18 patch_size,
19 nn::Conv2dConfig {
20 stride: patch_size,
21 ..Default::default()
22 },
23 vb.pp("proj"),
24 )?;
25
26 Ok(Self { proj })
27 }
28}
29
30impl Module for PatchEmbedder {
31 fn forward(&self, x: &Tensor) -> Result<Tensor> {
32 let x = self.proj.forward(x)?;
33
34 let (b, c, h, w) = x.dims4()?;
36 x.reshape((b, c, h * w))?.transpose(1, 2)
37 }
38}
39
40pub struct Unpatchifier {
41 patch_size: usize,
42 out_channels: usize,
43}
44
45impl Unpatchifier {
46 pub fn new(patch_size: usize, out_channels: usize) -> Result<Self> {
47 Ok(Self {
48 patch_size,
49 out_channels,
50 })
51 }
52
53 pub fn unpatchify(&self, x: &Tensor, h: usize, w: usize) -> Result<Tensor> {
54 let h = (h + 1) / self.patch_size;
55 let w = (w + 1) / self.patch_size;
56
57 let x = x.reshape((
58 x.dim(0)?,
59 h,
60 w,
61 self.patch_size,
62 self.patch_size,
63 self.out_channels,
64 ))?;
65 let x = x.permute((0, 5, 1, 3, 2, 4))?; x.reshape((
67 x.dim(0)?,
68 self.out_channels,
69 self.patch_size * h,
70 self.patch_size * w,
71 ))
72 }
73}
74
75pub struct PositionEmbedder {
76 pos_embed: Tensor,
77 patch_size: usize,
78 pos_embed_max_size: usize,
79}
80
81impl PositionEmbedder {
82 pub fn new(
83 hidden_size: usize,
84 patch_size: usize,
85 pos_embed_max_size: usize,
86 vb: nn::VarBuilder,
87 ) -> Result<Self> {
88 let pos_embed = vb.get(
89 (1, pos_embed_max_size * pos_embed_max_size, hidden_size),
90 "pos_embed",
91 )?;
92 Ok(Self {
93 pos_embed,
94 patch_size,
95 pos_embed_max_size,
96 })
97 }
98 pub fn get_cropped_pos_embed(&self, h: usize, w: usize) -> Result<Tensor> {
99 let h = (h + 1) / self.patch_size;
100 let w = (w + 1) / self.patch_size;
101
102 if h > self.pos_embed_max_size || w > self.pos_embed_max_size {
103 bail!("Input size is too large for the position embedding")
104 }
105
106 let top = (self.pos_embed_max_size - h) / 2;
107 let left = (self.pos_embed_max_size - w) / 2;
108
109 let pos_embed =
110 self.pos_embed
111 .reshape((1, self.pos_embed_max_size, self.pos_embed_max_size, ()))?;
112 let pos_embed = pos_embed.narrow(1, top, h)?.narrow(2, left, w)?;
113 pos_embed.reshape((1, h * w, ()))
114 }
115}
116
117pub struct TimestepEmbedder {
118 mlp: nn::Sequential,
119 frequency_embedding_size: usize,
120}
121
122impl TimestepEmbedder {
123 pub fn new(
124 hidden_size: usize,
125 frequency_embedding_size: usize,
126 vb: nn::VarBuilder,
127 ) -> Result<Self> {
128 let mlp = nn::seq()
129 .add(nn::linear(
130 frequency_embedding_size,
131 hidden_size,
132 vb.pp("mlp.0"),
133 )?)
134 .add(nn::Activation::Silu)
135 .add(nn::linear(hidden_size, hidden_size, vb.pp("mlp.2"))?);
136
137 Ok(Self {
138 mlp,
139 frequency_embedding_size,
140 })
141 }
142
143 fn timestep_embedding(t: &Tensor, dim: usize, max_period: f64) -> Result<Tensor> {
144 if dim % 2 != 0 {
145 bail!("Embedding dimension must be even")
146 }
147
148 if t.dtype() != DType::F32 && t.dtype() != DType::F64 {
149 bail!("Input tensor must be floating point")
150 }
151
152 let half = dim / 2;
153 let freqs = Tensor::arange(0f32, half as f32, t.device())?
154 .to_dtype(candle::DType::F32)?
155 .mul(&Tensor::full(
156 (-f64::ln(max_period) / half as f64) as f32,
157 half,
158 t.device(),
159 )?)?
160 .exp()?;
161
162 let args = t
163 .unsqueeze(1)?
164 .to_dtype(candle::DType::F32)?
165 .matmul(&freqs.unsqueeze(0)?)?;
166 let embedding = Tensor::cat(&[args.cos()?, args.sin()?], 1)?;
167 embedding.to_dtype(candle::DType::F16)
168 }
169}
170
171impl Module for TimestepEmbedder {
172 fn forward(&self, t: &Tensor) -> Result<Tensor> {
173 let t_freq = Self::timestep_embedding(t, self.frequency_embedding_size, 10000.0)?;
174 self.mlp.forward(&t_freq)
175 }
176}
177
178pub struct VectorEmbedder {
179 mlp: nn::Sequential,
180}
181
182impl VectorEmbedder {
183 pub fn new(input_dim: usize, hidden_size: usize, vb: nn::VarBuilder) -> Result<Self> {
184 let mlp = nn::seq()
185 .add(nn::linear(input_dim, hidden_size, vb.pp("mlp.0"))?)
186 .add(nn::Activation::Silu)
187 .add(nn::linear(hidden_size, hidden_size, vb.pp("mlp.2"))?);
188
189 Ok(Self { mlp })
190 }
191}
192
193impl Module for VectorEmbedder {
194 fn forward(&self, x: &Tensor) -> Result<Tensor> {
195 self.mlp.forward(x)
196 }
197}