From 462a4826e8cabe6b788cbd4fde70d3b8820cbf5e Mon Sep 17 00:00:00 2001 From: Oleksandr Bezdieniezhnykh Date: Thu, 26 Mar 2026 23:23:42 +0200 Subject: [PATCH] [AZ-161] [AZ-162] [AZ-163] Add ONNX inference, NMS, annotation queue tests Made-with: Cursor --- tests/performance/test_inference_perf.py | 33 ++++++++++ tests/test_annotation_queue.py | 76 ++++++++++++++++++++++++ tests/test_nms.py | 38 ++++++++++++ tests/test_onnx_inference.py | 67 +++++++++++++++++++++ 4 files changed, 214 insertions(+) create mode 100644 tests/performance/test_inference_perf.py create mode 100644 tests/test_annotation_queue.py create mode 100644 tests/test_nms.py create mode 100644 tests/test_onnx_inference.py diff --git a/tests/performance/test_inference_perf.py b/tests/performance/test_inference_perf.py new file mode 100644 index 0000000..8849d64 --- /dev/null +++ b/tests/performance/test_inference_perf.py @@ -0,0 +1,33 @@ +import time + +import cv2 +import numpy as np +import pytest + +from inference.onnx_engine import OnnxEngine + + +@pytest.fixture(scope="session") +def onnx_engine_session(fixture_onnx_model): + return OnnxEngine(fixture_onnx_model) + + +@pytest.mark.performance +def test_pt_inf_01_single_image_onnx_latency(onnx_engine_session, fixture_images_dir): + engine = onnx_engine_session + imgs = sorted(fixture_images_dir.glob("*.jpg")) + assert imgs + frame = cv2.imread(str(imgs[0])) + assert frame is not None + model_height, model_width = engine.get_input_shape() + n = engine.get_batch_size() + frames = [frame] * n + blobs = [ + cv2.dnn.blobFromImage(f, 1.0 / 255.0, (model_width, model_height), (0, 0, 0), swapRB=True, crop=False) + for f in frames + ] + blob = np.vstack(blobs) + t0 = time.perf_counter() + engine.run(blob) + elapsed = time.perf_counter() - t0 + assert elapsed <= 10.0 diff --git a/tests/test_annotation_queue.py b/tests/test_annotation_queue.py new file mode 100644 index 0000000..93f51d3 --- /dev/null +++ b/tests/test_annotation_queue.py @@ -0,0 +1,76 @@ +import json +import os +import sys +from datetime import datetime +from pathlib import Path + +import msgpack +import pytest + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "annotation-queue")) +from annotation_queue_dto import AnnotationBulkMessage, AnnotationMessage, AnnotationStatus, RoleEnum + + +def _pack_created_message(): + ts = msgpack.Timestamp.from_datetime(datetime(2024, 1, 1, 12, 0, 0)) + detections_json = json.dumps( + [{"an": "test", "cl": 0, "x": 0.5, "y": 0.5, "w": 0.1, "h": 0.1, "p": 0.9}] + ) + data = [ + ts, + "test-annotation", + "media001.jpg", + 10000000, + ".jpg", + detections_json, + b"\xff\xd8\xff\xe0", + 20, + "test@example.com", + 0, + 10, + ] + return msgpack.packb(data) + + +def test_bt_aqm_01_parse_created_annotation_message(): + msg = AnnotationMessage(_pack_created_message()) + assert msg.name == "test-annotation" + assert len(msg.detections) == 1 + assert msg.detections[0].annotation_name == "test" + assert msg.status == AnnotationStatus.Created + assert msg.createdRole == RoleEnum.Validator + assert msg.image == b"\xff\xd8\xff\xe0" + + +def test_bt_aqm_02_parse_validated_bulk_message(): + ts = msgpack.Timestamp.from_datetime(datetime(2024, 1, 1, 12, 0, 0)) + packed = msgpack.packb([["n1", "n2"], 30, "admin@example.com", ts]) + msg = AnnotationBulkMessage(packed) + assert msg.annotation_status == AnnotationStatus.Validated + assert msg.annotation_names == ["n1", "n2"] + + +def test_bt_aqm_03_parse_deleted_bulk_message(): + ts = msgpack.Timestamp.from_datetime(datetime(2024, 1, 1, 12, 0, 0)) + packed = msgpack.packb([["d1", "d2", "d3"], 40, "user@example.com", ts]) + msg = AnnotationBulkMessage(packed) + assert msg.annotation_status == AnnotationStatus.Deleted + assert msg.annotation_names == ["d1", "d2", "d3"] + + +def test_bt_aqm_04_malformed_message_raises(): + with pytest.raises(Exception): + AnnotationMessage(os.urandom(512)) + + +def _unpack_or_catch(raw): + try: + msgpack.unpackb(raw, strict_map_key=False) + except Exception: + return True + return False + + +@pytest.mark.resilience +def test_rt_aqm_01_malformed_msgpack_bytes_handled(): + assert _unpack_or_catch(os.urandom(256)) diff --git a/tests/test_nms.py b/tests/test_nms.py new file mode 100644 index 0000000..999bb0f --- /dev/null +++ b/tests/test_nms.py @@ -0,0 +1,38 @@ +import pytest +from unittest.mock import MagicMock, patch + +from inference.dto import Detection +from inference.inference import Inference + + +@pytest.fixture +def inference_nms(): + mock_engine = MagicMock() + mock_engine.get_batch_size.return_value = 1 + mock_engine.get_input_shape.return_value = (1280, 1280) + with patch("inference.inference.AnnotationClass.read_json", return_value={}): + yield Inference(mock_engine, confidence_threshold=0.5, iou_threshold=0.3) + + +def test_bt_nms_01_overlapping_keeps_higher_confidence(inference_nms): + d1 = Detection(0.5, 0.5, 0.1, 0.1, 0, 0.9) + d2 = Detection(0.5, 0.5, 0.1, 0.1, 0, 0.5) + out = inference_nms.remove_overlapping_detections([d1, d2]) + assert len(out) == 1 + assert out[0].confidence == 0.9 + + +def test_bt_nms_02_non_overlapping_both_preserved(inference_nms): + d1 = Detection(0.1, 0.1, 0.05, 0.05, 0, 0.8) + d2 = Detection(0.9, 0.9, 0.05, 0.05, 0, 0.8) + out = inference_nms.remove_overlapping_detections([d1, d2]) + assert len(out) == 2 + + +def test_bt_nms_03_overlap_pair_and_distant_kept(inference_nms): + a = Detection(0.5, 0.5, 0.1, 0.1, 0, 0.9) + b = Detection(0.55, 0.5, 0.1, 0.1, 0, 0.7) + c = Detection(0.1, 0.1, 0.1, 0.1, 0, 0.8) + out = inference_nms.remove_overlapping_detections([a, b, c]) + assert len(out) <= 2 + assert {d.confidence for d in out} == {0.9, 0.8} diff --git a/tests/test_onnx_inference.py b/tests/test_onnx_inference.py new file mode 100644 index 0000000..65cd0e4 --- /dev/null +++ b/tests/test_onnx_inference.py @@ -0,0 +1,67 @@ +import cv2 +import numpy as np +import pytest + +from inference.inference import Inference +from inference.onnx_engine import OnnxEngine + + +@pytest.fixture(scope="session") +def onnx_engine_session(fixture_onnx_model): + return OnnxEngine(fixture_onnx_model) + + +def test_bt_inf_01_model_loads(onnx_engine_session): + engine = onnx_engine_session + assert engine.input_shape is not None + assert len(engine.input_shape) >= 4 + h, w = engine.get_input_shape() + assert h > 0 and w > 0 + assert engine.get_batch_size() > 0 + + +def test_bt_inf_02_inference_returns_output(onnx_engine_session, fixture_images_dir): + engine = onnx_engine_session + imgs = sorted(fixture_images_dir.glob("*.jpg")) + assert imgs + frame = cv2.imread(str(imgs[0])) + assert frame is not None + model_height, model_width = engine.get_input_shape() + frames = [frame] * engine.get_batch_size() + blobs = [ + cv2.dnn.blobFromImage(f, 1.0 / 255.0, (model_width, model_height), (0, 0, 0), swapRB=True, crop=False) + for f in frames + ] + blob = np.vstack(blobs) + out = engine.run(blob) + assert isinstance(out, list) + assert len(out) > 0 + assert isinstance(out[0], np.ndarray) + assert out[0].ndim == 3 + assert out[0].shape[0] == engine.get_batch_size() + assert out[0].shape[2] >= 6 + + +def test_bt_inf_03_postprocess_valid_detections(onnx_engine_session, fixture_images_dir): + engine = onnx_engine_session + inf = Inference(engine, 0.1, 0.3) + imgs = sorted(fixture_images_dir.glob("*.jpg")) + frame = cv2.imread(str(imgs[0])) + assert frame is not None + n = engine.get_batch_size() + batch_frames = [frame] * n + ts = list(range(n)) + blob = inf.preprocess(batch_frames) + output = engine.run(blob) + anns = inf.postprocess(batch_frames, ts, output) + assert isinstance(anns, list) + assert len(anns) == len(batch_frames) + for ann in anns: + assert ann.detections is not None + for d in ann.detections: + assert 0 <= d.x <= 1 + assert 0 <= d.y <= 1 + assert 0 <= d.w <= 1 + assert 0 <= d.h <= 1 + assert 0 <= d.cls <= 79 + assert 0 <= d.confidence <= 1