candle_transformers/
object_detection.rs

1//! Bounding Boxes and Intersection
2//!
3//! This module provides functionality for handling bounding boxes and their manipulation,
4//! particularly in the context of object detection. It includes tools for calculating
5//! intersection over union (IoU) and non-maximum suppression (NMS).
6
7/// A bounding box around an object.
8#[derive(Debug, Clone)]
9pub struct Bbox<D> {
10    pub xmin: f32,
11    pub ymin: f32,
12    pub xmax: f32,
13    pub ymax: f32,
14    pub confidence: f32,
15    pub data: D,
16}
17
18#[derive(Debug, Clone, Copy, PartialEq)]
19pub struct KeyPoint {
20    pub x: f32,
21    pub y: f32,
22    pub mask: f32,
23}
24
25/// Intersection over union of two bounding boxes.
26pub fn iou<D>(b1: &Bbox<D>, b2: &Bbox<D>) -> f32 {
27    let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.);
28    let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.);
29    let i_xmin = b1.xmin.max(b2.xmin);
30    let i_xmax = b1.xmax.min(b2.xmax);
31    let i_ymin = b1.ymin.max(b2.ymin);
32    let i_ymax = b1.ymax.min(b2.ymax);
33    let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.);
34    i_area / (b1_area + b2_area - i_area)
35}
36
37pub fn non_maximum_suppression<D>(bboxes: &mut [Vec<Bbox<D>>], threshold: f32) {
38    // Perform non-maximum suppression.
39    for bboxes_for_class in bboxes.iter_mut() {
40        bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
41        let mut current_index = 0;
42        for index in 0..bboxes_for_class.len() {
43            let mut drop = false;
44            for prev_index in 0..current_index {
45                let iou = iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]);
46                if iou > threshold {
47                    drop = true;
48                    break;
49                }
50            }
51            if !drop {
52                bboxes_for_class.swap(current_index, index);
53                current_index += 1;
54            }
55        }
56        bboxes_for_class.truncate(current_index);
57    }
58}
59
60// Updates confidences starting at highest and comparing subsequent boxes.
61fn update_confidences<D>(
62    bboxes_for_class: &[Bbox<D>],
63    updated_confidences: &mut [f32],
64    iou_threshold: f32,
65    sigma: f32,
66) {
67    let len = bboxes_for_class.len();
68    for current_index in 0..len {
69        let current_bbox = &bboxes_for_class[current_index];
70        for index in (current_index + 1)..len {
71            let iou_val = iou(current_bbox, &bboxes_for_class[index]);
72            if iou_val > iou_threshold {
73                // Decay calculation from page 4 of: https://arxiv.org/pdf/1704.04503
74                let decay = (-iou_val * iou_val / sigma).exp();
75                let updated_confidence = bboxes_for_class[index].confidence * decay;
76                updated_confidences[index] = updated_confidence;
77            }
78        }
79    }
80}
81
82// Sorts the bounding boxes by confidence and applies soft non-maximum suppression.
83// This function is based on the algorithm described in https://arxiv.org/pdf/1704.04503
84pub fn soft_non_maximum_suppression<D>(
85    bboxes: &mut [Vec<Bbox<D>>],
86    iou_threshold: Option<f32>,
87    confidence_threshold: Option<f32>,
88    sigma: Option<f32>,
89) {
90    let iou_threshold = iou_threshold.unwrap_or(0.5);
91    let confidence_threshold = confidence_threshold.unwrap_or(0.1);
92    let sigma = sigma.unwrap_or(0.5);
93
94    for bboxes_for_class in bboxes.iter_mut() {
95        // Sort boxes by confidence in descending order
96        bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
97        let mut updated_confidences = bboxes_for_class
98            .iter()
99            .map(|bbox| bbox.confidence)
100            .collect::<Vec<_>>();
101        update_confidences(
102            bboxes_for_class,
103            &mut updated_confidences,
104            iou_threshold,
105            sigma,
106        );
107        // Update confidences, set to 0.0 if below threshold
108        for (i, &confidence) in updated_confidences.iter().enumerate() {
109            bboxes_for_class[i].confidence = if confidence < confidence_threshold {
110                0.0
111            } else {
112                confidence
113            };
114        }
115    }
116}