import pytest from unittest.mock import MagicMock, patch from inference.dto import Detection from inference.inference import Inference @pytest.fixture def inference_nms(): mock_engine = MagicMock() mock_engine.get_batch_size.return_value = 1 mock_engine.get_input_shape.return_value = (1280, 1280) with patch("inference.inference.AnnotationClass.read_json", return_value={}): yield Inference(mock_engine, confidence_threshold=0.5, iou_threshold=0.3) def test_bt_nms_01_overlapping_keeps_higher_confidence(inference_nms): d1 = Detection(0.5, 0.5, 0.1, 0.1, 0, 0.9) d2 = Detection(0.5, 0.5, 0.1, 0.1, 0, 0.5) out = inference_nms.remove_overlapping_detections([d1, d2]) assert len(out) == 1 assert out[0].confidence == 0.9 def test_bt_nms_02_non_overlapping_both_preserved(inference_nms): d1 = Detection(0.1, 0.1, 0.05, 0.05, 0, 0.8) d2 = Detection(0.9, 0.9, 0.05, 0.05, 0, 0.8) out = inference_nms.remove_overlapping_detections([d1, d2]) assert len(out) == 2 def test_bt_nms_03_overlap_pair_and_distant_kept(inference_nms): a = Detection(0.5, 0.5, 0.1, 0.1, 0, 0.9) b = Detection(0.55, 0.5, 0.1, 0.1, 0, 0.7) c = Detection(0.1, 0.1, 0.1, 0.1, 0, 0.8) out = inference_nms.remove_overlapping_detections([a, b, c]) assert len(out) <= 2 assert {d.confidence for d in out} == {0.9, 0.8}