mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 09:06:35 +00:00
[AZ-161] [AZ-162] [AZ-163] Add ONNX inference, NMS, annotation queue tests
Made-with: Cursor
This commit is contained in:
@@ -0,0 +1,67 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from inference.inference import Inference
|
||||
from inference.onnx_engine import OnnxEngine
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def onnx_engine_session(fixture_onnx_model):
|
||||
return OnnxEngine(fixture_onnx_model)
|
||||
|
||||
|
||||
def test_bt_inf_01_model_loads(onnx_engine_session):
|
||||
engine = onnx_engine_session
|
||||
assert engine.input_shape is not None
|
||||
assert len(engine.input_shape) >= 4
|
||||
h, w = engine.get_input_shape()
|
||||
assert h > 0 and w > 0
|
||||
assert engine.get_batch_size() > 0
|
||||
|
||||
|
||||
def test_bt_inf_02_inference_returns_output(onnx_engine_session, fixture_images_dir):
|
||||
engine = onnx_engine_session
|
||||
imgs = sorted(fixture_images_dir.glob("*.jpg"))
|
||||
assert imgs
|
||||
frame = cv2.imread(str(imgs[0]))
|
||||
assert frame is not None
|
||||
model_height, model_width = engine.get_input_shape()
|
||||
frames = [frame] * engine.get_batch_size()
|
||||
blobs = [
|
||||
cv2.dnn.blobFromImage(f, 1.0 / 255.0, (model_width, model_height), (0, 0, 0), swapRB=True, crop=False)
|
||||
for f in frames
|
||||
]
|
||||
blob = np.vstack(blobs)
|
||||
out = engine.run(blob)
|
||||
assert isinstance(out, list)
|
||||
assert len(out) > 0
|
||||
assert isinstance(out[0], np.ndarray)
|
||||
assert out[0].ndim == 3
|
||||
assert out[0].shape[0] == engine.get_batch_size()
|
||||
assert out[0].shape[2] >= 6
|
||||
|
||||
|
||||
def test_bt_inf_03_postprocess_valid_detections(onnx_engine_session, fixture_images_dir):
|
||||
engine = onnx_engine_session
|
||||
inf = Inference(engine, 0.1, 0.3)
|
||||
imgs = sorted(fixture_images_dir.glob("*.jpg"))
|
||||
frame = cv2.imread(str(imgs[0]))
|
||||
assert frame is not None
|
||||
n = engine.get_batch_size()
|
||||
batch_frames = [frame] * n
|
||||
ts = list(range(n))
|
||||
blob = inf.preprocess(batch_frames)
|
||||
output = engine.run(blob)
|
||||
anns = inf.postprocess(batch_frames, ts, output)
|
||||
assert isinstance(anns, list)
|
||||
assert len(anns) == len(batch_frames)
|
||||
for ann in anns:
|
||||
assert ann.detections is not None
|
||||
for d in ann.detections:
|
||||
assert 0 <= d.x <= 1
|
||||
assert 0 <= d.y <= 1
|
||||
assert 0 <= d.w <= 1
|
||||
assert 0 <= d.h <= 1
|
||||
assert 0 <= d.cls <= 79
|
||||
assert 0 <= d.confidence <= 1
|
||||
Reference in New Issue
Block a user