mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 10:36:35 +00:00
142c6c4de8
- Replaced module-level path variables in constants.py with a structured Pydantic Config class. - Updated all relevant modules (train.py, augmentation.py, exports.py, dataset-visualiser.py, manual_run.py) to access paths through the new config structure. - Fixed bugs related to image processing and model saving. - Enhanced test infrastructure to accommodate the new configuration approach. This refactor improves code maintainability and clarity by centralizing configuration management.
253 lines
8.5 KiB
Python
253 lines
8.5 KiB
Python
import random
|
|
import shutil
|
|
import sys
|
|
import types
|
|
from pathlib import Path
|
|
|
|
if "matplotlib" not in sys.modules:
|
|
_mpl = types.ModuleType("matplotlib")
|
|
_plt = types.ModuleType("matplotlib.pyplot")
|
|
_mpl.pyplot = _plt
|
|
sys.modules["matplotlib"] = _mpl
|
|
sys.modules["matplotlib.pyplot"] = _plt
|
|
|
|
import cv2
|
|
import numpy as np
|
|
|
|
from tests.conftest import apply_constants_patch
|
|
|
|
|
|
def _patch_augmentation_paths(monkeypatch, base: Path):
|
|
apply_constants_patch(monkeypatch, base)
|
|
|
|
|
|
def _seed():
|
|
random.seed(42)
|
|
np.random.seed(42)
|
|
|
|
|
|
def _augment_annotation_with_total(monkeypatch):
|
|
import augmentation as aug
|
|
|
|
orig = aug.Augmentator.augment_annotation
|
|
|
|
def wrapped(self, image_file):
|
|
self.total_to_process = self.total_images_to_process
|
|
return orig(self, image_file)
|
|
|
|
monkeypatch.setattr(aug.Augmentator, "augment_annotation", wrapped)
|
|
|
|
|
|
def test_bt_aug_01_augment_inner_returns_eight_image_labels(
|
|
tmp_path, monkeypatch, fixture_images_dir, fixture_labels_dir
|
|
):
|
|
_patch_augmentation_paths(monkeypatch, tmp_path)
|
|
_seed()
|
|
from augmentation import Augmentator
|
|
|
|
stem = sorted(fixture_images_dir.glob("*.jpg"))[0].stem
|
|
img_path = fixture_images_dir / f"{stem}.jpg"
|
|
lbl_path = fixture_labels_dir / f"{stem}.txt"
|
|
img = cv2.imdecode(np.fromfile(str(img_path), dtype=np.uint8), cv2.IMREAD_COLOR)
|
|
aug = Augmentator()
|
|
labels = aug.read_labels(lbl_path)
|
|
proc_img = Path(tmp_path) / "azaion" / "data-processed" / "images" / f"{stem}.jpg"
|
|
proc_lbl = Path(tmp_path) / "azaion" / "data-processed" / "labels" / f"{stem}.txt"
|
|
proc_img.parent.mkdir(parents=True, exist_ok=True)
|
|
proc_lbl.parent.mkdir(parents=True, exist_ok=True)
|
|
from dto.imageLabel import ImageLabel
|
|
|
|
img_ann = ImageLabel(
|
|
image_path=str(proc_img),
|
|
image=img,
|
|
labels_path=str(proc_lbl),
|
|
labels=labels,
|
|
)
|
|
out = aug.augment_inner(img_ann)
|
|
assert len(out) == 8
|
|
|
|
|
|
def test_bt_aug_02_naming_convention(tmp_path, monkeypatch, fixture_images_dir, fixture_labels_dir):
|
|
_patch_augmentation_paths(monkeypatch, tmp_path)
|
|
_seed()
|
|
from augmentation import Augmentator
|
|
from dto.imageLabel import ImageLabel
|
|
|
|
stem = "test_image"
|
|
proc_img = Path(tmp_path) / "azaion" / "data-processed" / "images" / f"{stem}.jpg"
|
|
proc_lbl = Path(tmp_path) / "azaion" / "data-processed" / "labels" / f"{stem}.txt"
|
|
proc_img.parent.mkdir(parents=True, exist_ok=True)
|
|
proc_lbl.parent.mkdir(parents=True, exist_ok=True)
|
|
src_img = sorted(fixture_images_dir.glob("*.jpg"))[0]
|
|
img = cv2.imdecode(np.fromfile(str(src_img), dtype=np.uint8), cv2.IMREAD_COLOR)
|
|
lbl_path = fixture_labels_dir / f"{src_img.stem}.txt"
|
|
labels = Augmentator().read_labels(lbl_path)
|
|
aug = Augmentator()
|
|
img_ann = ImageLabel(
|
|
image_path=str(proc_img),
|
|
image=img,
|
|
labels_path=str(proc_lbl),
|
|
labels=labels,
|
|
)
|
|
out = aug.augment_inner(img_ann)
|
|
names = [Path(o.image_path).name for o in out]
|
|
expected = [f"{stem}.jpg"] + [f"{stem}_{i}.jpg" for i in range(1, 8)]
|
|
assert names == expected
|
|
lbl_names = [Path(o.labels_path).name for o in out]
|
|
expected_lbl = [f"{stem}.txt"] + [f"{stem}_{i}.txt" for i in range(1, 8)]
|
|
assert lbl_names == expected_lbl
|
|
|
|
|
|
def _all_coords_in_unit(labels_list):
|
|
for row in labels_list:
|
|
for j in range(4):
|
|
v = float(row[j])
|
|
if v < 0.0 or v > 1.0:
|
|
return False
|
|
return True
|
|
|
|
|
|
def test_bt_aug_03_all_bbox_coords_in_zero_one(
|
|
tmp_path, monkeypatch, fixture_images_dir, fixture_labels_dir
|
|
):
|
|
_patch_augmentation_paths(monkeypatch, tmp_path)
|
|
_seed()
|
|
from augmentation import Augmentator
|
|
from dto.imageLabel import ImageLabel
|
|
|
|
stem = sorted(fixture_images_dir.glob("*.jpg"))[0].stem
|
|
proc_img = Path(tmp_path) / "azaion" / "data-processed" / "images" / f"{stem}.jpg"
|
|
proc_lbl = Path(tmp_path) / "azaion" / "data-processed" / "labels" / f"{stem}.txt"
|
|
proc_img.parent.mkdir(parents=True, exist_ok=True)
|
|
proc_lbl.parent.mkdir(parents=True, exist_ok=True)
|
|
img_path = fixture_images_dir / f"{stem}.jpg"
|
|
lbl_path = fixture_labels_dir / f"{stem}.txt"
|
|
img = cv2.imdecode(np.fromfile(str(img_path), dtype=np.uint8), cv2.IMREAD_COLOR)
|
|
aug = Augmentator()
|
|
labels = aug.read_labels(lbl_path)
|
|
img_ann = ImageLabel(
|
|
image_path=str(proc_img),
|
|
image=img,
|
|
labels_path=str(proc_lbl),
|
|
labels=labels,
|
|
)
|
|
out = aug.augment_inner(img_ann)
|
|
for o in out:
|
|
for row in o.labels:
|
|
assert len(row) >= 5
|
|
assert _all_coords_in_unit(o.labels)
|
|
|
|
|
|
def test_bt_aug_04_correct_bboxes_clips_edge(tmp_path, monkeypatch):
|
|
_patch_augmentation_paths(monkeypatch, tmp_path)
|
|
from augmentation import Augmentator
|
|
|
|
aug = Augmentator()
|
|
m = aug.correct_margin
|
|
inp = [[0.99, 0.5, 0.2, 0.1, 0]]
|
|
res = aug.correct_bboxes(inp)
|
|
assert len(res) == 1
|
|
x, y, w, h, _ = res[0]
|
|
hw, hh = 0.5 * w, 0.5 * h
|
|
assert x - hw >= m - 1e-9
|
|
assert x + hw <= 1.0 - m + 1e-9
|
|
assert y - hh >= m - 1e-9
|
|
assert y + hh <= 1.0 - m + 1e-9
|
|
|
|
|
|
def test_bt_aug_05_tiny_bbox_removed_after_clipping(tmp_path, monkeypatch):
|
|
_patch_augmentation_paths(monkeypatch, tmp_path)
|
|
from augmentation import Augmentator
|
|
|
|
aug = Augmentator()
|
|
inp = [[0.995, 0.5, 0.01, 0.5, 0]]
|
|
res = aug.correct_bboxes(inp)
|
|
assert res == []
|
|
|
|
|
|
def test_bt_aug_06_empty_label_eight_outputs_empty_labels(
|
|
tmp_path, monkeypatch, fixture_images_dir
|
|
):
|
|
_patch_augmentation_paths(monkeypatch, tmp_path)
|
|
_seed()
|
|
from augmentation import Augmentator
|
|
from dto.imageLabel import ImageLabel
|
|
|
|
stem = "empty_case"
|
|
proc_img = Path(tmp_path) / "azaion" / "data-processed" / "images" / f"{stem}.jpg"
|
|
proc_lbl = Path(tmp_path) / "azaion" / "data-processed" / "labels" / f"{stem}.txt"
|
|
proc_img.parent.mkdir(parents=True, exist_ok=True)
|
|
proc_lbl.parent.mkdir(parents=True, exist_ok=True)
|
|
src_img = sorted(fixture_images_dir.glob("*.jpg"))[0]
|
|
img = cv2.imdecode(np.fromfile(str(src_img), dtype=np.uint8), cv2.IMREAD_COLOR)
|
|
aug = Augmentator()
|
|
img_ann = ImageLabel(
|
|
image_path=str(proc_img),
|
|
image=img,
|
|
labels_path=str(proc_lbl),
|
|
labels=[],
|
|
)
|
|
out = aug.augment_inner(img_ann)
|
|
assert len(out) == 8
|
|
for o in out:
|
|
assert o.labels == []
|
|
|
|
|
|
def test_bt_aug_07_full_pipeline_five_images_forty_outputs(
|
|
tmp_path, monkeypatch, sample_images_labels
|
|
):
|
|
_patch_augmentation_paths(monkeypatch, tmp_path)
|
|
_augment_annotation_with_total(monkeypatch)
|
|
_seed()
|
|
import constants as c
|
|
from augmentation import Augmentator
|
|
|
|
img_dir = Path(c.config.data_images_dir)
|
|
lbl_dir = Path(c.config.data_labels_dir)
|
|
img_dir.mkdir(parents=True, exist_ok=True)
|
|
lbl_dir.mkdir(parents=True, exist_ok=True)
|
|
src_img, src_lbl = sample_images_labels(5)
|
|
for p in src_img.glob("*.jpg"):
|
|
shutil.copy2(p, img_dir / p.name)
|
|
for p in src_lbl.glob("*.txt"):
|
|
shutil.copy2(p, lbl_dir / p.name)
|
|
Augmentator().augment_annotations()
|
|
proc_img = Path(c.config.processed_images_dir)
|
|
proc_lbl = Path(c.config.processed_labels_dir)
|
|
assert len(list(proc_img.glob("*.jpg"))) == 40
|
|
assert len(list(proc_lbl.glob("*.txt"))) == 40
|
|
|
|
|
|
def test_bt_aug_08_skips_already_processed(tmp_path, monkeypatch, sample_images_labels):
|
|
_patch_augmentation_paths(monkeypatch, tmp_path)
|
|
_augment_annotation_with_total(monkeypatch)
|
|
_seed()
|
|
import constants as c
|
|
from augmentation import Augmentator
|
|
|
|
img_dir = Path(c.config.data_images_dir)
|
|
lbl_dir = Path(c.config.data_labels_dir)
|
|
proc_img = Path(c.config.processed_images_dir)
|
|
proc_lbl = Path(c.config.processed_labels_dir)
|
|
img_dir.mkdir(parents=True, exist_ok=True)
|
|
lbl_dir.mkdir(parents=True, exist_ok=True)
|
|
proc_img.mkdir(parents=True, exist_ok=True)
|
|
proc_lbl.mkdir(parents=True, exist_ok=True)
|
|
src_img, src_lbl = sample_images_labels(5)
|
|
jpgs = sorted(src_img.glob("*.jpg"))
|
|
for p in jpgs:
|
|
shutil.copy2(p, img_dir / p.name)
|
|
for p in src_lbl.glob("*.txt"):
|
|
shutil.copy2(p, lbl_dir / p.name)
|
|
markers = []
|
|
for p in jpgs[:3]:
|
|
dst = proc_img / p.name
|
|
shutil.copy2(p, dst)
|
|
markers.append(dst.read_bytes())
|
|
Augmentator().augment_annotations()
|
|
after_jpgs = list(proc_img.glob("*.jpg"))
|
|
assert len(after_jpgs) == 19
|
|
assert len(list(proc_lbl.glob("*.txt"))) == 16
|
|
for i, p in enumerate(jpgs[:3]):
|
|
assert (proc_img / p.name).read_bytes() == markers[i]
|