candle_transformers/models/mmdit/
projections.rs

1use candle::{Module, Result, Tensor};
2use candle_nn as nn;
3
4pub struct Qkv {
5    pub q: Tensor,
6    pub k: Tensor,
7    pub v: Tensor,
8}
9
10pub struct Mlp {
11    fc1: nn::Linear,
12    act: nn::Activation,
13    fc2: nn::Linear,
14}
15
16impl Mlp {
17    pub fn new(
18        in_features: usize,
19        hidden_features: usize,
20        vb: candle_nn::VarBuilder,
21    ) -> Result<Self> {
22        let fc1 = nn::linear(in_features, hidden_features, vb.pp("fc1"))?;
23        let act = nn::Activation::GeluPytorchTanh;
24        let fc2 = nn::linear(hidden_features, in_features, vb.pp("fc2"))?;
25
26        Ok(Self { fc1, act, fc2 })
27    }
28}
29
30impl Module for Mlp {
31    fn forward(&self, x: &Tensor) -> Result<Tensor> {
32        let x = self.fc1.forward(x)?;
33        let x = self.act.forward(&x)?;
34        self.fc2.forward(&x)
35    }
36}
37
38pub struct QkvOnlyAttnProjections {
39    qkv: nn::Linear,
40    head_dim: usize,
41}
42
43impl QkvOnlyAttnProjections {
44    pub fn new(dim: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
45        let head_dim = dim / num_heads;
46        let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?;
47        Ok(Self { qkv, head_dim })
48    }
49
50    pub fn pre_attention(&self, x: &Tensor) -> Result<Qkv> {
51        let qkv = self.qkv.forward(x)?;
52        split_qkv(&qkv, self.head_dim)
53    }
54}
55
56pub struct AttnProjections {
57    head_dim: usize,
58    qkv: nn::Linear,
59    ln_k: Option<candle_nn::RmsNorm>,
60    ln_q: Option<candle_nn::RmsNorm>,
61    proj: nn::Linear,
62}
63
64impl AttnProjections {
65    pub fn new(dim: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
66        let head_dim = dim / num_heads;
67        let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?;
68        let proj = nn::linear(dim, dim, vb.pp("proj"))?;
69        let (ln_k, ln_q) = if vb.contains_tensor("ln_k.weight") {
70            let ln_k = candle_nn::rms_norm(head_dim, 1e-6, vb.pp("ln_k"))?;
71            let ln_q = candle_nn::rms_norm(head_dim, 1e-6, vb.pp("ln_q"))?;
72            (Some(ln_k), Some(ln_q))
73        } else {
74            (None, None)
75        };
76        Ok(Self {
77            head_dim,
78            qkv,
79            proj,
80            ln_k,
81            ln_q,
82        })
83    }
84
85    pub fn pre_attention(&self, x: &Tensor) -> Result<Qkv> {
86        let qkv = self.qkv.forward(x)?;
87        let Qkv { q, k, v } = split_qkv(&qkv, self.head_dim)?;
88        let q = match self.ln_q.as_ref() {
89            None => q,
90            Some(l) => {
91                let (b, t, h) = q.dims3()?;
92                l.forward(&q.reshape((b, t, (), self.head_dim))?)?
93                    .reshape((b, t, h))?
94            }
95        };
96        let k = match self.ln_k.as_ref() {
97            None => k,
98            Some(l) => {
99                let (b, t, h) = k.dims3()?;
100                l.forward(&k.reshape((b, t, (), self.head_dim))?)?
101                    .reshape((b, t, h))?
102            }
103        };
104        Ok(Qkv { q, k, v })
105    }
106
107    pub fn post_attention(&self, x: &Tensor) -> Result<Tensor> {
108        self.proj.forward(x)
109    }
110}
111
112fn split_qkv(qkv: &Tensor, head_dim: usize) -> Result<Qkv> {
113    let (batch_size, seq_len, _) = qkv.dims3()?;
114    let qkv = qkv.reshape((batch_size, seq_len, 3, (), head_dim))?;
115    let q = qkv.get_on_dim(2, 0)?;
116    let q = q.reshape((batch_size, seq_len, ()))?;
117    let k = qkv.get_on_dim(2, 1)?;
118    let k = k.reshape((batch_size, seq_len, ()))?;
119    let v = qkv.get_on_dim(2, 2)?;
120    Ok(Qkv { q, k, v })
121}