diff --git a/tests/test_export.py b/tests/test_export.py index 866fd00..1e6f58f 100644 --- a/tests/test_export.py +++ b/tests/test_export.py @@ -1,4 +1,3 @@ -import shutil from pathlib import Path import cv2 @@ -17,29 +16,23 @@ _DATASET_IMAGES = _TESTS_DIR / "root" / "data" / "images" @pytest.fixture(scope="module") def exported_models(tmp_path_factory): + # Arrange tmp = tmp_path_factory.mktemp("export") model_dir = tmp / "models" model_dir.mkdir() - # Arrange pt_path = str(model_dir / "test.pt") - model = YOLO("yolo11n.pt") - model.save(pt_path) + YOLO("yolo11n.pt").save(pt_path) old_config = c.config c.config = c.Config.from_yaml(str(_CONFIG_TEST), root=str(tmp)) - onnx_out = model_dir / "test.onnx" - coreml_out = model_dir / "test.mlpackage" - # Act exports_mod.export_onnx(pt_path) exports_mod.export_coreml(pt_path) yield { - "pt": pt_path, - "onnx": str(onnx_out) if onnx_out.exists() else str(next(model_dir.glob("*.onnx"))), - "coreml": str(coreml_out) if coreml_out.exists() else str(next(model_dir.glob("*.mlpackage"), "")), + "onnx": str(next(model_dir.glob("*.onnx"))), "model_dir": model_dir, }