candle_transformers/models/
moondream.rs

1//! MoonDream Model vision-to-text
2//!
3//!
4//! Moondream is a computer-vision model that can answer real-world questions about images.
5//! It's lightweight with only 1.6B parameters, enabling it to run on mobile phones and edge devices.
6//! [MoonDream Original Implementation](https://github.com/vikhyat/moondream)
7//!
8//! The model consists of:
9//! - Vision encoder using a ViT-style architecture
10//! - Text decoder based on Microsoft's Phi model
11//! - Vision projection module to align vision and text embeddings
12//!
13//! # Examples
14//!
15//! <img src="https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg" width="200">
16//!
17//! ```bash
18//! # download an example image
19//! wget https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg
20//!
21//! # Now you can run Moondream from the `candle-examples` crate:
22//! cargo run --example moondream \
23//!   --release -- \
24//!   --prompt "What is the girl eating?"
25//!   --image "./demo-1.jpg"
26//!
27//! > avavx: false, neon: true, simd128: false, f16c: false
28//! > temp: 0.00 repeat-penalty: 1.00 repeat-last-n: 64
29//! > retrieved the files in 3.395583ms
30//! > Running on CPU, to run on GPU(metal), build this example with `--features metal`
31//! > loaded the model in 5.485493792s
32//! > loaded and encoded the image Tensor[dims 3, 378, 378; f32] in 4.801396417s
33//! > starting the inference loop
34//! > The girl is eating a hamburger.<
35//! > 9 tokens generated (0.68 token/s)
36//! ```
37
38use crate::models::mixformer::{Config as PhiConfig, MixFormerSequentialForCausalLM as PhiModel};
39use crate::models::with_tracing::{layer_norm, linear_b, LayerNorm, Linear};
40use candle::{IndexOp, Module, Result, Tensor, D};
41use candle_nn::VarBuilder;
42
43#[derive(Debug, Clone, serde::Deserialize)]
44pub struct Config {
45    pub phi_config: PhiConfig,
46    pub vision_config: VisionConfig,
47}
48
49impl Config {
50    pub fn v2() -> Self {
51        Self {
52            phi_config: PhiConfig::v1_5(),
53            vision_config: VisionConfig::v2(),
54        }
55    }
56}
57
58fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
59    let dim = q.dim(D::Minus1)?;
60    let scale_factor = 1.0 / (dim as f64).sqrt();
61    let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?;
62    candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(v)
63}
64
65#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
66pub struct VisionConfig {
67    pub(crate) image_embedding_dim: usize,
68    pub(crate) model_dim: usize,
69    pub(crate) hidden_dim: usize,
70    pub(crate) hidden_features: usize,
71    pub(crate) embed_len: usize,
72    pub(crate) embed_dim: usize,
73    pub(crate) num_blocks: usize,
74    pub(crate) num_heads: usize,
75    pub(crate) act: candle_nn::Activation,
76}
77
78impl VisionConfig {
79    pub fn v2() -> Self {
80        Self {
81            image_embedding_dim: 1152,
82            model_dim: 2048,
83            hidden_dim: 2048 * 4,
84            hidden_features: 4304,
85            embed_len: 729,
86            embed_dim: 1152,
87            num_blocks: 27,
88            num_heads: 16,
89            act: candle_nn::Activation::GeluPytorchTanh,
90        }
91    }
92}
93
94#[derive(Debug, Clone)]
95struct LinearPatchEmbedding {
96    linear: Linear,
97}
98
99impl LinearPatchEmbedding {
100    fn new(vb: VarBuilder) -> Result<Self> {
101        let linear = linear_b(588, 1152, true, vb.pp("linear"))?;
102        Ok(Self { linear })
103    }
104}
105
106impl Module for LinearPatchEmbedding {
107    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
108        xs.apply(&self.linear)
109    }
110}
111
112#[derive(Debug, Clone)]
113struct Attention {
114    num_heads: usize,
115    head_dim: usize,
116    qkv: Linear,
117    proj: Linear,
118    span: tracing::Span,
119}
120
121impl Attention {
122    pub fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
123        let qkv = linear_b(dim, dim * 3, true, vb.pp("qkv"))?;
124        let proj = linear_b(dim, dim, true, vb.pp("proj"))?;
125        Ok(Self {
126            num_heads,
127            head_dim: dim / num_heads,
128            qkv,
129            proj,
130            span: tracing::span!(tracing::Level::TRACE, "vit-attn"),
131        })
132    }
133}
134
135impl Module for Attention {
136    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
137        let _enter = self.span.enter();
138        let (b, n, c) = xs.dims3()?;
139        let qkv = xs
140            .apply(&self.qkv)?
141            .reshape((b, n, 3, self.num_heads, self.head_dim))?
142            .permute((2, 0, 3, 1, 4))?;
143        let (q, k, v) = (
144            qkv.i(0)?.contiguous()?,
145            qkv.i(1)?.contiguous()?,
146            qkv.i(2)?.contiguous()?,
147        );
148        scaled_dot_product_attention(&q, &k, &v)?
149            .transpose(1, 2)?
150            .reshape((b, n, c))?
151            .apply(&self.proj)
152    }
153}
154
155#[derive(Debug, Clone)]
156struct VitBlock {
157    attn: Attention,
158    mlp: Mlp,
159    norm1: LayerNorm,
160    norm2: LayerNorm,
161    span: tracing::Span,
162}
163
164impl VitBlock {
165    fn new(vb: VarBuilder, dim: usize, num_heads: usize, cfg: &VisionConfig) -> Result<Self> {
166        let attn = Attention::new(vb.pp("attn"), dim, num_heads)?;
167        let mlp = Mlp::new(vb.pp("mlp"), dim, cfg.hidden_features, dim, cfg.act)?;
168        let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
169        let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
170        Ok(Self {
171            attn,
172            mlp,
173            norm1,
174            norm2,
175            span: tracing::span!(tracing::Level::TRACE, "vit-block"),
176        })
177    }
178}
179
180impl Module for VitBlock {
181    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
182        let _enter = self.span.enter();
183        let ys = xs.apply(&self.norm1)?.apply(&self.attn)?;
184        let xs = (xs + &ys)?;
185        let ys = xs.apply(&self.norm2)?.apply(&self.mlp)?;
186        let xs = (&xs + &ys)?;
187        Ok(xs)
188    }
189}
190
191#[derive(Debug, Clone)]
192struct VisionTransformer {
193    patch_embed: LinearPatchEmbedding,
194    pos_embed: Tensor,
195    blocks: Vec<VitBlock>,
196    norm: LayerNorm,
197    span: tracing::Span,
198}
199
200impl VisionTransformer {
201    fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
202        let patch_embed = LinearPatchEmbedding::new(vb.pp("patch_embed"))?;
203        let pos_embed = vb.get((1, cfg.embed_len, cfg.embed_dim), "pos_embed")?;
204        let blocks = (0..cfg.num_blocks)
205            .map(|i| {
206                VitBlock::new(
207                    vb.pp(format!("blocks.{}", i)),
208                    cfg.embed_dim,
209                    cfg.num_heads,
210                    cfg,
211                )
212            })
213            .collect::<Result<_>>()?;
214        let norm = layer_norm(cfg.embed_dim, 1e-5, vb.pp("norm"))?;
215        Ok(Self {
216            patch_embed,
217            pos_embed,
218            blocks,
219            norm,
220            span: tracing::span!(tracing::Level::TRACE, "vit"),
221        })
222    }
223}
224
225impl Module for VisionTransformer {
226    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
227        let _enter = self.span.enter();
228        let mut xs = (&xs.apply(&self.patch_embed)? + &self.pos_embed)?;
229        for block in self.blocks.iter() {
230            xs = xs.apply(block)?;
231        }
232        xs.apply(&self.norm)
233    }
234}
235
236#[derive(Debug, Clone)]
237pub struct Encoder {
238    model: VisionTransformer,
239}
240
241impl Encoder {
242    fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
243        let model = VisionTransformer::new(cfg, vb.pp("model.visual"))?;
244        Ok(Self { model })
245    }
246}
247
248impl Module for Encoder {
249    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
250        xs.apply(&self.model)
251    }
252}
253
254#[derive(Debug, Clone)]
255struct Mlp {
256    fc1: Linear,
257    act: candle_nn::Activation,
258    fc2: Linear,
259    span: tracing::Span,
260}
261
262impl Mlp {
263    fn new(
264        vb: VarBuilder,
265        in_features: usize,
266        hidden_features: usize,
267        out_features: usize,
268        act: candle_nn::Activation,
269    ) -> Result<Self> {
270        let fc1 = linear_b(in_features, hidden_features, true, vb.pp("fc1"))?;
271        let fc2 = linear_b(hidden_features, out_features, true, vb.pp("fc2"))?;
272        Ok(Self {
273            fc1,
274            act,
275            fc2,
276            span: tracing::span!(tracing::Level::TRACE, "mlp"),
277        })
278    }
279}
280
281impl Module for Mlp {
282    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
283        let _enter = self.span.enter();
284        xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
285    }
286}
287
288#[derive(Debug, Clone)]
289struct VisionProjection {
290    mlp: Mlp,
291}
292
293impl VisionProjection {
294    fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
295        let mlp = Mlp::new(
296            vb.pp("mlp"),
297            cfg.image_embedding_dim,
298            cfg.hidden_dim,
299            cfg.model_dim,
300            cfg.act,
301        )?;
302        Ok(Self { mlp })
303    }
304}
305
306impl Module for VisionProjection {
307    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
308        xs.apply(&self.mlp)
309    }
310}
311
312#[derive(Debug, Clone)]
313pub struct VisionEncoder {
314    encoder: Encoder,
315    projection: VisionProjection,
316}
317
318impl VisionEncoder {
319    pub fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
320        let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
321        let projection = VisionProjection::new(cfg, vb.pp("projection"))?;
322        Ok(Self {
323            encoder,
324            projection,
325        })
326    }
327}
328
329impl Module for VisionEncoder {
330    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
331        let (b, c, hp1, wp2) = xs.dims4()?;
332        let (p1, p2) = (14, 14);
333        let h = hp1 / p1;
334        let w = wp2 / p2;
335        xs.reshape((b, c, h, p1, h, p2))?
336            .permute((0, 2, 4, 1, 3, 5))?
337            .reshape((b, h * w, c * p1 * p2))?
338            .apply(&self.encoder)?
339            .apply(&self.projection)
340    }
341}
342
343#[derive(Debug, Clone)]
344pub struct Model {
345    pub text_model: PhiModel,
346    pub vision_encoder: VisionEncoder,
347}
348
349impl Model {
350    pub fn new(config: &Config, vb: VarBuilder) -> Result<Self> {
351        let text_model = PhiModel::new_v2(&config.phi_config, vb.pp("text_model"))?;
352        let vision_encoder = VisionEncoder::new(&config.vision_config, vb.pp("vision_encoder"))?;
353        Ok(Self {
354            text_model,
355            vision_encoder,
356        })
357    }
358
359    pub fn vision_encoder(&self) -> &VisionEncoder {
360        &self.vision_encoder
361    }
362
363    pub fn text_model(&mut self) -> &mut PhiModel {
364        &mut self.text_model
365    }
366}