1use crate::models::with_tracing::{linear_b as linear, Linear};
10use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
11use candle_nn::VarBuilder;
12
13fn default_one() -> usize {
14 1
15}
16
17#[derive(Debug, Clone, serde::Deserialize, Default)]
18pub struct Config {
19 pub num_layers: usize,
20 pub padded_vocab_size: usize,
21 pub hidden_size: usize,
22 pub ffn_hidden_size: usize,
23 pub kv_channels: usize,
24 pub num_attention_heads: usize,
25 pub seq_length: usize,
26 pub layernorm_epsilon: f64,
27 pub rmsnorm: bool,
28 pub apply_residual_connection_post_layernorm: bool,
29 pub post_layer_norm: bool,
30 pub add_bias_linear: bool,
31 pub add_qkv_bias: bool,
32 pub bias_dropout_fusion: bool,
33 pub multi_query_attention: bool,
34 pub multi_query_group_num: usize,
35 pub apply_query_key_layer_scaling: bool,
36 pub attention_softmax_in_fp32: bool,
37 pub fp32_residual_connection: bool,
38 #[serde(default = "default_one")]
39 pub rope_ratio: usize,
40}
41
42impl Config {
43 pub fn codegeex4() -> Self {
44 Self {
45 num_layers: 40,
46 padded_vocab_size: 151552,
47 hidden_size: 4096,
48 ffn_hidden_size: 13696,
49 kv_channels: 128,
50 num_attention_heads: 32,
51 seq_length: 131072,
52 layernorm_epsilon: 1e-5,
53 rmsnorm: true,
54 apply_residual_connection_post_layernorm: false,
55 post_layer_norm: true,
56 add_bias_linear: false,
57 add_qkv_bias: true,
58 bias_dropout_fusion: true,
59 multi_query_attention: true,
60 multi_query_group_num: 2,
61 apply_query_key_layer_scaling: true,
62 attention_softmax_in_fp32: true,
63 fp32_residual_connection: false,
64 rope_ratio: 500,
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
70struct RotaryEmbedding {
71 cache: Tensor,
72}
73
74impl RotaryEmbedding {
75 fn new(cfg: &Config, dtype: DType, dev: &Device) -> Result<Self> {
76 let rotary_dim = cfg.kv_channels;
77 let n_elem = rotary_dim / 2;
78 let base = 10_000f64 * cfg.rope_ratio as f64;
79 let inv_freq: Vec<_> = (0..n_elem)
80 .step_by(2)
81 .map(|i| 1f32 / base.powf(i as f64 / n_elem as f64) as f32)
82 .collect();
83 let inv_freq_len = inv_freq.len();
84 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
85 let t = Tensor::arange(0u32, cfg.seq_length as u32, dev)?
86 .to_dtype(dtype)
87 .expect("unalbe to dytpe in Rotray Embedding new")
88 .reshape((cfg.seq_length, 1))?;
89 let freqs = t.matmul(&inv_freq)?;
90 let cache = Tensor::stack(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?;
91 Ok(Self { cache })
92 }
93
94 fn apply(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
95 let (seqlen, _b, np, _hn) = xs.dims4()?;
96 let cache = self.cache.narrow(0, seqlen_offset, seqlen)?;
97 let rot_dim = cache.dim(D::Minus2)? * 2;
98 let (xs, xs_pass) = (
99 xs.narrow(D::Minus1, 0, rot_dim)?,
100 xs.narrow(D::Minus1, rot_dim, rot_dim)?,
101 );
102 let xshaped = xs.reshape((seqlen, (), np, rot_dim / 2, 2))?;
103 let cache = cache.reshape((seqlen, (), 1, rot_dim / 2, 2))?;
104 let (xshaped0, xshaped1) = (
105 xshaped.i((.., .., .., .., 0))?,
106 xshaped.i((.., .., .., .., 1))?,
107 );
108 let (cache0, cache1) = (cache.i((.., .., .., .., 0))?, cache.i((.., .., .., .., 1))?);
109 let xs_out = Tensor::stack(
110 &[
111 (xshaped0.broadcast_mul(&cache0)? - xshaped1.broadcast_mul(&cache1)?)?,
112 (xshaped1.broadcast_mul(&cache0)? + xshaped0.broadcast_mul(&cache1)?)?,
113 ],
114 D::Minus1,
115 )?;
116 let xs_out = xs_out.flatten_from(3)?;
117 Tensor::cat(&[xs_out, xs_pass], D::Minus1)
118 }
119}
120
121#[derive(Debug, Clone)]
122struct CoreAttention {
123 coeff: Option<f64>,
124 norm_factor: f64,
125 dtype: DType,
126}
127
128fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32, dtype: DType) -> Result<Tensor> {
129 let shape = mask.shape();
130 let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
131 let m = mask.where_cond(&on_true.to_dtype(dtype)?, on_false)?;
132 Ok(m)
133}
134
135impl CoreAttention {
136 fn new(layer_number: usize, cfg: &Config, dtype: DType) -> Result<Self> {
137 let norm_factor = (cfg.kv_channels as f64).sqrt();
138 let (norm_factor, coeff) = if cfg.apply_query_key_layer_scaling {
139 let coeff = f64::max(1.0, layer_number as f64);
140 (norm_factor * coeff, Some(coeff))
141 } else {
142 (norm_factor, None)
143 };
144 Ok(Self {
145 coeff,
146 norm_factor,
147 dtype,
148 })
149 }
150
151 fn forward(
152 &self,
153 query_layer: &Tensor,
154 key_layer: &Tensor,
155 value_layer: &Tensor,
156 attention_mask: &Option<Tensor>,
157 ) -> Result<Tensor> {
158 let output_size = (
159 query_layer.dim(1)?, query_layer.dim(2)?, query_layer.dim(0)?, key_layer.dim(0)?, );
164 let query_layer =
165 query_layer.reshape((output_size.2, output_size.0 * output_size.1, ()))?;
166 let key_layer = key_layer.reshape((output_size.3, output_size.0 * output_size.1, ()))?;
167 let matmul_result = Tensor::matmul(
168 &query_layer.transpose(0, 1)?.contiguous()?,
169 &key_layer.transpose(0, 1)?.transpose(1, 2)?.contiguous()?,
170 )?;
171 let matmul_result = (matmul_result / self.norm_factor)?.reshape(output_size)?;
172 let matmul_result = match self.coeff {
173 None => matmul_result,
174 Some(coeff) => (matmul_result * coeff)?,
175 };
176 let attention_scores = match attention_mask {
177 Some(mask) => masked_fill(
178 &matmul_result,
179 &mask.broadcast_left((matmul_result.dim(0)?, matmul_result.dim(1)?))?,
180 f32::NEG_INFINITY,
181 self.dtype,
182 )?,
183 None => matmul_result,
184 };
185 let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
186
187 let output_size = (
188 value_layer.dim(1)?,
189 value_layer.dim(2)?,
190 query_layer.dim(0)?,
191 value_layer.dim(3)?,
192 );
193 let value_layer =
194 value_layer.reshape((value_layer.dim(0)?, output_size.0 * output_size.1, ()))?;
195 let attention_probs =
196 attention_probs.reshape((output_size.0 * output_size.1, output_size.2, ()))?;
197 let context_layer = Tensor::matmul(
198 &attention_probs.contiguous()?,
199 &value_layer.transpose(0, 1)?.contiguous()?,
200 )?;
201 let context_layer = context_layer.reshape(output_size)?;
202 let context_layer = context_layer.permute((2, 0, 1, 3))?.contiguous()?;
203 context_layer.flatten_from(D::Minus2)
204 }
205}
206
207#[derive(Debug, Clone)]
208struct SelfAttention {
209 query_key_value: Linear,
210 core_attention: CoreAttention,
211 dense: Linear,
212 multi_query_attention: bool,
213 num_attention_heads_per_partition: usize,
214 num_multi_query_groups_per_partition: usize,
215 hidden_size_per_attention_head: usize,
216 kv_cache: Option<(Tensor, Tensor)>,
217}
218
219impl SelfAttention {
220 fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
221 let projection_size = cfg.kv_channels * cfg.num_attention_heads;
222 let hidden_size_per_attention_head = projection_size / cfg.num_attention_heads;
223 let qkv_hidden_size = if cfg.multi_query_attention {
224 projection_size + 2 * hidden_size_per_attention_head * cfg.multi_query_group_num
225 } else {
226 3 * projection_size
227 };
228 let query_key_value = linear(
229 cfg.hidden_size,
230 qkv_hidden_size,
231 cfg.add_bias_linear || cfg.add_qkv_bias,
232 vb.pp("query_key_value"),
233 )?;
234 let core_attention = CoreAttention::new(layer_number, cfg, vb.dtype())?;
235 let dense = linear(
236 cfg.hidden_size,
237 cfg.hidden_size,
238 cfg.add_bias_linear,
239 vb.pp("dense"),
240 )?;
241 Ok(Self {
242 query_key_value,
243 core_attention,
244 dense,
245 multi_query_attention: cfg.multi_query_attention,
246 num_attention_heads_per_partition: cfg.num_attention_heads,
247 num_multi_query_groups_per_partition: cfg.multi_query_group_num,
248 hidden_size_per_attention_head: cfg.kv_channels,
249 kv_cache: None,
250 })
251 }
252
253 fn reset_kv_cache(&mut self) {
254 self.kv_cache = None
255 }
256
257 fn forward(
258 &mut self,
259 xs: &Tensor,
260 attention_mask: &Option<Tensor>,
261 rotary_emb: &RotaryEmbedding,
262 ) -> Result<Tensor> {
263 let mixed_x_layer = xs.apply(&self.query_key_value)?;
264 if !self.multi_query_attention {
265 candle::bail!("only multi_query_attention=true is supported")
266 }
267 let hpa = self.hidden_size_per_attention_head;
268 let query_layer =
269 mixed_x_layer.narrow(D::Minus1, 0, self.num_attention_heads_per_partition * hpa)?;
270 let key_layer = mixed_x_layer.narrow(
271 D::Minus1,
272 self.num_attention_heads_per_partition * hpa,
273 self.num_multi_query_groups_per_partition * hpa,
274 )?;
275 let value_layer = mixed_x_layer.narrow(
276 D::Minus1,
277 self.num_attention_heads_per_partition * hpa
278 + self.num_multi_query_groups_per_partition * hpa,
279 self.num_multi_query_groups_per_partition * hpa,
280 )?;
281 let query_layer = query_layer.reshape((
282 query_layer.dim(0)?,
283 query_layer.dim(1)?,
284 self.num_attention_heads_per_partition,
285 hpa,
286 ))?;
287 let key_layer = key_layer.reshape((
288 key_layer.dim(0)?,
289 key_layer.dim(1)?,
290 self.num_multi_query_groups_per_partition,
291 hpa,
292 ))?;
293 let value_layer = value_layer.reshape((
294 value_layer.dim(0)?,
295 value_layer.dim(1)?,
296 self.num_multi_query_groups_per_partition,
297 hpa,
298 ))?;
299
300 let seqlen_offset = match &self.kv_cache {
302 None => 0,
303 Some((prev_k, _)) => prev_k.dim(0)?,
304 };
305 let query_layer = rotary_emb.apply(&query_layer, seqlen_offset)?;
306 let key_layer = rotary_emb.apply(&key_layer, seqlen_offset)?;
307
308 let (key_layer, value_layer) = match &self.kv_cache {
310 None => (key_layer, value_layer),
311 Some((prev_k, prev_v)) => {
312 let k = Tensor::cat(&[prev_k, &key_layer], 0)?;
313 let v = Tensor::cat(&[prev_v, &value_layer], 0)?;
314 (k, v)
315 }
316 };
317 self.kv_cache = Some((key_layer.clone(), value_layer.clone()));
318
319 let ratio =
321 self.num_attention_heads_per_partition / self.num_multi_query_groups_per_partition;
322 let key_layer = {
323 let (d0, d1, d2, d3) = key_layer.dims4()?;
324 key_layer
325 .unsqueeze(D::Minus2)?
326 .expand((d0, d1, d2, ratio, d3))?
327 .reshape((
328 d0,
329 d1,
330 self.num_attention_heads_per_partition,
331 self.hidden_size_per_attention_head,
332 ))?
333 };
334 let value_layer = {
335 let (d0, d1, d2, d3) = value_layer.dims4()?;
336 value_layer
337 .unsqueeze(D::Minus2)?
338 .expand((d0, d1, d2, ratio, d3))?
339 .reshape((
340 d0,
341 d1,
342 self.num_attention_heads_per_partition,
343 self.hidden_size_per_attention_head,
344 ))?
345 };
346
347 let context_layer =
348 self.core_attention
349 .forward(&query_layer, &key_layer, &value_layer, attention_mask)?;
350 let output = context_layer.apply(&self.dense)?;
351 Ok(output)
352 }
353}
354
355#[allow(clippy::upper_case_acronyms)]
356#[derive(Debug, Clone)]
357struct MLP {
358 dense_h_to_4h: Linear,
359 dense_4h_to_h: Linear,
360}
361
362impl MLP {
363 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
364 let dense_h_to_4h = linear(
365 cfg.hidden_size,
366 cfg.ffn_hidden_size * 2,
367 cfg.add_bias_linear,
368 vb.pp("dense_h_to_4h"),
369 )?;
370 let dense_4h_to_h = linear(
371 cfg.ffn_hidden_size,
372 cfg.hidden_size,
373 cfg.add_bias_linear,
374 vb.pp("dense_4h_to_h"),
375 )?;
376 Ok(Self {
377 dense_4h_to_h,
378 dense_h_to_4h,
379 })
380 }
381}
382
383impl Module for MLP {
384 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
385 xs.apply(&self.dense_h_to_4h)?
386 .apply(&candle_nn::Activation::Swiglu)?
387 .apply(&self.dense_4h_to_h)
388 }
389}
390
391#[derive(Debug, Clone)]
392struct Block {
393 input_layernorm: candle_nn::LayerNorm,
394 self_attention: SelfAttention,
395 post_attention_layernorm: candle_nn::LayerNorm,
396 mlp: MLP,
397 apply_residual_connection_post_layernorm: bool,
398}
399
400impl Block {
401 fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
402 let input_layernorm = if cfg.rmsnorm {
403 candle_nn::rms_norm(
404 cfg.hidden_size,
405 cfg.layernorm_epsilon,
406 vb.pp("input_layernorm"),
407 )?
408 .into_inner()
409 } else {
410 candle_nn::layer_norm(
411 cfg.hidden_size,
412 cfg.layernorm_epsilon,
413 vb.pp("input_layernorm"),
414 )?
415 };
416 let post_attention_layernorm = if cfg.rmsnorm {
417 candle_nn::rms_norm(
418 cfg.hidden_size,
419 cfg.layernorm_epsilon,
420 vb.pp("post_attention_layernorm"),
421 )?
422 .into_inner()
423 } else {
424 candle_nn::layer_norm(
425 cfg.hidden_size,
426 cfg.layernorm_epsilon,
427 vb.pp("post_attention_layernorm"),
428 )?
429 };
430 let self_attention = SelfAttention::new(layer_number, cfg, vb.pp("self_attention"))?;
431 let mlp = MLP::new(cfg, vb.pp("mlp"))?;
432 Ok(Self {
433 input_layernorm,
434 self_attention,
435 post_attention_layernorm,
436 mlp,
437 apply_residual_connection_post_layernorm: cfg.apply_residual_connection_post_layernorm,
438 })
439 }
440
441 fn reset_kv_cache(&mut self) {
442 self.self_attention.reset_kv_cache()
443 }
444
445 fn forward(
446 &mut self,
447 xs: &Tensor,
448 attention_mask: &Option<Tensor>,
449 rotary_emb: &RotaryEmbedding,
450 ) -> Result<Tensor> {
451 let layernorm_output = xs.apply(&self.input_layernorm)?;
452 let attention_output =
453 self.self_attention
454 .forward(&layernorm_output, attention_mask, rotary_emb)?;
455 let residual = if self.apply_residual_connection_post_layernorm {
456 &layernorm_output
457 } else {
458 xs
459 };
460 let layernorm_input = (residual + attention_output)?;
461 let layernorm_output = layernorm_input.apply(&self.post_attention_layernorm)?;
462 let mlp_output = layernorm_output.apply(&self.mlp)?;
463 let residual = if self.apply_residual_connection_post_layernorm {
464 &layernorm_output
465 } else {
466 &layernorm_input
467 };
468 mlp_output + residual
469 }
470}
471
472#[derive(Debug, Clone)]
473struct Transformer {
474 layers: Vec<Block>,
475 final_layernorm: Option<candle_nn::LayerNorm>,
476 rotary_emb: RotaryEmbedding,
477}
478
479impl Transformer {
480 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
481 let vb_l = vb.pp("layers");
482 let mut layers = Vec::with_capacity(cfg.num_layers);
483 for layer_index in 0..cfg.num_layers {
484 let block = Block::new(layer_index + 1, cfg, vb_l.pp(layer_index))?;
485 layers.push(block)
486 }
487 let final_layernorm = if cfg.post_layer_norm {
488 let ln = if cfg.rmsnorm {
489 candle_nn::rms_norm(
490 cfg.hidden_size,
491 cfg.layernorm_epsilon,
492 vb.pp("final_layernorm"),
493 )?
494 .into_inner()
495 } else {
496 candle_nn::layer_norm(
497 cfg.hidden_size,
498 cfg.layernorm_epsilon,
499 vb.pp("final_layernorm"),
500 )?
501 };
502 Some(ln)
503 } else {
504 None
505 };
506 let rotary_emb = RotaryEmbedding::new(cfg, vb.dtype(), vb.device())?;
507 Ok(Self {
508 layers,
509 final_layernorm,
510 rotary_emb,
511 })
512 }
513
514 fn reset_kv_cache(&mut self) {
515 for block in self.layers.iter_mut() {
516 block.reset_kv_cache()
517 }
518 }
519
520 fn forward(&mut self, xs: &Tensor, attention_mask: &Option<Tensor>) -> Result<Tensor> {
521 let mut xs = xs.clone();
522 for block in self.layers.iter_mut() {
523 xs = block.forward(&xs, attention_mask, &self.rotary_emb)?
524 }
525 match self.final_layernorm.as_ref() {
526 None => Ok(xs),
527 Some(ln) => xs.apply(ln),
528 }
529 }
530}
531
532#[derive(Debug, Clone)]
533struct Embedding {
534 word_embeddings: candle_nn::Embedding,
535 fp32_residual_connection: bool,
536}
537
538impl Embedding {
539 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
540 let word_embeddings = candle_nn::embedding(
541 cfg.padded_vocab_size,
542 cfg.hidden_size,
543 vb.pp("word_embeddings"),
544 )?;
545 Ok(Self {
546 word_embeddings,
547 fp32_residual_connection: cfg.fp32_residual_connection,
548 })
549 }
550}
551
552impl Module for Embedding {
553 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
554 let xs = self.word_embeddings.forward(xs)?.transpose(0, 1)?; if self.fp32_residual_connection {
556 xs.to_dtype(candle::DType::F32)
557 } else {
558 xs.contiguous()
559 }
560 }
561}
562
563#[derive(Debug, Clone)]
564pub struct Model {
565 embedding: Embedding,
566 encoder: Transformer,
567 output_layer: Linear,
568}
569
570fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
571 let mask: Vec<_> = (0..size)
572 .flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
573 .collect();
574 Tensor::from_slice(&mask, (size, size), device)
575}
576
577impl Model {
578 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
579 let vb = vb.pp("transformer");
580 let embedding = Embedding::new(cfg, vb.pp("embedding"))?;
581 let encoder = Transformer::new(cfg, vb.pp("encoder"))?;
582 let output_layer = linear(
583 cfg.hidden_size,
584 cfg.padded_vocab_size,
585 false,
586 vb.pp("output_layer"),
587 )?;
588
589 Ok(Self {
590 embedding,
591 encoder,
592 output_layer,
593 })
594 }
595
596 pub fn reset_kv_cache(&mut self) {
597 self.encoder.reset_kv_cache()
598 }
599
600 pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
601 let (_b_size, seq_len) = xs.dims2()?;
602 let input_embeds = xs.apply(&self.embedding)?;
603 let attention_mask = if seq_len <= 1 {
604 None
605 } else {
606 Some(get_mask(seq_len, xs.device())?)
607 };
608 let xs = self.encoder.forward(&input_embeds, &attention_mask)?;
609 let lm_logits = xs.i(seq_len - 1)?.apply(&self.output_layer)?;
610 Ok(lm_logits)
611 }
612}