candle_transformers/models/
chatglm.rs

1//! Implementation of the ChatGLM2/3 models from THUDM.
2//!
3//! - 💻 [Github](https://github.com/THUDM/ChatGLM3) ChatGLM3: Advancing Multilingual Conversational Language Models with High-Quality Data
4//! - 💻 [Github](https://github.com/THUDM/ChatGLM2-6B) ChatGLM2-6B.
5//!
6use crate::models::with_tracing::{linear_b as linear, Linear};
7use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
8use candle_nn::VarBuilder;
9
10#[derive(Debug, Clone)]
11pub struct Config {
12    pub num_layers: usize,
13    pub padded_vocab_size: usize,
14    pub hidden_size: usize,
15    pub ffn_hidden_size: usize,
16    pub kv_channels: usize,
17    pub num_attention_heads: usize,
18    pub seq_length: usize,
19    pub layernorm_epsilon: f64,
20    pub rmsnorm: bool,
21    pub apply_residual_connection_post_layernorm: bool,
22    pub post_layer_norm: bool,
23    pub add_bias_linear: bool,
24    pub add_qkv_bias: bool,
25    pub bias_dropout_fusion: bool,
26    pub multi_query_attention: bool,
27    pub multi_query_group_num: usize,
28    pub apply_query_key_layer_scaling: bool,
29    pub attention_softmax_in_fp32: bool,
30    pub fp32_residual_connection: bool,
31}
32
33impl Config {
34    pub fn glm3_6b() -> Self {
35        Self {
36            num_layers: 28,
37            padded_vocab_size: 65024,
38            hidden_size: 4096,
39            ffn_hidden_size: 13696,
40            kv_channels: 128,
41            num_attention_heads: 32,
42            seq_length: 8192,
43            layernorm_epsilon: 1e-5,
44            rmsnorm: true,
45            apply_residual_connection_post_layernorm: false,
46            post_layer_norm: true,
47            add_bias_linear: false,
48            add_qkv_bias: true,
49            bias_dropout_fusion: true,
50            multi_query_attention: true,
51            multi_query_group_num: 2,
52            apply_query_key_layer_scaling: true,
53            attention_softmax_in_fp32: true,
54            fp32_residual_connection: false,
55        }
56    }
57}
58
59#[derive(Debug, Clone)]
60struct RotaryEmbedding {
61    cache: Tensor,
62}
63
64impl RotaryEmbedding {
65    fn new(cfg: &Config, dtype: DType, dev: &Device) -> Result<Self> {
66        let rotary_dim = cfg.kv_channels;
67        let n_elem = rotary_dim / 2;
68        let inv_freq: Vec<_> = (0..n_elem)
69            .step_by(2)
70            .map(|i| 1f32 / 10_000f64.powf(i as f64 / n_elem as f64) as f32)
71            .collect();
72        let inv_freq_len = inv_freq.len();
73        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
74        let t = Tensor::arange(0u32, cfg.seq_length as u32, dev)?
75            .to_dtype(dtype)?
76            .reshape((cfg.seq_length, 1))?;
77        let freqs = t.matmul(&inv_freq)?;
78        let cache = Tensor::stack(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?;
79        Ok(Self { cache })
80    }
81
82    fn apply(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
83        let (seqlen, _b, np, _hn) = xs.dims4()?;
84        let cache = self.cache.narrow(0, seqlen_offset, seqlen)?;
85        let rot_dim = cache.dim(D::Minus2)? * 2;
86        let (xs, xs_pass) = (
87            xs.narrow(D::Minus1, 0, rot_dim)?,
88            xs.narrow(D::Minus1, rot_dim, rot_dim)?,
89        );
90        let xshaped = xs.reshape((seqlen, (), np, rot_dim / 2, 2))?;
91        let cache = cache.reshape((seqlen, (), 1, rot_dim / 2, 2))?;
92        let (xshaped0, xshaped1) = (
93            xshaped.i((.., .., .., .., 0))?,
94            xshaped.i((.., .., .., .., 1))?,
95        );
96        let (cache0, cache1) = (cache.i((.., .., .., .., 0))?, cache.i((.., .., .., .., 1))?);
97        let xs_out = Tensor::stack(
98            &[
99                (xshaped0.broadcast_mul(&cache0)? - xshaped1.broadcast_mul(&cache1)?)?,
100                (xshaped1.broadcast_mul(&cache0)? + xshaped0.broadcast_mul(&cache1)?)?,
101            ],
102            D::Minus1,
103        )?;
104        let xs_out = xs_out.flatten_from(3)?;
105        Tensor::cat(&[xs_out, xs_pass], D::Minus1)
106    }
107}
108
109#[derive(Debug, Clone)]
110struct CoreAttention {
111    coeff: Option<f64>,
112    norm_factor: f64,
113}
114
115fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
116    let shape = mask.shape();
117    let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
118    let m = mask.where_cond(&on_true, on_false)?;
119    Ok(m)
120}
121
122impl CoreAttention {
123    fn new(layer_number: usize, cfg: &Config) -> Result<Self> {
124        let norm_factor = (cfg.kv_channels as f64).sqrt();
125        let (norm_factor, coeff) = if cfg.apply_query_key_layer_scaling {
126            let coeff = f64::max(1.0, layer_number as f64);
127            (norm_factor * coeff, Some(coeff))
128        } else {
129            (norm_factor, None)
130        };
131        Ok(Self { coeff, norm_factor })
132    }
133
134    fn forward(
135        &self,
136        query_layer: &Tensor,
137        key_layer: &Tensor,
138        value_layer: &Tensor,
139        attention_mask: &Option<Tensor>,
140    ) -> Result<Tensor> {
141        let output_size = (
142            query_layer.dim(1)?, // b
143            query_layer.dim(2)?, // np
144            query_layer.dim(0)?, // sq
145            key_layer.dim(0)?,   // sk
146        );
147        let query_layer =
148            query_layer.reshape((output_size.2, output_size.0 * output_size.1, ()))?;
149        let key_layer = key_layer.reshape((output_size.3, output_size.0 * output_size.1, ()))?;
150        let matmul_result = Tensor::matmul(
151            &query_layer.transpose(0, 1)?,
152            &key_layer.transpose(0, 1)?.transpose(1, 2)?,
153        )?;
154        let matmul_result = (matmul_result / self.norm_factor)?.reshape(output_size)?;
155        let matmul_result = match self.coeff {
156            None => matmul_result,
157            Some(coeff) => (matmul_result * coeff)?,
158        };
159        let attention_scores = match attention_mask {
160            Some(mask) => masked_fill(
161                &matmul_result,
162                &mask.broadcast_left((matmul_result.dim(0)?, matmul_result.dim(1)?))?,
163                f32::NEG_INFINITY,
164            )?,
165            None => matmul_result,
166        };
167        let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
168
169        let output_size = (
170            value_layer.dim(1)?,
171            value_layer.dim(2)?,
172            query_layer.dim(0)?,
173            value_layer.dim(3)?,
174        );
175        let value_layer =
176            value_layer.reshape((value_layer.dim(0)?, output_size.0 * output_size.1, ()))?;
177        let attention_probs =
178            attention_probs.reshape((output_size.0 * output_size.1, output_size.2, ()))?;
179        let context_layer = Tensor::matmul(&attention_probs, &value_layer.transpose(0, 1)?)?;
180        let context_layer = context_layer.reshape(output_size)?;
181        let context_layer = context_layer.permute((2, 0, 1, 3))?.contiguous()?;
182        context_layer.flatten_from(D::Minus2)
183    }
184}
185
186#[derive(Debug, Clone)]
187struct SelfAttention {
188    query_key_value: Linear,
189    core_attention: CoreAttention,
190    dense: Linear,
191    multi_query_attention: bool,
192    num_attention_heads_per_partition: usize,
193    num_multi_query_groups_per_partition: usize,
194    hidden_size_per_attention_head: usize,
195    kv_cache: Option<(Tensor, Tensor)>,
196}
197
198impl SelfAttention {
199    fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
200        let projection_size = cfg.kv_channels * cfg.num_attention_heads;
201        let hidden_size_per_attention_head = projection_size / cfg.num_attention_heads;
202        let qkv_hidden_size = if cfg.multi_query_attention {
203            projection_size + 2 * hidden_size_per_attention_head * cfg.multi_query_group_num
204        } else {
205            3 * projection_size
206        };
207        let query_key_value = linear(
208            cfg.hidden_size,
209            qkv_hidden_size,
210            cfg.add_bias_linear || cfg.add_qkv_bias,
211            vb.pp("query_key_value"),
212        )?;
213        let core_attention = CoreAttention::new(layer_number, cfg)?;
214        let dense = linear(
215            cfg.hidden_size,
216            cfg.hidden_size,
217            cfg.add_bias_linear,
218            vb.pp("dense"),
219        )?;
220        Ok(Self {
221            query_key_value,
222            core_attention,
223            dense,
224            multi_query_attention: cfg.multi_query_attention,
225            num_attention_heads_per_partition: cfg.num_attention_heads,
226            num_multi_query_groups_per_partition: cfg.multi_query_group_num,
227            hidden_size_per_attention_head: cfg.kv_channels,
228            kv_cache: None,
229        })
230    }
231
232    fn reset_kv_cache(&mut self) {
233        self.kv_cache = None
234    }
235
236    fn forward(
237        &mut self,
238        xs: &Tensor,
239        attention_mask: &Option<Tensor>,
240        rotary_emb: &RotaryEmbedding,
241    ) -> Result<Tensor> {
242        let mixed_x_layer = xs.apply(&self.query_key_value)?;
243        if !self.multi_query_attention {
244            candle::bail!("only multi_query_attention=true is supported")
245        }
246        let hpa = self.hidden_size_per_attention_head;
247        let query_layer =
248            mixed_x_layer.narrow(D::Minus1, 0, self.num_attention_heads_per_partition * hpa)?;
249        let key_layer = mixed_x_layer.narrow(
250            D::Minus1,
251            self.num_attention_heads_per_partition * hpa,
252            self.num_multi_query_groups_per_partition * hpa,
253        )?;
254        let value_layer = mixed_x_layer.narrow(
255            D::Minus1,
256            self.num_attention_heads_per_partition * hpa
257                + self.num_multi_query_groups_per_partition * hpa,
258            self.num_multi_query_groups_per_partition * hpa,
259        )?;
260        let query_layer = query_layer.reshape((
261            query_layer.dim(0)?,
262            query_layer.dim(1)?,
263            self.num_attention_heads_per_partition,
264            hpa,
265        ))?;
266        let key_layer = key_layer.reshape((
267            key_layer.dim(0)?,
268            key_layer.dim(1)?,
269            self.num_multi_query_groups_per_partition,
270            hpa,
271        ))?;
272        let value_layer = value_layer.reshape((
273            value_layer.dim(0)?,
274            value_layer.dim(1)?,
275            self.num_multi_query_groups_per_partition,
276            hpa,
277        ))?;
278
279        // Rotary embeddings.
280        let seqlen_offset = match &self.kv_cache {
281            None => 0,
282            Some((prev_k, _)) => prev_k.dim(0)?,
283        };
284        let query_layer = rotary_emb.apply(&query_layer, seqlen_offset)?;
285        let key_layer = rotary_emb.apply(&key_layer, seqlen_offset)?;
286
287        // KV cache.
288        let (key_layer, value_layer) = match &self.kv_cache {
289            None => (key_layer, value_layer),
290            Some((prev_k, prev_v)) => {
291                let k = Tensor::cat(&[prev_k, &key_layer], 0)?;
292                let v = Tensor::cat(&[prev_v, &value_layer], 0)?;
293                (k, v)
294            }
295        };
296        self.kv_cache = Some((key_layer.clone(), value_layer.clone()));
297
298        // Repeat KV.
299        let ratio =
300            self.num_attention_heads_per_partition / self.num_multi_query_groups_per_partition;
301        let key_layer = {
302            let (d0, d1, d2, d3) = key_layer.dims4()?;
303            key_layer
304                .unsqueeze(D::Minus2)?
305                .expand((d0, d1, d2, ratio, d3))?
306                .reshape((
307                    d0,
308                    d1,
309                    self.num_attention_heads_per_partition,
310                    self.hidden_size_per_attention_head,
311                ))?
312        };
313        let value_layer = {
314            let (d0, d1, d2, d3) = value_layer.dims4()?;
315            value_layer
316                .unsqueeze(D::Minus2)?
317                .expand((d0, d1, d2, ratio, d3))?
318                .reshape((
319                    d0,
320                    d1,
321                    self.num_attention_heads_per_partition,
322                    self.hidden_size_per_attention_head,
323                ))?
324        };
325
326        let context_layer =
327            self.core_attention
328                .forward(&query_layer, &key_layer, &value_layer, attention_mask)?;
329        let output = context_layer.apply(&self.dense)?;
330        Ok(output)
331    }
332}
333
334#[allow(clippy::upper_case_acronyms)]
335#[derive(Debug, Clone)]
336struct MLP {
337    dense_h_to_4h: Linear,
338    dense_4h_to_h: Linear,
339}
340
341impl MLP {
342    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
343        let dense_h_to_4h = linear(
344            cfg.hidden_size,
345            cfg.ffn_hidden_size * 2,
346            cfg.add_bias_linear,
347            vb.pp("dense_h_to_4h"),
348        )?;
349        let dense_4h_to_h = linear(
350            cfg.ffn_hidden_size,
351            cfg.hidden_size,
352            cfg.add_bias_linear,
353            vb.pp("dense_4h_to_h"),
354        )?;
355        Ok(Self {
356            dense_4h_to_h,
357            dense_h_to_4h,
358        })
359    }
360}
361
362impl Module for MLP {
363    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
364        xs.apply(&self.dense_h_to_4h)?
365            .apply(&candle_nn::Activation::Swiglu)?
366            .apply(&self.dense_4h_to_h)
367    }
368}
369
370#[derive(Debug, Clone)]
371struct Block {
372    input_layernorm: candle_nn::LayerNorm,
373    self_attention: SelfAttention,
374    post_attention_layernorm: candle_nn::LayerNorm,
375    mlp: MLP,
376    apply_residual_connection_post_layernorm: bool,
377}
378
379impl Block {
380    fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
381        let input_layernorm = if cfg.rmsnorm {
382            candle_nn::rms_norm(
383                cfg.hidden_size,
384                cfg.layernorm_epsilon,
385                vb.pp("input_layernorm"),
386            )?
387            .into_inner()
388        } else {
389            candle_nn::layer_norm(
390                cfg.hidden_size,
391                cfg.layernorm_epsilon,
392                vb.pp("input_layernorm"),
393            )?
394        };
395        let post_attention_layernorm = if cfg.rmsnorm {
396            candle_nn::rms_norm(
397                cfg.hidden_size,
398                cfg.layernorm_epsilon,
399                vb.pp("post_attention_layernorm"),
400            )?
401            .into_inner()
402        } else {
403            candle_nn::layer_norm(
404                cfg.hidden_size,
405                cfg.layernorm_epsilon,
406                vb.pp("post_attention_layernorm"),
407            )?
408        };
409        let self_attention = SelfAttention::new(layer_number, cfg, vb.pp("self_attention"))?;
410        let mlp = MLP::new(cfg, vb.pp("mlp"))?;
411        Ok(Self {
412            input_layernorm,
413            self_attention,
414            post_attention_layernorm,
415            mlp,
416            apply_residual_connection_post_layernorm: cfg.apply_residual_connection_post_layernorm,
417        })
418    }
419
420    fn reset_kv_cache(&mut self) {
421        self.self_attention.reset_kv_cache()
422    }
423
424    fn forward(
425        &mut self,
426        xs: &Tensor,
427        attention_mask: &Option<Tensor>,
428        rotary_emb: &RotaryEmbedding,
429    ) -> Result<Tensor> {
430        let layernorm_output = xs.apply(&self.input_layernorm)?;
431        let attention_output =
432            self.self_attention
433                .forward(&layernorm_output, attention_mask, rotary_emb)?;
434        let residual = if self.apply_residual_connection_post_layernorm {
435            &layernorm_output
436        } else {
437            xs
438        };
439        let layernorm_input = (residual + attention_output)?;
440        let layernorm_output = layernorm_input.apply(&self.post_attention_layernorm)?;
441        let mlp_output = layernorm_output.apply(&self.mlp)?;
442        let residual = if self.apply_residual_connection_post_layernorm {
443            &layernorm_output
444        } else {
445            &layernorm_input
446        };
447        mlp_output + residual
448    }
449}
450
451#[derive(Debug, Clone)]
452struct Transformer {
453    layers: Vec<Block>,
454    final_layernorm: Option<candle_nn::LayerNorm>,
455    rotary_emb: RotaryEmbedding,
456}
457
458impl Transformer {
459    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
460        let vb_l = vb.pp("layers");
461        let mut layers = Vec::with_capacity(cfg.num_layers);
462        for layer_index in 0..cfg.num_layers {
463            let block = Block::new(layer_index + 1, cfg, vb_l.pp(layer_index))?;
464            layers.push(block)
465        }
466        let final_layernorm = if cfg.post_layer_norm {
467            let ln = if cfg.rmsnorm {
468                candle_nn::rms_norm(
469                    cfg.hidden_size,
470                    cfg.layernorm_epsilon,
471                    vb.pp("final_layernorm"),
472                )?
473                .into_inner()
474            } else {
475                candle_nn::layer_norm(
476                    cfg.hidden_size,
477                    cfg.layernorm_epsilon,
478                    vb.pp("final_layernorm"),
479                )?
480            };
481            Some(ln)
482        } else {
483            None
484        };
485        let rotary_emb = RotaryEmbedding::new(cfg, vb.dtype(), vb.device())?;
486        Ok(Self {
487            layers,
488            final_layernorm,
489            rotary_emb,
490        })
491    }
492
493    fn reset_kv_cache(&mut self) {
494        for block in self.layers.iter_mut() {
495            block.reset_kv_cache()
496        }
497    }
498
499    fn forward(&mut self, xs: &Tensor, attention_mask: &Option<Tensor>) -> Result<Tensor> {
500        let mut xs = xs.clone();
501        for block in self.layers.iter_mut() {
502            xs = block.forward(&xs, attention_mask, &self.rotary_emb)?
503        }
504        match self.final_layernorm.as_ref() {
505            None => Ok(xs),
506            Some(ln) => xs.apply(ln),
507        }
508    }
509}
510
511#[derive(Debug, Clone)]
512struct Embedding {
513    word_embeddings: candle_nn::Embedding,
514    fp32_residual_connection: bool,
515}
516
517impl Embedding {
518    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
519        let word_embeddings = candle_nn::embedding(
520            cfg.padded_vocab_size,
521            cfg.hidden_size,
522            vb.pp("word_embeddings"),
523        )?;
524        Ok(Self {
525            word_embeddings,
526            fp32_residual_connection: cfg.fp32_residual_connection,
527        })
528    }
529}
530
531impl Module for Embedding {
532    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
533        let xs = self.word_embeddings.forward(xs)?.transpose(0, 1)?; // b,s,h -> s,b,h
534        if self.fp32_residual_connection {
535            xs.to_dtype(candle::DType::F32)
536        } else {
537            xs.contiguous()
538        }
539    }
540}
541
542#[derive(Debug, Clone)]
543pub struct Model {
544    embedding: Embedding,
545    encoder: Transformer,
546    output_layer: Linear,
547}
548
549fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
550    let mask: Vec<_> = (0..size)
551        .flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
552        .collect();
553    Tensor::from_slice(&mask, (size, size), device)
554}
555
556impl Model {
557    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
558        let vb = vb.pp("transformer");
559        let embedding = Embedding::new(cfg, vb.pp("embedding"))?;
560        let encoder = Transformer::new(cfg, vb.pp("encoder"))?;
561        let output_layer = linear(
562            cfg.hidden_size,
563            cfg.padded_vocab_size,
564            false,
565            vb.pp("output_layer"),
566        )?;
567        Ok(Self {
568            embedding,
569            encoder,
570            output_layer,
571        })
572    }
573
574    pub fn reset_kv_cache(&mut self) {
575        self.encoder.reset_kv_cache()
576    }
577
578    pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
579        let (_b_size, seq_len) = xs.dims2()?;
580        let input_embeds = xs.apply(&self.embedding)?;
581        let attention_mask = if seq_len <= 1 {
582            None
583        } else {
584            Some(get_mask(seq_len, xs.device())?)
585        };
586        let xs = self.encoder.forward(&input_embeds, &attention_mask)?;
587        let lm_logits = xs.i(seq_len - 1)?.apply(&self.output_layer)?;
588        Ok(lm_logits)
589    }
590}