candle_transformers/models/
glm4.rs

1//! GLM-4 inference implementation.
2//!
3//! An open bilingual language model with 130B parameters.
4//!
5//! Based on implementation from [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B)
6
7use crate::models::with_tracing::{linear_b as linear, Linear};
8use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
9use candle_nn::VarBuilder;
10
11fn default_one() -> usize {
12    1
13}
14
15#[derive(Debug, Clone, serde::Deserialize, Default)]
16pub struct Config {
17    pub num_layers: usize,
18    pub padded_vocab_size: usize,
19    pub hidden_size: usize,
20    pub ffn_hidden_size: usize,
21    pub kv_channels: usize,
22    pub num_attention_heads: usize,
23    pub seq_length: usize,
24    pub layernorm_epsilon: f64,
25    pub rmsnorm: bool,
26    pub apply_residual_connection_post_layernorm: bool,
27    pub post_layer_norm: bool,
28    pub add_bias_linear: bool,
29    pub add_qkv_bias: bool,
30    pub bias_dropout_fusion: bool,
31    pub multi_query_attention: bool,
32    pub multi_query_group_num: usize,
33    pub apply_query_key_layer_scaling: bool,
34    pub attention_softmax_in_fp32: bool,
35    pub fp32_residual_connection: bool,
36    #[serde(default = "default_one")]
37    pub rope_ratio: usize,
38}
39
40impl Config {
41    pub fn glm4() -> Self {
42        Self {
43            num_layers: 40,
44            padded_vocab_size: 151552,
45            hidden_size: 4096,
46            ffn_hidden_size: 13696,
47            kv_channels: 128,
48            num_attention_heads: 32,
49            seq_length: 8192,
50            layernorm_epsilon: 1e-5,
51            rmsnorm: true,
52            apply_residual_connection_post_layernorm: false,
53            post_layer_norm: true,
54            add_bias_linear: false,
55            add_qkv_bias: true,
56            bias_dropout_fusion: true,
57            multi_query_attention: true,
58            multi_query_group_num: 2,
59            apply_query_key_layer_scaling: true,
60            attention_softmax_in_fp32: true,
61            fp32_residual_connection: false,
62            rope_ratio: 500,
63        }
64    }
65}
66
67#[derive(Debug, Clone)]
68struct RotaryEmbedding {
69    cache: Tensor,
70}
71
72impl RotaryEmbedding {
73    fn new(cfg: &Config, dtype: DType, dev: &Device) -> Result<Self> {
74        let rotary_dim = cfg.kv_channels;
75        let n_elem = rotary_dim / 2;
76        let base = 10_000f64 * cfg.rope_ratio as f64;
77        let inv_freq: Vec<_> = (0..n_elem)
78            .step_by(2)
79            .map(|i| 1f32 / base.powf(i as f64 / n_elem as f64) as f32)
80            .collect();
81        let inv_freq_len = inv_freq.len();
82        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
83        let t = Tensor::arange(0u32, cfg.seq_length as u32, dev)?
84            .to_dtype(dtype)?
85            .reshape((cfg.seq_length, 1))?;
86        let freqs = t.matmul(&inv_freq)?;
87        let cache = Tensor::stack(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?;
88        Ok(Self { cache })
89    }
90
91    fn apply(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
92        let (seqlen, _b, np, _hn) = xs.dims4()?;
93        let cache = self.cache.narrow(0, seqlen_offset, seqlen)?;
94        let rot_dim = cache.dim(D::Minus2)? * 2;
95        let (xs, xs_pass) = (
96            xs.narrow(D::Minus1, 0, rot_dim)?,
97            xs.narrow(D::Minus1, rot_dim, rot_dim)?,
98        );
99        let xshaped = xs.reshape((seqlen, (), np, rot_dim / 2, 2))?;
100        let cache = cache.reshape((seqlen, (), 1, rot_dim / 2, 2))?;
101        let (xshaped0, xshaped1) = (
102            xshaped.i((.., .., .., .., 0))?,
103            xshaped.i((.., .., .., .., 1))?,
104        );
105        let (cache0, cache1) = (cache.i((.., .., .., .., 0))?, cache.i((.., .., .., .., 1))?);
106        let xs_out = Tensor::stack(
107            &[
108                (xshaped0.broadcast_mul(&cache0)? - xshaped1.broadcast_mul(&cache1)?)?,
109                (xshaped1.broadcast_mul(&cache0)? + xshaped0.broadcast_mul(&cache1)?)?,
110            ],
111            D::Minus1,
112        )?;
113        let xs_out = xs_out.flatten_from(3)?;
114        Tensor::cat(&[xs_out, xs_pass], D::Minus1)
115    }
116}
117
118#[derive(Debug, Clone)]
119struct CoreAttention {
120    coeff: Option<f64>,
121    norm_factor: f64,
122    dtype: DType,
123}
124
125fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32, dtype: DType) -> Result<Tensor> {
126    let shape = mask.shape();
127    let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
128    let m = mask.where_cond(&on_true.to_dtype(dtype)?, on_false)?;
129    Ok(m)
130}
131
132impl CoreAttention {
133    fn new(layer_number: usize, cfg: &Config, dtype: DType) -> Result<Self> {
134        let norm_factor = (cfg.kv_channels as f64).sqrt();
135        let (norm_factor, coeff) = if cfg.apply_query_key_layer_scaling {
136            let coeff = f64::max(1.0, layer_number as f64);
137            (norm_factor * coeff, Some(coeff))
138        } else {
139            (norm_factor, None)
140        };
141        Ok(Self {
142            coeff,
143            norm_factor,
144            dtype,
145        })
146    }
147
148    fn forward(
149        &self,
150        query_layer: &Tensor,
151        key_layer: &Tensor,
152        value_layer: &Tensor,
153        attention_mask: &Option<Tensor>,
154    ) -> Result<Tensor> {
155        let output_size = (
156            query_layer.dim(1)?, // b
157            query_layer.dim(2)?, // np
158            query_layer.dim(0)?, // sq
159            key_layer.dim(0)?,   // sk
160        );
161        let query_layer =
162            query_layer.reshape((output_size.2, output_size.0 * output_size.1, ()))?;
163        let key_layer = key_layer.reshape((output_size.3, output_size.0 * output_size.1, ()))?;
164        let matmul_result = Tensor::matmul(
165            &query_layer.transpose(0, 1)?.contiguous()?,
166            &key_layer.transpose(0, 1)?.transpose(1, 2)?.contiguous()?,
167        )?;
168        let matmul_result = (matmul_result / self.norm_factor)?.reshape(output_size)?;
169        let matmul_result = match self.coeff {
170            None => matmul_result,
171            Some(coeff) => (matmul_result * coeff)?,
172        };
173        let attention_scores = match attention_mask {
174            Some(mask) => masked_fill(
175                &matmul_result,
176                &mask.broadcast_left((matmul_result.dim(0)?, matmul_result.dim(1)?))?,
177                f32::NEG_INFINITY,
178                self.dtype,
179            )?,
180            None => matmul_result,
181        };
182        let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
183
184        let output_size = (
185            value_layer.dim(1)?,
186            value_layer.dim(2)?,
187            query_layer.dim(0)?,
188            value_layer.dim(3)?,
189        );
190        let value_layer =
191            value_layer.reshape((value_layer.dim(0)?, output_size.0 * output_size.1, ()))?;
192        let attention_probs =
193            attention_probs.reshape((output_size.0 * output_size.1, output_size.2, ()))?;
194        let context_layer = Tensor::matmul(
195            &attention_probs.contiguous()?,
196            &value_layer.transpose(0, 1)?.contiguous()?,
197        )?;
198        let context_layer = context_layer.reshape(output_size)?;
199        let context_layer = context_layer.permute((2, 0, 1, 3))?.contiguous()?;
200        context_layer.flatten_from(D::Minus2)
201    }
202}
203
204#[derive(Debug, Clone)]
205struct SelfAttention {
206    query_key_value: Linear,
207    core_attention: CoreAttention,
208    dense: Linear,
209    multi_query_attention: bool,
210    num_attention_heads_per_partition: usize,
211    num_multi_query_groups_per_partition: usize,
212    hidden_size_per_attention_head: usize,
213    kv_cache: Option<(Tensor, Tensor)>,
214}
215
216impl SelfAttention {
217    fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
218        let projection_size = cfg.kv_channels * cfg.num_attention_heads;
219        let hidden_size_per_attention_head = projection_size / cfg.num_attention_heads;
220        let qkv_hidden_size = if cfg.multi_query_attention {
221            projection_size + 2 * hidden_size_per_attention_head * cfg.multi_query_group_num
222        } else {
223            3 * projection_size
224        };
225        let query_key_value = linear(
226            cfg.hidden_size,
227            qkv_hidden_size,
228            cfg.add_bias_linear || cfg.add_qkv_bias,
229            vb.pp("query_key_value"),
230        )?;
231        let core_attention = CoreAttention::new(layer_number, cfg, vb.dtype())?;
232        let dense = linear(
233            cfg.hidden_size,
234            cfg.hidden_size,
235            cfg.add_bias_linear,
236            vb.pp("dense"),
237        )?;
238        Ok(Self {
239            query_key_value,
240            core_attention,
241            dense,
242            multi_query_attention: cfg.multi_query_attention,
243            num_attention_heads_per_partition: cfg.num_attention_heads,
244            num_multi_query_groups_per_partition: cfg.multi_query_group_num,
245            hidden_size_per_attention_head: cfg.kv_channels,
246            kv_cache: None,
247        })
248    }
249
250    fn reset_kv_cache(&mut self) {
251        self.kv_cache = None
252    }
253
254    fn forward(
255        &mut self,
256        xs: &Tensor,
257        attention_mask: &Option<Tensor>,
258        rotary_emb: &RotaryEmbedding,
259    ) -> Result<Tensor> {
260        let mixed_x_layer = xs.apply(&self.query_key_value)?;
261        if !self.multi_query_attention {
262            candle::bail!("only multi_query_attention=true is supported")
263        }
264        let hpa = self.hidden_size_per_attention_head;
265        let query_layer =
266            mixed_x_layer.narrow(D::Minus1, 0, self.num_attention_heads_per_partition * hpa)?;
267        let key_layer = mixed_x_layer.narrow(
268            D::Minus1,
269            self.num_attention_heads_per_partition * hpa,
270            self.num_multi_query_groups_per_partition * hpa,
271        )?;
272        let value_layer = mixed_x_layer.narrow(
273            D::Minus1,
274            self.num_attention_heads_per_partition * hpa
275                + self.num_multi_query_groups_per_partition * hpa,
276            self.num_multi_query_groups_per_partition * hpa,
277        )?;
278        let query_layer = query_layer.reshape((
279            query_layer.dim(0)?,
280            query_layer.dim(1)?,
281            self.num_attention_heads_per_partition,
282            hpa,
283        ))?;
284        let key_layer = key_layer.reshape((
285            key_layer.dim(0)?,
286            key_layer.dim(1)?,
287            self.num_multi_query_groups_per_partition,
288            hpa,
289        ))?;
290        let value_layer = value_layer.reshape((
291            value_layer.dim(0)?,
292            value_layer.dim(1)?,
293            self.num_multi_query_groups_per_partition,
294            hpa,
295        ))?;
296
297        // Rotary embeddings.
298        let seqlen_offset = match &self.kv_cache {
299            None => 0,
300            Some((prev_k, _)) => prev_k.dim(0)?,
301        };
302        let query_layer = rotary_emb.apply(&query_layer, seqlen_offset)?;
303        let key_layer = rotary_emb.apply(&key_layer, seqlen_offset)?;
304
305        // KV cache.
306        let (key_layer, value_layer) = match &self.kv_cache {
307            None => (key_layer, value_layer),
308            Some((prev_k, prev_v)) => {
309                let k = Tensor::cat(&[prev_k, &key_layer], 0)?;
310                let v = Tensor::cat(&[prev_v, &value_layer], 0)?;
311                (k, v)
312            }
313        };
314        self.kv_cache = Some((key_layer.clone(), value_layer.clone()));
315
316        // Repeat KV.
317        let ratio =
318            self.num_attention_heads_per_partition / self.num_multi_query_groups_per_partition;
319        let key_layer = {
320            let (d0, d1, d2, d3) = key_layer.dims4()?;
321            key_layer
322                .unsqueeze(D::Minus2)?
323                .expand((d0, d1, d2, ratio, d3))?
324                .reshape((
325                    d0,
326                    d1,
327                    self.num_attention_heads_per_partition,
328                    self.hidden_size_per_attention_head,
329                ))?
330        };
331        let value_layer = {
332            let (d0, d1, d2, d3) = value_layer.dims4()?;
333            value_layer
334                .unsqueeze(D::Minus2)?
335                .expand((d0, d1, d2, ratio, d3))?
336                .reshape((
337                    d0,
338                    d1,
339                    self.num_attention_heads_per_partition,
340                    self.hidden_size_per_attention_head,
341                ))?
342        };
343
344        let context_layer =
345            self.core_attention
346                .forward(&query_layer, &key_layer, &value_layer, attention_mask)?;
347        let output = context_layer.apply(&self.dense)?;
348        Ok(output)
349    }
350}
351
352#[allow(clippy::upper_case_acronyms)]
353#[derive(Debug, Clone)]
354struct MLP {
355    dense_h_to_4h: Linear,
356    dense_4h_to_h: Linear,
357}
358
359impl MLP {
360    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
361        let dense_h_to_4h = linear(
362            cfg.hidden_size,
363            cfg.ffn_hidden_size * 2,
364            cfg.add_bias_linear,
365            vb.pp("dense_h_to_4h"),
366        )?;
367        let dense_4h_to_h = linear(
368            cfg.ffn_hidden_size,
369            cfg.hidden_size,
370            cfg.add_bias_linear,
371            vb.pp("dense_4h_to_h"),
372        )?;
373        Ok(Self {
374            dense_4h_to_h,
375            dense_h_to_4h,
376        })
377    }
378}
379
380impl Module for MLP {
381    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
382        xs.apply(&self.dense_h_to_4h)?
383            .apply(&candle_nn::Activation::Swiglu)?
384            .apply(&self.dense_4h_to_h)
385    }
386}
387
388#[derive(Debug, Clone)]
389struct Block {
390    input_layernorm: candle_nn::LayerNorm,
391    self_attention: SelfAttention,
392    post_attention_layernorm: candle_nn::LayerNorm,
393    mlp: MLP,
394    apply_residual_connection_post_layernorm: bool,
395}
396
397impl Block {
398    fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
399        let input_layernorm = if cfg.rmsnorm {
400            candle_nn::rms_norm(
401                cfg.hidden_size,
402                cfg.layernorm_epsilon,
403                vb.pp("input_layernorm"),
404            )?
405            .into_inner()
406        } else {
407            candle_nn::layer_norm(
408                cfg.hidden_size,
409                cfg.layernorm_epsilon,
410                vb.pp("input_layernorm"),
411            )?
412        };
413        let post_attention_layernorm = if cfg.rmsnorm {
414            candle_nn::rms_norm(
415                cfg.hidden_size,
416                cfg.layernorm_epsilon,
417                vb.pp("post_attention_layernorm"),
418            )?
419            .into_inner()
420        } else {
421            candle_nn::layer_norm(
422                cfg.hidden_size,
423                cfg.layernorm_epsilon,
424                vb.pp("post_attention_layernorm"),
425            )?
426        };
427        let self_attention = SelfAttention::new(layer_number, cfg, vb.pp("self_attention"))?;
428        let mlp = MLP::new(cfg, vb.pp("mlp"))?;
429        Ok(Self {
430            input_layernorm,
431            self_attention,
432            post_attention_layernorm,
433            mlp,
434            apply_residual_connection_post_layernorm: cfg.apply_residual_connection_post_layernorm,
435        })
436    }
437
438    fn reset_kv_cache(&mut self) {
439        self.self_attention.reset_kv_cache()
440    }
441
442    fn forward(
443        &mut self,
444        xs: &Tensor,
445        attention_mask: &Option<Tensor>,
446        rotary_emb: &RotaryEmbedding,
447    ) -> Result<Tensor> {
448        let layernorm_output = xs.apply(&self.input_layernorm)?;
449        let attention_output =
450            self.self_attention
451                .forward(&layernorm_output, attention_mask, rotary_emb)?;
452        let residual = if self.apply_residual_connection_post_layernorm {
453            &layernorm_output
454        } else {
455            xs
456        };
457        let layernorm_input = (residual + attention_output)?;
458        let layernorm_output = layernorm_input.apply(&self.post_attention_layernorm)?;
459        let mlp_output = layernorm_output.apply(&self.mlp)?;
460        let residual = if self.apply_residual_connection_post_layernorm {
461            &layernorm_output
462        } else {
463            &layernorm_input
464        };
465        mlp_output + residual
466    }
467}
468
469#[derive(Debug, Clone)]
470struct Transformer {
471    layers: Vec<Block>,
472    final_layernorm: Option<candle_nn::LayerNorm>,
473    rotary_emb: RotaryEmbedding,
474}
475
476impl Transformer {
477    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
478        let vb_l = vb.pp("layers");
479        let mut layers = Vec::with_capacity(cfg.num_layers);
480        for layer_index in 0..cfg.num_layers {
481            let block = Block::new(layer_index + 1, cfg, vb_l.pp(layer_index))?;
482            layers.push(block)
483        }
484        let final_layernorm = if cfg.post_layer_norm {
485            let ln = if cfg.rmsnorm {
486                candle_nn::rms_norm(
487                    cfg.hidden_size,
488                    cfg.layernorm_epsilon,
489                    vb.pp("final_layernorm"),
490                )?
491                .into_inner()
492            } else {
493                candle_nn::layer_norm(
494                    cfg.hidden_size,
495                    cfg.layernorm_epsilon,
496                    vb.pp("final_layernorm"),
497                )?
498            };
499            Some(ln)
500        } else {
501            None
502        };
503        let rotary_emb = RotaryEmbedding::new(cfg, vb.dtype(), vb.device())?;
504        Ok(Self {
505            layers,
506            final_layernorm,
507            rotary_emb,
508        })
509    }
510
511    fn reset_kv_cache(&mut self) {
512        for block in self.layers.iter_mut() {
513            block.reset_kv_cache()
514        }
515    }
516
517    fn forward(&mut self, xs: &Tensor, attention_mask: &Option<Tensor>) -> Result<Tensor> {
518        let mut xs = xs.clone();
519        for block in self.layers.iter_mut() {
520            xs = block.forward(&xs, attention_mask, &self.rotary_emb)?
521        }
522        match self.final_layernorm.as_ref() {
523            None => Ok(xs),
524            Some(ln) => xs.apply(ln),
525        }
526    }
527}
528
529#[derive(Debug, Clone)]
530struct Embedding {
531    word_embeddings: candle_nn::Embedding,
532    fp32_residual_connection: bool,
533}
534
535impl Embedding {
536    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
537        let word_embeddings = candle_nn::embedding(
538            cfg.padded_vocab_size,
539            cfg.hidden_size,
540            vb.pp("word_embeddings"),
541        )?;
542        Ok(Self {
543            word_embeddings,
544            fp32_residual_connection: cfg.fp32_residual_connection,
545        })
546    }
547}
548
549impl Module for Embedding {
550    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
551        let xs = self.word_embeddings.forward(xs)?.transpose(0, 1)?; // b,s,h -> s,b,h
552        if self.fp32_residual_connection {
553            xs.to_dtype(candle::DType::F32)
554        } else {
555            xs.contiguous()
556        }
557    }
558}
559
560#[derive(Debug, Clone)]
561pub struct Model {
562    embedding: Embedding,
563    encoder: Transformer,
564    output_layer: Linear,
565}
566
567fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
568    let mask: Vec<_> = (0..size)
569        .flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
570        .collect();
571    Tensor::from_slice(&mask, (size, size), device)
572}
573
574impl Model {
575    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
576        let vb = vb.pp("transformer");
577        let embedding = Embedding::new(cfg, vb.pp("embedding"))?;
578        let encoder = Transformer::new(cfg, vb.pp("encoder"))?;
579        let output_layer = linear(
580            cfg.hidden_size,
581            cfg.padded_vocab_size,
582            false,
583            vb.pp("output_layer"),
584        )?;
585
586        Ok(Self {
587            embedding,
588            encoder,
589            output_layer,
590        })
591    }
592
593    pub fn reset_kv_cache(&mut self) {
594        self.encoder.reset_kv_cache()
595    }
596
597    pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
598        let (_b_size, seq_len) = xs.dims2()?;
599        let input_embeds = xs.apply(&self.embedding)?;
600        let attention_mask = if seq_len <= 1 {
601            None
602        } else {
603            Some(get_mask(seq_len, xs.device())?)
604        };
605        let xs = self.encoder.forward(&input_embeds, &attention_mask)?;
606        let lm_logits = xs.i(seq_len - 1)?.apply(&self.output_layer)?;
607        Ok(lm_logits)
608    }
609}