1use crate::quantized_nn::{linear_b as linear, Embedding, Linear};
19pub use crate::quantized_var_builder::VarBuilder;
20use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
21use std::sync::Arc;
22
23use crate::models::recurrent_gemma::{Config, Rglru, RmsNorm, RotaryEmbedding, TemporalBlockType};
24
25fn rms_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<RmsNorm> {
26 let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
27 Ok(RmsNorm::from_weight(weight, eps))
28}
29
30#[derive(Debug, Clone)]
31struct Mlp {
32 gate_proj: Linear,
33 up_proj: Linear,
34 down_proj: Linear,
35 act_fn: candle_nn::Activation,
36}
37
38impl Mlp {
39 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
40 let h = cfg.hidden_size;
41 let intermediate_size = cfg.intermediate_size / 2;
42 let gate_proj = linear(h, intermediate_size, true, vb.pp("gate_proj"))?;
43 let up_proj = linear(h, intermediate_size, true, vb.pp("up_proj"))?;
44 let down_proj = linear(intermediate_size, h, true, vb.pp("down_proj"))?;
45 Ok(Self {
46 gate_proj,
47 up_proj,
48 down_proj,
49 act_fn: cfg.hidden_activation,
50 })
51 }
52}
53
54impl Module for Mlp {
55 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
56 let gate = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
57 (gate * xs.apply(&self.up_proj))?.apply(&self.down_proj)
58 }
59}
60
61fn rglru(cfg: &Config, vb: VarBuilder) -> Result<Rglru> {
62 let h = cfg.hidden_size;
63 let lru_width = cfg.lru_width.unwrap_or(h);
64 let n_heads = cfg.num_attention_heads;
65 let block_width = lru_width / n_heads;
66 let recurrent_param = vb.get((lru_width,), "recurrent_param")?;
67 let input_gate_weight = vb.get((n_heads, block_width, block_width), "input_gate_weight")?;
68 let input_gate_bias = vb.get((n_heads, block_width), "input_gate_bias")?;
69 let recurrent_gate_weight =
70 vb.get((n_heads, block_width, block_width), "recurrent_gate_weight")?;
71 let recurrent_gate_bias = vb.get((n_heads, block_width), "recurrent_gate_bias")?;
72 Ok(Rglru {
73 recurrent_param: recurrent_param.dequantize(vb.device())?,
74 input_gate_bias: input_gate_bias.dequantize(vb.device())?,
75 input_gate_weight: input_gate_weight.dequantize(vb.device())?,
76 recurrent_gate_bias: recurrent_gate_bias.dequantize(vb.device())?,
77 recurrent_gate_weight: recurrent_gate_weight.dequantize(vb.device())?,
78 block_width,
79 n_heads,
80 recurrent_states: None,
81 })
82}
83
84#[derive(Debug, Clone)]
85struct RecurrentBlock {
86 linear_y: Linear,
87 linear_x: Linear,
88 linear_out: Linear,
89 conv_1d: candle_nn::Conv1d,
90 conv1d_state: Option<Tensor>,
91 conv1d_width: usize,
92 rg_lru: Rglru,
93 act_fn: candle_nn::Activation,
94}
95
96impl RecurrentBlock {
97 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
98 let h = cfg.hidden_size;
99 let lru_width = cfg.lru_width.unwrap_or(h);
100 let linear_y = linear(h, lru_width, true, vb.pp("linear_y"))?;
101 let linear_x = linear(h, lru_width, true, vb.pp("linear_x"))?;
102 let linear_out = linear(lru_width, h, true, vb.pp("linear_out"))?;
103
104 let conv_1d = {
105 let ws = vb
106 .get((lru_width, 1, cfg.conv1d_width), "conv_1d.weight")?
107 .dequantize(vb.device())?;
108 let bs = vb.get(lru_width, "conv_1d.bias")?.dequantize(vb.device())?;
109 let config = candle_nn::Conv1dConfig {
110 groups: lru_width,
111 padding: cfg.conv1d_width - 1,
112 ..Default::default()
113 };
114 candle_nn::Conv1d::new(ws, Some(bs), config)
115 };
116 let rg_lru = rglru(cfg, vb.pp("rg_lru"))?;
117 Ok(Self {
118 linear_y,
119 linear_x,
120 linear_out,
121 conv_1d,
122 conv1d_state: None,
123 conv1d_width: cfg.conv1d_width,
124 rg_lru,
125 act_fn: cfg.hidden_activation,
126 })
127 }
128
129 pub fn forward(&mut self, xs: &Tensor, pos: usize) -> Result<Tensor> {
130 let (_b_sz, seq_len, _) = xs.dims3()?;
131
132 let y_branch = xs.apply(&self.linear_y)?.apply(&self.act_fn)?;
133 let x_branch = xs.apply(&self.linear_x)?.transpose(1, 2)?;
134 let x_branch = if pos == 0 {
135 let x_len = x_branch.dim(D::Minus1)?;
136 let pad = self.conv1d_width as i64 - x_len as i64 - 1;
137 let padded = match pad.cmp(&0) {
138 std::cmp::Ordering::Equal => x_branch.clone(),
139 std::cmp::Ordering::Less => {
140 let rev_pad = (-pad) as usize;
141 x_branch.narrow(D::Minus1, rev_pad, x_len - rev_pad)?
142 }
143 std::cmp::Ordering::Greater => {
144 x_branch.pad_with_zeros(D::Minus1, pad as usize, 0)?
145 }
146 };
147 self.conv1d_state = Some(padded);
148 x_branch
149 .apply(&self.conv_1d)?
150 .narrow(D::Minus1, 0, seq_len)?
151 } else {
152 let conv_state = match self.conv1d_state.as_ref() {
153 None => candle::bail!("empty cache despite pos > 0"),
154 Some(s) => Tensor::cat(&[s, &x_branch], D::Minus1)?,
155 };
156 let w = self.conv_1d.weight().i((.., 0, ..))?;
157 let x_branch = conv_state.broadcast_mul(&w)?.sum(D::Minus1)?;
158 let x_branch = match self.conv_1d.bias() {
159 None => x_branch,
160 Some(b) => x_branch.broadcast_add(b)?,
161 };
162 let x_branch = x_branch.unsqueeze(D::Minus1)?;
163 self.conv1d_state = Some(conv_state.i((.., .., 1..))?);
164 x_branch
165 };
166 let x_branch = x_branch.transpose(1, 2)?;
167 let x_branch = self.rg_lru.forward(&x_branch, pos)?;
168 (x_branch * y_branch)?.apply(&self.linear_out)
169 }
170}
171
172#[derive(Debug, Clone)]
173struct SdpaAttention {
174 q_proj: Linear,
175 k_proj: Linear,
176 v_proj: Linear,
177 o_proj: Linear,
178 n_heads: usize,
179 n_kv_heads: usize,
180 head_dim: usize,
181 hidden_size: usize,
182 kv_cache: Option<(Tensor, Tensor)>,
183 rotary_emb: Arc<RotaryEmbedding>,
184}
185
186impl SdpaAttention {
187 fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
188 let h = cfg.hidden_size;
189 let n_heads = cfg.num_attention_heads;
190 let n_kv_heads = cfg.num_key_value_heads;
191 let hd = cfg.head_dim;
192 let q_proj = linear(h, n_heads * hd, cfg.attention_bias, vb.pp("q_proj"))?;
193 let k_proj = linear(h, n_kv_heads * hd, cfg.attention_bias, vb.pp("k_proj"))?;
194 let v_proj = linear(h, n_kv_heads * hd, cfg.attention_bias, vb.pp("v_proj"))?;
195 let o_proj = linear(n_heads * hd, h, true, vb.pp("o_proj"))?;
196 Ok(Self {
197 q_proj,
198 k_proj,
199 v_proj,
200 o_proj,
201 n_heads,
202 n_kv_heads,
203 head_dim: hd,
204 hidden_size: h,
205 kv_cache: None,
206 rotary_emb,
207 })
208 }
209
210 fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
211 let n_rep = self.n_heads / self.n_kv_heads;
212 crate::utils::repeat_kv(x, n_rep)
213 }
214
215 fn forward(
216 &mut self,
217 xs: &Tensor,
218 attention_mask: Option<&Tensor>,
219 pos: usize,
220 ) -> Result<Tensor> {
221 let (bsz, q_len, _) = xs.dims3()?;
222
223 let query_states = xs.apply(&self.q_proj)?;
224 let key_states = xs.apply(&self.k_proj)?;
225 let value_states = xs.apply(&self.v_proj)?;
226
227 let query_states = query_states
228 .reshape((bsz, q_len, self.n_heads, self.head_dim))?
229 .transpose(1, 2)?;
230 let key_states = key_states
231 .reshape((bsz, q_len, self.n_kv_heads, self.head_dim))?
232 .transpose(1, 2)?;
233 let value_states = value_states
234 .reshape((bsz, q_len, self.n_kv_heads, self.head_dim))?
235 .transpose(1, 2)?;
236 let query_states = query_states.chunk(2, D::Minus1)?;
237 let key_states = key_states.chunk(2, D::Minus1)?;
238 let (query_rot, key_rot) =
239 self.rotary_emb
240 .apply_rotary_emb_qkv(&query_states[0], &key_states[0], pos)?;
241 let query_states = Tensor::cat(&[&query_rot, &query_states[1]], D::Minus1)?.contiguous()?;
242 let key_states = Tensor::cat(&[&key_rot, &key_states[1]], D::Minus1)?.contiguous()?;
243
244 let (key_states, value_states) = match &self.kv_cache {
245 None => (key_states, value_states),
246 Some((prev_k, prev_v)) => {
247 let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
248 let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
249 (key_states, value_states)
250 }
251 };
252 self.kv_cache = Some((key_states.clone(), value_states.clone()));
253
254 let key_states = self.repeat_kv(key_states)?;
255 let value_states = self.repeat_kv(value_states)?;
256 let xs = {
257 let att = (query_states.matmul(&key_states.t()?)? / (self.head_dim as f64).sqrt())?;
258 let att = if q_len == 1 {
259 att
260 } else {
261 match attention_mask {
262 None => att,
263 Some(mask) => att.broadcast_add(mask)?,
264 }
265 };
266 let att = candle_nn::ops::softmax_last_dim(&att)?;
267 att.matmul(&value_states.contiguous()?)?
268 };
269
270 let xs = xs
271 .transpose(1, 2)?
272 .reshape((bsz, q_len, self.hidden_size))?;
273 self.o_proj.forward(&xs)
274 }
275}
276
277#[derive(Debug, Clone)]
278enum TemporalBlock {
279 Recurrent(RecurrentBlock),
280 Attention(SdpaAttention),
281}
282
283impl TemporalBlock {
284 fn forward(
285 &mut self,
286 xs: &Tensor,
287 attention_mask: Option<&Tensor>,
288 pos: usize,
289 ) -> Result<Tensor> {
290 match self {
291 Self::Recurrent(b) => b.forward(xs, pos),
292 Self::Attention(b) => b.forward(xs, attention_mask, pos),
293 }
294 }
295}
296
297#[derive(Debug, Clone)]
298struct DecoderLayer {
299 temporal_pre_norm: RmsNorm,
300 channel_pre_norm: RmsNorm,
301 temporal_block: TemporalBlock,
302 mlp_block: Mlp,
303}
304
305impl DecoderLayer {
306 fn new(
307 block_idx: usize,
308 rotary_emb: Arc<RotaryEmbedding>,
309 cfg: &Config,
310 vb: VarBuilder,
311 ) -> Result<Self> {
312 let h = cfg.hidden_size;
313 let temporal_pre_norm = rms_norm(h, cfg.rms_norm_eps, vb.pp("temporal_pre_norm"))?;
314 let channel_pre_norm = rms_norm(h, cfg.rms_norm_eps, vb.pp("channel_pre_norm"))?;
315 let temporal_block = match cfg.block_types[block_idx % cfg.block_types.len()] {
316 TemporalBlockType::Recurrent => {
317 let block = RecurrentBlock::new(cfg, vb.pp("temporal_block"))?;
318 TemporalBlock::Recurrent(block)
319 }
320 TemporalBlockType::Attention => {
321 let block = SdpaAttention::new(rotary_emb, cfg, vb.pp("temporal_block"))?;
322 TemporalBlock::Attention(block)
323 }
324 };
325 let mlp_block = Mlp::new(cfg, vb.pp("mlp_block"))?;
326 Ok(Self {
327 temporal_pre_norm,
328 channel_pre_norm,
329 temporal_block,
330 mlp_block,
331 })
332 }
333
334 fn forward(
335 &mut self,
336 xs: &Tensor,
337 attention_mask: Option<&Tensor>,
338 pos: usize,
339 ) -> Result<Tensor> {
340 let residual = xs;
341 let xs = xs.apply(&self.temporal_pre_norm)?;
342 let xs = self.temporal_block.forward(&xs, attention_mask, pos)?;
343 let xs = (xs + residual)?;
344 let residual = &xs;
345 let xs = xs.apply(&self.channel_pre_norm)?.apply(&self.mlp_block)?;
346 xs + residual
347 }
348}
349
350#[derive(Debug, Clone)]
351pub struct Model {
352 embed_tokens: Embedding,
353 layers: Vec<DecoderLayer>,
354 final_norm: RmsNorm,
355 lm_head: Linear,
356 hidden_size: usize,
357 logits_soft_cap: f64,
358 device: Device,
359}
360
361impl Model {
362 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
363 let embed_tokens = Embedding::new(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?;
364 let rotary_emb = Arc::new(RotaryEmbedding::new(DType::F32, cfg, vb.device())?);
365 let vb_b = vb.pp("layers");
366 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
367 for idx in 0..cfg.num_hidden_layers {
368 let layer = DecoderLayer::new(idx, rotary_emb.clone(), cfg, vb_b.pp(idx))?;
369 layers.push(layer)
370 }
371 let final_norm = rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("final_norm"))?;
372 let lm_head = linear(
373 cfg.hidden_size,
374 cfg.vocab_size,
375 false,
376 vb.pp("embed_tokens"),
377 )?;
378 Ok(Self {
379 embed_tokens,
380 layers,
381 final_norm,
382 lm_head,
383 hidden_size: cfg.hidden_size,
384 logits_soft_cap: cfg.logits_soft_cap,
385 device: vb.device().clone(),
386 })
387 }
388
389 fn prepare_decoder_attention_mask(
390 &self,
391 b_size: usize,
392 tgt_len: usize,
393 seqlen_offset: usize,
394 ) -> Result<Tensor> {
395 let mask: Vec<_> = (0..tgt_len)
396 .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
397 .collect();
398 let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
399 let mask = if seqlen_offset > 0 {
400 let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
401 Tensor::cat(&[&mask0, &mask], D::Minus1)?
402 } else {
403 mask
404 };
405 mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
406 .to_dtype(DType::F32)
407 }
408
409 pub fn forward(&mut self, xs: &Tensor, pos: usize) -> Result<Tensor> {
410 let (b_size, seq_len) = xs.dims2()?;
411 let attention_mask = if seq_len <= 1 {
412 None
413 } else {
414 let mask = self.prepare_decoder_attention_mask(b_size, seq_len, pos)?;
415 Some(mask)
416 };
417 let xs = xs.apply(&self.embed_tokens)?;
418 let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
419 for layer in self.layers.iter_mut() {
420 xs = layer.forward(&xs, attention_mask.as_ref(), pos)?;
421 }
422 let logits = xs
423 .narrow(1, seq_len - 1, 1)?
424 .apply(&self.final_norm)?
425 .apply(&self.lm_head)?;
426 let logits = ((logits / self.logits_soft_cap)?.tanh()? * self.logits_soft_cap)?;
427 Ok(logits)
428 }
429}