candle_transformers/models/stable_diffusion/
attention.rs

1//! Attention Based Building Blocks
2use candle::{DType, IndexOp, Result, Tensor, D};
3use candle_nn as nn;
4use candle_nn::Module;
5
6#[derive(Debug)]
7struct GeGlu {
8    proj: nn::Linear,
9    span: tracing::Span,
10}
11
12impl GeGlu {
13    fn new(vs: nn::VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> {
14        let proj = nn::linear(dim_in, dim_out * 2, vs.pp("proj"))?;
15        let span = tracing::span!(tracing::Level::TRACE, "geglu");
16        Ok(Self { proj, span })
17    }
18}
19
20impl Module for GeGlu {
21    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
22        let _enter = self.span.enter();
23        let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
24        &hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()?
25    }
26}
27
28/// A feed-forward layer.
29#[derive(Debug)]
30struct FeedForward {
31    project_in: GeGlu,
32    linear: nn::Linear,
33    span: tracing::Span,
34}
35
36impl FeedForward {
37    // The glu parameter in the python code is unused?
38    // https://github.com/huggingface/diffusers/blob/d3d22ce5a894becb951eec03e663951b28d45135/src/diffusers/models/attention.py#L347
39    /// Creates a new feed-forward layer based on some given input dimension, some
40    /// output dimension, and a multiplier to be used for the intermediary layer.
41    fn new(vs: nn::VarBuilder, dim: usize, dim_out: Option<usize>, mult: usize) -> Result<Self> {
42        let inner_dim = dim * mult;
43        let dim_out = dim_out.unwrap_or(dim);
44        let vs = vs.pp("net");
45        let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?;
46        let linear = nn::linear(inner_dim, dim_out, vs.pp("2"))?;
47        let span = tracing::span!(tracing::Level::TRACE, "ff");
48        Ok(Self {
49            project_in,
50            linear,
51            span,
52        })
53    }
54}
55
56impl Module for FeedForward {
57    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
58        let _enter = self.span.enter();
59        let xs = self.project_in.forward(xs)?;
60        self.linear.forward(&xs)
61    }
62}
63
64#[cfg(feature = "flash-attn")]
65fn flash_attn(
66    q: &Tensor,
67    k: &Tensor,
68    v: &Tensor,
69    softmax_scale: f32,
70    causal: bool,
71) -> Result<Tensor> {
72    candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
73}
74
75#[cfg(not(feature = "flash-attn"))]
76fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
77    unimplemented!("compile with '--features flash-attn'")
78}
79
80#[derive(Debug)]
81pub struct CrossAttention {
82    to_q: nn::Linear,
83    to_k: nn::Linear,
84    to_v: nn::Linear,
85    to_out: nn::Linear,
86    heads: usize,
87    scale: f64,
88    slice_size: Option<usize>,
89    span: tracing::Span,
90    span_attn: tracing::Span,
91    span_softmax: tracing::Span,
92    use_flash_attn: bool,
93}
94
95impl CrossAttention {
96    // Defaults should be heads = 8, dim_head = 64, context_dim = None
97    pub fn new(
98        vs: nn::VarBuilder,
99        query_dim: usize,
100        context_dim: Option<usize>,
101        heads: usize,
102        dim_head: usize,
103        slice_size: Option<usize>,
104        use_flash_attn: bool,
105    ) -> Result<Self> {
106        let inner_dim = dim_head * heads;
107        let context_dim = context_dim.unwrap_or(query_dim);
108        let scale = 1.0 / f64::sqrt(dim_head as f64);
109        let to_q = nn::linear_no_bias(query_dim, inner_dim, vs.pp("to_q"))?;
110        let to_k = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_k"))?;
111        let to_v = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_v"))?;
112        let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?;
113        let span = tracing::span!(tracing::Level::TRACE, "xa");
114        let span_attn = tracing::span!(tracing::Level::TRACE, "xa-attn");
115        let span_softmax = tracing::span!(tracing::Level::TRACE, "xa-softmax");
116        Ok(Self {
117            to_q,
118            to_k,
119            to_v,
120            to_out,
121            heads,
122            scale,
123            slice_size,
124            span,
125            span_attn,
126            span_softmax,
127            use_flash_attn,
128        })
129    }
130
131    fn reshape_heads_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> {
132        let (batch_size, seq_len, dim) = xs.dims3()?;
133        xs.reshape((batch_size, seq_len, self.heads, dim / self.heads))?
134            .transpose(1, 2)?
135            .reshape((batch_size * self.heads, seq_len, dim / self.heads))
136    }
137
138    fn reshape_batch_dim_to_heads(&self, xs: &Tensor) -> Result<Tensor> {
139        let (batch_size, seq_len, dim) = xs.dims3()?;
140        xs.reshape((batch_size / self.heads, self.heads, seq_len, dim))?
141            .transpose(1, 2)?
142            .reshape((batch_size / self.heads, seq_len, dim * self.heads))
143    }
144
145    fn sliced_attention(
146        &self,
147        query: &Tensor,
148        key: &Tensor,
149        value: &Tensor,
150        slice_size: usize,
151    ) -> Result<Tensor> {
152        let batch_size_attention = query.dim(0)?;
153        let mut hidden_states = Vec::with_capacity(batch_size_attention / slice_size);
154        let in_dtype = query.dtype();
155        let query = query.to_dtype(DType::F32)?;
156        let key = key.to_dtype(DType::F32)?;
157        let value = value.to_dtype(DType::F32)?;
158
159        for i in 0..batch_size_attention / slice_size {
160            let start_idx = i * slice_size;
161            let end_idx = (i + 1) * slice_size;
162
163            let xs = query
164                .i(start_idx..end_idx)?
165                .matmul(&(key.i(start_idx..end_idx)?.t()? * self.scale)?)?;
166            let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(&value.i(start_idx..end_idx)?)?;
167            hidden_states.push(xs)
168        }
169        let hidden_states = Tensor::stack(&hidden_states, 0)?.to_dtype(in_dtype)?;
170        self.reshape_batch_dim_to_heads(&hidden_states)
171    }
172
173    fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
174        let _enter = self.span_attn.enter();
175        let xs = if self.use_flash_attn {
176            let init_dtype = query.dtype();
177            let q = query
178                .to_dtype(candle::DType::F16)?
179                .unsqueeze(0)?
180                .transpose(1, 2)?;
181            let k = key
182                .to_dtype(candle::DType::F16)?
183                .unsqueeze(0)?
184                .transpose(1, 2)?;
185            let v = value
186                .to_dtype(candle::DType::F16)?
187                .unsqueeze(0)?
188                .transpose(1, 2)?;
189            flash_attn(&q, &k, &v, self.scale as f32, false)?
190                .transpose(1, 2)?
191                .squeeze(0)?
192                .to_dtype(init_dtype)?
193        } else {
194            let in_dtype = query.dtype();
195            let query = query.to_dtype(DType::F32)?;
196            let key = key.to_dtype(DType::F32)?;
197            let value = value.to_dtype(DType::F32)?;
198            let xs = query.matmul(&(key.t()? * self.scale)?)?;
199            let xs = {
200                let _enter = self.span_softmax.enter();
201                nn::ops::softmax_last_dim(&xs)?
202            };
203            xs.matmul(&value)?.to_dtype(in_dtype)?
204        };
205        self.reshape_batch_dim_to_heads(&xs)
206    }
207
208    pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
209        let _enter = self.span.enter();
210        let query = self.to_q.forward(xs)?;
211        let context = context.unwrap_or(xs).contiguous()?;
212        let key = self.to_k.forward(&context)?;
213        let value = self.to_v.forward(&context)?;
214        let query = self.reshape_heads_to_batch_dim(&query)?;
215        let key = self.reshape_heads_to_batch_dim(&key)?;
216        let value = self.reshape_heads_to_batch_dim(&value)?;
217        let dim0 = query.dim(0)?;
218        let slice_size = self.slice_size.and_then(|slice_size| {
219            if dim0 < slice_size {
220                None
221            } else {
222                Some(slice_size)
223            }
224        });
225        let xs = match slice_size {
226            None => self.attention(&query, &key, &value)?,
227            Some(slice_size) => self.sliced_attention(&query, &key, &value, slice_size)?,
228        };
229        self.to_out.forward(&xs)
230    }
231}
232
233/// A basic Transformer block.
234#[derive(Debug)]
235struct BasicTransformerBlock {
236    attn1: CrossAttention,
237    ff: FeedForward,
238    attn2: CrossAttention,
239    norm1: nn::LayerNorm,
240    norm2: nn::LayerNorm,
241    norm3: nn::LayerNorm,
242    span: tracing::Span,
243}
244
245impl BasicTransformerBlock {
246    fn new(
247        vs: nn::VarBuilder,
248        dim: usize,
249        n_heads: usize,
250        d_head: usize,
251        context_dim: Option<usize>,
252        sliced_attention_size: Option<usize>,
253        use_flash_attn: bool,
254    ) -> Result<Self> {
255        let attn1 = CrossAttention::new(
256            vs.pp("attn1"),
257            dim,
258            None,
259            n_heads,
260            d_head,
261            sliced_attention_size,
262            use_flash_attn,
263        )?;
264        let ff = FeedForward::new(vs.pp("ff"), dim, None, 4)?;
265        let attn2 = CrossAttention::new(
266            vs.pp("attn2"),
267            dim,
268            context_dim,
269            n_heads,
270            d_head,
271            sliced_attention_size,
272            use_flash_attn,
273        )?;
274        let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?;
275        let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?;
276        let norm3 = nn::layer_norm(dim, 1e-5, vs.pp("norm3"))?;
277        let span = tracing::span!(tracing::Level::TRACE, "basic-transformer");
278        Ok(Self {
279            attn1,
280            ff,
281            attn2,
282            norm1,
283            norm2,
284            norm3,
285            span,
286        })
287    }
288
289    fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
290        let _enter = self.span.enter();
291        let xs = (self.attn1.forward(&self.norm1.forward(xs)?, None)? + xs)?;
292        let xs = (self.attn2.forward(&self.norm2.forward(&xs)?, context)? + xs)?;
293        self.ff.forward(&self.norm3.forward(&xs)?)? + xs
294    }
295}
296
297#[derive(Debug, Clone, Copy)]
298pub struct SpatialTransformerConfig {
299    pub depth: usize,
300    pub num_groups: usize,
301    pub context_dim: Option<usize>,
302    pub sliced_attention_size: Option<usize>,
303    pub use_linear_projection: bool,
304}
305
306impl Default for SpatialTransformerConfig {
307    fn default() -> Self {
308        Self {
309            depth: 1,
310            num_groups: 32,
311            context_dim: None,
312            sliced_attention_size: None,
313            use_linear_projection: false,
314        }
315    }
316}
317
318#[derive(Debug)]
319enum Proj {
320    Conv2d(nn::Conv2d),
321    Linear(nn::Linear),
322}
323
324// Aka Transformer2DModel
325#[derive(Debug)]
326pub struct SpatialTransformer {
327    norm: nn::GroupNorm,
328    proj_in: Proj,
329    transformer_blocks: Vec<BasicTransformerBlock>,
330    proj_out: Proj,
331    span: tracing::Span,
332    pub config: SpatialTransformerConfig,
333}
334
335impl SpatialTransformer {
336    pub fn new(
337        vs: nn::VarBuilder,
338        in_channels: usize,
339        n_heads: usize,
340        d_head: usize,
341        use_flash_attn: bool,
342        config: SpatialTransformerConfig,
343    ) -> Result<Self> {
344        let inner_dim = n_heads * d_head;
345        let norm = nn::group_norm(config.num_groups, in_channels, 1e-6, vs.pp("norm"))?;
346        let proj_in = if config.use_linear_projection {
347            Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_in"))?)
348        } else {
349            Proj::Conv2d(nn::conv2d(
350                in_channels,
351                inner_dim,
352                1,
353                Default::default(),
354                vs.pp("proj_in"),
355            )?)
356        };
357        let mut transformer_blocks = vec![];
358        let vs_tb = vs.pp("transformer_blocks");
359        for index in 0..config.depth {
360            let tb = BasicTransformerBlock::new(
361                vs_tb.pp(index.to_string()),
362                inner_dim,
363                n_heads,
364                d_head,
365                config.context_dim,
366                config.sliced_attention_size,
367                use_flash_attn,
368            )?;
369            transformer_blocks.push(tb)
370        }
371        let proj_out = if config.use_linear_projection {
372            Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_out"))?)
373        } else {
374            Proj::Conv2d(nn::conv2d(
375                inner_dim,
376                in_channels,
377                1,
378                Default::default(),
379                vs.pp("proj_out"),
380            )?)
381        };
382        let span = tracing::span!(tracing::Level::TRACE, "spatial-transformer");
383        Ok(Self {
384            norm,
385            proj_in,
386            transformer_blocks,
387            proj_out,
388            span,
389            config,
390        })
391    }
392
393    pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
394        let _enter = self.span.enter();
395        let (batch, _channel, height, weight) = xs.dims4()?;
396        let residual = xs;
397        let xs = self.norm.forward(xs)?;
398        let (inner_dim, xs) = match &self.proj_in {
399            Proj::Conv2d(p) => {
400                let xs = p.forward(&xs)?;
401                let inner_dim = xs.dim(1)?;
402                let xs = xs
403                    .transpose(1, 2)?
404                    .t()?
405                    .reshape((batch, height * weight, inner_dim))?;
406                (inner_dim, xs)
407            }
408            Proj::Linear(p) => {
409                let inner_dim = xs.dim(1)?;
410                let xs = xs
411                    .transpose(1, 2)?
412                    .t()?
413                    .reshape((batch, height * weight, inner_dim))?;
414                (inner_dim, p.forward(&xs)?)
415            }
416        };
417        let mut xs = xs;
418        for block in self.transformer_blocks.iter() {
419            xs = block.forward(&xs, context)?
420        }
421        let xs = match &self.proj_out {
422            Proj::Conv2d(p) => p.forward(
423                &xs.reshape((batch, height, weight, inner_dim))?
424                    .t()?
425                    .transpose(1, 2)?,
426            )?,
427            Proj::Linear(p) => p
428                .forward(&xs)?
429                .reshape((batch, height, weight, inner_dim))?
430                .t()?
431                .transpose(1, 2)?,
432        };
433        xs + residual
434    }
435}
436
437/// Configuration for an attention block.
438#[derive(Debug, Clone, Copy)]
439pub struct AttentionBlockConfig {
440    pub num_head_channels: Option<usize>,
441    pub num_groups: usize,
442    pub rescale_output_factor: f64,
443    pub eps: f64,
444}
445
446impl Default for AttentionBlockConfig {
447    fn default() -> Self {
448        Self {
449            num_head_channels: None,
450            num_groups: 32,
451            rescale_output_factor: 1.,
452            eps: 1e-5,
453        }
454    }
455}
456
457#[derive(Debug)]
458pub struct AttentionBlock {
459    group_norm: nn::GroupNorm,
460    query: nn::Linear,
461    key: nn::Linear,
462    value: nn::Linear,
463    proj_attn: nn::Linear,
464    channels: usize,
465    num_heads: usize,
466    span: tracing::Span,
467    config: AttentionBlockConfig,
468}
469
470// In the .safetensor weights of official Stable Diffusion 3 Medium Huggingface repo
471// https://huggingface.co/stabilityai/stable-diffusion-3-medium
472// Linear layer may use a different dimension for the weight in the linear, which is
473// incompatible with the current implementation of the nn::linear constructor.
474// This is a workaround to handle the different dimensions.
475fn get_qkv_linear(channels: usize, vs: nn::VarBuilder) -> Result<nn::Linear> {
476    match vs.get((channels, channels), "weight") {
477        Ok(_) => nn::linear(channels, channels, vs),
478        Err(_) => {
479            let weight = vs
480                .get((channels, channels, 1, 1), "weight")?
481                .reshape((channels, channels))?;
482            let bias = vs.get((channels,), "bias")?;
483            Ok(nn::Linear::new(weight, Some(bias)))
484        }
485    }
486}
487
488impl AttentionBlock {
489    pub fn new(vs: nn::VarBuilder, channels: usize, config: AttentionBlockConfig) -> Result<Self> {
490        let num_head_channels = config.num_head_channels.unwrap_or(channels);
491        let num_heads = channels / num_head_channels;
492        let group_norm =
493            nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?;
494        let (q_path, k_path, v_path, out_path) = if vs.contains_tensor("to_q.weight") {
495            ("to_q", "to_k", "to_v", "to_out.0")
496        } else {
497            ("query", "key", "value", "proj_attn")
498        };
499        let query = get_qkv_linear(channels, vs.pp(q_path))?;
500        let key = get_qkv_linear(channels, vs.pp(k_path))?;
501        let value = get_qkv_linear(channels, vs.pp(v_path))?;
502        let proj_attn = get_qkv_linear(channels, vs.pp(out_path))?;
503        let span = tracing::span!(tracing::Level::TRACE, "attn-block");
504        Ok(Self {
505            group_norm,
506            query,
507            key,
508            value,
509            proj_attn,
510            channels,
511            num_heads,
512            span,
513            config,
514        })
515    }
516
517    fn transpose_for_scores(&self, xs: Tensor) -> Result<Tensor> {
518        let (batch, t, h_times_d) = xs.dims3()?;
519        xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))?
520            .transpose(1, 2)
521    }
522}
523
524impl Module for AttentionBlock {
525    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
526        let _enter = self.span.enter();
527        let in_dtype = xs.dtype();
528        let residual = xs;
529        let (batch, channel, height, width) = xs.dims4()?;
530        let xs = self
531            .group_norm
532            .forward(xs)?
533            .reshape((batch, channel, height * width))?
534            .transpose(1, 2)?;
535
536        let query_proj = self.query.forward(&xs)?;
537        let key_proj = self.key.forward(&xs)?;
538        let value_proj = self.value.forward(&xs)?;
539
540        let query_states = self
541            .transpose_for_scores(query_proj)?
542            .to_dtype(DType::F32)?;
543        let key_states = self.transpose_for_scores(key_proj)?.to_dtype(DType::F32)?;
544        let value_states = self
545            .transpose_for_scores(value_proj)?
546            .to_dtype(DType::F32)?;
547
548        // scale is applied twice, hence the -0.25 here rather than -0.5.
549        // https://github.com/huggingface/diffusers/blob/d3d22ce5a894becb951eec03e663951b28d45135/src/diffusers/models/attention.py#L87
550        let scale = f64::powf(self.channels as f64 / self.num_heads as f64, -0.25);
551        let attention_scores = (query_states * scale)?.matmul(&(key_states.t()? * scale)?)?;
552        let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?;
553
554        // TODO: revert the call to force_contiguous once the three matmul kernels have been
555        // adapted to handle layout with some dims set to 1.
556        let xs = attention_probs.matmul(&value_states)?;
557        let xs = xs.to_dtype(in_dtype)?;
558        let xs = xs.transpose(1, 2)?.contiguous()?;
559        let xs = xs.flatten_from(D::Minus2)?;
560        let xs = self
561            .proj_attn
562            .forward(&xs)?
563            .t()?
564            .reshape((batch, channel, height, width))?;
565        (xs + residual)? / self.config.rescale_output_factor
566    }
567}