candle_transformers/
object_detection.rs1#[derive(Debug, Clone)]
9pub struct Bbox<D> {
10 pub xmin: f32,
11 pub ymin: f32,
12 pub xmax: f32,
13 pub ymax: f32,
14 pub confidence: f32,
15 pub data: D,
16}
17
18#[derive(Debug, Clone, Copy, PartialEq)]
19pub struct KeyPoint {
20 pub x: f32,
21 pub y: f32,
22 pub mask: f32,
23}
24
25pub fn iou<D>(b1: &Bbox<D>, b2: &Bbox<D>) -> f32 {
27 let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.);
28 let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.);
29 let i_xmin = b1.xmin.max(b2.xmin);
30 let i_xmax = b1.xmax.min(b2.xmax);
31 let i_ymin = b1.ymin.max(b2.ymin);
32 let i_ymax = b1.ymax.min(b2.ymax);
33 let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.);
34 i_area / (b1_area + b2_area - i_area)
35}
36
37pub fn non_maximum_suppression<D>(bboxes: &mut [Vec<Bbox<D>>], threshold: f32) {
38 for bboxes_for_class in bboxes.iter_mut() {
40 bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
41 let mut current_index = 0;
42 for index in 0..bboxes_for_class.len() {
43 let mut drop = false;
44 for prev_index in 0..current_index {
45 let iou = iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]);
46 if iou > threshold {
47 drop = true;
48 break;
49 }
50 }
51 if !drop {
52 bboxes_for_class.swap(current_index, index);
53 current_index += 1;
54 }
55 }
56 bboxes_for_class.truncate(current_index);
57 }
58}
59
60fn update_confidences<D>(
62 bboxes_for_class: &[Bbox<D>],
63 updated_confidences: &mut [f32],
64 iou_threshold: f32,
65 sigma: f32,
66) {
67 let len = bboxes_for_class.len();
68 for current_index in 0..len {
69 let current_bbox = &bboxes_for_class[current_index];
70 for index in (current_index + 1)..len {
71 let iou_val = iou(current_bbox, &bboxes_for_class[index]);
72 if iou_val > iou_threshold {
73 let decay = (-iou_val * iou_val / sigma).exp();
75 let updated_confidence = bboxes_for_class[index].confidence * decay;
76 updated_confidences[index] = updated_confidence;
77 }
78 }
79 }
80}
81
82pub fn soft_non_maximum_suppression<D>(
85 bboxes: &mut [Vec<Bbox<D>>],
86 iou_threshold: Option<f32>,
87 confidence_threshold: Option<f32>,
88 sigma: Option<f32>,
89) {
90 let iou_threshold = iou_threshold.unwrap_or(0.5);
91 let confidence_threshold = confidence_threshold.unwrap_or(0.1);
92 let sigma = sigma.unwrap_or(0.5);
93
94 for bboxes_for_class in bboxes.iter_mut() {
95 bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
97 let mut updated_confidences = bboxes_for_class
98 .iter()
99 .map(|bbox| bbox.confidence)
100 .collect::<Vec<_>>();
101 update_confidences(
102 bboxes_for_class,
103 &mut updated_confidences,
104 iou_threshold,
105 sigma,
106 );
107 for (i, &confidence) in updated_confidences.iter().enumerate() {
109 bboxes_for_class[i].confidence = if confidence < confidence_threshold {
110 0.0
111 } else {
112 confidence
113 };
114 }
115 }
116}