1use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
7use candle::{DType, Device, Result, Tensor};
8use candle_nn::{Embedding, Module, VarBuilder};
9use serde::Deserialize;
10
11pub const DTYPE: DType = DType::F32;
12
13fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
14 let shape = mask.shape();
15 let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
16 let m = mask.where_cond(&on_true, on_false)?;
17 Ok(m)
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
21#[serde(rename_all = "lowercase")]
22enum HiddenAct {
23 Gelu,
24 Relu,
25}
26
27struct HiddenActLayer {
28 act: HiddenAct,
29 span: tracing::Span,
30}
31
32impl HiddenActLayer {
33 fn new(act: HiddenAct) -> Self {
34 let span = tracing::span!(tracing::Level::TRACE, "hidden-act");
35 Self { act, span }
36 }
37}
38
39impl Module for HiddenActLayer {
40 fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
41 let _enter = self.span.enter();
42 match self.act {
43 HiddenAct::Gelu => xs.gelu(),
45 HiddenAct::Relu => xs.relu(),
46 }
47 }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
51#[serde(rename_all = "lowercase")]
52enum PositionEmbeddingType {
53 #[default]
54 Absolute,
55}
56
57#[derive(Debug, Clone, PartialEq, Deserialize)]
58pub struct Config {
59 vocab_size: usize,
60 dim: usize,
61 n_layers: usize,
62 n_heads: usize,
63 hidden_dim: usize,
64 activation: HiddenAct,
65 max_position_embeddings: usize,
66 initializer_range: f64,
67 pad_token_id: usize,
68 #[serde(default)]
69 position_embedding_type: PositionEmbeddingType,
70 #[serde(default)]
71 use_cache: bool,
72 model_type: Option<String>,
73}
74
75impl Default for Config {
76 fn default() -> Self {
77 Self {
78 vocab_size: 30522,
79 dim: 768,
80 n_layers: 12,
81 n_heads: 12,
82 hidden_dim: 3072,
83 activation: HiddenAct::Gelu,
84 max_position_embeddings: 512,
85 initializer_range: 0.02,
86 pad_token_id: 0,
87 position_embedding_type: PositionEmbeddingType::Absolute,
88 use_cache: true,
89 model_type: Some("distilbert".to_string()),
90 }
91 }
92}
93
94struct Embeddings {
95 word_embeddings: Embedding,
96 position_embeddings: Embedding,
97 layer_norm: LayerNorm,
98 span: tracing::Span,
99}
100
101impl Embeddings {
102 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
103 let word_embeddings =
104 candle_nn::embedding(config.vocab_size, config.dim, vb.pp("word_embeddings"))?;
105 let position_embeddings = candle_nn::embedding(
106 config.max_position_embeddings,
107 config.dim,
108 vb.pp("position_embeddings"),
109 )?;
110 let layer_norm = layer_norm(config.dim, 1e-12, vb.pp("LayerNorm"))?;
111 Ok(Self {
112 word_embeddings,
113 position_embeddings,
114 layer_norm,
115 span: tracing::span!(tracing::Level::TRACE, "embeddings"),
116 })
117 }
118
119 fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
120 let _enter = self.span.enter();
121 let (_bsize, seq_len) = input_ids.dims2()?;
122 let input_embeddings = self.word_embeddings.forward(input_ids)?;
123 let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
124 let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
125 let embeddings =
126 input_embeddings.broadcast_add(&self.position_embeddings.forward(&position_ids)?)?;
127
128 let embeddings = self.layer_norm.forward(&embeddings)?;
129 Ok(embeddings)
130 }
131}
132
133struct MultiHeadSelfAttention {
134 q_lin: Linear,
135 k_lin: Linear,
136 v_lin: Linear,
137 out_lin: Linear,
138 n_heads: usize,
139 attention_head_size: usize,
140 span: tracing::Span,
141}
142
143impl MultiHeadSelfAttention {
144 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
145 let attention_head_size = config.dim / config.n_heads;
146 let all_head_size = config.n_heads * attention_head_size;
147 let dim = config.dim;
148 let q_lin = linear(dim, all_head_size, vb.pp("q_lin"))?;
149 let v_lin = linear(dim, all_head_size, vb.pp("v_lin"))?;
150 let k_lin = linear(dim, all_head_size, vb.pp("k_lin"))?;
151 let out_lin = linear(all_head_size, dim, vb.pp("out_lin"))?;
152 Ok(Self {
153 q_lin,
154 k_lin,
155 v_lin,
156 out_lin,
157 n_heads: config.n_heads,
158 attention_head_size,
159 span: tracing::span!(tracing::Level::TRACE, "attention"),
160 })
161 }
162}
163
164impl MultiHeadSelfAttention {
165 fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
166 let _enter = self.span.enter();
167 let (bs, q_length, _dim) = hidden_states.dims3()?;
168
169 let dim_per_head = self.attention_head_size;
170 let q = self.q_lin.forward(hidden_states)?;
171 let k = self.k_lin.forward(hidden_states)?;
172 let v = self.v_lin.forward(hidden_states)?;
173
174 let q = q
175 .reshape((bs, q_length, self.n_heads, dim_per_head))?
176 .transpose(1, 2)?;
177 let k = k
178 .reshape((bs, q_length, self.n_heads, dim_per_head))?
179 .transpose(1, 2)?;
180 let v = v
181 .reshape((bs, q_length, self.n_heads, dim_per_head))?
182 .transpose(1, 2)?;
183
184 let q: Tensor = (q / (dim_per_head as f64).sqrt())?;
185 let scores = q.matmul(&k.transpose(2, 3)?.contiguous()?)?;
186 let mask = attention_mask.broadcast_as(scores.shape())?;
187
188 let scores = masked_fill(&scores.to_dtype(DType::F32)?, &mask, f32::NEG_INFINITY)?;
189 let weights = candle_nn::ops::softmax(&scores, candle::D::Minus1)?;
190
191 let context = weights.matmul(&v.contiguous()?)?;
192 let context = context
193 .transpose(1, 2)?
194 .reshape((bs, q_length, self.n_heads * dim_per_head))?
195 .contiguous()?;
196 let context = self.out_lin.forward(&context)?;
197
198 Ok(context)
199 }
200}
201
202#[allow(clippy::upper_case_acronyms)]
203struct FFN {
204 lin1: Linear,
205 lin2: Linear,
206 activation: HiddenActLayer,
207 span: tracing::Span,
208}
209
210impl FFN {
211 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
212 let lin1 = linear(config.dim, config.hidden_dim, vb.pp("lin1"))?;
213 let lin2 = linear(config.hidden_dim, config.dim, vb.pp("lin2"))?;
214 Ok(Self {
215 lin1,
216 lin2,
217 activation: HiddenActLayer::new(config.activation),
218 span: tracing::span!(tracing::Level::TRACE, "ffn"),
219 })
220 }
221}
222
223impl Module for FFN {
224 fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
225 let _enter = self.span.enter();
226 hidden_states
227 .apply(&self.lin1)?
228 .apply(&self.activation)?
229 .apply(&self.lin2)
230 }
231}
232
233struct TransformerBlock {
234 attention: MultiHeadSelfAttention,
235 sa_layer_norm: LayerNorm,
236 ffn: FFN,
237 output_layer_norm: LayerNorm,
238 span: tracing::Span,
239}
240
241impl TransformerBlock {
242 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
243 let attention = MultiHeadSelfAttention::load(vb.pp("attention"), config)?;
244 let sa_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("sa_layer_norm"))?;
245 let ffn = FFN::load(vb.pp("ffn"), config)?;
246 let output_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("output_layer_norm"))?;
247 Ok(Self {
248 attention,
249 sa_layer_norm,
250 ffn,
251 output_layer_norm,
252 span: tracing::span!(tracing::Level::TRACE, "layer"),
253 })
254 }
255}
256
257impl TransformerBlock {
258 fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
259 let _enter = self.span.enter();
260 let sa_output = self.attention.forward(hidden_states, attention_mask)?;
261 let sa_output = sa_output.broadcast_add(hidden_states)?;
265 let sa_output = self.sa_layer_norm.forward(&sa_output)?;
266
267 let ffn_output = self.ffn.forward(&sa_output)?;
268 let ffn_output = (&ffn_output + sa_output)?;
269 let output = self.output_layer_norm.forward(&ffn_output)?;
270 Ok(output)
271 }
272}
273
274struct Transformer {
276 layers: Vec<TransformerBlock>,
277 span: tracing::Span,
278}
279
280impl Transformer {
281 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
282 let layers = (0..config.n_layers)
283 .map(|index| TransformerBlock::load(vb.pp(format!("layer.{index}")), config))
284 .collect::<Result<Vec<_>>>()?;
285 let span = tracing::span!(tracing::Level::TRACE, "encoder");
286 Ok(Transformer { layers, span })
287 }
288}
289
290impl Transformer {
291 fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
292 let _enter = self.span.enter();
293 let mut hidden_states = hidden_states.clone();
294 for layer in self.layers.iter() {
296 hidden_states = layer.forward(&hidden_states, attention_mask)?;
297 }
298 Ok(hidden_states)
299 }
300}
301
302pub struct DistilBertModel {
303 embeddings: Embeddings,
304 transformer: Transformer,
305 pub device: Device,
306 span: tracing::Span,
307}
308
309impl DistilBertModel {
310 pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
311 let (embeddings, transformer) = match (
312 Embeddings::load(vb.pp("embeddings"), config),
313 Transformer::load(vb.pp("transformer"), config),
314 ) {
315 (Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
316 (Err(err), _) | (_, Err(err)) => {
317 if let Some(model_type) = &config.model_type {
318 if let (Ok(embeddings), Ok(encoder)) = (
319 Embeddings::load(vb.pp(format!("{model_type}.embeddings")), config),
320 Transformer::load(vb.pp(format!("{model_type}.transformer")), config),
321 ) {
322 (embeddings, encoder)
323 } else {
324 return Err(err);
325 }
326 } else {
327 return Err(err);
328 }
329 }
330 };
331 Ok(Self {
332 embeddings,
333 transformer,
334 device: vb.device().clone(),
335 span: tracing::span!(tracing::Level::TRACE, "model"),
336 })
337 }
338
339 pub fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
340 let _enter = self.span.enter();
341 let embedding_output = self.embeddings.forward(input_ids)?;
342 let sequence_output = self
343 .transformer
344 .forward(&embedding_output, attention_mask)?;
345 Ok(sequence_output)
346 }
347}