mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 09:06:35 +00:00
[AZ-171] Add TensorRT tests, AC coverage gate in implement skill, optimize test infrastructure
- Add TensorRT export tests with graceful skip when no GPU available - Add AC test coverage verification step (Step 8) to implement skill - Add test coverage gap analysis to new-task skill - Move exported_models fixture to conftest.py as session-scoped (shared across modules) - Reorder tests: e2e training runs first so images/labels are available for all tests - Consolidate teardown into single session-level cleanup in conftest.py - Fix infrastructure tests to count files dynamically instead of hardcoded 20 Made-with: Cursor
This commit is contained in:
+86
-28
@@ -5,41 +5,23 @@ import cv2
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import pytest
|
||||
import torch
|
||||
from ultralytics import YOLO
|
||||
|
||||
import constants as c
|
||||
import exports as exports_mod
|
||||
|
||||
_HAS_TENSORRT = torch.cuda.is_available()
|
||||
try:
|
||||
import tensorrt
|
||||
except ImportError:
|
||||
_HAS_TENSORRT = False
|
||||
|
||||
_TESTS_DIR = Path(__file__).resolve().parent
|
||||
_CONFIG_TEST = _TESTS_DIR.parent / "config.test.yaml"
|
||||
_DATASET_IMAGES = _TESTS_DIR / "root" / "data" / "images"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def exported_models(tmp_path_factory):
|
||||
# Arrange
|
||||
tmp = tmp_path_factory.mktemp("export")
|
||||
model_dir = tmp / "models"
|
||||
model_dir.mkdir()
|
||||
|
||||
pt_path = str(model_dir / "test.pt")
|
||||
YOLO("yolo11n.pt").save(pt_path)
|
||||
|
||||
old_config = c.config
|
||||
c.config = c.Config.from_yaml(str(_CONFIG_TEST), root=str(tmp))
|
||||
|
||||
# Act
|
||||
exports_mod.export_onnx(pt_path)
|
||||
exports_mod.export_coreml(pt_path)
|
||||
|
||||
yield {
|
||||
"onnx": str(next(model_dir.glob("*.onnx"))),
|
||||
"model_dir": model_dir,
|
||||
}
|
||||
|
||||
c.config = old_config
|
||||
|
||||
|
||||
class TestOnnxExport:
|
||||
def test_onnx_file_created(self, exported_models):
|
||||
# Assert
|
||||
@@ -59,7 +41,7 @@ class TestOnnxExport:
|
||||
# Arrange
|
||||
session = ort.InferenceSession(exported_models["onnx"], providers=["CPUExecutionProvider"])
|
||||
meta = session.get_inputs()[0]
|
||||
imgsz = c.config.export.onnx_imgsz
|
||||
imgsz = exported_models["imgsz"]
|
||||
imgs = sorted(_DATASET_IMAGES.glob("*.jpg"))
|
||||
if not imgs:
|
||||
pytest.skip("no test images")
|
||||
@@ -77,7 +59,7 @@ class TestOnnxExport:
|
||||
# Arrange
|
||||
session = ort.InferenceSession(exported_models["onnx"], providers=["CPUExecutionProvider"])
|
||||
meta = session.get_inputs()[0]
|
||||
imgsz = c.config.export.onnx_imgsz
|
||||
imgsz = exported_models["imgsz"]
|
||||
imgs = sorted(_DATASET_IMAGES.glob("*.jpg"))
|
||||
if not imgs:
|
||||
pytest.skip("no test images")
|
||||
@@ -93,6 +75,82 @@ class TestOnnxExport:
|
||||
assert out[0].shape[0] == 4
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _HAS_TENSORRT, reason="TensorRT requires NVIDIA GPU and tensorrt package")
|
||||
class TestTensorrtExport:
|
||||
@pytest.fixture(scope="class")
|
||||
def tensorrt_model(self, exported_models):
|
||||
# Arrange
|
||||
model_dir = exported_models["model_dir"]
|
||||
pt_path = exported_models["pt_path"]
|
||||
old_config = c.config
|
||||
c.config = c.Config.from_yaml(str(_CONFIG_TEST), root=str(model_dir.parent))
|
||||
|
||||
# Act
|
||||
exports_mod.export_tensorrt(pt_path)
|
||||
|
||||
c.config = old_config
|
||||
engines = list(model_dir.glob("*.engine"))
|
||||
yield {
|
||||
"engine": str(engines[0]) if engines else None,
|
||||
"model_dir": model_dir,
|
||||
"imgsz": exported_models["imgsz"],
|
||||
}
|
||||
|
||||
for e in model_dir.glob("*.engine"):
|
||||
e.unlink(missing_ok=True)
|
||||
|
||||
def test_tensorrt_engine_created(self, tensorrt_model):
|
||||
# Assert
|
||||
assert tensorrt_model["engine"] is not None
|
||||
p = Path(tensorrt_model["engine"])
|
||||
assert p.exists()
|
||||
assert p.stat().st_size > 0
|
||||
|
||||
def test_tensorrt_inference_batch_1(self, tensorrt_model):
|
||||
# Arrange
|
||||
assert tensorrt_model["engine"] is not None
|
||||
imgs = sorted(_DATASET_IMAGES.glob("*.jpg"))
|
||||
if not imgs:
|
||||
pytest.skip("no test images")
|
||||
model = YOLO(tensorrt_model["engine"])
|
||||
|
||||
# Act
|
||||
results = model.predict(source=str(imgs[0]), imgsz=tensorrt_model["imgsz"], verbose=False)
|
||||
|
||||
# Assert
|
||||
assert len(results) == 1
|
||||
assert results[0].boxes is not None
|
||||
|
||||
def test_tensorrt_inference_batch_multiple(self, tensorrt_model):
|
||||
# Arrange
|
||||
assert tensorrt_model["engine"] is not None
|
||||
imgs = sorted(_DATASET_IMAGES.glob("*.jpg"))
|
||||
if len(imgs) < 4:
|
||||
pytest.skip("need at least 4 test images")
|
||||
model = YOLO(tensorrt_model["engine"])
|
||||
|
||||
# Act
|
||||
results = model.predict(source=[str(p) for p in imgs[:4]], imgsz=tensorrt_model["imgsz"], verbose=False)
|
||||
|
||||
# Assert
|
||||
assert len(results) == 4
|
||||
|
||||
def test_tensorrt_inference_batch_max(self, tensorrt_model):
|
||||
# Arrange
|
||||
assert tensorrt_model["engine"] is not None
|
||||
imgs = sorted(_DATASET_IMAGES.glob("*.jpg"))
|
||||
if not imgs:
|
||||
pytest.skip("no test images")
|
||||
model = YOLO(tensorrt_model["engine"])
|
||||
sources = [str(imgs[0])] * 8
|
||||
|
||||
# Act
|
||||
results = model.predict(source=sources, imgsz=tensorrt_model["imgsz"], verbose=False)
|
||||
|
||||
# Assert
|
||||
assert len(results) == 8
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform != "darwin", reason="CoreML requires macOS")
|
||||
class TestCoremlExport:
|
||||
def test_coreml_package_created(self, exported_models):
|
||||
@@ -117,7 +175,7 @@ class TestCoremlExport:
|
||||
model = YOLO(str(pkgs[0]))
|
||||
|
||||
# Act
|
||||
results = model.predict(source=str(imgs[0]), imgsz=c.config.export.onnx_imgsz, verbose=False)
|
||||
results = model.predict(source=str(imgs[0]), imgsz=exported_models["imgsz"], verbose=False)
|
||||
|
||||
# Assert
|
||||
assert len(results) == 1
|
||||
|
||||
Reference in New Issue
Block a user