candle_transformers/models/mmdit/
projections.rs1use candle::{Module, Result, Tensor};
2use candle_nn as nn;
3
4pub struct Qkv {
5 pub q: Tensor,
6 pub k: Tensor,
7 pub v: Tensor,
8}
9
10pub struct Mlp {
11 fc1: nn::Linear,
12 act: nn::Activation,
13 fc2: nn::Linear,
14}
15
16impl Mlp {
17 pub fn new(
18 in_features: usize,
19 hidden_features: usize,
20 vb: candle_nn::VarBuilder,
21 ) -> Result<Self> {
22 let fc1 = nn::linear(in_features, hidden_features, vb.pp("fc1"))?;
23 let act = nn::Activation::GeluPytorchTanh;
24 let fc2 = nn::linear(hidden_features, in_features, vb.pp("fc2"))?;
25
26 Ok(Self { fc1, act, fc2 })
27 }
28}
29
30impl Module for Mlp {
31 fn forward(&self, x: &Tensor) -> Result<Tensor> {
32 let x = self.fc1.forward(x)?;
33 let x = self.act.forward(&x)?;
34 self.fc2.forward(&x)
35 }
36}
37
38pub struct QkvOnlyAttnProjections {
39 qkv: nn::Linear,
40 head_dim: usize,
41}
42
43impl QkvOnlyAttnProjections {
44 pub fn new(dim: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
45 let head_dim = dim / num_heads;
46 let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?;
47 Ok(Self { qkv, head_dim })
48 }
49
50 pub fn pre_attention(&self, x: &Tensor) -> Result<Qkv> {
51 let qkv = self.qkv.forward(x)?;
52 split_qkv(&qkv, self.head_dim)
53 }
54}
55
56pub struct AttnProjections {
57 head_dim: usize,
58 qkv: nn::Linear,
59 ln_k: Option<candle_nn::RmsNorm>,
60 ln_q: Option<candle_nn::RmsNorm>,
61 proj: nn::Linear,
62}
63
64impl AttnProjections {
65 pub fn new(dim: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
66 let head_dim = dim / num_heads;
67 let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?;
68 let proj = nn::linear(dim, dim, vb.pp("proj"))?;
69 let (ln_k, ln_q) = if vb.contains_tensor("ln_k.weight") {
70 let ln_k = candle_nn::rms_norm(head_dim, 1e-6, vb.pp("ln_k"))?;
71 let ln_q = candle_nn::rms_norm(head_dim, 1e-6, vb.pp("ln_q"))?;
72 (Some(ln_k), Some(ln_q))
73 } else {
74 (None, None)
75 };
76 Ok(Self {
77 head_dim,
78 qkv,
79 proj,
80 ln_k,
81 ln_q,
82 })
83 }
84
85 pub fn pre_attention(&self, x: &Tensor) -> Result<Qkv> {
86 let qkv = self.qkv.forward(x)?;
87 let Qkv { q, k, v } = split_qkv(&qkv, self.head_dim)?;
88 let q = match self.ln_q.as_ref() {
89 None => q,
90 Some(l) => {
91 let (b, t, h) = q.dims3()?;
92 l.forward(&q.reshape((b, t, (), self.head_dim))?)?
93 .reshape((b, t, h))?
94 }
95 };
96 let k = match self.ln_k.as_ref() {
97 None => k,
98 Some(l) => {
99 let (b, t, h) = k.dims3()?;
100 l.forward(&k.reshape((b, t, (), self.head_dim))?)?
101 .reshape((b, t, h))?
102 }
103 };
104 Ok(Qkv { q, k, v })
105 }
106
107 pub fn post_attention(&self, x: &Tensor) -> Result<Tensor> {
108 self.proj.forward(x)
109 }
110}
111
112fn split_qkv(qkv: &Tensor, head_dim: usize) -> Result<Qkv> {
113 let (batch_size, seq_len, _) = qkv.dims3()?;
114 let qkv = qkv.reshape((batch_size, seq_len, 3, (), head_dim))?;
115 let q = qkv.get_on_dim(2, 0)?;
116 let q = q.reshape((batch_size, seq_len, ()))?;
117 let k = qkv.get_on_dim(2, 1)?;
118 let k = k.reshape((batch_size, seq_len, ()))?;
119 let v = qkv.get_on_dim(2, 2)?;
120 Ok(Qkv { q, k, v })
121}