mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 17:46:41 +00:00
Refactor configuration and update test structure for improved clarity
- Updated `.gitignore` to remove committed test fixture data exclusions. - Increased batch size in `config.test.yaml` from 4 to 128 for training. - Simplified directory structure in `config.yaml` by removing unnecessary data paths. - Adjusted paths in `augmentation.py`, `dataset-visualiser.py`, and `exports.py` to align with the new configuration structure. - Enhanced `annotation_queue_handler.py` to utilize the updated configuration for directory management. - Added CSV logging of test results in `conftest.py` for better test reporting. These changes streamline the configuration management and enhance the testing framework, ensuring better organization and clarity in the project.
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import shutil
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
@@ -13,15 +14,42 @@ _TESTS_DIR = Path(__file__).resolve().parent
|
||||
_TEST_ROOT = _TESTS_DIR / "root"
|
||||
_DATASET_IMAGES = _TEST_ROOT / "data" / "images"
|
||||
_CONFIG_TEST = _TESTS_DIR.parent / "config.test.yaml"
|
||||
_SOURCE_DATASET = _TESTS_DIR.parent / "_docs" / "00_problem" / "input_data" / "dataset"
|
||||
|
||||
|
||||
def _hardlink_tree(src_dir: Path, dst_dir: Path):
|
||||
dst_dir.mkdir(parents=True, exist_ok=True)
|
||||
for f in src_dir.iterdir():
|
||||
if f.is_file():
|
||||
target = dst_dir / f.name
|
||||
if not target.exists():
|
||||
os.link(f, target)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def e2e_result():
|
||||
# Arrange
|
||||
src_images = _SOURCE_DATASET / "images"
|
||||
src_labels = _SOURCE_DATASET / "labels"
|
||||
if not src_images.is_dir() or not src_labels.is_dir():
|
||||
pytest.skip("source dataset not found")
|
||||
|
||||
old_config = c.config
|
||||
c.config = c.Config.from_yaml(str(_CONFIG_TEST), root=str(_TEST_ROOT))
|
||||
|
||||
dst_images = Path(c.config.images_dir)
|
||||
dst_labels = Path(c.config.labels_dir)
|
||||
for d in (dst_images, dst_labels, c.config.datasets_dir, c.config.models_dir, c.config.corrupted_dir):
|
||||
shutil.rmtree(str(d), ignore_errors=True)
|
||||
|
||||
_hardlink_tree(src_images, dst_images)
|
||||
_hardlink_tree(src_labels, dst_labels)
|
||||
|
||||
linked_count = len(list(dst_images.glob("*.jpg")))
|
||||
|
||||
Path(c.config.models_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Act
|
||||
train_mod.train_dataset()
|
||||
|
||||
exports_mod.export_onnx(c.config.current_pt_model)
|
||||
@@ -31,8 +59,11 @@ def e2e_result():
|
||||
|
||||
yield {
|
||||
"today_dataset": today_ds,
|
||||
"linked_count": linked_count,
|
||||
}
|
||||
|
||||
shutil.rmtree(str(dst_images), ignore_errors=True)
|
||||
shutil.rmtree(str(dst_labels), ignore_errors=True)
|
||||
shutil.rmtree(c.config.datasets_dir, ignore_errors=True)
|
||||
shutil.rmtree(c.config.models_dir, ignore_errors=True)
|
||||
shutil.rmtree(c.config.corrupted_dir, ignore_errors=True)
|
||||
@@ -42,6 +73,7 @@ def e2e_result():
|
||||
@pytest.mark.e2e
|
||||
class TestTrainingPipeline:
|
||||
def test_dataset_formed(self, e2e_result):
|
||||
# Assert
|
||||
base = Path(e2e_result["today_dataset"])
|
||||
for split in ("train", "valid", "test"):
|
||||
assert (base / split / "images").is_dir()
|
||||
@@ -50,7 +82,7 @@ class TestTrainingPipeline:
|
||||
len(list((base / s / "images").glob("*.jpg")))
|
||||
for s in ("train", "valid", "test")
|
||||
)
|
||||
assert total == 20
|
||||
assert 0 < total <= e2e_result["linked_count"]
|
||||
|
||||
def test_data_yaml_created(self, e2e_result):
|
||||
yaml_path = Path(e2e_result["today_dataset"]) / "data.yaml"
|
||||
|
||||
Reference in New Issue
Block a user