Files
ai-training/tests/test_augmentation_nonfunc.py
T

149 lines
5.1 KiB
Python

import random
import shutil
import sys
import types
from pathlib import Path
from types import SimpleNamespace
import cv2
import numpy as np
import pytest
from tests.conftest import apply_constants_patch
if "matplotlib" not in sys.modules:
_mpl = types.ModuleType("matplotlib")
_plt = types.ModuleType("matplotlib.pyplot")
_mpl.pyplot = _plt
sys.modules["matplotlib"] = _mpl
sys.modules["matplotlib.pyplot"] = _plt
def _patch_augmentation_paths(monkeypatch, base: Path):
import augmentation as aug
import constants as c
apply_constants_patch(monkeypatch, base)
monkeypatch.setattr(aug, "data_images_dir", c.data_images_dir)
monkeypatch.setattr(aug, "data_labels_dir", c.data_labels_dir)
monkeypatch.setattr(aug, "processed_images_dir", c.processed_images_dir)
monkeypatch.setattr(aug, "processed_labels_dir", c.processed_labels_dir)
monkeypatch.setattr(aug, "processed_dir", c.processed_dir)
def _augment_annotation_with_total(monkeypatch):
import augmentation as aug
orig = aug.Augmentator.augment_annotation
def wrapped(self, image_file):
self.total_to_process = self.total_images_to_process
return orig(self, image_file)
monkeypatch.setattr(aug.Augmentator, "augment_annotation", wrapped)
def _seed():
random.seed(42)
np.random.seed(42)
@pytest.mark.resilience
def test_rt_aug_01_corrupted_image_skipped(
tmp_path, monkeypatch, fixture_images_dir, fixture_labels_dir
):
_patch_augmentation_paths(monkeypatch, tmp_path)
_augment_annotation_with_total(monkeypatch)
_seed()
import constants as c
from augmentation import Augmentator
img_dir = Path(c.data_images_dir)
lbl_dir = Path(c.data_labels_dir)
img_dir.mkdir(parents=True, exist_ok=True)
lbl_dir.mkdir(parents=True, exist_ok=True)
stem = sorted(fixture_images_dir.glob("*.jpg"))[0].stem
shutil.copy2(fixture_images_dir / f"{stem}.jpg", img_dir / f"{stem}.jpg")
shutil.copy2(fixture_labels_dir / f"{stem}.txt", lbl_dir / f"{stem}.txt")
raw = (fixture_images_dir / f"{stem}.jpg").read_bytes()[:200]
(img_dir / "corrupted_trunc.jpg").write_bytes(raw)
Augmentator().augment_annotations()
proc_img = Path(c.processed_images_dir)
assert len(list(proc_img.glob("*.jpg"))) == 8
@pytest.mark.resilience
def test_rt_aug_02_missing_label_no_crash(tmp_path, monkeypatch, fixture_images_dir):
_patch_augmentation_paths(monkeypatch, tmp_path)
_augment_annotation_with_total(monkeypatch)
import constants as c
from augmentation import Augmentator
img_dir = Path(c.data_images_dir)
lbl_dir = Path(c.data_labels_dir)
img_dir.mkdir(parents=True, exist_ok=True)
lbl_dir.mkdir(parents=True, exist_ok=True)
stem = "no_label_here"
shutil.copy2(sorted(fixture_images_dir.glob("*.jpg"))[0], img_dir / f"{stem}.jpg")
aug = Augmentator()
aug.total_images_to_process = 1
aug.augment_annotation(SimpleNamespace(name=f"{stem}.jpg"))
assert len(list(Path(c.processed_images_dir).glob("*.jpg"))) == 0
@pytest.mark.resilience
def test_rt_aug_03_narrow_bbox_fewer_or_eight_variants(
tmp_path, monkeypatch, fixture_images_dir
):
_patch_augmentation_paths(monkeypatch, tmp_path)
_seed()
from augmentation import Augmentator
from dto.imageLabel import ImageLabel
stem = "narrow_bbox"
proc_img = Path(tmp_path) / "azaion" / "data-processed" / "images" / f"{stem}.jpg"
proc_lbl = Path(tmp_path) / "azaion" / "data-processed" / "labels" / f"{stem}.txt"
proc_img.parent.mkdir(parents=True, exist_ok=True)
proc_lbl.parent.mkdir(parents=True, exist_ok=True)
src_img = sorted(fixture_images_dir.glob("*.jpg"))[0]
img = cv2.imdecode(np.fromfile(str(src_img), dtype=np.uint8), cv2.IMREAD_COLOR)
aug = Augmentator()
labels = [[0.5, 0.5, 0.0005, 0.0005, 0]]
img_ann = ImageLabel(
image_path=str(proc_img),
image=img,
labels_path=str(proc_lbl),
labels=labels,
)
out = aug.augment_inner(img_ann)
assert 1 <= len(out) <= 8
@pytest.mark.resource_limit
def test_rl_aug_01_augment_inner_exactly_eight_outputs(
tmp_path, monkeypatch, fixture_images_dir, fixture_labels_dir
):
_patch_augmentation_paths(monkeypatch, tmp_path)
_seed()
from augmentation import Augmentator
from dto.imageLabel import ImageLabel
stem = sorted(fixture_images_dir.glob("*.jpg"))[0].stem
img_path = fixture_images_dir / f"{stem}.jpg"
lbl_path = fixture_labels_dir / f"{stem}.txt"
img = cv2.imdecode(np.fromfile(str(img_path), dtype=np.uint8), cv2.IMREAD_COLOR)
aug = Augmentator()
labels = aug.read_labels(lbl_path)
proc_img = Path(tmp_path) / "azaion" / "data-processed" / "images" / f"{stem}.jpg"
proc_lbl = Path(tmp_path) / "azaion" / "data-processed" / "labels" / f"{stem}.txt"
proc_img.parent.mkdir(parents=True, exist_ok=True)
proc_lbl.parent.mkdir(parents=True, exist_ok=True)
img_ann = ImageLabel(
image_path=str(proc_img),
image=img,
labels_path=str(proc_lbl),
labels=labels,
)
out = aug.augment_inner(img_ann)
assert len(out) == 8