mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 08:56:35 +00:00
142c6c4de8
- Replaced module-level path variables in constants.py with a structured Pydantic Config class. - Updated all relevant modules (train.py, augmentation.py, exports.py, dataset-visualiser.py, manual_run.py) to access paths through the new config structure. - Fixed bugs related to image processing and model saving. - Enhanced test infrastructure to accommodate the new configuration approach. This refactor improves code maintainability and clarity by centralizing configuration management.
129 lines
3.7 KiB
Python
129 lines
3.7 KiB
Python
import shutil
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
|
_DATASET_IMAGES = _PROJECT_ROOT / "_docs/00_problem/input_data/dataset/images"
|
|
_DATASET_LABELS = _PROJECT_ROOT / "_docs/00_problem/input_data/dataset/labels"
|
|
_ONNX_MODEL = _PROJECT_ROOT / "_docs/00_problem/input_data/azaion.onnx"
|
|
_CLASSES_JSON = _PROJECT_ROOT / "classes.json"
|
|
_CONFIG_TEST = _PROJECT_ROOT / "config.test.yaml"
|
|
|
|
collect_ignore = ["security_test.py", "imagelabel_visualize_test.py"]
|
|
|
|
|
|
def apply_constants_patch(monkeypatch, base: Path):
|
|
import constants as c
|
|
monkeypatch.setattr(c, "config", c.Config.from_yaml(str(_CONFIG_TEST), root=str(base / "azaion")))
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def fixture_images_dir():
|
|
p = _DATASET_IMAGES
|
|
if not p.is_dir():
|
|
pytest.skip(f"missing dataset images: {p}")
|
|
return p
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def fixture_labels_dir():
|
|
p = _DATASET_LABELS
|
|
if not p.is_dir():
|
|
pytest.skip(f"missing dataset labels: {p}")
|
|
return p
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def fixture_onnx_model():
|
|
p = _ONNX_MODEL
|
|
if not p.is_file():
|
|
pytest.skip(f"missing onnx model: {p}")
|
|
return p.read_bytes()
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def fixture_classes_json():
|
|
p = _CLASSES_JSON
|
|
if not p.is_file():
|
|
pytest.skip(f"missing classes.json: {p}")
|
|
return p
|
|
|
|
|
|
@pytest.fixture
|
|
def constants_patch(monkeypatch):
|
|
def _apply(base: Path):
|
|
apply_constants_patch(monkeypatch, base)
|
|
|
|
return _apply
|
|
|
|
|
|
@pytest.fixture
|
|
def work_dir(tmp_path):
|
|
w = tmp_path / "work"
|
|
(w / "images").mkdir(parents=True)
|
|
(w / "labels").mkdir(parents=True)
|
|
return w
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_image_label(fixture_images_dir, fixture_labels_dir, tmp_path):
|
|
imgs = sorted(fixture_images_dir.glob("*.jpg"))
|
|
if not imgs:
|
|
raise RuntimeError("no images in fixture_images_dir")
|
|
stem = imgs[0].stem
|
|
src_img = fixture_images_dir / f"{stem}.jpg"
|
|
src_lbl = fixture_labels_dir / f"{stem}.txt"
|
|
dst_img = tmp_path / "images" / f"{stem}.jpg"
|
|
dst_lbl = tmp_path / "labels" / f"{stem}.txt"
|
|
dst_img.parent.mkdir(parents=True, exist_ok=True)
|
|
dst_lbl.parent.mkdir(parents=True, exist_ok=True)
|
|
shutil.copy2(src_img, dst_img)
|
|
shutil.copy2(src_lbl, dst_lbl)
|
|
return dst_img, dst_lbl
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_images_labels(fixture_images_dir, fixture_labels_dir, tmp_path):
|
|
def _factory(count: int):
|
|
if count < 1:
|
|
raise ValueError("count must be >= 1")
|
|
imgs = sorted(fixture_images_dir.glob("*.jpg"))
|
|
if count > len(imgs):
|
|
raise ValueError("count exceeds available images")
|
|
out_img = tmp_path / "images"
|
|
out_lbl = tmp_path / "labels"
|
|
out_img.mkdir(parents=True, exist_ok=True)
|
|
out_lbl.mkdir(parents=True, exist_ok=True)
|
|
for p in imgs[:count]:
|
|
stem = p.stem
|
|
shutil.copy2(fixture_images_dir / f"{stem}.jpg", out_img / f"{stem}.jpg")
|
|
shutil.copy2(fixture_labels_dir / f"{stem}.txt", out_lbl / f"{stem}.txt")
|
|
return out_img, out_lbl
|
|
|
|
return _factory
|
|
|
|
|
|
@pytest.fixture
|
|
def corrupted_label(tmp_path):
|
|
p = tmp_path / "labels" / "corrupted.txt"
|
|
p.parent.mkdir(parents=True, exist_ok=True)
|
|
p.write_text("0 1.5 0.5 0.1 0.1\n", encoding="utf-8")
|
|
return p
|
|
|
|
|
|
@pytest.fixture
|
|
def edge_bbox_label(tmp_path):
|
|
p = tmp_path / "labels" / "edge.txt"
|
|
p.parent.mkdir(parents=True, exist_ok=True)
|
|
p.write_text("0 0.01 0.5 0.02 0.3\n", encoding="utf-8")
|
|
return p
|
|
|
|
|
|
@pytest.fixture
|
|
def empty_label(tmp_path):
|
|
p = tmp_path / "labels" / "empty.txt"
|
|
p.parent.mkdir(parents=True, exist_ok=True)
|
|
p.write_text("", encoding="utf-8")
|
|
return p
|