1use crate::models::with_tracing::{linear_b as linear, Linear};
7use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
8use candle_nn::VarBuilder;
9
10#[derive(Debug, Clone)]
11pub struct Config {
12 pub num_layers: usize,
13 pub padded_vocab_size: usize,
14 pub hidden_size: usize,
15 pub ffn_hidden_size: usize,
16 pub kv_channels: usize,
17 pub num_attention_heads: usize,
18 pub seq_length: usize,
19 pub layernorm_epsilon: f64,
20 pub rmsnorm: bool,
21 pub apply_residual_connection_post_layernorm: bool,
22 pub post_layer_norm: bool,
23 pub add_bias_linear: bool,
24 pub add_qkv_bias: bool,
25 pub bias_dropout_fusion: bool,
26 pub multi_query_attention: bool,
27 pub multi_query_group_num: usize,
28 pub apply_query_key_layer_scaling: bool,
29 pub attention_softmax_in_fp32: bool,
30 pub fp32_residual_connection: bool,
31}
32
33impl Config {
34 pub fn glm3_6b() -> Self {
35 Self {
36 num_layers: 28,
37 padded_vocab_size: 65024,
38 hidden_size: 4096,
39 ffn_hidden_size: 13696,
40 kv_channels: 128,
41 num_attention_heads: 32,
42 seq_length: 8192,
43 layernorm_epsilon: 1e-5,
44 rmsnorm: true,
45 apply_residual_connection_post_layernorm: false,
46 post_layer_norm: true,
47 add_bias_linear: false,
48 add_qkv_bias: true,
49 bias_dropout_fusion: true,
50 multi_query_attention: true,
51 multi_query_group_num: 2,
52 apply_query_key_layer_scaling: true,
53 attention_softmax_in_fp32: true,
54 fp32_residual_connection: false,
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
60struct RotaryEmbedding {
61 cache: Tensor,
62}
63
64impl RotaryEmbedding {
65 fn new(cfg: &Config, dtype: DType, dev: &Device) -> Result<Self> {
66 let rotary_dim = cfg.kv_channels;
67 let n_elem = rotary_dim / 2;
68 let inv_freq: Vec<_> = (0..n_elem)
69 .step_by(2)
70 .map(|i| 1f32 / 10_000f64.powf(i as f64 / n_elem as f64) as f32)
71 .collect();
72 let inv_freq_len = inv_freq.len();
73 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
74 let t = Tensor::arange(0u32, cfg.seq_length as u32, dev)?
75 .to_dtype(dtype)?
76 .reshape((cfg.seq_length, 1))?;
77 let freqs = t.matmul(&inv_freq)?;
78 let cache = Tensor::stack(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?;
79 Ok(Self { cache })
80 }
81
82 fn apply(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
83 let (seqlen, _b, np, _hn) = xs.dims4()?;
84 let cache = self.cache.narrow(0, seqlen_offset, seqlen)?;
85 let rot_dim = cache.dim(D::Minus2)? * 2;
86 let (xs, xs_pass) = (
87 xs.narrow(D::Minus1, 0, rot_dim)?,
88 xs.narrow(D::Minus1, rot_dim, rot_dim)?,
89 );
90 let xshaped = xs.reshape((seqlen, (), np, rot_dim / 2, 2))?;
91 let cache = cache.reshape((seqlen, (), 1, rot_dim / 2, 2))?;
92 let (xshaped0, xshaped1) = (
93 xshaped.i((.., .., .., .., 0))?,
94 xshaped.i((.., .., .., .., 1))?,
95 );
96 let (cache0, cache1) = (cache.i((.., .., .., .., 0))?, cache.i((.., .., .., .., 1))?);
97 let xs_out = Tensor::stack(
98 &[
99 (xshaped0.broadcast_mul(&cache0)? - xshaped1.broadcast_mul(&cache1)?)?,
100 (xshaped1.broadcast_mul(&cache0)? + xshaped0.broadcast_mul(&cache1)?)?,
101 ],
102 D::Minus1,
103 )?;
104 let xs_out = xs_out.flatten_from(3)?;
105 Tensor::cat(&[xs_out, xs_pass], D::Minus1)
106 }
107}
108
109#[derive(Debug, Clone)]
110struct CoreAttention {
111 coeff: Option<f64>,
112 norm_factor: f64,
113}
114
115fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
116 let shape = mask.shape();
117 let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
118 let m = mask.where_cond(&on_true, on_false)?;
119 Ok(m)
120}
121
122impl CoreAttention {
123 fn new(layer_number: usize, cfg: &Config) -> Result<Self> {
124 let norm_factor = (cfg.kv_channels as f64).sqrt();
125 let (norm_factor, coeff) = if cfg.apply_query_key_layer_scaling {
126 let coeff = f64::max(1.0, layer_number as f64);
127 (norm_factor * coeff, Some(coeff))
128 } else {
129 (norm_factor, None)
130 };
131 Ok(Self { coeff, norm_factor })
132 }
133
134 fn forward(
135 &self,
136 query_layer: &Tensor,
137 key_layer: &Tensor,
138 value_layer: &Tensor,
139 attention_mask: &Option<Tensor>,
140 ) -> Result<Tensor> {
141 let output_size = (
142 query_layer.dim(1)?, query_layer.dim(2)?, query_layer.dim(0)?, key_layer.dim(0)?, );
147 let query_layer =
148 query_layer.reshape((output_size.2, output_size.0 * output_size.1, ()))?;
149 let key_layer = key_layer.reshape((output_size.3, output_size.0 * output_size.1, ()))?;
150 let matmul_result = Tensor::matmul(
151 &query_layer.transpose(0, 1)?,
152 &key_layer.transpose(0, 1)?.transpose(1, 2)?,
153 )?;
154 let matmul_result = (matmul_result / self.norm_factor)?.reshape(output_size)?;
155 let matmul_result = match self.coeff {
156 None => matmul_result,
157 Some(coeff) => (matmul_result * coeff)?,
158 };
159 let attention_scores = match attention_mask {
160 Some(mask) => masked_fill(
161 &matmul_result,
162 &mask.broadcast_left((matmul_result.dim(0)?, matmul_result.dim(1)?))?,
163 f32::NEG_INFINITY,
164 )?,
165 None => matmul_result,
166 };
167 let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
168
169 let output_size = (
170 value_layer.dim(1)?,
171 value_layer.dim(2)?,
172 query_layer.dim(0)?,
173 value_layer.dim(3)?,
174 );
175 let value_layer =
176 value_layer.reshape((value_layer.dim(0)?, output_size.0 * output_size.1, ()))?;
177 let attention_probs =
178 attention_probs.reshape((output_size.0 * output_size.1, output_size.2, ()))?;
179 let context_layer = Tensor::matmul(&attention_probs, &value_layer.transpose(0, 1)?)?;
180 let context_layer = context_layer.reshape(output_size)?;
181 let context_layer = context_layer.permute((2, 0, 1, 3))?.contiguous()?;
182 context_layer.flatten_from(D::Minus2)
183 }
184}
185
186#[derive(Debug, Clone)]
187struct SelfAttention {
188 query_key_value: Linear,
189 core_attention: CoreAttention,
190 dense: Linear,
191 multi_query_attention: bool,
192 num_attention_heads_per_partition: usize,
193 num_multi_query_groups_per_partition: usize,
194 hidden_size_per_attention_head: usize,
195 kv_cache: Option<(Tensor, Tensor)>,
196}
197
198impl SelfAttention {
199 fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
200 let projection_size = cfg.kv_channels * cfg.num_attention_heads;
201 let hidden_size_per_attention_head = projection_size / cfg.num_attention_heads;
202 let qkv_hidden_size = if cfg.multi_query_attention {
203 projection_size + 2 * hidden_size_per_attention_head * cfg.multi_query_group_num
204 } else {
205 3 * projection_size
206 };
207 let query_key_value = linear(
208 cfg.hidden_size,
209 qkv_hidden_size,
210 cfg.add_bias_linear || cfg.add_qkv_bias,
211 vb.pp("query_key_value"),
212 )?;
213 let core_attention = CoreAttention::new(layer_number, cfg)?;
214 let dense = linear(
215 cfg.hidden_size,
216 cfg.hidden_size,
217 cfg.add_bias_linear,
218 vb.pp("dense"),
219 )?;
220 Ok(Self {
221 query_key_value,
222 core_attention,
223 dense,
224 multi_query_attention: cfg.multi_query_attention,
225 num_attention_heads_per_partition: cfg.num_attention_heads,
226 num_multi_query_groups_per_partition: cfg.multi_query_group_num,
227 hidden_size_per_attention_head: cfg.kv_channels,
228 kv_cache: None,
229 })
230 }
231
232 fn reset_kv_cache(&mut self) {
233 self.kv_cache = None
234 }
235
236 fn forward(
237 &mut self,
238 xs: &Tensor,
239 attention_mask: &Option<Tensor>,
240 rotary_emb: &RotaryEmbedding,
241 ) -> Result<Tensor> {
242 let mixed_x_layer = xs.apply(&self.query_key_value)?;
243 if !self.multi_query_attention {
244 candle::bail!("only multi_query_attention=true is supported")
245 }
246 let hpa = self.hidden_size_per_attention_head;
247 let query_layer =
248 mixed_x_layer.narrow(D::Minus1, 0, self.num_attention_heads_per_partition * hpa)?;
249 let key_layer = mixed_x_layer.narrow(
250 D::Minus1,
251 self.num_attention_heads_per_partition * hpa,
252 self.num_multi_query_groups_per_partition * hpa,
253 )?;
254 let value_layer = mixed_x_layer.narrow(
255 D::Minus1,
256 self.num_attention_heads_per_partition * hpa
257 + self.num_multi_query_groups_per_partition * hpa,
258 self.num_multi_query_groups_per_partition * hpa,
259 )?;
260 let query_layer = query_layer.reshape((
261 query_layer.dim(0)?,
262 query_layer.dim(1)?,
263 self.num_attention_heads_per_partition,
264 hpa,
265 ))?;
266 let key_layer = key_layer.reshape((
267 key_layer.dim(0)?,
268 key_layer.dim(1)?,
269 self.num_multi_query_groups_per_partition,
270 hpa,
271 ))?;
272 let value_layer = value_layer.reshape((
273 value_layer.dim(0)?,
274 value_layer.dim(1)?,
275 self.num_multi_query_groups_per_partition,
276 hpa,
277 ))?;
278
279 let seqlen_offset = match &self.kv_cache {
281 None => 0,
282 Some((prev_k, _)) => prev_k.dim(0)?,
283 };
284 let query_layer = rotary_emb.apply(&query_layer, seqlen_offset)?;
285 let key_layer = rotary_emb.apply(&key_layer, seqlen_offset)?;
286
287 let (key_layer, value_layer) = match &self.kv_cache {
289 None => (key_layer, value_layer),
290 Some((prev_k, prev_v)) => {
291 let k = Tensor::cat(&[prev_k, &key_layer], 0)?;
292 let v = Tensor::cat(&[prev_v, &value_layer], 0)?;
293 (k, v)
294 }
295 };
296 self.kv_cache = Some((key_layer.clone(), value_layer.clone()));
297
298 let ratio =
300 self.num_attention_heads_per_partition / self.num_multi_query_groups_per_partition;
301 let key_layer = {
302 let (d0, d1, d2, d3) = key_layer.dims4()?;
303 key_layer
304 .unsqueeze(D::Minus2)?
305 .expand((d0, d1, d2, ratio, d3))?
306 .reshape((
307 d0,
308 d1,
309 self.num_attention_heads_per_partition,
310 self.hidden_size_per_attention_head,
311 ))?
312 };
313 let value_layer = {
314 let (d0, d1, d2, d3) = value_layer.dims4()?;
315 value_layer
316 .unsqueeze(D::Minus2)?
317 .expand((d0, d1, d2, ratio, d3))?
318 .reshape((
319 d0,
320 d1,
321 self.num_attention_heads_per_partition,
322 self.hidden_size_per_attention_head,
323 ))?
324 };
325
326 let context_layer =
327 self.core_attention
328 .forward(&query_layer, &key_layer, &value_layer, attention_mask)?;
329 let output = context_layer.apply(&self.dense)?;
330 Ok(output)
331 }
332}
333
334#[allow(clippy::upper_case_acronyms)]
335#[derive(Debug, Clone)]
336struct MLP {
337 dense_h_to_4h: Linear,
338 dense_4h_to_h: Linear,
339}
340
341impl MLP {
342 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
343 let dense_h_to_4h = linear(
344 cfg.hidden_size,
345 cfg.ffn_hidden_size * 2,
346 cfg.add_bias_linear,
347 vb.pp("dense_h_to_4h"),
348 )?;
349 let dense_4h_to_h = linear(
350 cfg.ffn_hidden_size,
351 cfg.hidden_size,
352 cfg.add_bias_linear,
353 vb.pp("dense_4h_to_h"),
354 )?;
355 Ok(Self {
356 dense_4h_to_h,
357 dense_h_to_4h,
358 })
359 }
360}
361
362impl Module for MLP {
363 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
364 xs.apply(&self.dense_h_to_4h)?
365 .apply(&candle_nn::Activation::Swiglu)?
366 .apply(&self.dense_4h_to_h)
367 }
368}
369
370#[derive(Debug, Clone)]
371struct Block {
372 input_layernorm: candle_nn::LayerNorm,
373 self_attention: SelfAttention,
374 post_attention_layernorm: candle_nn::LayerNorm,
375 mlp: MLP,
376 apply_residual_connection_post_layernorm: bool,
377}
378
379impl Block {
380 fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
381 let input_layernorm = if cfg.rmsnorm {
382 candle_nn::rms_norm(
383 cfg.hidden_size,
384 cfg.layernorm_epsilon,
385 vb.pp("input_layernorm"),
386 )?
387 .into_inner()
388 } else {
389 candle_nn::layer_norm(
390 cfg.hidden_size,
391 cfg.layernorm_epsilon,
392 vb.pp("input_layernorm"),
393 )?
394 };
395 let post_attention_layernorm = if cfg.rmsnorm {
396 candle_nn::rms_norm(
397 cfg.hidden_size,
398 cfg.layernorm_epsilon,
399 vb.pp("post_attention_layernorm"),
400 )?
401 .into_inner()
402 } else {
403 candle_nn::layer_norm(
404 cfg.hidden_size,
405 cfg.layernorm_epsilon,
406 vb.pp("post_attention_layernorm"),
407 )?
408 };
409 let self_attention = SelfAttention::new(layer_number, cfg, vb.pp("self_attention"))?;
410 let mlp = MLP::new(cfg, vb.pp("mlp"))?;
411 Ok(Self {
412 input_layernorm,
413 self_attention,
414 post_attention_layernorm,
415 mlp,
416 apply_residual_connection_post_layernorm: cfg.apply_residual_connection_post_layernorm,
417 })
418 }
419
420 fn reset_kv_cache(&mut self) {
421 self.self_attention.reset_kv_cache()
422 }
423
424 fn forward(
425 &mut self,
426 xs: &Tensor,
427 attention_mask: &Option<Tensor>,
428 rotary_emb: &RotaryEmbedding,
429 ) -> Result<Tensor> {
430 let layernorm_output = xs.apply(&self.input_layernorm)?;
431 let attention_output =
432 self.self_attention
433 .forward(&layernorm_output, attention_mask, rotary_emb)?;
434 let residual = if self.apply_residual_connection_post_layernorm {
435 &layernorm_output
436 } else {
437 xs
438 };
439 let layernorm_input = (residual + attention_output)?;
440 let layernorm_output = layernorm_input.apply(&self.post_attention_layernorm)?;
441 let mlp_output = layernorm_output.apply(&self.mlp)?;
442 let residual = if self.apply_residual_connection_post_layernorm {
443 &layernorm_output
444 } else {
445 &layernorm_input
446 };
447 mlp_output + residual
448 }
449}
450
451#[derive(Debug, Clone)]
452struct Transformer {
453 layers: Vec<Block>,
454 final_layernorm: Option<candle_nn::LayerNorm>,
455 rotary_emb: RotaryEmbedding,
456}
457
458impl Transformer {
459 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
460 let vb_l = vb.pp("layers");
461 let mut layers = Vec::with_capacity(cfg.num_layers);
462 for layer_index in 0..cfg.num_layers {
463 let block = Block::new(layer_index + 1, cfg, vb_l.pp(layer_index))?;
464 layers.push(block)
465 }
466 let final_layernorm = if cfg.post_layer_norm {
467 let ln = if cfg.rmsnorm {
468 candle_nn::rms_norm(
469 cfg.hidden_size,
470 cfg.layernorm_epsilon,
471 vb.pp("final_layernorm"),
472 )?
473 .into_inner()
474 } else {
475 candle_nn::layer_norm(
476 cfg.hidden_size,
477 cfg.layernorm_epsilon,
478 vb.pp("final_layernorm"),
479 )?
480 };
481 Some(ln)
482 } else {
483 None
484 };
485 let rotary_emb = RotaryEmbedding::new(cfg, vb.dtype(), vb.device())?;
486 Ok(Self {
487 layers,
488 final_layernorm,
489 rotary_emb,
490 })
491 }
492
493 fn reset_kv_cache(&mut self) {
494 for block in self.layers.iter_mut() {
495 block.reset_kv_cache()
496 }
497 }
498
499 fn forward(&mut self, xs: &Tensor, attention_mask: &Option<Tensor>) -> Result<Tensor> {
500 let mut xs = xs.clone();
501 for block in self.layers.iter_mut() {
502 xs = block.forward(&xs, attention_mask, &self.rotary_emb)?
503 }
504 match self.final_layernorm.as_ref() {
505 None => Ok(xs),
506 Some(ln) => xs.apply(ln),
507 }
508 }
509}
510
511#[derive(Debug, Clone)]
512struct Embedding {
513 word_embeddings: candle_nn::Embedding,
514 fp32_residual_connection: bool,
515}
516
517impl Embedding {
518 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
519 let word_embeddings = candle_nn::embedding(
520 cfg.padded_vocab_size,
521 cfg.hidden_size,
522 vb.pp("word_embeddings"),
523 )?;
524 Ok(Self {
525 word_embeddings,
526 fp32_residual_connection: cfg.fp32_residual_connection,
527 })
528 }
529}
530
531impl Module for Embedding {
532 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
533 let xs = self.word_embeddings.forward(xs)?.transpose(0, 1)?; if self.fp32_residual_connection {
535 xs.to_dtype(candle::DType::F32)
536 } else {
537 xs.contiguous()
538 }
539 }
540}
541
542#[derive(Debug, Clone)]
543pub struct Model {
544 embedding: Embedding,
545 encoder: Transformer,
546 output_layer: Linear,
547}
548
549fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
550 let mask: Vec<_> = (0..size)
551 .flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
552 .collect();
553 Tensor::from_slice(&mask, (size, size), device)
554}
555
556impl Model {
557 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
558 let vb = vb.pp("transformer");
559 let embedding = Embedding::new(cfg, vb.pp("embedding"))?;
560 let encoder = Transformer::new(cfg, vb.pp("encoder"))?;
561 let output_layer = linear(
562 cfg.hidden_size,
563 cfg.padded_vocab_size,
564 false,
565 vb.pp("output_layer"),
566 )?;
567 Ok(Self {
568 embedding,
569 encoder,
570 output_layer,
571 })
572 }
573
574 pub fn reset_kv_cache(&mut self) {
575 self.encoder.reset_kv_cache()
576 }
577
578 pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
579 let (_b_size, seq_len) = xs.dims2()?;
580 let input_embeds = xs.apply(&self.embedding)?;
581 let attention_mask = if seq_len <= 1 {
582 None
583 } else {
584 Some(get_mask(seq_len, xs.device())?)
585 };
586 let xs = self.encoder.forward(&input_embeds, &attention_mask)?;
587 let lm_logits = xs.i(seq_len - 1)?.apply(&self.output_layer)?;
588 Ok(lm_logits)
589 }
590}