1use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm};
19use candle::{DType, Device, Error, IndexOp, Module, Result, Tensor, D};
20use candle_nn::{layer_norm, Activation, LayerNorm, VarBuilder};
21use std::sync::Arc;
22
23#[derive(Debug, Copy, Clone, PartialEq, serde::Deserialize)]
25pub enum ModelVariant {
26 Large, Small, }
29
30impl Default for ModelVariant {
31 fn default() -> Self {
32 Self::Large
33 }
34}
35
36#[derive(Debug, Default, Clone, PartialEq, serde::Deserialize)]
39pub struct Config {
40 pub variant: ModelVariant,
41 pub vocab_size: usize,
42 pub hidden_size: usize,
43 pub intermediate_size: usize,
44 pub num_hidden_layers: usize,
45 pub num_attention_heads: usize,
46 pub max_position_embeddings: usize,
47 pub rope_theta: f64,
48 pub embed_head: EmbedHead,
49 pub norm_eps: f64, pub activation_fn: Activation, pub num_key_value_heads: usize,
53 pub type_vocab_size: usize,
55 pub scaling_factor: f64,
56}
57
58#[derive(Debug, Default, Clone, PartialEq, serde::Deserialize)]
62pub struct EmbedHead {
63 pub in_features: usize,
64 pub out_features: usize,
65}
66
67#[derive(Debug, Clone, Copy)]
70pub enum EmbedDim {
71 Dim256,
72 Dim768,
73 Dim1024,
74 Dim2048,
75 Dim4096,
76 Dim6144,
77 Dim8192,
78}
79
80impl Default for EmbedDim {
81 fn default() -> Self {
82 Self::Dim1024
83 }
84}
85
86impl EmbedDim {
87 pub fn config(&self, in_features: usize) -> EmbedHead {
88 EmbedHead {
89 in_features,
90 out_features: match &self {
91 Self::Dim256 => 256,
92 Self::Dim768 => 768,
93 Self::Dim1024 => 1024,
94 Self::Dim2048 => 2048,
95 Self::Dim4096 => 4096,
96 Self::Dim6144 => 6144,
97 Self::Dim8192 => 8192,
98 },
99 }
100 }
101}
102
103impl Config {
105 pub fn new_1_5_b_v5(embed_dim: EmbedDim) -> Self {
107 Self {
110 variant: ModelVariant::Large,
111 activation_fn: candle_nn::Activation::Silu,
112 vocab_size: 151646,
113 hidden_size: 1536,
114 intermediate_size: 8960,
115 num_hidden_layers: 28,
116 num_attention_heads: 12,
117 num_key_value_heads: 2,
118 max_position_embeddings: 131072,
119 rope_theta: 1000000.,
120 norm_eps: 1e-06,
121 embed_head: embed_dim.config(1536),
122 ..Default::default()
123 }
124 }
125
126 pub fn new_400_m_v5(embed_dim: EmbedDim) -> Self {
128 Self {
129 variant: ModelVariant::Small,
130 vocab_size: 30528,
131 hidden_size: 1024,
132 intermediate_size: 4096,
133 num_hidden_layers: 24,
134 num_attention_heads: 16,
135 max_position_embeddings: 8192,
136 type_vocab_size: 2,
137 norm_eps: 1e-12,
138 scaling_factor: 2.0,
139 rope_theta: 160000.0,
140 activation_fn: Activation::Gelu,
141 embed_head: embed_dim.config(1024),
142 ..Default::default()
143 }
144 }
145}
146
147#[derive(Debug, Clone)]
148struct RotaryEmbedding {
149 sin: Tensor,
150 cos: Tensor,
151}
152
153impl RotaryEmbedding {
154 fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
155 let dim = cfg.hidden_size / cfg.num_attention_heads;
156 let max_seq_len = if cfg.scaling_factor == 0. {
158 cfg.max_position_embeddings
159 } else {
160 ((cfg.max_position_embeddings as f64) * cfg.scaling_factor) as usize
161 };
162
163 let inv_freq: Vec<_> = (0..dim)
165 .step_by(2)
166 .map(|i| {
167 let rope_theta = if cfg.scaling_factor == 0. {
169 cfg.rope_theta
170 } else {
171 cfg.rope_theta * cfg.scaling_factor
172 };
173 let mut freq = 1. / rope_theta.powf(i as f64 / dim as f64);
174
175 if cfg.scaling_factor != 0. {
176 freq /= cfg.scaling_factor.powf(2.0 / (dim as f64))
177 }
178
179 freq as f32
180 })
181 .collect();
182
183 let inv_freq_len = inv_freq.len();
184 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
185
186 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
188 .to_dtype(dtype)?
189 .reshape((max_seq_len, 1))?;
190 let freqs = t.matmul(&inv_freq)?;
191 Ok(Self {
196 sin: freqs.sin()?,
197 cos: freqs.cos()?,
198 })
199 }
200
201 fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
203 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
204 let cos = self.cos.narrow(0, 0, seq_len)?;
205 let sin = self.sin.narrow(0, 0, seq_len)?;
206
207 let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
208 let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
209 Ok((q_embed, k_embed))
210 }
211}
212
213#[derive(Debug, Clone)]
214#[allow(clippy::upper_case_acronyms)]
215struct MLP {
216 variant: ModelVariant,
217 gate_proj: Linear,
218 up_proj: Option<Linear>, down_proj: Linear,
220 act_fn: Activation,
221}
222
223impl MLP {
224 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
225 let hidden_sz = cfg.hidden_size;
226 let intermediate_sz = cfg.intermediate_size;
227
228 let (gate_proj, up_proj, down_proj) = match cfg.variant {
229 ModelVariant::Large => (
230 linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?,
231 Some(linear_no_bias(
232 hidden_sz,
233 intermediate_sz,
234 vb.pp("up_proj"),
235 )?),
236 linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?,
237 ),
238 ModelVariant::Small => (
239 linear_no_bias(hidden_sz, intermediate_sz * 2, vb.pp("up_gate_proj"))?,
240 None,
241 linear(intermediate_sz, hidden_sz, vb.pp("down_proj"))?,
242 ),
243 };
244
245 Ok(Self {
246 variant: cfg.variant,
247 gate_proj,
248 up_proj,
249 down_proj,
250 act_fn: cfg.activation_fn,
251 })
252 }
253}
254
255impl Module for MLP {
256 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
257 let up = self.gate_proj.forward(xs)?;
258
259 let (lhs, rhs) = match self.variant {
260 ModelVariant::Large => {
261 let lhs = up.apply(&self.act_fn)?;
262 let rhs = xs.apply(self.up_proj.as_ref().unwrap())?;
263
264 (lhs, rhs)
265 }
266 ModelVariant::Small => {
267 let (_batch_size, _seq_len, hidden_dim) = up.dims3()?;
269 let split_size = hidden_dim / 2;
270
271 let up_states = up.narrow(2, 0, split_size)?;
273 let gate = up.narrow(2, split_size, split_size)?.apply(&self.act_fn)?;
274
275 (up_states, gate)
276 }
277 };
278
279 (lhs * rhs)?.apply(&self.down_proj)
280 }
281}
282
283#[derive(Debug, Clone)]
284struct Attention {
285 qkv_proj: Linear,
286 o_proj: Linear,
287 num_heads: usize,
288 num_kv_heads: usize,
289 num_kv_groups: usize,
290 head_dim: usize,
291 hidden_size: usize,
292 rotary_emb: Arc<RotaryEmbedding>,
293 variant: ModelVariant,
294}
295
296impl Attention {
297 fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
298 let hidden_sz = cfg.hidden_size;
299 let num_heads = cfg.num_attention_heads;
300 let num_kv_heads = cfg.num_key_value_heads;
301 let num_kv_groups = if num_kv_heads > 0 {
302 num_heads / num_kv_heads
303 } else {
304 0
305 };
306 let head_dim = hidden_sz / num_heads;
307
308 let (qkv_proj, o_proj) = match cfg.variant {
309 ModelVariant::Large => {
310 let q_w = vb
313 .pp("q_proj")
314 .get((num_heads * head_dim, hidden_sz), "weight")?;
315 let k_w = vb
316 .pp("k_proj")
317 .get((num_kv_heads * head_dim, hidden_sz), "weight")?;
318 let v_w = vb
319 .pp("v_proj")
320 .get((num_kv_heads * head_dim, hidden_sz), "weight")?;
321 let q_b = vb.pp("q_proj").get(num_heads * head_dim, "bias")?;
323 let k_b = vb.pp("k_proj").get(num_kv_heads * head_dim, "bias")?;
324 let v_b = vb.pp("v_proj").get(num_kv_heads * head_dim, "bias")?;
325
326 let qkv_w = Tensor::cat(&[&q_w, &k_w, &v_w], 0)?;
327 let qkv_b = Tensor::cat(&[&q_b, &k_b, &v_b], 0)?;
328
329 (
330 Linear::from_weights(qkv_w, Some(qkv_b)),
331 linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?,
332 )
333 }
334 ModelVariant::Small => (
335 linear(hidden_sz, 3 * num_heads * head_dim, vb.pp("qkv_proj"))?,
336 linear(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?,
337 ),
338 };
339
340 Ok(Self {
341 qkv_proj,
342 o_proj,
343 num_heads,
344 num_kv_heads,
345 num_kv_groups,
346 head_dim,
347 hidden_size: hidden_sz,
348 rotary_emb,
349 variant: cfg.variant,
350 })
351 }
352
353 fn forward(&mut self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
354 let (b_sz, q_len, _) = xs.dims3()?;
355
356 let qkv = self.qkv_proj.forward(xs)?;
357
358 let n_kv_heads = match self.variant {
359 ModelVariant::Large => self.num_kv_heads,
360 ModelVariant::Small => self.num_heads,
361 };
362
363 let (query_states, key_states, value_states) = match self.variant {
364 ModelVariant::Large => {
365 let q_sz = self.num_heads * self.head_dim;
366 let kv_sz = n_kv_heads * self.head_dim;
367
368 let q = qkv.narrow(D::Minus1, 0, q_sz)?.reshape((
369 b_sz,
370 q_len,
371 self.num_heads,
372 self.head_dim,
373 ))?;
374 let k = qkv.narrow(D::Minus1, q_sz, kv_sz)?.reshape((
375 b_sz,
376 q_len,
377 n_kv_heads,
378 self.head_dim,
379 ))?;
380 let v = qkv.narrow(D::Minus1, q_sz + kv_sz, kv_sz)?.reshape((
381 b_sz,
382 q_len,
383 n_kv_heads,
384 self.head_dim,
385 ))?;
386
387 (q, k, v)
388 }
389 ModelVariant::Small => {
390 let qkv = qkv.reshape((b_sz, q_len, 3, self.num_heads, self.head_dim))?;
392
393 (
394 qkv.i((.., .., 0, .., ..))?,
395 qkv.i((.., .., 1, .., ..))?,
396 qkv.i((.., .., 2, .., ..))?,
397 )
398 }
399 };
400
401 let query_states = query_states.transpose(1, 2)?.contiguous()?;
402 let key_states = key_states.transpose(1, 2)?.contiguous()?;
403 let value_states = value_states.transpose(1, 2)?.contiguous()?;
404
405 let (query_states, key_states) = self
406 .rotary_emb
407 .apply_rotary_emb_qkv(&query_states, &key_states)?;
408
409 let (key_states, value_states) = if self.variant == ModelVariant::Large {
411 (
412 crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?,
413 crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?,
414 )
415 } else {
416 (key_states, value_states)
417 };
418
419 let attn_output = {
420 let scale = 1f64 / f64::sqrt(self.head_dim as f64);
421 let attn_weights = query_states.matmul(&key_states.transpose(2, 3)?)?;
422 let attn_weights = (attn_weights * scale)?;
423
424 let attn_weights = match attention_mask {
425 None => attn_weights,
426 Some(mask) => attn_weights.broadcast_add(mask)?,
427 };
428 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
429
430 attn_weights.matmul(&value_states)?
431 };
432
433 attn_output
434 .transpose(1, 2)?
435 .reshape((b_sz, q_len, self.hidden_size))?
436 .apply(&self.o_proj)
437 }
438}
439
440#[derive(Debug, Clone)]
441enum NormType {
442 Layer(LayerNorm),
443 Rms(RmsNorm),
444}
445
446#[derive(Debug, Clone)]
447struct Layer {
448 variant: ModelVariant,
449 attention: Attention,
450 mlp: MLP,
451 layernorm: NormType,
454 post_attention_layernorm: NormType,
455}
456
457impl Layer {
458 fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
459 let attention = Attention::new(
460 rotary_emb,
461 cfg,
462 vb.pp(if cfg.variant == ModelVariant::Large {
463 "self_attn"
464 } else {
465 "attention"
466 }),
467 )?;
468 let mlp = MLP::new(cfg, vb.pp("mlp"))?;
469 let (layernorm, post_attention_layernorm) = match cfg.variant {
470 ModelVariant::Large => (
471 NormType::Rms(RmsNorm::new(
472 cfg.hidden_size,
473 cfg.norm_eps,
474 vb.pp("input_layernorm"),
475 )?),
476 NormType::Rms(RmsNorm::new(
477 cfg.hidden_size,
478 cfg.norm_eps,
479 vb.pp("post_attention_layernorm"),
480 )?),
481 ),
482 ModelVariant::Small => (
483 NormType::Layer(layer_norm(
484 cfg.hidden_size,
485 candle_nn::LayerNormConfig {
486 eps: cfg.norm_eps,
487 ..Default::default()
488 },
489 vb.pp("mlp_ln"),
490 )?),
491 NormType::Layer(layer_norm(
492 cfg.hidden_size,
493 candle_nn::LayerNormConfig {
494 eps: cfg.norm_eps,
495 ..Default::default()
496 },
497 vb.pp("attn_ln"),
498 )?),
499 ),
500 };
501
502 Ok(Self {
503 variant: cfg.variant,
504 attention,
505 mlp,
506 layernorm,
507 post_attention_layernorm,
508 })
509 }
510
511 fn forward(&mut self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
512 let residual = xs;
531
532 match self.variant {
533 ModelVariant::Large => {
534 let (attn_ln, input_ln) = if let (NormType::Rms(attn_ln), NormType::Rms(input_ln)) =
535 (&self.post_attention_layernorm, &self.layernorm)
536 {
537 (attn_ln, input_ln)
538 } else {
539 return Err(candle::error::Error::Msg(
540 "Stella 1.5B expects RMSNorm".to_string(),
541 ));
542 };
543
544 let xs = input_ln.forward(xs)?;
545 let xs = (self.attention.forward(&xs, attention_mask)? + residual)?;
546
547 let residual = &xs;
548 let xs = xs.apply(attn_ln)?.apply(&self.mlp)?;
549
550 residual + xs
551 }
552 ModelVariant::Small => {
553 let (attn_ln, output_ln) =
554 if let (NormType::Layer(attn_ln), NormType::Layer(input_ln)) =
555 (&self.post_attention_layernorm, &self.layernorm)
556 {
557 (attn_ln, input_ln)
558 } else {
559 return Err(candle::error::Error::Msg(
560 "Stella 400M expects RMSNorm".to_string(),
561 ));
562 };
563
564 let xs = (self.attention.forward(xs, attention_mask)? + residual)?;
565 let xs = attn_ln.forward(&xs)?;
566
567 let residual = &xs;
568 let xs = (self.mlp.forward(&xs)? + residual)?;
569
570 output_ln.forward(&xs)
571 }
572 }
573 }
574}
575
576#[derive(Debug, Clone)]
577pub struct Embeddings {
578 variant: ModelVariant,
579 embeddings: candle_nn::Embedding,
582 token_type_embeddings: Option<candle_nn::Embedding>,
584 layer_norm: Option<LayerNorm>,
585 position_ids: Option<Tensor>,
586}
587
588impl Embeddings {
589 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
590 let (embeddings, token_type_embeddings, layer_norm, position_ids) = match cfg.variant {
591 ModelVariant::Large => (
592 candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?,
593 None,
594 None,
595 None,
596 ),
597 ModelVariant::Small => {
598 let vb = vb.pp("embeddings");
599 let weight = vb.pp("LayerNorm").get_with_hints(
600 cfg.hidden_size,
601 "weight",
602 candle_nn::Init::Const(1.0),
603 )?;
604 let bias = vb.pp("LayerNorm").get_with_hints(
605 cfg.hidden_size,
606 "bias",
607 candle_nn::Init::Const(0.0),
608 )?;
609 let dev = bias.device().clone();
610
611 let layer_norm = candle_nn::LayerNorm::new(weight, bias, cfg.norm_eps);
612
613 (
614 candle_nn::embedding(
615 cfg.vocab_size,
616 cfg.hidden_size,
617 vb.pp("word_embeddings"),
618 )?,
619 Some(candle_nn::embedding(
620 cfg.type_vocab_size,
621 cfg.hidden_size,
622 vb.pp("token_type_embeddings"),
623 )?),
624 Some(layer_norm),
625 Some(Tensor::arange(
626 0u32,
627 cfg.max_position_embeddings as u32,
628 &dev,
629 )?),
630 )
631 }
632 };
633
634 Ok(Self {
635 variant: cfg.variant,
636 embeddings,
637 token_type_embeddings,
638 layer_norm,
639 position_ids,
640 })
641 }
642}
643
644impl Module for Embeddings {
645 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
646 let embd = self.embeddings.forward(xs)?;
647 if self.variant == ModelVariant::Large {
649 return Ok(embd);
650 }
651
652 let (token_type_embed, layer_norm, pos_ids) =
653 if let (Some(token_type_embd), Some(layer_norm), Some(position_ids)) = (
654 &self.token_type_embeddings,
655 &self.layer_norm,
656 &self.position_ids,
657 ) {
658 (token_type_embd, layer_norm, position_ids)
659 } else {
660 return Err(Error::Msg(
661 "Stella 400M requires `token_type_embeddings`, `layer_norm` and `position_ids`"
662 .to_string(),
663 ));
664 };
665
666 let (batch_size, seq_length) = xs.dims2()?;
667
668 let pos_ids = pos_ids
669 .as_ref()
670 .narrow(0, 0, seq_length)?
671 .expand((batch_size, seq_length))?;
672
673 layer_norm.forward(&embd.add(&token_type_embed.forward(&pos_ids.zeros_like()?)?)?)
674 }
675}
676
677#[derive(Debug, Clone)]
678pub struct Model {
679 embeddings: Embeddings,
680 layers: Vec<Layer>,
681 norm: Option<RmsNorm>,
682 device: Device,
683 dtype: DType,
684}
685
686impl Model {
687 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
688 let vb_m = match cfg.variant {
689 ModelVariant::Large => vb.pp("model"),
690 ModelVariant::Small => vb.pp("new"),
691 };
692 let embeddings = Embeddings::new(cfg, vb_m.clone())?;
695 let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
696 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
697 let vb_l = match cfg.variant {
698 ModelVariant::Large => vb_m.pp("layers"),
699 ModelVariant::Small => vb_m.pp("encoder").pp("layer"),
700 };
701 for layer_idx in 0..cfg.num_hidden_layers {
702 let layer = Layer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
703 layers.push(layer)
704 }
705 let norm = match cfg.variant {
706 ModelVariant::Large => Some(RmsNorm::new(
707 cfg.hidden_size,
708 cfg.norm_eps,
709 vb_m.pp("norm"),
710 )?),
711 ModelVariant::Small => None,
712 };
713 Ok(Self {
714 embeddings,
715 layers,
716 norm,
717 device: vb.device().clone(),
718 dtype: vb.dtype(),
719 })
720 }
721
722 fn prepare_attention_mask(&self, attn_mask: &Tensor) -> Result<Tensor> {
723 let (b_sz, sql_len) = attn_mask.dims2()?;
724 let mut mask: Vec<Tensor> = vec![];
725 for b in 0..b_sz {
726 mask.push(attn_mask.i((b, ..))?.expand((1, 1, sql_len, sql_len))?);
727 }
728 let mask = Tensor::cat(&mask, 0)?;
729 let on_true = mask.zeros_like()?.to_dtype(self.dtype)?;
730 let on_false = Tensor::new(f32::NEG_INFINITY, &self.device)?
731 .broadcast_as(mask.shape())?
732 .to_dtype(self.dtype)?;
733 mask.where_cond(&on_true, &on_false)
734 }
735
736 pub fn forward(&mut self, input_ids: &Tensor, mask: &Tensor) -> Result<Tensor> {
737 let (_, seq_len) = input_ids.dims2()?;
738 let attention_mask = if seq_len <= 1 {
739 None
740 } else {
741 Some(self.prepare_attention_mask(mask)?)
743 };
744
745 let mut xs = self.embeddings.forward(input_ids)?;
746 for layer in self.layers.iter_mut() {
747 xs = layer.forward(&xs, attention_mask.as_ref())?
748 }
749
750 if let Some(n) = &self.norm {
751 xs.apply(n)
752 } else {
753 Ok(xs)
754 }
755 }
756}
757
758#[derive(Debug)]
759pub struct EmbeddingModel {
760 base_model: Model,
761 lm_head: Linear,
762}
763
764impl EmbeddingModel {
765 pub fn new(cfg: &Config, base_vb: VarBuilder, embed_vb: VarBuilder) -> Result<Self> {
766 let base_model = Model::new(cfg, base_vb.clone())?;
767 let lm_head = linear(
768 cfg.embed_head.in_features,
769 cfg.embed_head.out_features,
770 embed_vb.pp("linear"),
771 )?;
772
773 Ok(Self {
774 base_model,
775 lm_head,
776 })
777 }
778
779 pub fn forward(&mut self, input_ids: &Tensor, mask: &Tensor) -> Result<Tensor> {
780 let x = self.base_model.forward(input_ids, mask)?;
781 let x = self.pool(&x, mask)?;
782
783 self.lm_head.forward(&x.to_dtype(DType::F32)?) }
786
787 pub fn forward_norm(&mut self, input_ids: &Tensor, mask: &Tensor) -> Result<Tensor> {
789 let x = self.forward(input_ids, mask)?;
790 x.broadcast_div(&x.sqr()?.sum_keepdim(1)?.sqrt()?)
792 }
793
794 fn pool(&self, x: &Tensor, mask: &Tensor) -> Result<Tensor> {
795 let mask = mask.to_dtype(x.dtype())?; let (batch_size, seq_len, hidden_dim) = x.dims3()?;
797 let mask_expanded = mask
799 .unsqueeze(2)?
800 .broadcast_as((batch_size, seq_len, hidden_dim))?; let x = (x * &mask_expanded)?;
803
804 let sum_mask = mask
806 .sum(1)?
807 .unsqueeze(1)?
808 .expand((batch_size, hidden_dim))?;
809 x.sum(1)? / sum_mask
810 }
811}