mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 12:16:35 +00:00
142c6c4de8
- Replaced module-level path variables in constants.py with a structured Pydantic Config class. - Updated all relevant modules (train.py, augmentation.py, exports.py, dataset-visualiser.py, manual_run.py) to access paths through the new config structure. - Fixed bugs related to image processing and model saving. - Enhanced test infrastructure to accommodate the new configuration approach. This refactor improves code maintainability and clarity by centralizing configuration management.
141 lines
4.7 KiB
Python
141 lines
4.7 KiB
Python
import random
|
|
import shutil
|
|
import sys
|
|
import types
|
|
from pathlib import Path
|
|
from types import SimpleNamespace
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from tests.conftest import apply_constants_patch
|
|
|
|
if "matplotlib" not in sys.modules:
|
|
_mpl = types.ModuleType("matplotlib")
|
|
_plt = types.ModuleType("matplotlib.pyplot")
|
|
_mpl.pyplot = _plt
|
|
sys.modules["matplotlib"] = _mpl
|
|
sys.modules["matplotlib.pyplot"] = _plt
|
|
|
|
|
|
def _patch_augmentation_paths(monkeypatch, base: Path):
|
|
apply_constants_patch(monkeypatch, base)
|
|
|
|
|
|
def _augment_annotation_with_total(monkeypatch):
|
|
import augmentation as aug
|
|
|
|
orig = aug.Augmentator.augment_annotation
|
|
|
|
def wrapped(self, image_file):
|
|
self.total_to_process = self.total_images_to_process
|
|
return orig(self, image_file)
|
|
|
|
monkeypatch.setattr(aug.Augmentator, "augment_annotation", wrapped)
|
|
|
|
|
|
def _seed():
|
|
random.seed(42)
|
|
np.random.seed(42)
|
|
|
|
|
|
@pytest.mark.resilience
|
|
def test_rt_aug_01_corrupted_image_skipped(
|
|
tmp_path, monkeypatch, fixture_images_dir, fixture_labels_dir
|
|
):
|
|
_patch_augmentation_paths(monkeypatch, tmp_path)
|
|
_augment_annotation_with_total(monkeypatch)
|
|
_seed()
|
|
import constants as c
|
|
from augmentation import Augmentator
|
|
|
|
img_dir = Path(c.config.data_images_dir)
|
|
lbl_dir = Path(c.config.data_labels_dir)
|
|
img_dir.mkdir(parents=True, exist_ok=True)
|
|
lbl_dir.mkdir(parents=True, exist_ok=True)
|
|
stem = sorted(fixture_images_dir.glob("*.jpg"))[0].stem
|
|
shutil.copy2(fixture_images_dir / f"{stem}.jpg", img_dir / f"{stem}.jpg")
|
|
shutil.copy2(fixture_labels_dir / f"{stem}.txt", lbl_dir / f"{stem}.txt")
|
|
raw = (fixture_images_dir / f"{stem}.jpg").read_bytes()[:200]
|
|
(img_dir / "corrupted_trunc.jpg").write_bytes(raw)
|
|
Augmentator().augment_annotations()
|
|
proc_img = Path(c.config.processed_images_dir)
|
|
assert len(list(proc_img.glob("*.jpg"))) == 8
|
|
|
|
|
|
@pytest.mark.resilience
|
|
def test_rt_aug_02_missing_label_no_crash(tmp_path, monkeypatch, fixture_images_dir):
|
|
_patch_augmentation_paths(monkeypatch, tmp_path)
|
|
_augment_annotation_with_total(monkeypatch)
|
|
import constants as c
|
|
from augmentation import Augmentator
|
|
|
|
img_dir = Path(c.config.data_images_dir)
|
|
lbl_dir = Path(c.config.data_labels_dir)
|
|
img_dir.mkdir(parents=True, exist_ok=True)
|
|
lbl_dir.mkdir(parents=True, exist_ok=True)
|
|
stem = "no_label_here"
|
|
shutil.copy2(sorted(fixture_images_dir.glob("*.jpg"))[0], img_dir / f"{stem}.jpg")
|
|
aug = Augmentator()
|
|
aug.total_images_to_process = 1
|
|
aug.augment_annotation(SimpleNamespace(name=f"{stem}.jpg"))
|
|
assert len(list(Path(c.config.processed_images_dir).glob("*.jpg"))) == 0
|
|
|
|
|
|
@pytest.mark.resilience
|
|
def test_rt_aug_03_narrow_bbox_fewer_or_eight_variants(
|
|
tmp_path, monkeypatch, fixture_images_dir
|
|
):
|
|
_patch_augmentation_paths(monkeypatch, tmp_path)
|
|
_seed()
|
|
from augmentation import Augmentator
|
|
from dto.imageLabel import ImageLabel
|
|
|
|
stem = "narrow_bbox"
|
|
proc_img = Path(tmp_path) / "azaion" / "data-processed" / "images" / f"{stem}.jpg"
|
|
proc_lbl = Path(tmp_path) / "azaion" / "data-processed" / "labels" / f"{stem}.txt"
|
|
proc_img.parent.mkdir(parents=True, exist_ok=True)
|
|
proc_lbl.parent.mkdir(parents=True, exist_ok=True)
|
|
src_img = sorted(fixture_images_dir.glob("*.jpg"))[0]
|
|
img = cv2.imdecode(np.fromfile(str(src_img), dtype=np.uint8), cv2.IMREAD_COLOR)
|
|
aug = Augmentator()
|
|
labels = [[0.5, 0.5, 0.0005, 0.0005, 0]]
|
|
img_ann = ImageLabel(
|
|
image_path=str(proc_img),
|
|
image=img,
|
|
labels_path=str(proc_lbl),
|
|
labels=labels,
|
|
)
|
|
out = aug.augment_inner(img_ann)
|
|
assert 1 <= len(out) <= 8
|
|
|
|
|
|
@pytest.mark.resource_limit
|
|
def test_rl_aug_01_augment_inner_exactly_eight_outputs(
|
|
tmp_path, monkeypatch, fixture_images_dir, fixture_labels_dir
|
|
):
|
|
_patch_augmentation_paths(monkeypatch, tmp_path)
|
|
_seed()
|
|
from augmentation import Augmentator
|
|
from dto.imageLabel import ImageLabel
|
|
|
|
stem = sorted(fixture_images_dir.glob("*.jpg"))[0].stem
|
|
img_path = fixture_images_dir / f"{stem}.jpg"
|
|
lbl_path = fixture_labels_dir / f"{stem}.txt"
|
|
img = cv2.imdecode(np.fromfile(str(img_path), dtype=np.uint8), cv2.IMREAD_COLOR)
|
|
aug = Augmentator()
|
|
labels = aug.read_labels(lbl_path)
|
|
proc_img = Path(tmp_path) / "azaion" / "data-processed" / "images" / f"{stem}.jpg"
|
|
proc_lbl = Path(tmp_path) / "azaion" / "data-processed" / "labels" / f"{stem}.txt"
|
|
proc_img.parent.mkdir(parents=True, exist_ok=True)
|
|
proc_lbl.parent.mkdir(parents=True, exist_ok=True)
|
|
img_ann = ImageLabel(
|
|
image_path=str(proc_img),
|
|
image=img,
|
|
labels_path=str(proc_lbl),
|
|
labels=labels,
|
|
)
|
|
out = aug.augment_inner(img_ann)
|
|
assert len(out) == 8
|