candle_transformers/models/
quantized_moondream.rs

1//! Implementation of a quantized Moondream vision language model.
2//!
3//! Moondream is a lightweight vision-language model for image understanding and generation.
4//! This module provides a quantized version for reduced memory usage and faster inference.
5//!
6//! Key features:
7//! - ViT-based vision encoder
8//! - Phi-2 text decoder model
9//! - Memory efficient 8-bit quantization
10//! - Optimized for efficient deployment
11//!
12//! References:
13//! - [Moondream Model](https://github.com/vikhyat/moondream)
14//!
15
16use crate::models::moondream::{Config, VisionConfig};
17use crate::models::quantized_mixformer::MixFormerSequentialForCausalLM as PhiModel;
18use crate::quantized_nn::{layer_norm, linear_b, Linear};
19use crate::quantized_var_builder::VarBuilder;
20use candle::{IndexOp, Module, Result, Tensor, D};
21
22fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
23    let dim = q.dim(D::Minus1)?;
24    let scale_factor = 1.0 / (dim as f64).sqrt();
25    let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?;
26    candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(v)
27}
28
29#[derive(Debug, Clone)]
30struct LinearPatchEmbedding {
31    linear: Linear,
32}
33
34impl LinearPatchEmbedding {
35    fn new(vb: VarBuilder) -> Result<Self> {
36        let linear = linear_b(588, 1152, true, vb.pp("linear"))?;
37        Ok(Self { linear })
38    }
39}
40
41impl Module for LinearPatchEmbedding {
42    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
43        xs.apply(&self.linear)
44    }
45}
46
47#[derive(Debug, Clone)]
48struct Attention {
49    num_heads: usize,
50    head_dim: usize,
51    qkv: Linear,
52    proj: Linear,
53}
54
55impl Attention {
56    pub fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
57        let qkv = linear_b(dim, dim * 3, true, vb.pp("qkv"))?;
58        let proj = linear_b(dim, dim, true, vb.pp("proj"))?;
59        Ok(Self {
60            num_heads,
61            head_dim: dim / num_heads,
62            qkv,
63            proj,
64        })
65    }
66}
67
68impl Module for Attention {
69    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
70        let (b, n, c) = xs.dims3()?;
71        let qkv = xs
72            .apply(&self.qkv)?
73            .reshape((b, n, 3, self.num_heads, self.head_dim))?
74            .permute((2, 0, 3, 1, 4))?;
75        let (q, k, v) = (
76            qkv.i(0)?.contiguous()?,
77            qkv.i(1)?.contiguous()?,
78            qkv.i(2)?.contiguous()?,
79        );
80        scaled_dot_product_attention(&q, &k, &v)?
81            .transpose(1, 2)?
82            .reshape((b, n, c))?
83            .apply(&self.proj)
84    }
85}
86
87#[derive(Debug, Clone)]
88struct VitBlock {
89    attn: Attention,
90    mlp: Mlp,
91    norm1: candle_nn::LayerNorm,
92    norm2: candle_nn::LayerNorm,
93}
94
95impl VitBlock {
96    fn new(vb: VarBuilder, dim: usize, num_heads: usize, cfg: &VisionConfig) -> Result<Self> {
97        let attn = Attention::new(vb.pp("attn"), dim, num_heads)?;
98        let mlp = Mlp::new(vb.pp("mlp"), dim, cfg.hidden_features, dim, cfg.act)?;
99        let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
100        let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
101        Ok(Self {
102            attn,
103            mlp,
104            norm1,
105            norm2,
106        })
107    }
108}
109
110impl Module for VitBlock {
111    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
112        let ys = xs.apply(&self.norm1)?.apply(&self.attn)?;
113        let xs = (xs + &ys)?;
114        let ys = xs.apply(&self.norm2)?.apply(&self.mlp)?;
115        let xs = (&xs + &ys)?;
116        Ok(xs)
117    }
118}
119
120#[derive(Debug, Clone)]
121struct VisionTransformer {
122    patch_embed: LinearPatchEmbedding,
123    pos_embed: Tensor,
124    blocks: Vec<VitBlock>,
125    norm: candle_nn::LayerNorm,
126}
127
128impl VisionTransformer {
129    fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
130        let patch_embed = LinearPatchEmbedding::new(vb.pp("patch_embed"))?;
131        let pos_embed = vb
132            .get((1, cfg.embed_len, cfg.embed_dim), "pos_embed")?
133            .dequantize(vb.device())?;
134        let blocks = (0..cfg.num_blocks)
135            .map(|i| {
136                VitBlock::new(
137                    vb.pp(format!("blocks.{}", i)),
138                    cfg.embed_dim,
139                    cfg.num_heads,
140                    cfg,
141                )
142            })
143            .collect::<Result<_>>()?;
144        let norm = layer_norm(cfg.embed_dim, 1e-5, vb.pp("norm"))?;
145        Ok(Self {
146            patch_embed,
147            pos_embed,
148            blocks,
149            norm,
150        })
151    }
152}
153
154impl Module for VisionTransformer {
155    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
156        let mut xs = (&xs.apply(&self.patch_embed)? + &self.pos_embed)?;
157        for block in self.blocks.iter() {
158            xs = xs.apply(block)?;
159        }
160        xs.apply(&self.norm)
161    }
162}
163
164#[derive(Debug, Clone)]
165pub struct Encoder {
166    model: VisionTransformer,
167}
168
169impl Encoder {
170    fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
171        let model = VisionTransformer::new(cfg, vb.pp("model.visual"))?;
172        Ok(Self { model })
173    }
174}
175
176impl Module for Encoder {
177    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
178        xs.apply(&self.model)
179    }
180}
181
182#[derive(Debug, Clone)]
183struct Mlp {
184    fc1: Linear,
185    act: candle_nn::Activation,
186    fc2: Linear,
187}
188
189impl Mlp {
190    fn new(
191        vb: VarBuilder,
192        in_features: usize,
193        hidden_features: usize,
194        out_features: usize,
195        act: candle_nn::Activation,
196    ) -> Result<Self> {
197        let fc1 = linear_b(in_features, hidden_features, true, vb.pp("fc1"))?;
198        let fc2 = linear_b(hidden_features, out_features, true, vb.pp("fc2"))?;
199        Ok(Self { fc1, act, fc2 })
200    }
201}
202
203impl Module for Mlp {
204    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
205        xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
206    }
207}
208
209#[derive(Debug, Clone)]
210struct VisionProjection {
211    mlp: Mlp,
212}
213
214impl VisionProjection {
215    fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
216        let mlp = Mlp::new(
217            vb.pp("mlp"),
218            cfg.image_embedding_dim,
219            cfg.hidden_dim,
220            cfg.model_dim,
221            cfg.act,
222        )?;
223        Ok(Self { mlp })
224    }
225}
226
227impl Module for VisionProjection {
228    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
229        xs.apply(&self.mlp)
230    }
231}
232
233#[derive(Debug, Clone)]
234pub struct VisionEncoder {
235    encoder: Encoder,
236    projection: VisionProjection,
237}
238
239impl VisionEncoder {
240    pub fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
241        let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
242        let projection = VisionProjection::new(cfg, vb.pp("projection"))?;
243        Ok(Self {
244            encoder,
245            projection,
246        })
247    }
248}
249
250impl Module for VisionEncoder {
251    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
252        let (b, c, hp1, wp2) = xs.dims4()?;
253        let (p1, p2) = (14, 14);
254        let h = hp1 / p1;
255        let w = wp2 / p2;
256        xs.reshape((b, c, h, p1, h, p2))?
257            .permute((0, 2, 4, 1, 3, 5))?
258            .reshape((b, h * w, c * p1 * p2))?
259            .apply(&self.encoder)?
260            .apply(&self.projection)
261    }
262}
263
264pub struct Model {
265    pub text_model: PhiModel,
266    pub vision_encoder: VisionEncoder,
267}
268
269impl Model {
270    pub fn new(config: &Config, vb: VarBuilder) -> Result<Self> {
271        let text_model = PhiModel::new_v2(&config.phi_config, vb.pp("text_model"))?;
272        let vision_encoder = VisionEncoder::new(&config.vision_config, vb.pp("vision_encoder"))?;
273        Ok(Self {
274            text_model,
275            vision_encoder,
276        })
277    }
278
279    pub fn vision_encoder(&self) -> &VisionEncoder {
280        &self.vision_encoder
281    }
282
283    pub fn text_model(&mut self) -> &mut PhiModel {
284        &mut self.text_model
285    }
286}