candle_transformers/models/wuerstchen/
attention_processor.rs

1use candle::{Module, Result, Tensor};
2use candle_nn::{linear, Linear, VarBuilder};
3
4// A simplified version of:
5// https://github.com/huggingface/diffusers/blob/119ad2c3dc8a8fb8446a83f4bf6f20929487b47f/src/diffusers/models/attention_processor.py#L38
6#[derive(Debug)]
7pub struct Attention {
8    to_q: Linear,
9    to_k: Linear,
10    to_v: Linear,
11    to_out: Linear,
12    heads: usize,
13    scale: f64,
14    use_flash_attn: bool,
15}
16
17#[cfg(feature = "flash-attn")]
18fn flash_attn(
19    q: &Tensor,
20    k: &Tensor,
21    v: &Tensor,
22    softmax_scale: f32,
23    causal: bool,
24) -> Result<Tensor> {
25    candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
26}
27
28#[cfg(not(feature = "flash-attn"))]
29fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
30    unimplemented!("compile with '--features flash-attn'")
31}
32
33impl Attention {
34    pub fn new(
35        query_dim: usize,
36        heads: usize,
37        dim_head: usize,
38        use_flash_attn: bool,
39        vb: VarBuilder,
40    ) -> Result<Self> {
41        let inner_dim = dim_head * heads;
42        let scale = 1.0 / f64::sqrt(dim_head as f64);
43        let to_q = linear(query_dim, inner_dim, vb.pp("to_q"))?;
44        let to_k = linear(query_dim, inner_dim, vb.pp("to_k"))?;
45        let to_v = linear(query_dim, inner_dim, vb.pp("to_v"))?;
46        let to_out = linear(inner_dim, query_dim, vb.pp("to_out.0"))?;
47        Ok(Self {
48            to_q,
49            to_k,
50            to_v,
51            to_out,
52            scale,
53            heads,
54            use_flash_attn,
55        })
56    }
57
58    fn batch_to_head_dim(&self, xs: &Tensor) -> Result<Tensor> {
59        let (b_size, seq_len, dim) = xs.dims3()?;
60        xs.reshape((b_size / self.heads, self.heads, seq_len, dim))?
61            .permute((0, 2, 1, 3))?
62            .reshape((b_size / self.heads, seq_len, dim * self.heads))
63    }
64
65    fn head_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> {
66        let (b_size, seq_len, dim) = xs.dims3()?;
67        xs.reshape((b_size, seq_len, self.heads, dim / self.heads))?
68            .permute((0, 2, 1, 3))?
69            .reshape((b_size * self.heads, seq_len, dim / self.heads))
70    }
71
72    fn get_attention_scores(&self, query: &Tensor, key: &Tensor) -> Result<Tensor> {
73        let attn_probs = (query.matmul(&key.t()?)? * self.scale)?;
74        candle_nn::ops::softmax_last_dim(&attn_probs)
75    }
76
77    pub fn forward(&self, xs: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
78        let (b_size, channel, h, w) = xs.dims4()?;
79        let xs = xs.reshape((b_size, channel, h * w))?.t()?;
80
81        let query = self.to_q.forward(&xs)?;
82        let key = self.to_k.forward(encoder_hidden_states)?;
83        let value = self.to_v.forward(encoder_hidden_states)?;
84
85        let query = self.head_to_batch_dim(&query)?;
86        let key = self.head_to_batch_dim(&key)?;
87        let value = self.head_to_batch_dim(&value)?;
88
89        let xs = if self.use_flash_attn {
90            let init_dtype = query.dtype();
91            let q = query
92                .to_dtype(candle::DType::F16)?
93                .unsqueeze(0)?
94                .transpose(1, 2)?;
95            let k = key
96                .to_dtype(candle::DType::F16)?
97                .unsqueeze(0)?
98                .transpose(1, 2)?;
99            let v = value
100                .to_dtype(candle::DType::F16)?
101                .unsqueeze(0)?
102                .transpose(1, 2)?;
103            flash_attn(&q, &k, &v, self.scale as f32, false)?
104                .transpose(1, 2)?
105                .squeeze(0)?
106                .to_dtype(init_dtype)?
107        } else {
108            let attn_prs = self.get_attention_scores(&query, &key)?;
109            attn_prs.matmul(&value)?
110        };
111        let xs = self.batch_to_head_dim(&xs)?;
112
113        self.to_out
114            .forward(&xs)?
115            .t()?
116            .reshape((b_size, channel, h, w))
117    }
118}