candle_transformers/models/nvembed_v2/
model.rs

1use super::embedding::Model as EmbeddingModel;
2use crate::models::{
3    mistral::Config,
4    with_tracing::{layer_norm, linear, linear_no_bias, LayerNorm, Linear},
5};
6use candle::{DType, Device, Result, Tensor, D};
7use candle_nn::{ops::softmax_last_dim, LayerNormConfig, Module, VarBuilder};
8
9// Geglu and feedforward from candle-transformers/src/models/stable_diffusion/attention.rs
10#[derive(Debug)]
11struct GeGlu {
12    proj: Linear,
13    span: tracing::Span,
14}
15
16impl GeGlu {
17    fn new(vs: VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> {
18        let proj = linear(dim_in, dim_out * 2, vs)?;
19        let span = tracing::span!(tracing::Level::TRACE, "geglu");
20        Ok(Self { proj, span })
21    }
22}
23
24impl Module for GeGlu {
25    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
26        let _enter = self.span.enter();
27        let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
28        &hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()?
29    }
30}
31
32#[derive(Debug)]
33struct FeedForward {
34    project_in: GeGlu,
35    linear: Linear,
36    span: tracing::Span,
37}
38
39impl FeedForward {
40    fn new(vs: VarBuilder, dim: usize, dim_out: Option<usize>, mult: usize) -> Result<Self> {
41        let inner_dim = dim * mult;
42        let dim_out = dim_out.unwrap_or(dim);
43        let vs = vs.pp("net");
44        let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?;
45        let linear = linear(inner_dim, dim_out, vs.pp("2"))?;
46        let span = tracing::span!(tracing::Level::TRACE, "ff");
47        Ok(Self {
48            project_in,
49            linear,
50            span,
51        })
52    }
53}
54
55impl Module for FeedForward {
56    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
57        let _enter = self.span.enter();
58        let xs = self.project_in.forward(xs)?;
59        self.linear.forward(&xs)
60    }
61}
62
63// CrossAttention from candle-transformers/src/models/stable_diffusion/attention.rs
64#[derive(Debug)]
65struct CrossAttention {
66    to_q: Linear,
67    to_kv: Linear,
68    to_out: Linear,
69    heads: usize,
70    scale: f64,
71    span: tracing::Span,
72    span_attn: tracing::Span,
73    span_softmax: tracing::Span,
74}
75
76impl CrossAttention {
77    fn new(
78        vs: VarBuilder,
79        query_dim: usize,
80        context_dim: Option<usize>,
81        heads: usize,
82        dim_head: usize,
83    ) -> Result<Self> {
84        let inner_dim = dim_head * heads;
85        let context_dim = context_dim.unwrap_or(query_dim);
86        let scale = 1.0 / f64::sqrt(dim_head as f64);
87        let to_q = linear_no_bias(query_dim, inner_dim, vs.pp("to_q"))?;
88        let to_kv = linear_no_bias(context_dim, inner_dim * 2, vs.pp("to_kv"))?;
89        let to_out = linear_no_bias(inner_dim, query_dim, vs.pp("to_out"))?;
90        let span = tracing::span!(tracing::Level::TRACE, "xa");
91        let span_attn = tracing::span!(tracing::Level::TRACE, "xa-attn");
92        let span_softmax = tracing::span!(tracing::Level::TRACE, "xa-softmax");
93        Ok(Self {
94            to_q,
95            to_kv,
96            to_out,
97            heads,
98            scale,
99            span,
100            span_attn,
101            span_softmax,
102        })
103    }
104
105    fn reshape_heads_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> {
106        let (batch_size, seq_len, dim) = xs.dims3()?;
107        xs.reshape((batch_size, seq_len, self.heads, dim / self.heads))?
108            .transpose(1, 2)?
109            .reshape((batch_size * self.heads, seq_len, dim / self.heads))
110    }
111
112    fn reshape_batch_dim_to_heads(&self, xs: &Tensor) -> Result<Tensor> {
113        let (batch_size, seq_len, dim) = xs.dims3()?;
114        xs.reshape((batch_size / self.heads, self.heads, seq_len, dim))?
115            .transpose(1, 2)?
116            .reshape((batch_size / self.heads, seq_len, dim * self.heads))
117    }
118
119    fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
120        let _enter = self.span_attn.enter();
121
122        let in_dtype = query.dtype();
123        let query = query.to_dtype(DType::F32)?;
124        let key = key.to_dtype(DType::F32)?;
125        let value = value.to_dtype(DType::F32)?;
126        let xs = query.matmul(&(key.t()? * self.scale)?)?;
127        let xs = {
128            let _enter = self.span_softmax.enter();
129            softmax_last_dim(&xs)?
130        };
131        let xs = xs.matmul(&value)?.to_dtype(in_dtype)?;
132
133        self.reshape_batch_dim_to_heads(&xs)
134    }
135
136    fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
137        let _enter = self.span.enter();
138        let query = self.to_q.forward(xs)?;
139        let context = context.unwrap_or(xs).contiguous()?;
140        let kv_chunks = self
141            .to_kv
142            .forward(&context)?
143            .chunk(2, context.shape().dims().len() - 1)?;
144        let (key, value) = (kv_chunks[0].clone(), kv_chunks[1].clone());
145        let query = self.reshape_heads_to_batch_dim(&query)?;
146        let key = self.reshape_heads_to_batch_dim(&key)?;
147        let value = self.reshape_heads_to_batch_dim(&value)?;
148
149        let xs = self.attention(&query, &key, &value)?;
150        self.to_out.forward(&xs)
151    }
152}
153
154#[derive(Debug)]
155pub struct Model {
156    embedding_model: EmbeddingModel,
157    cross_attn: CrossAttention,
158    cross_attn_norm: LayerNorm,
159    cross_attn_context_norm: LayerNorm,
160    ff: FeedForward,
161    ff_norm: LayerNorm,
162    latents: Tensor,
163    pub device: Device,
164    pub dtype: DType,
165}
166
167impl Model {
168    pub fn new(vb: VarBuilder) -> Result<Self> {
169        // Embedding model
170        let cfg = Config::config_7b_v0_1(false);
171        let embedding_model = EmbeddingModel::new(&cfg, vb.pp("embedding_model"))?;
172
173        // Latent attention
174        let dim = 4096;
175        let vb = vb.pp("latent_attention_model");
176        let latents = vb.get((512, dim), "latents")?;
177
178        // Cross attend blocks
179        let vb = vb.pp("cross_attend_blocks");
180        let cross_attn_norm = layer_norm(dim, LayerNormConfig::default(), vb.pp("0.norm"))?;
181        let cross_attn_context_norm = layer_norm(
182            dim,
183            candle_nn::LayerNormConfig::default(),
184            vb.pp("0.norm_context"),
185        )?;
186        let cross_attn = CrossAttention::new(vb.pp("0.fn"), dim, None, 8, 4096)?;
187
188        let ff_norm = layer_norm(dim, LayerNormConfig::default(), vb.pp("1.norm"))?;
189        let ff = FeedForward::new(vb.pp("1.fn"), dim, None, 4)?;
190
191        Ok(Self {
192            embedding_model,
193            cross_attn,
194            cross_attn_norm,
195            cross_attn_context_norm,
196            ff,
197            ff_norm,
198            latents,
199            device: vb.device().clone(),
200            dtype: vb.dtype(),
201        })
202    }
203
204    pub fn forward(
205        &mut self,
206        input_ids: &Tensor,
207        attn_mask: &Tensor,
208        pool_mask: &Tensor,
209    ) -> Result<Tensor> {
210        // Embedding model
211        let hiddens = self
212            .embedding_model
213            .forward(attn_mask, input_ids, self.dtype)?;
214
215        // Latent attention
216        let b = hiddens.dims()[0];
217        let x = self.latents.unsqueeze(0)?.repeat((b, 1, 1))?;
218        let original_hiddens = &hiddens;
219
220        let hiddens = self.cross_attn_norm.forward(original_hiddens)?;
221        let x = self.cross_attn_context_norm.forward(&x)?;
222        let cross_hiddens = (self.cross_attn.forward(&hiddens, Some(&x))? + original_hiddens)?;
223
224        let hiddens = self.ff_norm.forward(&cross_hiddens)?;
225        let hiddens = (self.ff.forward(&hiddens)? + cross_hiddens)?;
226
227        // Mean pooling
228        let hiddens_masked = hiddens.broadcast_mul(&pool_mask.unsqueeze(D::Minus1)?)?;
229        let s = hiddens_masked.sum(1)?;
230        let d = pool_mask.sum_keepdim(1)?;
231        s.broadcast_div(&d)
232    }
233}