candle_transformers/models/mmdit/
embedding.rs

1use candle::{bail, DType, Module, Result, Tensor};
2use candle_nn as nn;
3
4pub struct PatchEmbedder {
5    proj: nn::Conv2d,
6}
7
8impl PatchEmbedder {
9    pub fn new(
10        patch_size: usize,
11        in_channels: usize,
12        embed_dim: usize,
13        vb: nn::VarBuilder,
14    ) -> Result<Self> {
15        let proj = nn::conv2d(
16            in_channels,
17            embed_dim,
18            patch_size,
19            nn::Conv2dConfig {
20                stride: patch_size,
21                ..Default::default()
22            },
23            vb.pp("proj"),
24        )?;
25
26        Ok(Self { proj })
27    }
28}
29
30impl Module for PatchEmbedder {
31    fn forward(&self, x: &Tensor) -> Result<Tensor> {
32        let x = self.proj.forward(x)?;
33
34        // flatten spatial dim and transpose to channels last
35        let (b, c, h, w) = x.dims4()?;
36        x.reshape((b, c, h * w))?.transpose(1, 2)
37    }
38}
39
40pub struct Unpatchifier {
41    patch_size: usize,
42    out_channels: usize,
43}
44
45impl Unpatchifier {
46    pub fn new(patch_size: usize, out_channels: usize) -> Result<Self> {
47        Ok(Self {
48            patch_size,
49            out_channels,
50        })
51    }
52
53    pub fn unpatchify(&self, x: &Tensor, h: usize, w: usize) -> Result<Tensor> {
54        let h = (h + 1) / self.patch_size;
55        let w = (w + 1) / self.patch_size;
56
57        let x = x.reshape((
58            x.dim(0)?,
59            h,
60            w,
61            self.patch_size,
62            self.patch_size,
63            self.out_channels,
64        ))?;
65        let x = x.permute((0, 5, 1, 3, 2, 4))?; // "nhwpqc->nchpwq"
66        x.reshape((
67            x.dim(0)?,
68            self.out_channels,
69            self.patch_size * h,
70            self.patch_size * w,
71        ))
72    }
73}
74
75pub struct PositionEmbedder {
76    pos_embed: Tensor,
77    patch_size: usize,
78    pos_embed_max_size: usize,
79}
80
81impl PositionEmbedder {
82    pub fn new(
83        hidden_size: usize,
84        patch_size: usize,
85        pos_embed_max_size: usize,
86        vb: nn::VarBuilder,
87    ) -> Result<Self> {
88        let pos_embed = vb.get(
89            (1, pos_embed_max_size * pos_embed_max_size, hidden_size),
90            "pos_embed",
91        )?;
92        Ok(Self {
93            pos_embed,
94            patch_size,
95            pos_embed_max_size,
96        })
97    }
98    pub fn get_cropped_pos_embed(&self, h: usize, w: usize) -> Result<Tensor> {
99        let h = (h + 1) / self.patch_size;
100        let w = (w + 1) / self.patch_size;
101
102        if h > self.pos_embed_max_size || w > self.pos_embed_max_size {
103            bail!("Input size is too large for the position embedding")
104        }
105
106        let top = (self.pos_embed_max_size - h) / 2;
107        let left = (self.pos_embed_max_size - w) / 2;
108
109        let pos_embed =
110            self.pos_embed
111                .reshape((1, self.pos_embed_max_size, self.pos_embed_max_size, ()))?;
112        let pos_embed = pos_embed.narrow(1, top, h)?.narrow(2, left, w)?;
113        pos_embed.reshape((1, h * w, ()))
114    }
115}
116
117pub struct TimestepEmbedder {
118    mlp: nn::Sequential,
119    frequency_embedding_size: usize,
120}
121
122impl TimestepEmbedder {
123    pub fn new(
124        hidden_size: usize,
125        frequency_embedding_size: usize,
126        vb: nn::VarBuilder,
127    ) -> Result<Self> {
128        let mlp = nn::seq()
129            .add(nn::linear(
130                frequency_embedding_size,
131                hidden_size,
132                vb.pp("mlp.0"),
133            )?)
134            .add(nn::Activation::Silu)
135            .add(nn::linear(hidden_size, hidden_size, vb.pp("mlp.2"))?);
136
137        Ok(Self {
138            mlp,
139            frequency_embedding_size,
140        })
141    }
142
143    fn timestep_embedding(t: &Tensor, dim: usize, max_period: f64) -> Result<Tensor> {
144        if dim % 2 != 0 {
145            bail!("Embedding dimension must be even")
146        }
147
148        if t.dtype() != DType::F32 && t.dtype() != DType::F64 {
149            bail!("Input tensor must be floating point")
150        }
151
152        let half = dim / 2;
153        let freqs = Tensor::arange(0f32, half as f32, t.device())?
154            .to_dtype(candle::DType::F32)?
155            .mul(&Tensor::full(
156                (-f64::ln(max_period) / half as f64) as f32,
157                half,
158                t.device(),
159            )?)?
160            .exp()?;
161
162        let args = t
163            .unsqueeze(1)?
164            .to_dtype(candle::DType::F32)?
165            .matmul(&freqs.unsqueeze(0)?)?;
166        let embedding = Tensor::cat(&[args.cos()?, args.sin()?], 1)?;
167        embedding.to_dtype(candle::DType::F16)
168    }
169}
170
171impl Module for TimestepEmbedder {
172    fn forward(&self, t: &Tensor) -> Result<Tensor> {
173        let t_freq = Self::timestep_embedding(t, self.frequency_embedding_size, 10000.0)?;
174        self.mlp.forward(&t_freq)
175    }
176}
177
178pub struct VectorEmbedder {
179    mlp: nn::Sequential,
180}
181
182impl VectorEmbedder {
183    pub fn new(input_dim: usize, hidden_size: usize, vb: nn::VarBuilder) -> Result<Self> {
184        let mlp = nn::seq()
185            .add(nn::linear(input_dim, hidden_size, vb.pp("mlp.0"))?)
186            .add(nn::Activation::Silu)
187            .add(nn::linear(hidden_size, hidden_size, vb.pp("mlp.2"))?);
188
189        Ok(Self { mlp })
190    }
191}
192
193impl Module for VectorEmbedder {
194    fn forward(&self, x: &Tensor) -> Result<Tensor> {
195        self.mlp.forward(x)
196    }
197}