mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 21:46:35 +00:00
462a4826e8
Made-with: Cursor
39 lines
1.3 KiB
Python
39 lines
1.3 KiB
Python
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from inference.dto import Detection
|
|
from inference.inference import Inference
|
|
|
|
|
|
@pytest.fixture
|
|
def inference_nms():
|
|
mock_engine = MagicMock()
|
|
mock_engine.get_batch_size.return_value = 1
|
|
mock_engine.get_input_shape.return_value = (1280, 1280)
|
|
with patch("inference.inference.AnnotationClass.read_json", return_value={}):
|
|
yield Inference(mock_engine, confidence_threshold=0.5, iou_threshold=0.3)
|
|
|
|
|
|
def test_bt_nms_01_overlapping_keeps_higher_confidence(inference_nms):
|
|
d1 = Detection(0.5, 0.5, 0.1, 0.1, 0, 0.9)
|
|
d2 = Detection(0.5, 0.5, 0.1, 0.1, 0, 0.5)
|
|
out = inference_nms.remove_overlapping_detections([d1, d2])
|
|
assert len(out) == 1
|
|
assert out[0].confidence == 0.9
|
|
|
|
|
|
def test_bt_nms_02_non_overlapping_both_preserved(inference_nms):
|
|
d1 = Detection(0.1, 0.1, 0.05, 0.05, 0, 0.8)
|
|
d2 = Detection(0.9, 0.9, 0.05, 0.05, 0, 0.8)
|
|
out = inference_nms.remove_overlapping_detections([d1, d2])
|
|
assert len(out) == 2
|
|
|
|
|
|
def test_bt_nms_03_overlap_pair_and_distant_kept(inference_nms):
|
|
a = Detection(0.5, 0.5, 0.1, 0.1, 0, 0.9)
|
|
b = Detection(0.55, 0.5, 0.1, 0.1, 0, 0.7)
|
|
c = Detection(0.1, 0.1, 0.1, 0.1, 0, 0.8)
|
|
out = inference_nms.remove_overlapping_detections([a, b, c])
|
|
assert len(out) <= 2
|
|
assert {d.confidence for d in out} == {0.9, 0.8}
|