candle_transformers/models/openclip/
text_model.rs1use candle::{DType, IndexOp, Result, Tensor, D};
5use candle_nn::{
6 embedding, layer_norm, linear, ops::softmax_last_dim, Embedding, LayerNorm, Linear, Module,
7 VarBuilder,
8};
9
10#[derive(Debug, Clone)]
11pub struct Config {
12 pub vocab_size: usize,
13 pub embed_dim: usize,
14 pub intermediate_size: usize,
15 pub max_position_embeddings: usize,
16 pub pad_with: Option<String>,
17 pub num_hidden_layers: usize,
18 pub num_attention_heads: usize,
19 pub projection_dim: usize,
20}
21
22impl Config {
23 pub fn vit_base_patch32() -> Self {
24 Self {
25 vocab_size: 49408,
26 embed_dim: 512,
27 intermediate_size: 2048,
28 max_position_embeddings: 77,
29 pad_with: None,
30 num_hidden_layers: 12,
31 num_attention_heads: 8,
32 projection_dim: 512,
33 }
34 }
35}
36
37#[derive(Clone, Debug)]
38struct TextEmbeddings {
39 token_embedding: Embedding,
40 position_embedding: Tensor,
41}
42
43impl TextEmbeddings {
44 fn new(vs: VarBuilder, c: &Config) -> Result<Self> {
45 let token_embedding = embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
46 let position_embedding = vs.get(
47 (c.max_position_embeddings, c.embed_dim),
48 "positional_embedding",
49 )?;
50 Ok(TextEmbeddings {
51 token_embedding,
52 position_embedding,
53 })
54 }
55}
56
57impl Module for TextEmbeddings {
58 fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
59 let seq_length = input_ids.dim(D::Minus1)?;
60 let inputs_embeds = self.token_embedding.forward(input_ids)?;
61
62 let position_embedding = self.position_embedding.narrow(0, 0, seq_length)?;
63
64 inputs_embeds.broadcast_add(&position_embedding)
65 }
66}
67
68#[derive(Clone, Debug)]
69struct Attention {
70 k_proj: candle_nn::Linear,
71 v_proj: candle_nn::Linear,
72 q_proj: candle_nn::Linear,
73 out_proj: Linear,
74 head_dim: usize,
75 scale: f64,
76 num_attention_heads: usize,
77}
78
79impl Attention {
80 fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
81 let embed_dim = c.embed_dim;
82 let num_attention_heads = c.num_attention_heads;
83
84 let in_proj_weights = vs
85 .get((embed_dim * 3, embed_dim), "in_proj_weight")?
86 .chunk(3, 0)?;
87 let (q_w, k_w, v_w) = (
88 &in_proj_weights[0],
89 &in_proj_weights[1],
90 &in_proj_weights[2],
91 );
92 let in_proj_biases = vs.get(embed_dim * 3, "in_proj_bias")?.chunk(3, 0)?;
93 let (q_b, k_b, v_b) = (&in_proj_biases[0], &in_proj_biases[1], &in_proj_biases[2]);
94
95 let q_proj = Linear::new(q_w.clone(), Some(q_b.clone()));
96 let k_proj = Linear::new(k_w.clone(), Some(k_b.clone()));
97 let v_proj = Linear::new(v_w.clone(), Some(v_b.clone()));
98 let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?;
99 let head_dim = embed_dim / num_attention_heads;
100 let scale = (head_dim as f64).powf(-0.5);
101
102 Ok(Attention {
103 k_proj,
104 v_proj,
105 q_proj,
106 out_proj,
107 head_dim,
108 scale,
109 num_attention_heads,
110 })
111 }
112
113 fn shape_multihead(&self, xs: &Tensor, bsz: usize, seq_len: usize) -> Result<Tensor> {
114 xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
115 .transpose(1, 2)?
116 .contiguous()?
117 .to_dtype(DType::F32)
118 }
119
120 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
121 let in_dtype = xs.dtype();
122 let (bsz, seq_len, embed_dim) = xs.dims3()?;
123
124 let q = self.shape_multihead(&self.q_proj.forward(xs)?, bsz, seq_len)?;
125 let k = self.shape_multihead(&self.k_proj.forward(xs)?, bsz, seq_len)?;
126 let v = self.shape_multihead(&self.v_proj.forward(xs)?, bsz, seq_len)?;
127 let q = (q * self.scale)?;
128
129 let attn_weights = q.matmul(&k.transpose(D::Minus1, D::Minus2)?)?;
130
131 let attn_weights = softmax_last_dim(&attn_weights)?;
132
133 let attn_output = attn_weights.matmul(&v)?.to_dtype(in_dtype)?;
134 let attn_output = attn_output
135 .transpose(1, 2)?
136 .contiguous()?
137 .reshape((bsz, seq_len, embed_dim))?;
138 let out = self.out_proj.forward(&attn_output)?;
139 Ok(out)
140 }
141}
142
143#[derive(Clone, Debug)]
144struct Mlp {
145 fc1: Linear,
146 fc2: Linear,
147}
148
149impl Mlp {
150 fn new(vs: VarBuilder, c: &Config) -> Result<Self> {
151 let fc1 = linear(c.embed_dim, c.intermediate_size, vs.pp("c_fc"))?;
152 let fc2 = linear(c.intermediate_size, c.embed_dim, vs.pp("c_proj"))?;
153
154 Ok(Mlp { fc1, fc2 })
155 }
156}
157
158impl Mlp {
159 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
160 let xs = self.fc1.forward(xs)?;
161 self.fc2.forward(&xs.gelu_erf()?)
162 }
163}
164
165#[derive(Clone, Debug)]
166struct EncoderLayer {
167 self_attn: Attention,
168 layer_norm1: LayerNorm,
169 mlp: Mlp,
170 layer_norm2: LayerNorm,
171}
172
173impl EncoderLayer {
174 fn new(vs: VarBuilder, c: &Config) -> Result<Self> {
175 let self_attn = Attention::new(vs.pp("attn"), c)?;
176 let layer_norm1 = layer_norm(c.embed_dim, 1e-5, vs.pp("ln_1"))?;
177 let mlp = Mlp::new(vs.pp("mlp"), c)?;
178 let layer_norm2 = layer_norm(c.embed_dim, 1e-5, vs.pp("ln_2"))?;
179
180 Ok(EncoderLayer {
181 self_attn,
182 layer_norm1,
183 mlp,
184 layer_norm2,
185 })
186 }
187
188 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
189 let residual = xs;
190 let xs = self.layer_norm1.forward(xs)?;
191 let xs = self.self_attn.forward(&xs)?;
192 let xs = (xs + residual)?;
193
194 let residual = &xs;
195 let xs = self.layer_norm2.forward(&xs)?;
196 let xs = self.mlp.forward(&xs)?;
197 let out = (xs + residual)?;
198 Ok(out)
199 }
200}
201
202#[derive(Clone, Debug)]
203pub struct Encoder {
204 layers: Vec<EncoderLayer>,
205}
206
207impl Encoder {
208 pub fn new(vs: VarBuilder, c: &Config) -> Result<Self> {
209 let vs = vs.pp("resblocks");
210 let mut layers: Vec<EncoderLayer> = Vec::new();
211 for index in 0..c.num_hidden_layers {
212 let layer = EncoderLayer::new(vs.pp(index.to_string()), c)?;
213 layers.push(layer)
214 }
215 Ok(Encoder { layers })
216 }
217
218 pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
219 let mut xs = xs.clone();
220 for layer in self.layers.iter() {
221 xs = layer.forward(&xs)?;
222 }
223 Ok(xs)
224 }
225}
226
227#[derive(Clone, Debug)]
229pub struct OpenClipTextTransformer {
230 embeddings: TextEmbeddings,
231 encoder: Encoder,
232 final_layer_norm: LayerNorm,
233}
234
235impl OpenClipTextTransformer {
236 pub fn new(vs: VarBuilder, c: &Config) -> Result<Self> {
237 let embeddings = TextEmbeddings::new(vs.clone(), c)?;
238 let final_layer_norm = layer_norm(c.embed_dim, 1e-5, vs.pp("ln_final"))?;
239 let encoder = Encoder::new(vs.pp("transformer"), c)?;
240 Ok(OpenClipTextTransformer {
241 embeddings,
242 encoder,
243 final_layer_norm,
244 })
245 }
246
247 pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
248 let input_ids = self.embeddings.forward(input_ids)?;
249 let input_ids = self.encoder.forward(&input_ids)?;
250 self.final_layer_norm.forward(&input_ids)
251 }
252}
253
254impl Module for OpenClipTextTransformer {
255 fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
256 let output = self.forward(input_ids)?;
257 let sequence_max_indices = input_ids.argmax(D::Minus1)?.to_dtype(DType::I64)?;
258
259 let mut indices = Vec::new();
260 for (batch_idx, &seq_idx) in sequence_max_indices.to_vec1::<i64>()?.iter().enumerate() {
261 let index = output.i((batch_idx, seq_idx as usize))?.unsqueeze(0)?;
262 indices.push(index);
263 }
264 Tensor::cat(&indices, 0)
265 }
266}