mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 10:56:36 +00:00
142c6c4de8
- Replaced module-level path variables in constants.py with a structured Pydantic Config class. - Updated all relevant modules (train.py, augmentation.py, exports.py, dataset-visualiser.py, manual_run.py) to access paths through the new config structure. - Fixed bugs related to image processing and model saving. - Enhanced test infrastructure to accommodate the new configuration approach. This refactor improves code maintainability and clarity by centralizing configuration management.
119 lines
3.5 KiB
Python
119 lines
3.5 KiB
Python
import concurrent.futures
|
|
import random
|
|
import shutil
|
|
import sys
|
|
import time
|
|
import types
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from tests.conftest import apply_constants_patch
|
|
|
|
if "matplotlib" not in sys.modules:
|
|
_mpl = types.ModuleType("matplotlib")
|
|
_plt = types.ModuleType("matplotlib.pyplot")
|
|
_mpl.pyplot = _plt
|
|
sys.modules["matplotlib"] = _mpl
|
|
sys.modules["matplotlib.pyplot"] = _plt
|
|
|
|
|
|
def _patch_augmentation_paths(monkeypatch, base: Path):
|
|
apply_constants_patch(monkeypatch, base)
|
|
|
|
|
|
def _augment_annotation_with_total(monkeypatch):
|
|
import augmentation as aug
|
|
|
|
orig = aug.Augmentator.augment_annotation
|
|
|
|
def wrapped(self, image_file):
|
|
self.total_to_process = self.total_images_to_process
|
|
return orig(self, image_file)
|
|
|
|
monkeypatch.setattr(aug.Augmentator, "augment_annotation", wrapped)
|
|
|
|
|
|
def _seed():
|
|
random.seed(42)
|
|
np.random.seed(42)
|
|
|
|
|
|
@pytest.mark.performance
|
|
def test_pt_aug_01_throughput_ten_images_sixty_seconds(
|
|
tmp_path, monkeypatch, sample_images_labels
|
|
):
|
|
_patch_augmentation_paths(monkeypatch, tmp_path)
|
|
_augment_annotation_with_total(monkeypatch)
|
|
_seed()
|
|
import constants as c
|
|
from augmentation import Augmentator
|
|
|
|
img_dir = Path(c.config.data_images_dir)
|
|
lbl_dir = Path(c.config.data_labels_dir)
|
|
img_dir.mkdir(parents=True, exist_ok=True)
|
|
lbl_dir.mkdir(parents=True, exist_ok=True)
|
|
src_img, src_lbl = sample_images_labels(10)
|
|
for p in src_img.glob("*.jpg"):
|
|
shutil.copy2(p, img_dir / p.name)
|
|
for p in src_lbl.glob("*.txt"):
|
|
shutil.copy2(p, lbl_dir / p.name)
|
|
t0 = time.perf_counter()
|
|
Augmentator().augment_annotations()
|
|
elapsed = time.perf_counter() - t0
|
|
assert elapsed <= 60.0
|
|
|
|
|
|
@pytest.mark.performance
|
|
def test_pt_aug_02_parallel_at_least_one_point_five_x_faster(
|
|
tmp_path, monkeypatch, sample_images_labels
|
|
):
|
|
_patch_augmentation_paths(monkeypatch, tmp_path)
|
|
_augment_annotation_with_total(monkeypatch)
|
|
_seed()
|
|
import constants as c
|
|
from augmentation import Augmentator
|
|
|
|
img_dir = Path(c.config.data_images_dir)
|
|
lbl_dir = Path(c.config.data_labels_dir)
|
|
proc_dir = Path(c.config.processed_dir)
|
|
img_dir.mkdir(parents=True, exist_ok=True)
|
|
lbl_dir.mkdir(parents=True, exist_ok=True)
|
|
src_img, src_lbl = sample_images_labels(10)
|
|
for p in src_img.glob("*.jpg"):
|
|
shutil.copy2(p, img_dir / p.name)
|
|
for p in src_lbl.glob("*.txt"):
|
|
shutil.copy2(p, lbl_dir / p.name)
|
|
Path(c.config.processed_images_dir).mkdir(parents=True, exist_ok=True)
|
|
Path(c.config.processed_labels_dir).mkdir(parents=True, exist_ok=True)
|
|
names = sorted(p.name for p in img_dir.glob("*.jpg"))
|
|
|
|
class _E:
|
|
__slots__ = ("name",)
|
|
|
|
def __init__(self, name):
|
|
self.name = name
|
|
|
|
entries = [_E(n) for n in names]
|
|
|
|
aug_seq = Augmentator()
|
|
aug_seq.total_images_to_process = len(entries)
|
|
t0 = time.perf_counter()
|
|
for e in entries:
|
|
aug_seq.augment_annotation(e)
|
|
seq_elapsed = time.perf_counter() - t0
|
|
|
|
shutil.rmtree(proc_dir)
|
|
Path(c.config.processed_images_dir).mkdir(parents=True, exist_ok=True)
|
|
Path(c.config.processed_labels_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
aug_par = Augmentator()
|
|
aug_par.total_images_to_process = len(entries)
|
|
t0 = time.perf_counter()
|
|
with concurrent.futures.ThreadPoolExecutor() as ex:
|
|
list(ex.map(aug_par.augment_annotation, entries))
|
|
par_elapsed = time.perf_counter() - t0
|
|
|
|
assert seq_elapsed >= par_elapsed * 1.5
|