mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 10:46:35 +00:00
142c6c4de8
- Replaced module-level path variables in constants.py with a structured Pydantic Config class. - Updated all relevant modules (train.py, augmentation.py, exports.py, dataset-visualiser.py, manual_run.py) to access paths through the new config structure. - Fixed bugs related to image processing and model saving. - Enhanced test infrastructure to accommodate the new configuration approach. This refactor improves code maintainability and clarity by centralizing configuration management.
114 lines
3.5 KiB
Python
114 lines
3.5 KiB
Python
import sys
|
|
import types
|
|
import importlib
|
|
import shutil
|
|
from os import path as osp
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
for _n in ("boto3", "netron", "requests"):
|
|
if _n not in sys.modules:
|
|
sys.modules[_n] = types.ModuleType(_n)
|
|
|
|
for _k in [k for k in sys.modules if k == "ultralytics" or k.startswith("ultralytics.")]:
|
|
del sys.modules[_k]
|
|
from ultralytics import YOLO
|
|
|
|
for _m in ("exports", "train"):
|
|
if _m in sys.modules:
|
|
importlib.reload(sys.modules[_m])
|
|
|
|
import constants as c
|
|
import train as train_mod
|
|
import exports as exports_mod
|
|
|
|
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
|
_DATASET_IMAGES = _PROJECT_ROOT / "_docs/00_problem/input_data/dataset/images"
|
|
_DATASET_LABELS = _PROJECT_ROOT / "_docs/00_problem/input_data/dataset/labels"
|
|
_CONFIG_TEST = _PROJECT_ROOT / "config.test.yaml"
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def e2e_result(tmp_path_factory):
|
|
base = tmp_path_factory.mktemp("e2e")
|
|
|
|
old_config = c.config
|
|
c.config = c.Config.from_yaml(str(_CONFIG_TEST), root=str(base / "azaion"))
|
|
|
|
data_img = Path(c.config.data_images_dir)
|
|
data_lbl = Path(c.config.data_labels_dir)
|
|
data_img.mkdir(parents=True)
|
|
data_lbl.mkdir(parents=True)
|
|
Path(c.config.models_dir).mkdir(parents=True)
|
|
|
|
for img in sorted(_DATASET_IMAGES.glob("*.jpg")):
|
|
shutil.copy2(img, data_img / img.name)
|
|
lbl = _DATASET_LABELS / f"{img.stem}.txt"
|
|
if lbl.exists():
|
|
shutil.copy2(lbl, data_lbl / lbl.name)
|
|
|
|
from augmentation import Augmentator
|
|
Augmentator().augment_annotations()
|
|
|
|
train_mod.train_dataset()
|
|
|
|
exports_mod.export_onnx(c.config.current_pt_model)
|
|
exports_mod.export_coreml(c.config.current_pt_model)
|
|
|
|
today_ds = osp.join(c.config.datasets_dir, train_mod.today_folder)
|
|
|
|
yield {
|
|
"today_dataset": today_ds,
|
|
}
|
|
|
|
c.config = old_config
|
|
|
|
|
|
@pytest.mark.e2e
|
|
class TestTrainingPipeline:
|
|
def test_augmentation_produced_output(self, e2e_result):
|
|
proc = Path(c.config.processed_images_dir)
|
|
assert len(list(proc.glob("*.jpg"))) == 800
|
|
|
|
def test_dataset_formed(self, e2e_result):
|
|
base = Path(e2e_result["today_dataset"])
|
|
for split in ("train", "valid", "test"):
|
|
assert (base / split / "images").is_dir()
|
|
assert (base / split / "labels").is_dir()
|
|
total = sum(
|
|
len(list((base / s / "images").glob("*.jpg")))
|
|
for s in ("train", "valid", "test")
|
|
)
|
|
assert total == 800
|
|
|
|
def test_data_yaml_created(self, e2e_result):
|
|
yaml_path = Path(e2e_result["today_dataset"]) / "data.yaml"
|
|
assert yaml_path.exists()
|
|
content = yaml_path.read_text()
|
|
assert "nc: 80" in content
|
|
assert "train:" in content
|
|
assert "val:" in content
|
|
|
|
def test_training_produces_pt(self, e2e_result):
|
|
pt = Path(c.config.current_pt_model)
|
|
assert pt.exists()
|
|
assert pt.stat().st_size > 0
|
|
|
|
def test_export_onnx(self, e2e_result):
|
|
p = Path(c.config.current_onnx_model)
|
|
assert p.exists()
|
|
assert p.suffix == ".onnx"
|
|
assert p.stat().st_size > 0
|
|
|
|
def test_export_coreml(self, e2e_result):
|
|
pkgs = list(Path(c.config.models_dir).glob("*.mlpackage"))
|
|
assert len(pkgs) >= 1
|
|
|
|
def test_onnx_inference(self, e2e_result):
|
|
onnx_model = YOLO(c.config.current_onnx_model)
|
|
img = sorted(_DATASET_IMAGES.glob("*.jpg"))[0]
|
|
results = onnx_model.predict(source=str(img), imgsz=c.config.export.onnx_imgsz, verbose=False)
|
|
assert len(results) == 1
|
|
assert results[0].boxes is not None
|