candle_transformers/models/nvembed_v2/
model.rs1use super::embedding::Model as EmbeddingModel;
2use crate::models::{
3 mistral::Config,
4 with_tracing::{layer_norm, linear, linear_no_bias, LayerNorm, Linear},
5};
6use candle::{DType, Device, Result, Tensor, D};
7use candle_nn::{ops::softmax_last_dim, LayerNormConfig, Module, VarBuilder};
8
9#[derive(Debug)]
11struct GeGlu {
12 proj: Linear,
13 span: tracing::Span,
14}
15
16impl GeGlu {
17 fn new(vs: VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> {
18 let proj = linear(dim_in, dim_out * 2, vs)?;
19 let span = tracing::span!(tracing::Level::TRACE, "geglu");
20 Ok(Self { proj, span })
21 }
22}
23
24impl Module for GeGlu {
25 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
26 let _enter = self.span.enter();
27 let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
28 &hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()?
29 }
30}
31
32#[derive(Debug)]
33struct FeedForward {
34 project_in: GeGlu,
35 linear: Linear,
36 span: tracing::Span,
37}
38
39impl FeedForward {
40 fn new(vs: VarBuilder, dim: usize, dim_out: Option<usize>, mult: usize) -> Result<Self> {
41 let inner_dim = dim * mult;
42 let dim_out = dim_out.unwrap_or(dim);
43 let vs = vs.pp("net");
44 let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?;
45 let linear = linear(inner_dim, dim_out, vs.pp("2"))?;
46 let span = tracing::span!(tracing::Level::TRACE, "ff");
47 Ok(Self {
48 project_in,
49 linear,
50 span,
51 })
52 }
53}
54
55impl Module for FeedForward {
56 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
57 let _enter = self.span.enter();
58 let xs = self.project_in.forward(xs)?;
59 self.linear.forward(&xs)
60 }
61}
62
63#[derive(Debug)]
65struct CrossAttention {
66 to_q: Linear,
67 to_kv: Linear,
68 to_out: Linear,
69 heads: usize,
70 scale: f64,
71 span: tracing::Span,
72 span_attn: tracing::Span,
73 span_softmax: tracing::Span,
74}
75
76impl CrossAttention {
77 fn new(
78 vs: VarBuilder,
79 query_dim: usize,
80 context_dim: Option<usize>,
81 heads: usize,
82 dim_head: usize,
83 ) -> Result<Self> {
84 let inner_dim = dim_head * heads;
85 let context_dim = context_dim.unwrap_or(query_dim);
86 let scale = 1.0 / f64::sqrt(dim_head as f64);
87 let to_q = linear_no_bias(query_dim, inner_dim, vs.pp("to_q"))?;
88 let to_kv = linear_no_bias(context_dim, inner_dim * 2, vs.pp("to_kv"))?;
89 let to_out = linear_no_bias(inner_dim, query_dim, vs.pp("to_out"))?;
90 let span = tracing::span!(tracing::Level::TRACE, "xa");
91 let span_attn = tracing::span!(tracing::Level::TRACE, "xa-attn");
92 let span_softmax = tracing::span!(tracing::Level::TRACE, "xa-softmax");
93 Ok(Self {
94 to_q,
95 to_kv,
96 to_out,
97 heads,
98 scale,
99 span,
100 span_attn,
101 span_softmax,
102 })
103 }
104
105 fn reshape_heads_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> {
106 let (batch_size, seq_len, dim) = xs.dims3()?;
107 xs.reshape((batch_size, seq_len, self.heads, dim / self.heads))?
108 .transpose(1, 2)?
109 .reshape((batch_size * self.heads, seq_len, dim / self.heads))
110 }
111
112 fn reshape_batch_dim_to_heads(&self, xs: &Tensor) -> Result<Tensor> {
113 let (batch_size, seq_len, dim) = xs.dims3()?;
114 xs.reshape((batch_size / self.heads, self.heads, seq_len, dim))?
115 .transpose(1, 2)?
116 .reshape((batch_size / self.heads, seq_len, dim * self.heads))
117 }
118
119 fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
120 let _enter = self.span_attn.enter();
121
122 let in_dtype = query.dtype();
123 let query = query.to_dtype(DType::F32)?;
124 let key = key.to_dtype(DType::F32)?;
125 let value = value.to_dtype(DType::F32)?;
126 let xs = query.matmul(&(key.t()? * self.scale)?)?;
127 let xs = {
128 let _enter = self.span_softmax.enter();
129 softmax_last_dim(&xs)?
130 };
131 let xs = xs.matmul(&value)?.to_dtype(in_dtype)?;
132
133 self.reshape_batch_dim_to_heads(&xs)
134 }
135
136 fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
137 let _enter = self.span.enter();
138 let query = self.to_q.forward(xs)?;
139 let context = context.unwrap_or(xs).contiguous()?;
140 let kv_chunks = self
141 .to_kv
142 .forward(&context)?
143 .chunk(2, context.shape().dims().len() - 1)?;
144 let (key, value) = (kv_chunks[0].clone(), kv_chunks[1].clone());
145 let query = self.reshape_heads_to_batch_dim(&query)?;
146 let key = self.reshape_heads_to_batch_dim(&key)?;
147 let value = self.reshape_heads_to_batch_dim(&value)?;
148
149 let xs = self.attention(&query, &key, &value)?;
150 self.to_out.forward(&xs)
151 }
152}
153
154#[derive(Debug)]
155pub struct Model {
156 embedding_model: EmbeddingModel,
157 cross_attn: CrossAttention,
158 cross_attn_norm: LayerNorm,
159 cross_attn_context_norm: LayerNorm,
160 ff: FeedForward,
161 ff_norm: LayerNorm,
162 latents: Tensor,
163 pub device: Device,
164 pub dtype: DType,
165}
166
167impl Model {
168 pub fn new(vb: VarBuilder) -> Result<Self> {
169 let cfg = Config::config_7b_v0_1(false);
171 let embedding_model = EmbeddingModel::new(&cfg, vb.pp("embedding_model"))?;
172
173 let dim = 4096;
175 let vb = vb.pp("latent_attention_model");
176 let latents = vb.get((512, dim), "latents")?;
177
178 let vb = vb.pp("cross_attend_blocks");
180 let cross_attn_norm = layer_norm(dim, LayerNormConfig::default(), vb.pp("0.norm"))?;
181 let cross_attn_context_norm = layer_norm(
182 dim,
183 candle_nn::LayerNormConfig::default(),
184 vb.pp("0.norm_context"),
185 )?;
186 let cross_attn = CrossAttention::new(vb.pp("0.fn"), dim, None, 8, 4096)?;
187
188 let ff_norm = layer_norm(dim, LayerNormConfig::default(), vb.pp("1.norm"))?;
189 let ff = FeedForward::new(vb.pp("1.fn"), dim, None, 4)?;
190
191 Ok(Self {
192 embedding_model,
193 cross_attn,
194 cross_attn_norm,
195 cross_attn_context_norm,
196 ff,
197 ff_norm,
198 latents,
199 device: vb.device().clone(),
200 dtype: vb.dtype(),
201 })
202 }
203
204 pub fn forward(
205 &mut self,
206 input_ids: &Tensor,
207 attn_mask: &Tensor,
208 pool_mask: &Tensor,
209 ) -> Result<Tensor> {
210 let hiddens = self
212 .embedding_model
213 .forward(attn_mask, input_ids, self.dtype)?;
214
215 let b = hiddens.dims()[0];
217 let x = self.latents.unsqueeze(0)?.repeat((b, 1, 1))?;
218 let original_hiddens = &hiddens;
219
220 let hiddens = self.cross_attn_norm.forward(original_hiddens)?;
221 let x = self.cross_attn_context_norm.forward(&x)?;
222 let cross_hiddens = (self.cross_attn.forward(&hiddens, Some(&x))? + original_hiddens)?;
223
224 let hiddens = self.ff_norm.forward(&cross_hiddens)?;
225 let hiddens = (self.ff.forward(&hiddens)? + cross_hiddens)?;
226
227 let hiddens_masked = hiddens.broadcast_mul(&pool_mask.unsqueeze(D::Minus1)?)?;
229 let s = hiddens_masked.sum(1)?;
230 let d = pool_mask.sum_keepdim(1)?;
231 s.broadcast_div(&d)
232 }
233}