1use crate::models::mixformer::{Config as PhiConfig, MixFormerSequentialForCausalLM as PhiModel};
39use crate::models::with_tracing::{layer_norm, linear_b, LayerNorm, Linear};
40use candle::{IndexOp, Module, Result, Tensor, D};
41use candle_nn::VarBuilder;
42
43#[derive(Debug, Clone, serde::Deserialize)]
44pub struct Config {
45 pub phi_config: PhiConfig,
46 pub vision_config: VisionConfig,
47}
48
49impl Config {
50 pub fn v2() -> Self {
51 Self {
52 phi_config: PhiConfig::v1_5(),
53 vision_config: VisionConfig::v2(),
54 }
55 }
56}
57
58fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
59 let dim = q.dim(D::Minus1)?;
60 let scale_factor = 1.0 / (dim as f64).sqrt();
61 let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?;
62 candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(v)
63}
64
65#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
66pub struct VisionConfig {
67 pub(crate) image_embedding_dim: usize,
68 pub(crate) model_dim: usize,
69 pub(crate) hidden_dim: usize,
70 pub(crate) hidden_features: usize,
71 pub(crate) embed_len: usize,
72 pub(crate) embed_dim: usize,
73 pub(crate) num_blocks: usize,
74 pub(crate) num_heads: usize,
75 pub(crate) act: candle_nn::Activation,
76}
77
78impl VisionConfig {
79 pub fn v2() -> Self {
80 Self {
81 image_embedding_dim: 1152,
82 model_dim: 2048,
83 hidden_dim: 2048 * 4,
84 hidden_features: 4304,
85 embed_len: 729,
86 embed_dim: 1152,
87 num_blocks: 27,
88 num_heads: 16,
89 act: candle_nn::Activation::GeluPytorchTanh,
90 }
91 }
92}
93
94#[derive(Debug, Clone)]
95struct LinearPatchEmbedding {
96 linear: Linear,
97}
98
99impl LinearPatchEmbedding {
100 fn new(vb: VarBuilder) -> Result<Self> {
101 let linear = linear_b(588, 1152, true, vb.pp("linear"))?;
102 Ok(Self { linear })
103 }
104}
105
106impl Module for LinearPatchEmbedding {
107 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
108 xs.apply(&self.linear)
109 }
110}
111
112#[derive(Debug, Clone)]
113struct Attention {
114 num_heads: usize,
115 head_dim: usize,
116 qkv: Linear,
117 proj: Linear,
118 span: tracing::Span,
119}
120
121impl Attention {
122 pub fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
123 let qkv = linear_b(dim, dim * 3, true, vb.pp("qkv"))?;
124 let proj = linear_b(dim, dim, true, vb.pp("proj"))?;
125 Ok(Self {
126 num_heads,
127 head_dim: dim / num_heads,
128 qkv,
129 proj,
130 span: tracing::span!(tracing::Level::TRACE, "vit-attn"),
131 })
132 }
133}
134
135impl Module for Attention {
136 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
137 let _enter = self.span.enter();
138 let (b, n, c) = xs.dims3()?;
139 let qkv = xs
140 .apply(&self.qkv)?
141 .reshape((b, n, 3, self.num_heads, self.head_dim))?
142 .permute((2, 0, 3, 1, 4))?;
143 let (q, k, v) = (
144 qkv.i(0)?.contiguous()?,
145 qkv.i(1)?.contiguous()?,
146 qkv.i(2)?.contiguous()?,
147 );
148 scaled_dot_product_attention(&q, &k, &v)?
149 .transpose(1, 2)?
150 .reshape((b, n, c))?
151 .apply(&self.proj)
152 }
153}
154
155#[derive(Debug, Clone)]
156struct VitBlock {
157 attn: Attention,
158 mlp: Mlp,
159 norm1: LayerNorm,
160 norm2: LayerNorm,
161 span: tracing::Span,
162}
163
164impl VitBlock {
165 fn new(vb: VarBuilder, dim: usize, num_heads: usize, cfg: &VisionConfig) -> Result<Self> {
166 let attn = Attention::new(vb.pp("attn"), dim, num_heads)?;
167 let mlp = Mlp::new(vb.pp("mlp"), dim, cfg.hidden_features, dim, cfg.act)?;
168 let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
169 let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
170 Ok(Self {
171 attn,
172 mlp,
173 norm1,
174 norm2,
175 span: tracing::span!(tracing::Level::TRACE, "vit-block"),
176 })
177 }
178}
179
180impl Module for VitBlock {
181 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
182 let _enter = self.span.enter();
183 let ys = xs.apply(&self.norm1)?.apply(&self.attn)?;
184 let xs = (xs + &ys)?;
185 let ys = xs.apply(&self.norm2)?.apply(&self.mlp)?;
186 let xs = (&xs + &ys)?;
187 Ok(xs)
188 }
189}
190
191#[derive(Debug, Clone)]
192struct VisionTransformer {
193 patch_embed: LinearPatchEmbedding,
194 pos_embed: Tensor,
195 blocks: Vec<VitBlock>,
196 norm: LayerNorm,
197 span: tracing::Span,
198}
199
200impl VisionTransformer {
201 fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
202 let patch_embed = LinearPatchEmbedding::new(vb.pp("patch_embed"))?;
203 let pos_embed = vb.get((1, cfg.embed_len, cfg.embed_dim), "pos_embed")?;
204 let blocks = (0..cfg.num_blocks)
205 .map(|i| {
206 VitBlock::new(
207 vb.pp(format!("blocks.{}", i)),
208 cfg.embed_dim,
209 cfg.num_heads,
210 cfg,
211 )
212 })
213 .collect::<Result<_>>()?;
214 let norm = layer_norm(cfg.embed_dim, 1e-5, vb.pp("norm"))?;
215 Ok(Self {
216 patch_embed,
217 pos_embed,
218 blocks,
219 norm,
220 span: tracing::span!(tracing::Level::TRACE, "vit"),
221 })
222 }
223}
224
225impl Module for VisionTransformer {
226 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
227 let _enter = self.span.enter();
228 let mut xs = (&xs.apply(&self.patch_embed)? + &self.pos_embed)?;
229 for block in self.blocks.iter() {
230 xs = xs.apply(block)?;
231 }
232 xs.apply(&self.norm)
233 }
234}
235
236#[derive(Debug, Clone)]
237pub struct Encoder {
238 model: VisionTransformer,
239}
240
241impl Encoder {
242 fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
243 let model = VisionTransformer::new(cfg, vb.pp("model.visual"))?;
244 Ok(Self { model })
245 }
246}
247
248impl Module for Encoder {
249 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
250 xs.apply(&self.model)
251 }
252}
253
254#[derive(Debug, Clone)]
255struct Mlp {
256 fc1: Linear,
257 act: candle_nn::Activation,
258 fc2: Linear,
259 span: tracing::Span,
260}
261
262impl Mlp {
263 fn new(
264 vb: VarBuilder,
265 in_features: usize,
266 hidden_features: usize,
267 out_features: usize,
268 act: candle_nn::Activation,
269 ) -> Result<Self> {
270 let fc1 = linear_b(in_features, hidden_features, true, vb.pp("fc1"))?;
271 let fc2 = linear_b(hidden_features, out_features, true, vb.pp("fc2"))?;
272 Ok(Self {
273 fc1,
274 act,
275 fc2,
276 span: tracing::span!(tracing::Level::TRACE, "mlp"),
277 })
278 }
279}
280
281impl Module for Mlp {
282 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
283 let _enter = self.span.enter();
284 xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
285 }
286}
287
288#[derive(Debug, Clone)]
289struct VisionProjection {
290 mlp: Mlp,
291}
292
293impl VisionProjection {
294 fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
295 let mlp = Mlp::new(
296 vb.pp("mlp"),
297 cfg.image_embedding_dim,
298 cfg.hidden_dim,
299 cfg.model_dim,
300 cfg.act,
301 )?;
302 Ok(Self { mlp })
303 }
304}
305
306impl Module for VisionProjection {
307 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
308 xs.apply(&self.mlp)
309 }
310}
311
312#[derive(Debug, Clone)]
313pub struct VisionEncoder {
314 encoder: Encoder,
315 projection: VisionProjection,
316}
317
318impl VisionEncoder {
319 pub fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
320 let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
321 let projection = VisionProjection::new(cfg, vb.pp("projection"))?;
322 Ok(Self {
323 encoder,
324 projection,
325 })
326 }
327}
328
329impl Module for VisionEncoder {
330 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
331 let (b, c, hp1, wp2) = xs.dims4()?;
332 let (p1, p2) = (14, 14);
333 let h = hp1 / p1;
334 let w = wp2 / p2;
335 xs.reshape((b, c, h, p1, h, p2))?
336 .permute((0, 2, 4, 1, 3, 5))?
337 .reshape((b, h * w, c * p1 * p2))?
338 .apply(&self.encoder)?
339 .apply(&self.projection)
340 }
341}
342
343#[derive(Debug, Clone)]
344pub struct Model {
345 pub text_model: PhiModel,
346 pub vision_encoder: VisionEncoder,
347}
348
349impl Model {
350 pub fn new(config: &Config, vb: VarBuilder) -> Result<Self> {
351 let text_model = PhiModel::new_v2(&config.phi_config, vb.pp("text_model"))?;
352 let vision_encoder = VisionEncoder::new(&config.vision_config, vb.pp("vision_encoder"))?;
353 Ok(Self {
354 text_model,
355 vision_encoder,
356 })
357 }
358
359 pub fn vision_encoder(&self) -> &VisionEncoder {
360 &self.vision_encoder
361 }
362
363 pub fn text_model(&mut self) -> &mut PhiModel {
364 &mut self.text_model
365 }
366}