candle_transformers/models/wuerstchen/
attention_processor.rs1use candle::{Module, Result, Tensor};
2use candle_nn::{linear, Linear, VarBuilder};
3
4#[derive(Debug)]
7pub struct Attention {
8 to_q: Linear,
9 to_k: Linear,
10 to_v: Linear,
11 to_out: Linear,
12 heads: usize,
13 scale: f64,
14 use_flash_attn: bool,
15}
16
17#[cfg(feature = "flash-attn")]
18fn flash_attn(
19 q: &Tensor,
20 k: &Tensor,
21 v: &Tensor,
22 softmax_scale: f32,
23 causal: bool,
24) -> Result<Tensor> {
25 candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
26}
27
28#[cfg(not(feature = "flash-attn"))]
29fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
30 unimplemented!("compile with '--features flash-attn'")
31}
32
33impl Attention {
34 pub fn new(
35 query_dim: usize,
36 heads: usize,
37 dim_head: usize,
38 use_flash_attn: bool,
39 vb: VarBuilder,
40 ) -> Result<Self> {
41 let inner_dim = dim_head * heads;
42 let scale = 1.0 / f64::sqrt(dim_head as f64);
43 let to_q = linear(query_dim, inner_dim, vb.pp("to_q"))?;
44 let to_k = linear(query_dim, inner_dim, vb.pp("to_k"))?;
45 let to_v = linear(query_dim, inner_dim, vb.pp("to_v"))?;
46 let to_out = linear(inner_dim, query_dim, vb.pp("to_out.0"))?;
47 Ok(Self {
48 to_q,
49 to_k,
50 to_v,
51 to_out,
52 scale,
53 heads,
54 use_flash_attn,
55 })
56 }
57
58 fn batch_to_head_dim(&self, xs: &Tensor) -> Result<Tensor> {
59 let (b_size, seq_len, dim) = xs.dims3()?;
60 xs.reshape((b_size / self.heads, self.heads, seq_len, dim))?
61 .permute((0, 2, 1, 3))?
62 .reshape((b_size / self.heads, seq_len, dim * self.heads))
63 }
64
65 fn head_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> {
66 let (b_size, seq_len, dim) = xs.dims3()?;
67 xs.reshape((b_size, seq_len, self.heads, dim / self.heads))?
68 .permute((0, 2, 1, 3))?
69 .reshape((b_size * self.heads, seq_len, dim / self.heads))
70 }
71
72 fn get_attention_scores(&self, query: &Tensor, key: &Tensor) -> Result<Tensor> {
73 let attn_probs = (query.matmul(&key.t()?)? * self.scale)?;
74 candle_nn::ops::softmax_last_dim(&attn_probs)
75 }
76
77 pub fn forward(&self, xs: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
78 let (b_size, channel, h, w) = xs.dims4()?;
79 let xs = xs.reshape((b_size, channel, h * w))?.t()?;
80
81 let query = self.to_q.forward(&xs)?;
82 let key = self.to_k.forward(encoder_hidden_states)?;
83 let value = self.to_v.forward(encoder_hidden_states)?;
84
85 let query = self.head_to_batch_dim(&query)?;
86 let key = self.head_to_batch_dim(&key)?;
87 let value = self.head_to_batch_dim(&value)?;
88
89 let xs = if self.use_flash_attn {
90 let init_dtype = query.dtype();
91 let q = query
92 .to_dtype(candle::DType::F16)?
93 .unsqueeze(0)?
94 .transpose(1, 2)?;
95 let k = key
96 .to_dtype(candle::DType::F16)?
97 .unsqueeze(0)?
98 .transpose(1, 2)?;
99 let v = value
100 .to_dtype(candle::DType::F16)?
101 .unsqueeze(0)?
102 .transpose(1, 2)?;
103 flash_attn(&q, &k, &v, self.scale as f32, false)?
104 .transpose(1, 2)?
105 .squeeze(0)?
106 .to_dtype(init_dtype)?
107 } else {
108 let attn_prs = self.get_attention_scores(&query, &key)?;
109 attn_prs.matmul(&value)?
110 };
111 let xs = self.batch_to_head_dim(&xs)?;
112
113 self.to_out
114 .forward(&xs)?
115 .t()?
116 .reshape((b_size, channel, h, w))
117 }
118}