mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 09:16:36 +00:00
[AZ-161] [AZ-162] [AZ-163] Add ONNX inference, NMS, annotation queue tests
Made-with: Cursor
This commit is contained in:
@@ -0,0 +1,33 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from inference.onnx_engine import OnnxEngine
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def onnx_engine_session(fixture_onnx_model):
|
||||||
|
return OnnxEngine(fixture_onnx_model)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.performance
|
||||||
|
def test_pt_inf_01_single_image_onnx_latency(onnx_engine_session, fixture_images_dir):
|
||||||
|
engine = onnx_engine_session
|
||||||
|
imgs = sorted(fixture_images_dir.glob("*.jpg"))
|
||||||
|
assert imgs
|
||||||
|
frame = cv2.imread(str(imgs[0]))
|
||||||
|
assert frame is not None
|
||||||
|
model_height, model_width = engine.get_input_shape()
|
||||||
|
n = engine.get_batch_size()
|
||||||
|
frames = [frame] * n
|
||||||
|
blobs = [
|
||||||
|
cv2.dnn.blobFromImage(f, 1.0 / 255.0, (model_width, model_height), (0, 0, 0), swapRB=True, crop=False)
|
||||||
|
for f in frames
|
||||||
|
]
|
||||||
|
blob = np.vstack(blobs)
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
engine.run(blob)
|
||||||
|
elapsed = time.perf_counter() - t0
|
||||||
|
assert elapsed <= 10.0
|
||||||
@@ -0,0 +1,76 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import msgpack
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "annotation-queue"))
|
||||||
|
from annotation_queue_dto import AnnotationBulkMessage, AnnotationMessage, AnnotationStatus, RoleEnum
|
||||||
|
|
||||||
|
|
||||||
|
def _pack_created_message():
|
||||||
|
ts = msgpack.Timestamp.from_datetime(datetime(2024, 1, 1, 12, 0, 0))
|
||||||
|
detections_json = json.dumps(
|
||||||
|
[{"an": "test", "cl": 0, "x": 0.5, "y": 0.5, "w": 0.1, "h": 0.1, "p": 0.9}]
|
||||||
|
)
|
||||||
|
data = [
|
||||||
|
ts,
|
||||||
|
"test-annotation",
|
||||||
|
"media001.jpg",
|
||||||
|
10000000,
|
||||||
|
".jpg",
|
||||||
|
detections_json,
|
||||||
|
b"\xff\xd8\xff\xe0",
|
||||||
|
20,
|
||||||
|
"test@example.com",
|
||||||
|
0,
|
||||||
|
10,
|
||||||
|
]
|
||||||
|
return msgpack.packb(data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bt_aqm_01_parse_created_annotation_message():
|
||||||
|
msg = AnnotationMessage(_pack_created_message())
|
||||||
|
assert msg.name == "test-annotation"
|
||||||
|
assert len(msg.detections) == 1
|
||||||
|
assert msg.detections[0].annotation_name == "test"
|
||||||
|
assert msg.status == AnnotationStatus.Created
|
||||||
|
assert msg.createdRole == RoleEnum.Validator
|
||||||
|
assert msg.image == b"\xff\xd8\xff\xe0"
|
||||||
|
|
||||||
|
|
||||||
|
def test_bt_aqm_02_parse_validated_bulk_message():
|
||||||
|
ts = msgpack.Timestamp.from_datetime(datetime(2024, 1, 1, 12, 0, 0))
|
||||||
|
packed = msgpack.packb([["n1", "n2"], 30, "admin@example.com", ts])
|
||||||
|
msg = AnnotationBulkMessage(packed)
|
||||||
|
assert msg.annotation_status == AnnotationStatus.Validated
|
||||||
|
assert msg.annotation_names == ["n1", "n2"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_bt_aqm_03_parse_deleted_bulk_message():
|
||||||
|
ts = msgpack.Timestamp.from_datetime(datetime(2024, 1, 1, 12, 0, 0))
|
||||||
|
packed = msgpack.packb([["d1", "d2", "d3"], 40, "user@example.com", ts])
|
||||||
|
msg = AnnotationBulkMessage(packed)
|
||||||
|
assert msg.annotation_status == AnnotationStatus.Deleted
|
||||||
|
assert msg.annotation_names == ["d1", "d2", "d3"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_bt_aqm_04_malformed_message_raises():
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
AnnotationMessage(os.urandom(512))
|
||||||
|
|
||||||
|
|
||||||
|
def _unpack_or_catch(raw):
|
||||||
|
try:
|
||||||
|
msgpack.unpackb(raw, strict_map_key=False)
|
||||||
|
except Exception:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.resilience
|
||||||
|
def test_rt_aqm_01_malformed_msgpack_bytes_handled():
|
||||||
|
assert _unpack_or_catch(os.urandom(256))
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from inference.dto import Detection
|
||||||
|
from inference.inference import Inference
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def inference_nms():
|
||||||
|
mock_engine = MagicMock()
|
||||||
|
mock_engine.get_batch_size.return_value = 1
|
||||||
|
mock_engine.get_input_shape.return_value = (1280, 1280)
|
||||||
|
with patch("inference.inference.AnnotationClass.read_json", return_value={}):
|
||||||
|
yield Inference(mock_engine, confidence_threshold=0.5, iou_threshold=0.3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bt_nms_01_overlapping_keeps_higher_confidence(inference_nms):
|
||||||
|
d1 = Detection(0.5, 0.5, 0.1, 0.1, 0, 0.9)
|
||||||
|
d2 = Detection(0.5, 0.5, 0.1, 0.1, 0, 0.5)
|
||||||
|
out = inference_nms.remove_overlapping_detections([d1, d2])
|
||||||
|
assert len(out) == 1
|
||||||
|
assert out[0].confidence == 0.9
|
||||||
|
|
||||||
|
|
||||||
|
def test_bt_nms_02_non_overlapping_both_preserved(inference_nms):
|
||||||
|
d1 = Detection(0.1, 0.1, 0.05, 0.05, 0, 0.8)
|
||||||
|
d2 = Detection(0.9, 0.9, 0.05, 0.05, 0, 0.8)
|
||||||
|
out = inference_nms.remove_overlapping_detections([d1, d2])
|
||||||
|
assert len(out) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_bt_nms_03_overlap_pair_and_distant_kept(inference_nms):
|
||||||
|
a = Detection(0.5, 0.5, 0.1, 0.1, 0, 0.9)
|
||||||
|
b = Detection(0.55, 0.5, 0.1, 0.1, 0, 0.7)
|
||||||
|
c = Detection(0.1, 0.1, 0.1, 0.1, 0, 0.8)
|
||||||
|
out = inference_nms.remove_overlapping_detections([a, b, c])
|
||||||
|
assert len(out) <= 2
|
||||||
|
assert {d.confidence for d in out} == {0.9, 0.8}
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from inference.inference import Inference
|
||||||
|
from inference.onnx_engine import OnnxEngine
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def onnx_engine_session(fixture_onnx_model):
|
||||||
|
return OnnxEngine(fixture_onnx_model)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bt_inf_01_model_loads(onnx_engine_session):
|
||||||
|
engine = onnx_engine_session
|
||||||
|
assert engine.input_shape is not None
|
||||||
|
assert len(engine.input_shape) >= 4
|
||||||
|
h, w = engine.get_input_shape()
|
||||||
|
assert h > 0 and w > 0
|
||||||
|
assert engine.get_batch_size() > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_bt_inf_02_inference_returns_output(onnx_engine_session, fixture_images_dir):
|
||||||
|
engine = onnx_engine_session
|
||||||
|
imgs = sorted(fixture_images_dir.glob("*.jpg"))
|
||||||
|
assert imgs
|
||||||
|
frame = cv2.imread(str(imgs[0]))
|
||||||
|
assert frame is not None
|
||||||
|
model_height, model_width = engine.get_input_shape()
|
||||||
|
frames = [frame] * engine.get_batch_size()
|
||||||
|
blobs = [
|
||||||
|
cv2.dnn.blobFromImage(f, 1.0 / 255.0, (model_width, model_height), (0, 0, 0), swapRB=True, crop=False)
|
||||||
|
for f in frames
|
||||||
|
]
|
||||||
|
blob = np.vstack(blobs)
|
||||||
|
out = engine.run(blob)
|
||||||
|
assert isinstance(out, list)
|
||||||
|
assert len(out) > 0
|
||||||
|
assert isinstance(out[0], np.ndarray)
|
||||||
|
assert out[0].ndim == 3
|
||||||
|
assert out[0].shape[0] == engine.get_batch_size()
|
||||||
|
assert out[0].shape[2] >= 6
|
||||||
|
|
||||||
|
|
||||||
|
def test_bt_inf_03_postprocess_valid_detections(onnx_engine_session, fixture_images_dir):
|
||||||
|
engine = onnx_engine_session
|
||||||
|
inf = Inference(engine, 0.1, 0.3)
|
||||||
|
imgs = sorted(fixture_images_dir.glob("*.jpg"))
|
||||||
|
frame = cv2.imread(str(imgs[0]))
|
||||||
|
assert frame is not None
|
||||||
|
n = engine.get_batch_size()
|
||||||
|
batch_frames = [frame] * n
|
||||||
|
ts = list(range(n))
|
||||||
|
blob = inf.preprocess(batch_frames)
|
||||||
|
output = engine.run(blob)
|
||||||
|
anns = inf.postprocess(batch_frames, ts, output)
|
||||||
|
assert isinstance(anns, list)
|
||||||
|
assert len(anns) == len(batch_frames)
|
||||||
|
for ann in anns:
|
||||||
|
assert ann.detections is not None
|
||||||
|
for d in ann.detections:
|
||||||
|
assert 0 <= d.x <= 1
|
||||||
|
assert 0 <= d.y <= 1
|
||||||
|
assert 0 <= d.w <= 1
|
||||||
|
assert 0 <= d.h <= 1
|
||||||
|
assert 0 <= d.cls <= 79
|
||||||
|
assert 0 <= d.confidence <= 1
|
||||||
Reference in New Issue
Block a user