mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 07:06:36 +00:00
[AZ-171] Remove unused onnx_batch config, add macOS skip for CoreML tests
Made-with: Cursor
This commit is contained in:
@@ -8,4 +8,3 @@ training:
|
|||||||
|
|
||||||
export:
|
export:
|
||||||
onnx_imgsz: 320
|
onnx_imgsz: 320
|
||||||
onnx_batch: 1
|
|
||||||
|
|||||||
@@ -27,4 +27,3 @@ training:
|
|||||||
|
|
||||||
export:
|
export:
|
||||||
onnx_imgsz: 1280
|
onnx_imgsz: 1280
|
||||||
onnx_batch: 4
|
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ class TrainingConfig(BaseModel):
|
|||||||
|
|
||||||
class ExportConfig(BaseModel):
|
class ExportConfig(BaseModel):
|
||||||
onnx_imgsz: int = 1280
|
onnx_imgsz: int = 1280
|
||||||
onnx_batch: int = 4
|
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseModel):
|
class Config(BaseModel):
|
||||||
|
|||||||
+1
-4
@@ -25,9 +25,7 @@ def export_rknn(model_path):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def export_onnx(model_path, batch_size=None):
|
def export_onnx(model_path):
|
||||||
if batch_size is None:
|
|
||||||
batch_size = constants.config.export.onnx_batch
|
|
||||||
model = YOLO(model_path)
|
model = YOLO(model_path)
|
||||||
onnx_path = Path(model_path).stem + '.onnx'
|
onnx_path = Path(model_path).stem + '.onnx'
|
||||||
if path.exists(onnx_path):
|
if path.exists(onnx_path):
|
||||||
@@ -36,7 +34,6 @@ def export_onnx(model_path, batch_size=None):
|
|||||||
model.export(
|
model.export(
|
||||||
format="onnx",
|
format="onnx",
|
||||||
imgsz=constants.config.export.onnx_imgsz,
|
imgsz=constants.config.export.onnx_imgsz,
|
||||||
batch=batch_size,
|
|
||||||
dynamic=True,
|
dynamic=True,
|
||||||
simplify=True,
|
simplify=True,
|
||||||
nms=True,
|
nms=True,
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
@@ -92,6 +93,7 @@ class TestOnnxExport:
|
|||||||
assert out[0].shape[0] == 4
|
assert out[0].shape[0] == 4
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.platform != "darwin", reason="CoreML requires macOS")
|
||||||
class TestCoremlExport:
|
class TestCoremlExport:
|
||||||
def test_coreml_package_created(self, exported_models):
|
def test_coreml_package_created(self, exported_models):
|
||||||
# Assert
|
# Assert
|
||||||
|
|||||||
Reference in New Issue
Block a user