mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 19:56:36 +00:00
Refactor constants management to use Pydantic BaseModel for configuration
- Replaced module-level path variables in constants.py with a structured Pydantic Config class. - Updated all relevant modules (train.py, augmentation.py, exports.py, dataset-visualiser.py, manual_run.py) to access paths through the new config structure. - Fixed bugs related to image processing and model saving. - Enhanced test infrastructure to accommodate the new configuration approach. This refactor improves code maintainability and clarity by centralizing configuration management.
This commit is contained in:
@@ -0,0 +1,113 @@
|
||||
import sys
|
||||
import types
|
||||
import importlib
|
||||
import shutil
|
||||
from os import path as osp
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
for _n in ("boto3", "netron", "requests"):
|
||||
if _n not in sys.modules:
|
||||
sys.modules[_n] = types.ModuleType(_n)
|
||||
|
||||
for _k in [k for k in sys.modules if k == "ultralytics" or k.startswith("ultralytics.")]:
|
||||
del sys.modules[_k]
|
||||
from ultralytics import YOLO
|
||||
|
||||
for _m in ("exports", "train"):
|
||||
if _m in sys.modules:
|
||||
importlib.reload(sys.modules[_m])
|
||||
|
||||
import constants as c
|
||||
import train as train_mod
|
||||
import exports as exports_mod
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
_DATASET_IMAGES = _PROJECT_ROOT / "_docs/00_problem/input_data/dataset/images"
|
||||
_DATASET_LABELS = _PROJECT_ROOT / "_docs/00_problem/input_data/dataset/labels"
|
||||
_CONFIG_TEST = _PROJECT_ROOT / "config.test.yaml"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def e2e_result(tmp_path_factory):
|
||||
base = tmp_path_factory.mktemp("e2e")
|
||||
|
||||
old_config = c.config
|
||||
c.config = c.Config.from_yaml(str(_CONFIG_TEST), root=str(base / "azaion"))
|
||||
|
||||
data_img = Path(c.config.data_images_dir)
|
||||
data_lbl = Path(c.config.data_labels_dir)
|
||||
data_img.mkdir(parents=True)
|
||||
data_lbl.mkdir(parents=True)
|
||||
Path(c.config.models_dir).mkdir(parents=True)
|
||||
|
||||
for img in sorted(_DATASET_IMAGES.glob("*.jpg")):
|
||||
shutil.copy2(img, data_img / img.name)
|
||||
lbl = _DATASET_LABELS / f"{img.stem}.txt"
|
||||
if lbl.exists():
|
||||
shutil.copy2(lbl, data_lbl / lbl.name)
|
||||
|
||||
from augmentation import Augmentator
|
||||
Augmentator().augment_annotations()
|
||||
|
||||
train_mod.train_dataset()
|
||||
|
||||
exports_mod.export_onnx(c.config.current_pt_model)
|
||||
exports_mod.export_coreml(c.config.current_pt_model)
|
||||
|
||||
today_ds = osp.join(c.config.datasets_dir, train_mod.today_folder)
|
||||
|
||||
yield {
|
||||
"today_dataset": today_ds,
|
||||
}
|
||||
|
||||
c.config = old_config
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
class TestTrainingPipeline:
|
||||
def test_augmentation_produced_output(self, e2e_result):
|
||||
proc = Path(c.config.processed_images_dir)
|
||||
assert len(list(proc.glob("*.jpg"))) == 800
|
||||
|
||||
def test_dataset_formed(self, e2e_result):
|
||||
base = Path(e2e_result["today_dataset"])
|
||||
for split in ("train", "valid", "test"):
|
||||
assert (base / split / "images").is_dir()
|
||||
assert (base / split / "labels").is_dir()
|
||||
total = sum(
|
||||
len(list((base / s / "images").glob("*.jpg")))
|
||||
for s in ("train", "valid", "test")
|
||||
)
|
||||
assert total == 800
|
||||
|
||||
def test_data_yaml_created(self, e2e_result):
|
||||
yaml_path = Path(e2e_result["today_dataset"]) / "data.yaml"
|
||||
assert yaml_path.exists()
|
||||
content = yaml_path.read_text()
|
||||
assert "nc: 80" in content
|
||||
assert "train:" in content
|
||||
assert "val:" in content
|
||||
|
||||
def test_training_produces_pt(self, e2e_result):
|
||||
pt = Path(c.config.current_pt_model)
|
||||
assert pt.exists()
|
||||
assert pt.stat().st_size > 0
|
||||
|
||||
def test_export_onnx(self, e2e_result):
|
||||
p = Path(c.config.current_onnx_model)
|
||||
assert p.exists()
|
||||
assert p.suffix == ".onnx"
|
||||
assert p.stat().st_size > 0
|
||||
|
||||
def test_export_coreml(self, e2e_result):
|
||||
pkgs = list(Path(c.config.models_dir).glob("*.mlpackage"))
|
||||
assert len(pkgs) >= 1
|
||||
|
||||
def test_onnx_inference(self, e2e_result):
|
||||
onnx_model = YOLO(c.config.current_onnx_model)
|
||||
img = sorted(_DATASET_IMAGES.glob("*.jpg"))[0]
|
||||
results = onnx_model.predict(source=str(img), imgsz=c.config.export.onnx_imgsz, verbose=False)
|
||||
assert len(results) == 1
|
||||
assert results[0].boxes is not None
|
||||
Reference in New Issue
Block a user