candle_transformers/models/openclip/
text_model.rs

1//! Text encoder as used in most OpenCLIP pretrained models
2//! https://github.com/mlfoundations/open_clip
3
4use candle::{DType, IndexOp, Result, Tensor, D};
5use candle_nn::{
6    embedding, layer_norm, linear, ops::softmax_last_dim, Embedding, LayerNorm, Linear, Module,
7    VarBuilder,
8};
9
10#[derive(Debug, Clone)]
11pub struct Config {
12    pub vocab_size: usize,
13    pub embed_dim: usize,
14    pub intermediate_size: usize,
15    pub max_position_embeddings: usize,
16    pub pad_with: Option<String>,
17    pub num_hidden_layers: usize,
18    pub num_attention_heads: usize,
19    pub projection_dim: usize,
20}
21
22impl Config {
23    pub fn vit_base_patch32() -> Self {
24        Self {
25            vocab_size: 49408,
26            embed_dim: 512,
27            intermediate_size: 2048,
28            max_position_embeddings: 77,
29            pad_with: None,
30            num_hidden_layers: 12,
31            num_attention_heads: 8,
32            projection_dim: 512,
33        }
34    }
35}
36
37#[derive(Clone, Debug)]
38struct TextEmbeddings {
39    token_embedding: Embedding,
40    position_embedding: Tensor,
41}
42
43impl TextEmbeddings {
44    fn new(vs: VarBuilder, c: &Config) -> Result<Self> {
45        let token_embedding = embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
46        let position_embedding = vs.get(
47            (c.max_position_embeddings, c.embed_dim),
48            "positional_embedding",
49        )?;
50        Ok(TextEmbeddings {
51            token_embedding,
52            position_embedding,
53        })
54    }
55}
56
57impl Module for TextEmbeddings {
58    fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
59        let seq_length = input_ids.dim(D::Minus1)?;
60        let inputs_embeds = self.token_embedding.forward(input_ids)?;
61
62        let position_embedding = self.position_embedding.narrow(0, 0, seq_length)?;
63
64        inputs_embeds.broadcast_add(&position_embedding)
65    }
66}
67
68#[derive(Clone, Debug)]
69struct Attention {
70    k_proj: candle_nn::Linear,
71    v_proj: candle_nn::Linear,
72    q_proj: candle_nn::Linear,
73    out_proj: Linear,
74    head_dim: usize,
75    scale: f64,
76    num_attention_heads: usize,
77}
78
79impl Attention {
80    fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
81        let embed_dim = c.embed_dim;
82        let num_attention_heads = c.num_attention_heads;
83
84        let in_proj_weights = vs
85            .get((embed_dim * 3, embed_dim), "in_proj_weight")?
86            .chunk(3, 0)?;
87        let (q_w, k_w, v_w) = (
88            &in_proj_weights[0],
89            &in_proj_weights[1],
90            &in_proj_weights[2],
91        );
92        let in_proj_biases = vs.get(embed_dim * 3, "in_proj_bias")?.chunk(3, 0)?;
93        let (q_b, k_b, v_b) = (&in_proj_biases[0], &in_proj_biases[1], &in_proj_biases[2]);
94
95        let q_proj = Linear::new(q_w.clone(), Some(q_b.clone()));
96        let k_proj = Linear::new(k_w.clone(), Some(k_b.clone()));
97        let v_proj = Linear::new(v_w.clone(), Some(v_b.clone()));
98        let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?;
99        let head_dim = embed_dim / num_attention_heads;
100        let scale = (head_dim as f64).powf(-0.5);
101
102        Ok(Attention {
103            k_proj,
104            v_proj,
105            q_proj,
106            out_proj,
107            head_dim,
108            scale,
109            num_attention_heads,
110        })
111    }
112
113    fn shape_multihead(&self, xs: &Tensor, bsz: usize, seq_len: usize) -> Result<Tensor> {
114        xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
115            .transpose(1, 2)?
116            .contiguous()?
117            .to_dtype(DType::F32)
118    }
119
120    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
121        let in_dtype = xs.dtype();
122        let (bsz, seq_len, embed_dim) = xs.dims3()?;
123
124        let q = self.shape_multihead(&self.q_proj.forward(xs)?, bsz, seq_len)?;
125        let k = self.shape_multihead(&self.k_proj.forward(xs)?, bsz, seq_len)?;
126        let v = self.shape_multihead(&self.v_proj.forward(xs)?, bsz, seq_len)?;
127        let q = (q * self.scale)?;
128
129        let attn_weights = q.matmul(&k.transpose(D::Minus1, D::Minus2)?)?;
130
131        let attn_weights = softmax_last_dim(&attn_weights)?;
132
133        let attn_output = attn_weights.matmul(&v)?.to_dtype(in_dtype)?;
134        let attn_output = attn_output
135            .transpose(1, 2)?
136            .contiguous()?
137            .reshape((bsz, seq_len, embed_dim))?;
138        let out = self.out_proj.forward(&attn_output)?;
139        Ok(out)
140    }
141}
142
143#[derive(Clone, Debug)]
144struct Mlp {
145    fc1: Linear,
146    fc2: Linear,
147}
148
149impl Mlp {
150    fn new(vs: VarBuilder, c: &Config) -> Result<Self> {
151        let fc1 = linear(c.embed_dim, c.intermediate_size, vs.pp("c_fc"))?;
152        let fc2 = linear(c.intermediate_size, c.embed_dim, vs.pp("c_proj"))?;
153
154        Ok(Mlp { fc1, fc2 })
155    }
156}
157
158impl Mlp {
159    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
160        let xs = self.fc1.forward(xs)?;
161        self.fc2.forward(&xs.gelu_erf()?)
162    }
163}
164
165#[derive(Clone, Debug)]
166struct EncoderLayer {
167    self_attn: Attention,
168    layer_norm1: LayerNorm,
169    mlp: Mlp,
170    layer_norm2: LayerNorm,
171}
172
173impl EncoderLayer {
174    fn new(vs: VarBuilder, c: &Config) -> Result<Self> {
175        let self_attn = Attention::new(vs.pp("attn"), c)?;
176        let layer_norm1 = layer_norm(c.embed_dim, 1e-5, vs.pp("ln_1"))?;
177        let mlp = Mlp::new(vs.pp("mlp"), c)?;
178        let layer_norm2 = layer_norm(c.embed_dim, 1e-5, vs.pp("ln_2"))?;
179
180        Ok(EncoderLayer {
181            self_attn,
182            layer_norm1,
183            mlp,
184            layer_norm2,
185        })
186    }
187
188    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
189        let residual = xs;
190        let xs = self.layer_norm1.forward(xs)?;
191        let xs = self.self_attn.forward(&xs)?;
192        let xs = (xs + residual)?;
193
194        let residual = &xs;
195        let xs = self.layer_norm2.forward(&xs)?;
196        let xs = self.mlp.forward(&xs)?;
197        let out = (xs + residual)?;
198        Ok(out)
199    }
200}
201
202#[derive(Clone, Debug)]
203pub struct Encoder {
204    layers: Vec<EncoderLayer>,
205}
206
207impl Encoder {
208    pub fn new(vs: VarBuilder, c: &Config) -> Result<Self> {
209        let vs = vs.pp("resblocks");
210        let mut layers: Vec<EncoderLayer> = Vec::new();
211        for index in 0..c.num_hidden_layers {
212            let layer = EncoderLayer::new(vs.pp(index.to_string()), c)?;
213            layers.push(layer)
214        }
215        Ok(Encoder { layers })
216    }
217
218    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
219        let mut xs = xs.clone();
220        for layer in self.layers.iter() {
221            xs = layer.forward(&xs)?;
222        }
223        Ok(xs)
224    }
225}
226
227/// A text transformer as used in CLIP variants.
228#[derive(Clone, Debug)]
229pub struct OpenClipTextTransformer {
230    embeddings: TextEmbeddings,
231    encoder: Encoder,
232    final_layer_norm: LayerNorm,
233}
234
235impl OpenClipTextTransformer {
236    pub fn new(vs: VarBuilder, c: &Config) -> Result<Self> {
237        let embeddings = TextEmbeddings::new(vs.clone(), c)?;
238        let final_layer_norm = layer_norm(c.embed_dim, 1e-5, vs.pp("ln_final"))?;
239        let encoder = Encoder::new(vs.pp("transformer"), c)?;
240        Ok(OpenClipTextTransformer {
241            embeddings,
242            encoder,
243            final_layer_norm,
244        })
245    }
246
247    pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
248        let input_ids = self.embeddings.forward(input_ids)?;
249        let input_ids = self.encoder.forward(&input_ids)?;
250        self.final_layer_norm.forward(&input_ids)
251    }
252}
253
254impl Module for OpenClipTextTransformer {
255    fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
256        let output = self.forward(input_ids)?;
257        let sequence_max_indices = input_ids.argmax(D::Minus1)?.to_dtype(DType::I64)?;
258
259        let mut indices = Vec::new();
260        for (batch_idx, &seq_idx) in sequence_max_indices.to_vec1::<i64>()?.iter().enumerate() {
261            let index = output.i((batch_idx, seq_idx as usize))?.unsqueeze(0)?;
262            indices.push(index);
263        }
264        Tensor::cat(&indices, 0)
265    }
266}