mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 22:26:36 +00:00
18b88ba9bf
- Updated `.gitignore` to remove committed test fixture data exclusions. - Increased batch size in `config.test.yaml` from 4 to 128 for training. - Simplified directory structure in `config.yaml` by removing unnecessary data paths. - Adjusted paths in `augmentation.py`, `dataset-visualiser.py`, and `exports.py` to align with the new configuration structure. - Enhanced `annotation_queue_handler.py` to utilize the updated configuration for directory management. - Added CSV logging of test results in `conftest.py` for better test reporting. These changes streamline the configuration management and enhance the testing framework, ensuring better organization and clarity in the project.
116 lines
3.7 KiB
Python
116 lines
3.7 KiB
Python
import os
|
|
import shutil
|
|
from os import path
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
from ultralytics import YOLO
|
|
|
|
import constants as c
|
|
import train as train_mod
|
|
import exports as exports_mod
|
|
|
|
_TESTS_DIR = Path(__file__).resolve().parent
|
|
_TEST_ROOT = _TESTS_DIR / "root"
|
|
_DATASET_IMAGES = _TEST_ROOT / "data" / "images"
|
|
_CONFIG_TEST = _TESTS_DIR.parent / "config.test.yaml"
|
|
_SOURCE_DATASET = _TESTS_DIR.parent / "_docs" / "00_problem" / "input_data" / "dataset"
|
|
|
|
|
|
def _hardlink_tree(src_dir: Path, dst_dir: Path):
|
|
dst_dir.mkdir(parents=True, exist_ok=True)
|
|
for f in src_dir.iterdir():
|
|
if f.is_file():
|
|
target = dst_dir / f.name
|
|
if not target.exists():
|
|
os.link(f, target)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def e2e_result():
|
|
# Arrange
|
|
src_images = _SOURCE_DATASET / "images"
|
|
src_labels = _SOURCE_DATASET / "labels"
|
|
if not src_images.is_dir() or not src_labels.is_dir():
|
|
pytest.skip("source dataset not found")
|
|
|
|
old_config = c.config
|
|
c.config = c.Config.from_yaml(str(_CONFIG_TEST), root=str(_TEST_ROOT))
|
|
|
|
dst_images = Path(c.config.images_dir)
|
|
dst_labels = Path(c.config.labels_dir)
|
|
for d in (dst_images, dst_labels, c.config.datasets_dir, c.config.models_dir, c.config.corrupted_dir):
|
|
shutil.rmtree(str(d), ignore_errors=True)
|
|
|
|
_hardlink_tree(src_images, dst_images)
|
|
_hardlink_tree(src_labels, dst_labels)
|
|
|
|
linked_count = len(list(dst_images.glob("*.jpg")))
|
|
|
|
Path(c.config.models_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
# Act
|
|
train_mod.train_dataset()
|
|
|
|
exports_mod.export_onnx(c.config.current_pt_model)
|
|
exports_mod.export_coreml(c.config.current_pt_model)
|
|
|
|
today_ds = path.join(c.config.datasets_dir, train_mod.today_folder)
|
|
|
|
yield {
|
|
"today_dataset": today_ds,
|
|
"linked_count": linked_count,
|
|
}
|
|
|
|
shutil.rmtree(str(dst_images), ignore_errors=True)
|
|
shutil.rmtree(str(dst_labels), ignore_errors=True)
|
|
shutil.rmtree(c.config.datasets_dir, ignore_errors=True)
|
|
shutil.rmtree(c.config.models_dir, ignore_errors=True)
|
|
shutil.rmtree(c.config.corrupted_dir, ignore_errors=True)
|
|
c.config = old_config
|
|
|
|
|
|
@pytest.mark.e2e
|
|
class TestTrainingPipeline:
|
|
def test_dataset_formed(self, e2e_result):
|
|
# Assert
|
|
base = Path(e2e_result["today_dataset"])
|
|
for split in ("train", "valid", "test"):
|
|
assert (base / split / "images").is_dir()
|
|
assert (base / split / "labels").is_dir()
|
|
total = sum(
|
|
len(list((base / s / "images").glob("*.jpg")))
|
|
for s in ("train", "valid", "test")
|
|
)
|
|
assert 0 < total <= e2e_result["linked_count"]
|
|
|
|
def test_data_yaml_created(self, e2e_result):
|
|
yaml_path = Path(e2e_result["today_dataset"]) / "data.yaml"
|
|
assert yaml_path.exists()
|
|
content = yaml_path.read_text()
|
|
assert "nc: 80" in content
|
|
assert "train:" in content
|
|
assert "val:" in content
|
|
|
|
def test_training_produces_pt(self, e2e_result):
|
|
pt = Path(c.config.current_pt_model)
|
|
assert pt.exists()
|
|
assert pt.stat().st_size > 0
|
|
|
|
def test_export_onnx(self, e2e_result):
|
|
p = Path(c.config.current_onnx_model)
|
|
assert p.exists()
|
|
assert p.suffix == ".onnx"
|
|
assert p.stat().st_size > 0
|
|
|
|
def test_export_coreml(self, e2e_result):
|
|
pkgs = list(Path(c.config.models_dir).glob("*.mlpackage"))
|
|
assert len(pkgs) >= 1
|
|
|
|
def test_onnx_inference(self, e2e_result):
|
|
onnx_model = YOLO(c.config.current_onnx_model)
|
|
img = sorted(_DATASET_IMAGES.glob("*.jpg"))[0]
|
|
results = onnx_model.predict(source=str(img), imgsz=c.config.export.onnx_imgsz, verbose=False)
|
|
assert len(results) == 1
|
|
assert results[0].boxes is not None
|