candle_transformers/models/
quantized_moondream.rs1use crate::models::moondream::{Config, VisionConfig};
17use crate::models::quantized_mixformer::MixFormerSequentialForCausalLM as PhiModel;
18use crate::quantized_nn::{layer_norm, linear_b, Linear};
19use crate::quantized_var_builder::VarBuilder;
20use candle::{IndexOp, Module, Result, Tensor, D};
21
22fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
23 let dim = q.dim(D::Minus1)?;
24 let scale_factor = 1.0 / (dim as f64).sqrt();
25 let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?;
26 candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(v)
27}
28
29#[derive(Debug, Clone)]
30struct LinearPatchEmbedding {
31 linear: Linear,
32}
33
34impl LinearPatchEmbedding {
35 fn new(vb: VarBuilder) -> Result<Self> {
36 let linear = linear_b(588, 1152, true, vb.pp("linear"))?;
37 Ok(Self { linear })
38 }
39}
40
41impl Module for LinearPatchEmbedding {
42 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
43 xs.apply(&self.linear)
44 }
45}
46
47#[derive(Debug, Clone)]
48struct Attention {
49 num_heads: usize,
50 head_dim: usize,
51 qkv: Linear,
52 proj: Linear,
53}
54
55impl Attention {
56 pub fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
57 let qkv = linear_b(dim, dim * 3, true, vb.pp("qkv"))?;
58 let proj = linear_b(dim, dim, true, vb.pp("proj"))?;
59 Ok(Self {
60 num_heads,
61 head_dim: dim / num_heads,
62 qkv,
63 proj,
64 })
65 }
66}
67
68impl Module for Attention {
69 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
70 let (b, n, c) = xs.dims3()?;
71 let qkv = xs
72 .apply(&self.qkv)?
73 .reshape((b, n, 3, self.num_heads, self.head_dim))?
74 .permute((2, 0, 3, 1, 4))?;
75 let (q, k, v) = (
76 qkv.i(0)?.contiguous()?,
77 qkv.i(1)?.contiguous()?,
78 qkv.i(2)?.contiguous()?,
79 );
80 scaled_dot_product_attention(&q, &k, &v)?
81 .transpose(1, 2)?
82 .reshape((b, n, c))?
83 .apply(&self.proj)
84 }
85}
86
87#[derive(Debug, Clone)]
88struct VitBlock {
89 attn: Attention,
90 mlp: Mlp,
91 norm1: candle_nn::LayerNorm,
92 norm2: candle_nn::LayerNorm,
93}
94
95impl VitBlock {
96 fn new(vb: VarBuilder, dim: usize, num_heads: usize, cfg: &VisionConfig) -> Result<Self> {
97 let attn = Attention::new(vb.pp("attn"), dim, num_heads)?;
98 let mlp = Mlp::new(vb.pp("mlp"), dim, cfg.hidden_features, dim, cfg.act)?;
99 let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
100 let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
101 Ok(Self {
102 attn,
103 mlp,
104 norm1,
105 norm2,
106 })
107 }
108}
109
110impl Module for VitBlock {
111 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
112 let ys = xs.apply(&self.norm1)?.apply(&self.attn)?;
113 let xs = (xs + &ys)?;
114 let ys = xs.apply(&self.norm2)?.apply(&self.mlp)?;
115 let xs = (&xs + &ys)?;
116 Ok(xs)
117 }
118}
119
120#[derive(Debug, Clone)]
121struct VisionTransformer {
122 patch_embed: LinearPatchEmbedding,
123 pos_embed: Tensor,
124 blocks: Vec<VitBlock>,
125 norm: candle_nn::LayerNorm,
126}
127
128impl VisionTransformer {
129 fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
130 let patch_embed = LinearPatchEmbedding::new(vb.pp("patch_embed"))?;
131 let pos_embed = vb
132 .get((1, cfg.embed_len, cfg.embed_dim), "pos_embed")?
133 .dequantize(vb.device())?;
134 let blocks = (0..cfg.num_blocks)
135 .map(|i| {
136 VitBlock::new(
137 vb.pp(format!("blocks.{}", i)),
138 cfg.embed_dim,
139 cfg.num_heads,
140 cfg,
141 )
142 })
143 .collect::<Result<_>>()?;
144 let norm = layer_norm(cfg.embed_dim, 1e-5, vb.pp("norm"))?;
145 Ok(Self {
146 patch_embed,
147 pos_embed,
148 blocks,
149 norm,
150 })
151 }
152}
153
154impl Module for VisionTransformer {
155 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
156 let mut xs = (&xs.apply(&self.patch_embed)? + &self.pos_embed)?;
157 for block in self.blocks.iter() {
158 xs = xs.apply(block)?;
159 }
160 xs.apply(&self.norm)
161 }
162}
163
164#[derive(Debug, Clone)]
165pub struct Encoder {
166 model: VisionTransformer,
167}
168
169impl Encoder {
170 fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
171 let model = VisionTransformer::new(cfg, vb.pp("model.visual"))?;
172 Ok(Self { model })
173 }
174}
175
176impl Module for Encoder {
177 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
178 xs.apply(&self.model)
179 }
180}
181
182#[derive(Debug, Clone)]
183struct Mlp {
184 fc1: Linear,
185 act: candle_nn::Activation,
186 fc2: Linear,
187}
188
189impl Mlp {
190 fn new(
191 vb: VarBuilder,
192 in_features: usize,
193 hidden_features: usize,
194 out_features: usize,
195 act: candle_nn::Activation,
196 ) -> Result<Self> {
197 let fc1 = linear_b(in_features, hidden_features, true, vb.pp("fc1"))?;
198 let fc2 = linear_b(hidden_features, out_features, true, vb.pp("fc2"))?;
199 Ok(Self { fc1, act, fc2 })
200 }
201}
202
203impl Module for Mlp {
204 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
205 xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
206 }
207}
208
209#[derive(Debug, Clone)]
210struct VisionProjection {
211 mlp: Mlp,
212}
213
214impl VisionProjection {
215 fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
216 let mlp = Mlp::new(
217 vb.pp("mlp"),
218 cfg.image_embedding_dim,
219 cfg.hidden_dim,
220 cfg.model_dim,
221 cfg.act,
222 )?;
223 Ok(Self { mlp })
224 }
225}
226
227impl Module for VisionProjection {
228 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
229 xs.apply(&self.mlp)
230 }
231}
232
233#[derive(Debug, Clone)]
234pub struct VisionEncoder {
235 encoder: Encoder,
236 projection: VisionProjection,
237}
238
239impl VisionEncoder {
240 pub fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
241 let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
242 let projection = VisionProjection::new(cfg, vb.pp("projection"))?;
243 Ok(Self {
244 encoder,
245 projection,
246 })
247 }
248}
249
250impl Module for VisionEncoder {
251 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
252 let (b, c, hp1, wp2) = xs.dims4()?;
253 let (p1, p2) = (14, 14);
254 let h = hp1 / p1;
255 let w = wp2 / p2;
256 xs.reshape((b, c, h, p1, h, p2))?
257 .permute((0, 2, 4, 1, 3, 5))?
258 .reshape((b, h * w, c * p1 * p2))?
259 .apply(&self.encoder)?
260 .apply(&self.projection)
261 }
262}
263
264pub struct Model {
265 pub text_model: PhiModel,
266 pub vision_encoder: VisionEncoder,
267}
268
269impl Model {
270 pub fn new(config: &Config, vb: VarBuilder) -> Result<Self> {
271 let text_model = PhiModel::new_v2(&config.phi_config, vb.pp("text_model"))?;
272 let vision_encoder = VisionEncoder::new(&config.vision_config, vb.pp("vision_encoder"))?;
273 Ok(Self {
274 text_model,
275 vision_encoder,
276 })
277 }
278
279 pub fn vision_encoder(&self) -> &VisionEncoder {
280 &self.vision_encoder
281 }
282
283 pub fn text_model(&mut self) -> &mut PhiModel {
284 &mut self.text_model
285 }
286}