Files
ai-training/tests/test_export.py
T
2026-03-28 17:33:40 +02:00

107 lines
3.2 KiB
Python

from pathlib import Path
import cv2
import numpy as np
import onnxruntime as ort
import pytest
from ultralytics import YOLO
import constants as c
import exports as exports_mod
_TESTS_DIR = Path(__file__).resolve().parent
_CONFIG_TEST = _TESTS_DIR.parent / "config.test.yaml"
_DATASET_IMAGES = _TESTS_DIR / "root" / "data" / "images"
@pytest.fixture(scope="module")
def exported_models(tmp_path_factory):
# Arrange
tmp = tmp_path_factory.mktemp("export")
model_dir = tmp / "models"
model_dir.mkdir()
pt_path = str(model_dir / "test.pt")
YOLO("yolo11n.pt").save(pt_path)
old_config = c.config
c.config = c.Config.from_yaml(str(_CONFIG_TEST), root=str(tmp))
# Act
exports_mod.export_onnx(pt_path)
exports_mod.export_coreml(pt_path)
yield {
"onnx": str(next(model_dir.glob("*.onnx"))),
"model_dir": model_dir,
}
c.config = old_config
class TestOnnxExport:
def test_onnx_file_created(self, exported_models):
# Assert
p = Path(exported_models["onnx"])
assert p.exists()
assert p.stat().st_size > 0
def test_onnx_batch_dimension_is_dynamic(self, exported_models):
# Arrange
session = ort.InferenceSession(exported_models["onnx"], providers=["CPUExecutionProvider"])
batch_dim = session.get_inputs()[0].shape[0]
# Assert
assert isinstance(batch_dim, str) or batch_dim == -1
def test_onnx_inference_batch_1(self, exported_models):
# Arrange
session = ort.InferenceSession(exported_models["onnx"], providers=["CPUExecutionProvider"])
meta = session.get_inputs()[0]
imgsz = c.config.export.onnx_imgsz
imgs = sorted(_DATASET_IMAGES.glob("*.jpg"))
if not imgs:
pytest.skip("no test images")
blob = cv2.dnn.blobFromImage(
cv2.imread(str(imgs[0])), 1.0 / 255.0, (imgsz, imgsz), (0, 0, 0), swapRB=True, crop=False,
)
# Act
out = session.run(None, {meta.name: blob})
# Assert
assert out[0].shape[0] == 1
def test_onnx_inference_batch_multiple(self, exported_models):
# Arrange
session = ort.InferenceSession(exported_models["onnx"], providers=["CPUExecutionProvider"])
meta = session.get_inputs()[0]
imgsz = c.config.export.onnx_imgsz
imgs = sorted(_DATASET_IMAGES.glob("*.jpg"))
if not imgs:
pytest.skip("no test images")
single = cv2.dnn.blobFromImage(
cv2.imread(str(imgs[0])), 1.0 / 255.0, (imgsz, imgsz), (0, 0, 0), swapRB=True, crop=False,
)
batch = np.concatenate([single] * 4, axis=0)
# Act
out = session.run(None, {meta.name: batch})
# Assert
assert out[0].shape[0] == 4
class TestCoremlExport:
def test_coreml_package_created(self, exported_models):
# Assert
pkgs = list(exported_models["model_dir"].glob("*.mlpackage"))
assert len(pkgs) >= 1
def test_coreml_package_has_model(self, exported_models):
# Assert
pkgs = list(exported_models["model_dir"].glob("*.mlpackage"))
assert len(pkgs) >= 1
model_file = pkgs[0] / "Data" / "com.apple.CoreML" / "model.mlmodel"
assert model_file.exists()