Files
ai-training/tests/test_training_e2e.py
T
Oleksandr Bezdieniezhnykh 18b88ba9bf Refactor configuration and update test structure for improved clarity
- Updated `.gitignore` to remove committed test fixture data exclusions.
- Increased batch size in `config.test.yaml` from 4 to 128 for training.
- Simplified directory structure in `config.yaml` by removing unnecessary data paths.
- Adjusted paths in `augmentation.py`, `dataset-visualiser.py`, and `exports.py` to align with the new configuration structure.
- Enhanced `annotation_queue_handler.py` to utilize the updated configuration for directory management.
- Added CSV logging of test results in `conftest.py` for better test reporting.

These changes streamline the configuration management and enhance the testing framework, ensuring better organization and clarity in the project.
2026-03-28 07:32:40 +02:00

116 lines
3.7 KiB
Python

import os
import shutil
from os import path
from pathlib import Path
import pytest
from ultralytics import YOLO
import constants as c
import train as train_mod
import exports as exports_mod
_TESTS_DIR = Path(__file__).resolve().parent
_TEST_ROOT = _TESTS_DIR / "root"
_DATASET_IMAGES = _TEST_ROOT / "data" / "images"
_CONFIG_TEST = _TESTS_DIR.parent / "config.test.yaml"
_SOURCE_DATASET = _TESTS_DIR.parent / "_docs" / "00_problem" / "input_data" / "dataset"
def _hardlink_tree(src_dir: Path, dst_dir: Path):
dst_dir.mkdir(parents=True, exist_ok=True)
for f in src_dir.iterdir():
if f.is_file():
target = dst_dir / f.name
if not target.exists():
os.link(f, target)
@pytest.fixture(scope="module")
def e2e_result():
# Arrange
src_images = _SOURCE_DATASET / "images"
src_labels = _SOURCE_DATASET / "labels"
if not src_images.is_dir() or not src_labels.is_dir():
pytest.skip("source dataset not found")
old_config = c.config
c.config = c.Config.from_yaml(str(_CONFIG_TEST), root=str(_TEST_ROOT))
dst_images = Path(c.config.images_dir)
dst_labels = Path(c.config.labels_dir)
for d in (dst_images, dst_labels, c.config.datasets_dir, c.config.models_dir, c.config.corrupted_dir):
shutil.rmtree(str(d), ignore_errors=True)
_hardlink_tree(src_images, dst_images)
_hardlink_tree(src_labels, dst_labels)
linked_count = len(list(dst_images.glob("*.jpg")))
Path(c.config.models_dir).mkdir(parents=True, exist_ok=True)
# Act
train_mod.train_dataset()
exports_mod.export_onnx(c.config.current_pt_model)
exports_mod.export_coreml(c.config.current_pt_model)
today_ds = path.join(c.config.datasets_dir, train_mod.today_folder)
yield {
"today_dataset": today_ds,
"linked_count": linked_count,
}
shutil.rmtree(str(dst_images), ignore_errors=True)
shutil.rmtree(str(dst_labels), ignore_errors=True)
shutil.rmtree(c.config.datasets_dir, ignore_errors=True)
shutil.rmtree(c.config.models_dir, ignore_errors=True)
shutil.rmtree(c.config.corrupted_dir, ignore_errors=True)
c.config = old_config
@pytest.mark.e2e
class TestTrainingPipeline:
def test_dataset_formed(self, e2e_result):
# Assert
base = Path(e2e_result["today_dataset"])
for split in ("train", "valid", "test"):
assert (base / split / "images").is_dir()
assert (base / split / "labels").is_dir()
total = sum(
len(list((base / s / "images").glob("*.jpg")))
for s in ("train", "valid", "test")
)
assert 0 < total <= e2e_result["linked_count"]
def test_data_yaml_created(self, e2e_result):
yaml_path = Path(e2e_result["today_dataset"]) / "data.yaml"
assert yaml_path.exists()
content = yaml_path.read_text()
assert "nc: 80" in content
assert "train:" in content
assert "val:" in content
def test_training_produces_pt(self, e2e_result):
pt = Path(c.config.current_pt_model)
assert pt.exists()
assert pt.stat().st_size > 0
def test_export_onnx(self, e2e_result):
p = Path(c.config.current_onnx_model)
assert p.exists()
assert p.suffix == ".onnx"
assert p.stat().st_size > 0
def test_export_coreml(self, e2e_result):
pkgs = list(Path(c.config.models_dir).glob("*.mlpackage"))
assert len(pkgs) >= 1
def test_onnx_inference(self, e2e_result):
onnx_model = YOLO(c.config.current_onnx_model)
img = sorted(_DATASET_IMAGES.glob("*.jpg"))[0]
results = onnx_model.predict(source=str(img), imgsz=c.config.export.onnx_imgsz, verbose=False)
assert len(results) == 1
assert results[0].boxes is not None