mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 22:06:36 +00:00
462a4826e8
Made-with: Cursor
68 lines
2.1 KiB
Python
68 lines
2.1 KiB
Python
import cv2
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from inference.inference import Inference
|
|
from inference.onnx_engine import OnnxEngine
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def onnx_engine_session(fixture_onnx_model):
|
|
return OnnxEngine(fixture_onnx_model)
|
|
|
|
|
|
def test_bt_inf_01_model_loads(onnx_engine_session):
|
|
engine = onnx_engine_session
|
|
assert engine.input_shape is not None
|
|
assert len(engine.input_shape) >= 4
|
|
h, w = engine.get_input_shape()
|
|
assert h > 0 and w > 0
|
|
assert engine.get_batch_size() > 0
|
|
|
|
|
|
def test_bt_inf_02_inference_returns_output(onnx_engine_session, fixture_images_dir):
|
|
engine = onnx_engine_session
|
|
imgs = sorted(fixture_images_dir.glob("*.jpg"))
|
|
assert imgs
|
|
frame = cv2.imread(str(imgs[0]))
|
|
assert frame is not None
|
|
model_height, model_width = engine.get_input_shape()
|
|
frames = [frame] * engine.get_batch_size()
|
|
blobs = [
|
|
cv2.dnn.blobFromImage(f, 1.0 / 255.0, (model_width, model_height), (0, 0, 0), swapRB=True, crop=False)
|
|
for f in frames
|
|
]
|
|
blob = np.vstack(blobs)
|
|
out = engine.run(blob)
|
|
assert isinstance(out, list)
|
|
assert len(out) > 0
|
|
assert isinstance(out[0], np.ndarray)
|
|
assert out[0].ndim == 3
|
|
assert out[0].shape[0] == engine.get_batch_size()
|
|
assert out[0].shape[2] >= 6
|
|
|
|
|
|
def test_bt_inf_03_postprocess_valid_detections(onnx_engine_session, fixture_images_dir):
|
|
engine = onnx_engine_session
|
|
inf = Inference(engine, 0.1, 0.3)
|
|
imgs = sorted(fixture_images_dir.glob("*.jpg"))
|
|
frame = cv2.imread(str(imgs[0]))
|
|
assert frame is not None
|
|
n = engine.get_batch_size()
|
|
batch_frames = [frame] * n
|
|
ts = list(range(n))
|
|
blob = inf.preprocess(batch_frames)
|
|
output = engine.run(blob)
|
|
anns = inf.postprocess(batch_frames, ts, output)
|
|
assert isinstance(anns, list)
|
|
assert len(anns) == len(batch_frames)
|
|
for ann in anns:
|
|
assert ann.detections is not None
|
|
for d in ann.detections:
|
|
assert 0 <= d.x <= 1
|
|
assert 0 <= d.y <= 1
|
|
assert 0 <= d.w <= 1
|
|
assert 0 <= d.h <= 1
|
|
assert 0 <= d.cls <= 79
|
|
assert 0 <= d.confidence <= 1
|