1use candle::{DType, IndexOp, Result, Tensor, D};
3use candle_nn as nn;
4use candle_nn::Module;
5
6#[derive(Debug)]
7struct GeGlu {
8 proj: nn::Linear,
9 span: tracing::Span,
10}
11
12impl GeGlu {
13 fn new(vs: nn::VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> {
14 let proj = nn::linear(dim_in, dim_out * 2, vs.pp("proj"))?;
15 let span = tracing::span!(tracing::Level::TRACE, "geglu");
16 Ok(Self { proj, span })
17 }
18}
19
20impl Module for GeGlu {
21 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
22 let _enter = self.span.enter();
23 let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
24 &hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()?
25 }
26}
27
28#[derive(Debug)]
30struct FeedForward {
31 project_in: GeGlu,
32 linear: nn::Linear,
33 span: tracing::Span,
34}
35
36impl FeedForward {
37 fn new(vs: nn::VarBuilder, dim: usize, dim_out: Option<usize>, mult: usize) -> Result<Self> {
42 let inner_dim = dim * mult;
43 let dim_out = dim_out.unwrap_or(dim);
44 let vs = vs.pp("net");
45 let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?;
46 let linear = nn::linear(inner_dim, dim_out, vs.pp("2"))?;
47 let span = tracing::span!(tracing::Level::TRACE, "ff");
48 Ok(Self {
49 project_in,
50 linear,
51 span,
52 })
53 }
54}
55
56impl Module for FeedForward {
57 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
58 let _enter = self.span.enter();
59 let xs = self.project_in.forward(xs)?;
60 self.linear.forward(&xs)
61 }
62}
63
64#[cfg(feature = "flash-attn")]
65fn flash_attn(
66 q: &Tensor,
67 k: &Tensor,
68 v: &Tensor,
69 softmax_scale: f32,
70 causal: bool,
71) -> Result<Tensor> {
72 candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
73}
74
75#[cfg(not(feature = "flash-attn"))]
76fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
77 unimplemented!("compile with '--features flash-attn'")
78}
79
80#[derive(Debug)]
81pub struct CrossAttention {
82 to_q: nn::Linear,
83 to_k: nn::Linear,
84 to_v: nn::Linear,
85 to_out: nn::Linear,
86 heads: usize,
87 scale: f64,
88 slice_size: Option<usize>,
89 span: tracing::Span,
90 span_attn: tracing::Span,
91 span_softmax: tracing::Span,
92 use_flash_attn: bool,
93}
94
95impl CrossAttention {
96 pub fn new(
98 vs: nn::VarBuilder,
99 query_dim: usize,
100 context_dim: Option<usize>,
101 heads: usize,
102 dim_head: usize,
103 slice_size: Option<usize>,
104 use_flash_attn: bool,
105 ) -> Result<Self> {
106 let inner_dim = dim_head * heads;
107 let context_dim = context_dim.unwrap_or(query_dim);
108 let scale = 1.0 / f64::sqrt(dim_head as f64);
109 let to_q = nn::linear_no_bias(query_dim, inner_dim, vs.pp("to_q"))?;
110 let to_k = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_k"))?;
111 let to_v = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_v"))?;
112 let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?;
113 let span = tracing::span!(tracing::Level::TRACE, "xa");
114 let span_attn = tracing::span!(tracing::Level::TRACE, "xa-attn");
115 let span_softmax = tracing::span!(tracing::Level::TRACE, "xa-softmax");
116 Ok(Self {
117 to_q,
118 to_k,
119 to_v,
120 to_out,
121 heads,
122 scale,
123 slice_size,
124 span,
125 span_attn,
126 span_softmax,
127 use_flash_attn,
128 })
129 }
130
131 fn reshape_heads_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> {
132 let (batch_size, seq_len, dim) = xs.dims3()?;
133 xs.reshape((batch_size, seq_len, self.heads, dim / self.heads))?
134 .transpose(1, 2)?
135 .reshape((batch_size * self.heads, seq_len, dim / self.heads))
136 }
137
138 fn reshape_batch_dim_to_heads(&self, xs: &Tensor) -> Result<Tensor> {
139 let (batch_size, seq_len, dim) = xs.dims3()?;
140 xs.reshape((batch_size / self.heads, self.heads, seq_len, dim))?
141 .transpose(1, 2)?
142 .reshape((batch_size / self.heads, seq_len, dim * self.heads))
143 }
144
145 fn sliced_attention(
146 &self,
147 query: &Tensor,
148 key: &Tensor,
149 value: &Tensor,
150 slice_size: usize,
151 ) -> Result<Tensor> {
152 let batch_size_attention = query.dim(0)?;
153 let mut hidden_states = Vec::with_capacity(batch_size_attention / slice_size);
154 let in_dtype = query.dtype();
155 let query = query.to_dtype(DType::F32)?;
156 let key = key.to_dtype(DType::F32)?;
157 let value = value.to_dtype(DType::F32)?;
158
159 for i in 0..batch_size_attention / slice_size {
160 let start_idx = i * slice_size;
161 let end_idx = (i + 1) * slice_size;
162
163 let xs = query
164 .i(start_idx..end_idx)?
165 .matmul(&(key.i(start_idx..end_idx)?.t()? * self.scale)?)?;
166 let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(&value.i(start_idx..end_idx)?)?;
167 hidden_states.push(xs)
168 }
169 let hidden_states = Tensor::stack(&hidden_states, 0)?.to_dtype(in_dtype)?;
170 self.reshape_batch_dim_to_heads(&hidden_states)
171 }
172
173 fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
174 let _enter = self.span_attn.enter();
175 let xs = if self.use_flash_attn {
176 let init_dtype = query.dtype();
177 let q = query
178 .to_dtype(candle::DType::F16)?
179 .unsqueeze(0)?
180 .transpose(1, 2)?;
181 let k = key
182 .to_dtype(candle::DType::F16)?
183 .unsqueeze(0)?
184 .transpose(1, 2)?;
185 let v = value
186 .to_dtype(candle::DType::F16)?
187 .unsqueeze(0)?
188 .transpose(1, 2)?;
189 flash_attn(&q, &k, &v, self.scale as f32, false)?
190 .transpose(1, 2)?
191 .squeeze(0)?
192 .to_dtype(init_dtype)?
193 } else {
194 let in_dtype = query.dtype();
195 let query = query.to_dtype(DType::F32)?;
196 let key = key.to_dtype(DType::F32)?;
197 let value = value.to_dtype(DType::F32)?;
198 let xs = query.matmul(&(key.t()? * self.scale)?)?;
199 let xs = {
200 let _enter = self.span_softmax.enter();
201 nn::ops::softmax_last_dim(&xs)?
202 };
203 xs.matmul(&value)?.to_dtype(in_dtype)?
204 };
205 self.reshape_batch_dim_to_heads(&xs)
206 }
207
208 pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
209 let _enter = self.span.enter();
210 let query = self.to_q.forward(xs)?;
211 let context = context.unwrap_or(xs).contiguous()?;
212 let key = self.to_k.forward(&context)?;
213 let value = self.to_v.forward(&context)?;
214 let query = self.reshape_heads_to_batch_dim(&query)?;
215 let key = self.reshape_heads_to_batch_dim(&key)?;
216 let value = self.reshape_heads_to_batch_dim(&value)?;
217 let dim0 = query.dim(0)?;
218 let slice_size = self.slice_size.and_then(|slice_size| {
219 if dim0 < slice_size {
220 None
221 } else {
222 Some(slice_size)
223 }
224 });
225 let xs = match slice_size {
226 None => self.attention(&query, &key, &value)?,
227 Some(slice_size) => self.sliced_attention(&query, &key, &value, slice_size)?,
228 };
229 self.to_out.forward(&xs)
230 }
231}
232
233#[derive(Debug)]
235struct BasicTransformerBlock {
236 attn1: CrossAttention,
237 ff: FeedForward,
238 attn2: CrossAttention,
239 norm1: nn::LayerNorm,
240 norm2: nn::LayerNorm,
241 norm3: nn::LayerNorm,
242 span: tracing::Span,
243}
244
245impl BasicTransformerBlock {
246 fn new(
247 vs: nn::VarBuilder,
248 dim: usize,
249 n_heads: usize,
250 d_head: usize,
251 context_dim: Option<usize>,
252 sliced_attention_size: Option<usize>,
253 use_flash_attn: bool,
254 ) -> Result<Self> {
255 let attn1 = CrossAttention::new(
256 vs.pp("attn1"),
257 dim,
258 None,
259 n_heads,
260 d_head,
261 sliced_attention_size,
262 use_flash_attn,
263 )?;
264 let ff = FeedForward::new(vs.pp("ff"), dim, None, 4)?;
265 let attn2 = CrossAttention::new(
266 vs.pp("attn2"),
267 dim,
268 context_dim,
269 n_heads,
270 d_head,
271 sliced_attention_size,
272 use_flash_attn,
273 )?;
274 let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?;
275 let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?;
276 let norm3 = nn::layer_norm(dim, 1e-5, vs.pp("norm3"))?;
277 let span = tracing::span!(tracing::Level::TRACE, "basic-transformer");
278 Ok(Self {
279 attn1,
280 ff,
281 attn2,
282 norm1,
283 norm2,
284 norm3,
285 span,
286 })
287 }
288
289 fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
290 let _enter = self.span.enter();
291 let xs = (self.attn1.forward(&self.norm1.forward(xs)?, None)? + xs)?;
292 let xs = (self.attn2.forward(&self.norm2.forward(&xs)?, context)? + xs)?;
293 self.ff.forward(&self.norm3.forward(&xs)?)? + xs
294 }
295}
296
297#[derive(Debug, Clone, Copy)]
298pub struct SpatialTransformerConfig {
299 pub depth: usize,
300 pub num_groups: usize,
301 pub context_dim: Option<usize>,
302 pub sliced_attention_size: Option<usize>,
303 pub use_linear_projection: bool,
304}
305
306impl Default for SpatialTransformerConfig {
307 fn default() -> Self {
308 Self {
309 depth: 1,
310 num_groups: 32,
311 context_dim: None,
312 sliced_attention_size: None,
313 use_linear_projection: false,
314 }
315 }
316}
317
318#[derive(Debug)]
319enum Proj {
320 Conv2d(nn::Conv2d),
321 Linear(nn::Linear),
322}
323
324#[derive(Debug)]
326pub struct SpatialTransformer {
327 norm: nn::GroupNorm,
328 proj_in: Proj,
329 transformer_blocks: Vec<BasicTransformerBlock>,
330 proj_out: Proj,
331 span: tracing::Span,
332 pub config: SpatialTransformerConfig,
333}
334
335impl SpatialTransformer {
336 pub fn new(
337 vs: nn::VarBuilder,
338 in_channels: usize,
339 n_heads: usize,
340 d_head: usize,
341 use_flash_attn: bool,
342 config: SpatialTransformerConfig,
343 ) -> Result<Self> {
344 let inner_dim = n_heads * d_head;
345 let norm = nn::group_norm(config.num_groups, in_channels, 1e-6, vs.pp("norm"))?;
346 let proj_in = if config.use_linear_projection {
347 Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_in"))?)
348 } else {
349 Proj::Conv2d(nn::conv2d(
350 in_channels,
351 inner_dim,
352 1,
353 Default::default(),
354 vs.pp("proj_in"),
355 )?)
356 };
357 let mut transformer_blocks = vec![];
358 let vs_tb = vs.pp("transformer_blocks");
359 for index in 0..config.depth {
360 let tb = BasicTransformerBlock::new(
361 vs_tb.pp(index.to_string()),
362 inner_dim,
363 n_heads,
364 d_head,
365 config.context_dim,
366 config.sliced_attention_size,
367 use_flash_attn,
368 )?;
369 transformer_blocks.push(tb)
370 }
371 let proj_out = if config.use_linear_projection {
372 Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_out"))?)
373 } else {
374 Proj::Conv2d(nn::conv2d(
375 inner_dim,
376 in_channels,
377 1,
378 Default::default(),
379 vs.pp("proj_out"),
380 )?)
381 };
382 let span = tracing::span!(tracing::Level::TRACE, "spatial-transformer");
383 Ok(Self {
384 norm,
385 proj_in,
386 transformer_blocks,
387 proj_out,
388 span,
389 config,
390 })
391 }
392
393 pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
394 let _enter = self.span.enter();
395 let (batch, _channel, height, weight) = xs.dims4()?;
396 let residual = xs;
397 let xs = self.norm.forward(xs)?;
398 let (inner_dim, xs) = match &self.proj_in {
399 Proj::Conv2d(p) => {
400 let xs = p.forward(&xs)?;
401 let inner_dim = xs.dim(1)?;
402 let xs = xs
403 .transpose(1, 2)?
404 .t()?
405 .reshape((batch, height * weight, inner_dim))?;
406 (inner_dim, xs)
407 }
408 Proj::Linear(p) => {
409 let inner_dim = xs.dim(1)?;
410 let xs = xs
411 .transpose(1, 2)?
412 .t()?
413 .reshape((batch, height * weight, inner_dim))?;
414 (inner_dim, p.forward(&xs)?)
415 }
416 };
417 let mut xs = xs;
418 for block in self.transformer_blocks.iter() {
419 xs = block.forward(&xs, context)?
420 }
421 let xs = match &self.proj_out {
422 Proj::Conv2d(p) => p.forward(
423 &xs.reshape((batch, height, weight, inner_dim))?
424 .t()?
425 .transpose(1, 2)?,
426 )?,
427 Proj::Linear(p) => p
428 .forward(&xs)?
429 .reshape((batch, height, weight, inner_dim))?
430 .t()?
431 .transpose(1, 2)?,
432 };
433 xs + residual
434 }
435}
436
437#[derive(Debug, Clone, Copy)]
439pub struct AttentionBlockConfig {
440 pub num_head_channels: Option<usize>,
441 pub num_groups: usize,
442 pub rescale_output_factor: f64,
443 pub eps: f64,
444}
445
446impl Default for AttentionBlockConfig {
447 fn default() -> Self {
448 Self {
449 num_head_channels: None,
450 num_groups: 32,
451 rescale_output_factor: 1.,
452 eps: 1e-5,
453 }
454 }
455}
456
457#[derive(Debug)]
458pub struct AttentionBlock {
459 group_norm: nn::GroupNorm,
460 query: nn::Linear,
461 key: nn::Linear,
462 value: nn::Linear,
463 proj_attn: nn::Linear,
464 channels: usize,
465 num_heads: usize,
466 span: tracing::Span,
467 config: AttentionBlockConfig,
468}
469
470fn get_qkv_linear(channels: usize, vs: nn::VarBuilder) -> Result<nn::Linear> {
476 match vs.get((channels, channels), "weight") {
477 Ok(_) => nn::linear(channels, channels, vs),
478 Err(_) => {
479 let weight = vs
480 .get((channels, channels, 1, 1), "weight")?
481 .reshape((channels, channels))?;
482 let bias = vs.get((channels,), "bias")?;
483 Ok(nn::Linear::new(weight, Some(bias)))
484 }
485 }
486}
487
488impl AttentionBlock {
489 pub fn new(vs: nn::VarBuilder, channels: usize, config: AttentionBlockConfig) -> Result<Self> {
490 let num_head_channels = config.num_head_channels.unwrap_or(channels);
491 let num_heads = channels / num_head_channels;
492 let group_norm =
493 nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?;
494 let (q_path, k_path, v_path, out_path) = if vs.contains_tensor("to_q.weight") {
495 ("to_q", "to_k", "to_v", "to_out.0")
496 } else {
497 ("query", "key", "value", "proj_attn")
498 };
499 let query = get_qkv_linear(channels, vs.pp(q_path))?;
500 let key = get_qkv_linear(channels, vs.pp(k_path))?;
501 let value = get_qkv_linear(channels, vs.pp(v_path))?;
502 let proj_attn = get_qkv_linear(channels, vs.pp(out_path))?;
503 let span = tracing::span!(tracing::Level::TRACE, "attn-block");
504 Ok(Self {
505 group_norm,
506 query,
507 key,
508 value,
509 proj_attn,
510 channels,
511 num_heads,
512 span,
513 config,
514 })
515 }
516
517 fn transpose_for_scores(&self, xs: Tensor) -> Result<Tensor> {
518 let (batch, t, h_times_d) = xs.dims3()?;
519 xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))?
520 .transpose(1, 2)
521 }
522}
523
524impl Module for AttentionBlock {
525 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
526 let _enter = self.span.enter();
527 let in_dtype = xs.dtype();
528 let residual = xs;
529 let (batch, channel, height, width) = xs.dims4()?;
530 let xs = self
531 .group_norm
532 .forward(xs)?
533 .reshape((batch, channel, height * width))?
534 .transpose(1, 2)?;
535
536 let query_proj = self.query.forward(&xs)?;
537 let key_proj = self.key.forward(&xs)?;
538 let value_proj = self.value.forward(&xs)?;
539
540 let query_states = self
541 .transpose_for_scores(query_proj)?
542 .to_dtype(DType::F32)?;
543 let key_states = self.transpose_for_scores(key_proj)?.to_dtype(DType::F32)?;
544 let value_states = self
545 .transpose_for_scores(value_proj)?
546 .to_dtype(DType::F32)?;
547
548 let scale = f64::powf(self.channels as f64 / self.num_heads as f64, -0.25);
551 let attention_scores = (query_states * scale)?.matmul(&(key_states.t()? * scale)?)?;
552 let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?;
553
554 let xs = attention_probs.matmul(&value_states)?;
557 let xs = xs.to_dtype(in_dtype)?;
558 let xs = xs.transpose(1, 2)?.contiguous()?;
559 let xs = xs.flatten_from(D::Minus2)?;
560 let xs = self
561 .proj_attn
562 .forward(&xs)?
563 .t()?
564 .reshape((batch, channel, height, width))?;
565 (xs + residual)? / self.config.rescale_output_factor
566 }
567}