Files
ai-training/tests/test_nms.py
T
2026-03-26 23:23:42 +02:00

39 lines
1.3 KiB
Python

import pytest
from unittest.mock import MagicMock, patch
from inference.dto import Detection
from inference.inference import Inference
@pytest.fixture
def inference_nms():
mock_engine = MagicMock()
mock_engine.get_batch_size.return_value = 1
mock_engine.get_input_shape.return_value = (1280, 1280)
with patch("inference.inference.AnnotationClass.read_json", return_value={}):
yield Inference(mock_engine, confidence_threshold=0.5, iou_threshold=0.3)
def test_bt_nms_01_overlapping_keeps_higher_confidence(inference_nms):
d1 = Detection(0.5, 0.5, 0.1, 0.1, 0, 0.9)
d2 = Detection(0.5, 0.5, 0.1, 0.1, 0, 0.5)
out = inference_nms.remove_overlapping_detections([d1, d2])
assert len(out) == 1
assert out[0].confidence == 0.9
def test_bt_nms_02_non_overlapping_both_preserved(inference_nms):
d1 = Detection(0.1, 0.1, 0.05, 0.05, 0, 0.8)
d2 = Detection(0.9, 0.9, 0.05, 0.05, 0, 0.8)
out = inference_nms.remove_overlapping_detections([d1, d2])
assert len(out) == 2
def test_bt_nms_03_overlap_pair_and_distant_kept(inference_nms):
a = Detection(0.5, 0.5, 0.1, 0.1, 0, 0.9)
b = Detection(0.55, 0.5, 0.1, 0.1, 0, 0.7)
c = Detection(0.1, 0.1, 0.1, 0.1, 0, 0.8)
out = inference_nms.remove_overlapping_detections([a, b, c])
assert len(out) <= 2
assert {d.confidence for d in out} == {0.9, 0.8}