mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 07:06:36 +00:00
Refactor constants management to use Pydantic BaseModel for configuration
- Replaced module-level path variables in constants.py with a structured Pydantic Config class. - Updated all relevant modules (train.py, augmentation.py, exports.py, dataset-visualiser.py, manual_run.py) to access paths through the new config structure. - Fixed bugs related to image processing and model saving. - Enhanced test infrastructure to accommodate the new configuration approach. This refactor improves code maintainability and clarity by centralizing configuration management.
This commit is contained in:
+2
-23
@@ -8,35 +8,14 @@ _DATASET_IMAGES = _PROJECT_ROOT / "_docs/00_problem/input_data/dataset/images"
|
||||
_DATASET_LABELS = _PROJECT_ROOT / "_docs/00_problem/input_data/dataset/labels"
|
||||
_ONNX_MODEL = _PROJECT_ROOT / "_docs/00_problem/input_data/azaion.onnx"
|
||||
_CLASSES_JSON = _PROJECT_ROOT / "classes.json"
|
||||
_CONFIG_TEST = _PROJECT_ROOT / "config.test.yaml"
|
||||
|
||||
collect_ignore = ["security_test.py", "imagelabel_visualize_test.py"]
|
||||
|
||||
|
||||
def apply_constants_patch(monkeypatch, base: Path):
|
||||
import constants as c
|
||||
from os import path
|
||||
|
||||
root = str(base.resolve())
|
||||
azaion = path.join(root, "azaion")
|
||||
monkeypatch.setattr(c, "azaion", azaion)
|
||||
data_dir = path.join(azaion, "data")
|
||||
monkeypatch.setattr(c, "data_dir", data_dir)
|
||||
monkeypatch.setattr(c, "data_images_dir", path.join(data_dir, c.images))
|
||||
monkeypatch.setattr(c, "data_labels_dir", path.join(data_dir, c.labels))
|
||||
processed_dir = path.join(azaion, "data-processed")
|
||||
monkeypatch.setattr(c, "processed_dir", processed_dir)
|
||||
monkeypatch.setattr(c, "processed_images_dir", path.join(processed_dir, c.images))
|
||||
monkeypatch.setattr(c, "processed_labels_dir", path.join(processed_dir, c.labels))
|
||||
corrupted_dir = path.join(azaion, "data-corrupted")
|
||||
monkeypatch.setattr(c, "corrupted_dir", corrupted_dir)
|
||||
monkeypatch.setattr(c, "corrupted_images_dir", path.join(corrupted_dir, c.images))
|
||||
monkeypatch.setattr(c, "corrupted_labels_dir", path.join(corrupted_dir, c.labels))
|
||||
monkeypatch.setattr(c, "sample_dir", path.join(azaion, "data-sample"))
|
||||
monkeypatch.setattr(c, "datasets_dir", path.join(azaion, "datasets"))
|
||||
models_dir = path.join(azaion, "models")
|
||||
monkeypatch.setattr(c, "models_dir", models_dir)
|
||||
monkeypatch.setattr(c, "CURRENT_PT_MODEL", path.join(models_dir, f"{c.prefix[:-1]}.pt"))
|
||||
monkeypatch.setattr(c, "CURRENT_ONNX_MODEL", path.join(models_dir, f"{c.prefix[:-1]}.onnx"))
|
||||
monkeypatch.setattr(c, "config", c.Config.from_yaml(str(_CONFIG_TEST), root=str(base / "azaion")))
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
||||
@@ -20,15 +20,7 @@ if "matplotlib" not in sys.modules:
|
||||
|
||||
|
||||
def _patch_augmentation_paths(monkeypatch, base: Path):
|
||||
import augmentation as aug
|
||||
import constants as c
|
||||
|
||||
apply_constants_patch(monkeypatch, base)
|
||||
monkeypatch.setattr(aug, "data_images_dir", c.data_images_dir)
|
||||
monkeypatch.setattr(aug, "data_labels_dir", c.data_labels_dir)
|
||||
monkeypatch.setattr(aug, "processed_images_dir", c.processed_images_dir)
|
||||
monkeypatch.setattr(aug, "processed_labels_dir", c.processed_labels_dir)
|
||||
monkeypatch.setattr(aug, "processed_dir", c.processed_dir)
|
||||
|
||||
|
||||
def _augment_annotation_with_total(monkeypatch):
|
||||
@@ -58,8 +50,8 @@ def test_pt_aug_01_throughput_ten_images_sixty_seconds(
|
||||
import constants as c
|
||||
from augmentation import Augmentator
|
||||
|
||||
img_dir = Path(c.data_images_dir)
|
||||
lbl_dir = Path(c.data_labels_dir)
|
||||
img_dir = Path(c.config.data_images_dir)
|
||||
lbl_dir = Path(c.config.data_labels_dir)
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
lbl_dir.mkdir(parents=True, exist_ok=True)
|
||||
src_img, src_lbl = sample_images_labels(10)
|
||||
@@ -83,9 +75,9 @@ def test_pt_aug_02_parallel_at_least_one_point_five_x_faster(
|
||||
import constants as c
|
||||
from augmentation import Augmentator
|
||||
|
||||
img_dir = Path(c.data_images_dir)
|
||||
lbl_dir = Path(c.data_labels_dir)
|
||||
proc_dir = Path(c.processed_dir)
|
||||
img_dir = Path(c.config.data_images_dir)
|
||||
lbl_dir = Path(c.config.data_labels_dir)
|
||||
proc_dir = Path(c.config.processed_dir)
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
lbl_dir.mkdir(parents=True, exist_ok=True)
|
||||
src_img, src_lbl = sample_images_labels(10)
|
||||
@@ -93,8 +85,8 @@ def test_pt_aug_02_parallel_at_least_one_point_five_x_faster(
|
||||
shutil.copy2(p, img_dir / p.name)
|
||||
for p in src_lbl.glob("*.txt"):
|
||||
shutil.copy2(p, lbl_dir / p.name)
|
||||
Path(c.processed_images_dir).mkdir(parents=True, exist_ok=True)
|
||||
Path(c.processed_labels_dir).mkdir(parents=True, exist_ok=True)
|
||||
Path(c.config.processed_images_dir).mkdir(parents=True, exist_ok=True)
|
||||
Path(c.config.processed_labels_dir).mkdir(parents=True, exist_ok=True)
|
||||
names = sorted(p.name for p in img_dir.glob("*.jpg"))
|
||||
|
||||
class _E:
|
||||
@@ -113,8 +105,8 @@ def test_pt_aug_02_parallel_at_least_one_point_five_x_faster(
|
||||
seq_elapsed = time.perf_counter() - t0
|
||||
|
||||
shutil.rmtree(proc_dir)
|
||||
Path(c.processed_images_dir).mkdir(parents=True, exist_ok=True)
|
||||
Path(c.processed_labels_dir).mkdir(parents=True, exist_ok=True)
|
||||
Path(c.config.processed_images_dir).mkdir(parents=True, exist_ok=True)
|
||||
Path(c.config.processed_labels_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
aug_par = Augmentator()
|
||||
aug_par.total_images_to_process = len(entries)
|
||||
|
||||
@@ -56,8 +56,8 @@ def _prepare_form_dataset(
|
||||
constants_patch(tmp_path)
|
||||
import train
|
||||
|
||||
proc_img = Path(c_mod.processed_images_dir)
|
||||
proc_lbl = Path(c_mod.processed_labels_dir)
|
||||
proc_img = Path(c_mod.config.processed_images_dir)
|
||||
proc_lbl = Path(c_mod.config.processed_labels_dir)
|
||||
proc_img.mkdir(parents=True, exist_ok=True)
|
||||
proc_lbl.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -70,14 +70,8 @@ def _prepare_form_dataset(
|
||||
if stem in corrupt_stems:
|
||||
dst.write_text("0 1.5 0.5 0.1 0.1\n", encoding="utf-8")
|
||||
|
||||
today_ds = osp.join(c_mod.datasets_dir, train.today_folder)
|
||||
monkeypatch.setattr(train, "today_dataset", today_ds)
|
||||
monkeypatch.setattr(train, "processed_images_dir", c_mod.processed_images_dir)
|
||||
monkeypatch.setattr(train, "processed_labels_dir", c_mod.processed_labels_dir)
|
||||
monkeypatch.setattr(train, "corrupted_images_dir", c_mod.corrupted_images_dir)
|
||||
monkeypatch.setattr(train, "corrupted_labels_dir", c_mod.corrupted_labels_dir)
|
||||
monkeypatch.setattr(train, "datasets_dir", c_mod.datasets_dir)
|
||||
return train
|
||||
today_ds = osp.join(c_mod.config.datasets_dir, train.today_folder)
|
||||
return train, today_ds
|
||||
|
||||
|
||||
@pytest.mark.performance
|
||||
@@ -88,7 +82,7 @@ def test_pt_dsf_01_dataset_formation_under_thirty_seconds(
|
||||
fixture_images_dir,
|
||||
fixture_labels_dir,
|
||||
):
|
||||
train = _prepare_form_dataset(
|
||||
train, today_ds = _prepare_form_dataset(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
constants_patch,
|
||||
|
||||
@@ -42,9 +42,13 @@ def data_yaml_text(monkeypatch, tmp_path, fixture_classes_json):
|
||||
_stub_train_imports()
|
||||
import train
|
||||
|
||||
monkeypatch.setattr(train, "today_dataset", str(tmp_path))
|
||||
import constants as c
|
||||
monkeypatch.setattr(c, "config", c.Config(dirs=c.DirsConfig(root=str(tmp_path))))
|
||||
monkeypatch.setattr(train, "today_folder", "")
|
||||
from pathlib import Path
|
||||
Path(c.config.datasets_dir).mkdir(parents=True, exist_ok=True)
|
||||
train.create_yaml()
|
||||
return (tmp_path / "data.yaml").read_text(encoding="utf-8")
|
||||
return (Path(c.config.datasets_dir) / "data.yaml").read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def test_bt_cls_01_base_classes(fixture_classes_json):
|
||||
|
||||
@@ -18,15 +18,7 @@ from tests.conftest import apply_constants_patch
|
||||
|
||||
|
||||
def _patch_augmentation_paths(monkeypatch, base: Path):
|
||||
import augmentation as aug
|
||||
import constants as c
|
||||
|
||||
apply_constants_patch(monkeypatch, base)
|
||||
monkeypatch.setattr(aug, "data_images_dir", c.data_images_dir)
|
||||
monkeypatch.setattr(aug, "data_labels_dir", c.data_labels_dir)
|
||||
monkeypatch.setattr(aug, "processed_images_dir", c.processed_images_dir)
|
||||
monkeypatch.setattr(aug, "processed_labels_dir", c.processed_labels_dir)
|
||||
monkeypatch.setattr(aug, "processed_dir", c.processed_dir)
|
||||
|
||||
|
||||
def _seed():
|
||||
@@ -210,8 +202,8 @@ def test_bt_aug_07_full_pipeline_five_images_forty_outputs(
|
||||
import constants as c
|
||||
from augmentation import Augmentator
|
||||
|
||||
img_dir = Path(c.data_images_dir)
|
||||
lbl_dir = Path(c.data_labels_dir)
|
||||
img_dir = Path(c.config.data_images_dir)
|
||||
lbl_dir = Path(c.config.data_labels_dir)
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
lbl_dir.mkdir(parents=True, exist_ok=True)
|
||||
src_img, src_lbl = sample_images_labels(5)
|
||||
@@ -220,8 +212,8 @@ def test_bt_aug_07_full_pipeline_five_images_forty_outputs(
|
||||
for p in src_lbl.glob("*.txt"):
|
||||
shutil.copy2(p, lbl_dir / p.name)
|
||||
Augmentator().augment_annotations()
|
||||
proc_img = Path(c.processed_images_dir)
|
||||
proc_lbl = Path(c.processed_labels_dir)
|
||||
proc_img = Path(c.config.processed_images_dir)
|
||||
proc_lbl = Path(c.config.processed_labels_dir)
|
||||
assert len(list(proc_img.glob("*.jpg"))) == 40
|
||||
assert len(list(proc_lbl.glob("*.txt"))) == 40
|
||||
|
||||
@@ -233,10 +225,10 @@ def test_bt_aug_08_skips_already_processed(tmp_path, monkeypatch, sample_images_
|
||||
import constants as c
|
||||
from augmentation import Augmentator
|
||||
|
||||
img_dir = Path(c.data_images_dir)
|
||||
lbl_dir = Path(c.data_labels_dir)
|
||||
proc_img = Path(c.processed_images_dir)
|
||||
proc_lbl = Path(c.processed_labels_dir)
|
||||
img_dir = Path(c.config.data_images_dir)
|
||||
lbl_dir = Path(c.config.data_labels_dir)
|
||||
proc_img = Path(c.config.processed_images_dir)
|
||||
proc_lbl = Path(c.config.processed_labels_dir)
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
lbl_dir.mkdir(parents=True, exist_ok=True)
|
||||
proc_img.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -20,15 +20,7 @@ if "matplotlib" not in sys.modules:
|
||||
|
||||
|
||||
def _patch_augmentation_paths(monkeypatch, base: Path):
|
||||
import augmentation as aug
|
||||
import constants as c
|
||||
|
||||
apply_constants_patch(monkeypatch, base)
|
||||
monkeypatch.setattr(aug, "data_images_dir", c.data_images_dir)
|
||||
monkeypatch.setattr(aug, "data_labels_dir", c.data_labels_dir)
|
||||
monkeypatch.setattr(aug, "processed_images_dir", c.processed_images_dir)
|
||||
monkeypatch.setattr(aug, "processed_labels_dir", c.processed_labels_dir)
|
||||
monkeypatch.setattr(aug, "processed_dir", c.processed_dir)
|
||||
|
||||
|
||||
def _augment_annotation_with_total(monkeypatch):
|
||||
@@ -58,8 +50,8 @@ def test_rt_aug_01_corrupted_image_skipped(
|
||||
import constants as c
|
||||
from augmentation import Augmentator
|
||||
|
||||
img_dir = Path(c.data_images_dir)
|
||||
lbl_dir = Path(c.data_labels_dir)
|
||||
img_dir = Path(c.config.data_images_dir)
|
||||
lbl_dir = Path(c.config.data_labels_dir)
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
lbl_dir.mkdir(parents=True, exist_ok=True)
|
||||
stem = sorted(fixture_images_dir.glob("*.jpg"))[0].stem
|
||||
@@ -68,7 +60,7 @@ def test_rt_aug_01_corrupted_image_skipped(
|
||||
raw = (fixture_images_dir / f"{stem}.jpg").read_bytes()[:200]
|
||||
(img_dir / "corrupted_trunc.jpg").write_bytes(raw)
|
||||
Augmentator().augment_annotations()
|
||||
proc_img = Path(c.processed_images_dir)
|
||||
proc_img = Path(c.config.processed_images_dir)
|
||||
assert len(list(proc_img.glob("*.jpg"))) == 8
|
||||
|
||||
|
||||
@@ -79,8 +71,8 @@ def test_rt_aug_02_missing_label_no_crash(tmp_path, monkeypatch, fixture_images_
|
||||
import constants as c
|
||||
from augmentation import Augmentator
|
||||
|
||||
img_dir = Path(c.data_images_dir)
|
||||
lbl_dir = Path(c.data_labels_dir)
|
||||
img_dir = Path(c.config.data_images_dir)
|
||||
lbl_dir = Path(c.config.data_labels_dir)
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
lbl_dir.mkdir(parents=True, exist_ok=True)
|
||||
stem = "no_label_here"
|
||||
@@ -88,7 +80,7 @@ def test_rt_aug_02_missing_label_no_crash(tmp_path, monkeypatch, fixture_images_
|
||||
aug = Augmentator()
|
||||
aug.total_images_to_process = 1
|
||||
aug.augment_annotation(SimpleNamespace(name=f"{stem}.jpg"))
|
||||
assert len(list(Path(c.processed_images_dir).glob("*.jpg"))) == 0
|
||||
assert len(list(Path(c.config.processed_images_dir).glob("*.jpg"))) == 0
|
||||
|
||||
|
||||
@pytest.mark.resilience
|
||||
|
||||
@@ -55,8 +55,8 @@ def _prepare_form_dataset(
|
||||
constants_patch(tmp_path)
|
||||
import train
|
||||
|
||||
proc_img = Path(c_mod.processed_images_dir)
|
||||
proc_lbl = Path(c_mod.processed_labels_dir)
|
||||
proc_img = Path(c_mod.config.processed_images_dir)
|
||||
proc_lbl = Path(c_mod.config.processed_labels_dir)
|
||||
proc_img.mkdir(parents=True, exist_ok=True)
|
||||
proc_lbl.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -69,14 +69,8 @@ def _prepare_form_dataset(
|
||||
if stem in corrupt_stems:
|
||||
dst.write_text("0 1.5 0.5 0.1 0.1\n", encoding="utf-8")
|
||||
|
||||
today_ds = osp.join(c_mod.datasets_dir, train.today_folder)
|
||||
monkeypatch.setattr(train, "today_dataset", today_ds)
|
||||
monkeypatch.setattr(train, "processed_images_dir", c_mod.processed_images_dir)
|
||||
monkeypatch.setattr(train, "processed_labels_dir", c_mod.processed_labels_dir)
|
||||
monkeypatch.setattr(train, "corrupted_images_dir", c_mod.corrupted_images_dir)
|
||||
monkeypatch.setattr(train, "corrupted_labels_dir", c_mod.corrupted_labels_dir)
|
||||
monkeypatch.setattr(train, "datasets_dir", c_mod.datasets_dir)
|
||||
return train
|
||||
today_ds = osp.join(c_mod.config.datasets_dir, train.today_folder)
|
||||
return train, today_ds
|
||||
|
||||
|
||||
def _count_jpg(p):
|
||||
@@ -90,7 +84,7 @@ def test_bt_dsf_01_split_ratio_70_20_10(
|
||||
fixture_images_dir,
|
||||
fixture_labels_dir,
|
||||
):
|
||||
train = _prepare_form_dataset(
|
||||
train, today_ds = _prepare_form_dataset(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
constants_patch,
|
||||
@@ -100,10 +94,9 @@ def test_bt_dsf_01_split_ratio_70_20_10(
|
||||
set(),
|
||||
)
|
||||
train.form_dataset()
|
||||
base = train.today_dataset
|
||||
assert _count_jpg(Path(base, "train", "images")) == 70
|
||||
assert _count_jpg(Path(base, "valid", "images")) == 20
|
||||
assert _count_jpg(Path(base, "test", "images")) == 10
|
||||
assert _count_jpg(Path(today_ds, "train", "images")) == 70
|
||||
assert _count_jpg(Path(today_ds, "valid", "images")) == 20
|
||||
assert _count_jpg(Path(today_ds, "test", "images")) == 10
|
||||
|
||||
|
||||
def test_bt_dsf_02_six_subdirectories(
|
||||
@@ -113,7 +106,7 @@ def test_bt_dsf_02_six_subdirectories(
|
||||
fixture_images_dir,
|
||||
fixture_labels_dir,
|
||||
):
|
||||
train = _prepare_form_dataset(
|
||||
train, today_ds = _prepare_form_dataset(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
constants_patch,
|
||||
@@ -123,7 +116,7 @@ def test_bt_dsf_02_six_subdirectories(
|
||||
set(),
|
||||
)
|
||||
train.form_dataset()
|
||||
base = Path(train.today_dataset)
|
||||
base = Path(today_ds)
|
||||
assert (base / "train" / "images").is_dir()
|
||||
assert (base / "train" / "labels").is_dir()
|
||||
assert (base / "valid" / "images").is_dir()
|
||||
@@ -139,7 +132,7 @@ def test_bt_dsf_03_total_files_one_hundred(
|
||||
fixture_images_dir,
|
||||
fixture_labels_dir,
|
||||
):
|
||||
train = _prepare_form_dataset(
|
||||
train, today_ds = _prepare_form_dataset(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
constants_patch,
|
||||
@@ -149,11 +142,10 @@ def test_bt_dsf_03_total_files_one_hundred(
|
||||
set(),
|
||||
)
|
||||
train.form_dataset()
|
||||
base = train.today_dataset
|
||||
n = (
|
||||
_count_jpg(Path(base, "train", "images"))
|
||||
+ _count_jpg(Path(base, "valid", "images"))
|
||||
+ _count_jpg(Path(base, "test", "images"))
|
||||
_count_jpg(Path(today_ds, "train", "images"))
|
||||
+ _count_jpg(Path(today_ds, "valid", "images"))
|
||||
+ _count_jpg(Path(today_ds, "test", "images"))
|
||||
)
|
||||
assert n == 100
|
||||
|
||||
@@ -167,7 +159,7 @@ def test_bt_dsf_04_corrupted_labels_quarantined(
|
||||
):
|
||||
stems = [p.stem for p in sorted(fixture_images_dir.glob("*.jpg"))[:100]]
|
||||
corrupt = set(stems[:5])
|
||||
train = _prepare_form_dataset(
|
||||
train, today_ds = _prepare_form_dataset(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
constants_patch,
|
||||
@@ -177,15 +169,14 @@ def test_bt_dsf_04_corrupted_labels_quarantined(
|
||||
corrupt,
|
||||
)
|
||||
train.form_dataset()
|
||||
base = train.today_dataset
|
||||
split_total = (
|
||||
_count_jpg(Path(base, "train", "images"))
|
||||
+ _count_jpg(Path(base, "valid", "images"))
|
||||
+ _count_jpg(Path(base, "test", "images"))
|
||||
_count_jpg(Path(today_ds, "train", "images"))
|
||||
+ _count_jpg(Path(today_ds, "valid", "images"))
|
||||
+ _count_jpg(Path(today_ds, "test", "images"))
|
||||
)
|
||||
assert split_total == 95
|
||||
assert _count_jpg(c_mod.corrupted_images_dir) == 5
|
||||
assert len(list(Path(c_mod.corrupted_labels_dir).glob("*.txt"))) == 5
|
||||
assert _count_jpg(c_mod.config.corrupted_images_dir) == 5
|
||||
assert len(list(Path(c_mod.config.corrupted_labels_dir).glob("*.txt"))) == 5
|
||||
|
||||
|
||||
@pytest.mark.resilience
|
||||
@@ -196,7 +187,7 @@ def test_rt_dsf_01_empty_processed_no_crash(
|
||||
fixture_images_dir,
|
||||
fixture_labels_dir,
|
||||
):
|
||||
train = _prepare_form_dataset(
|
||||
train, today_ds = _prepare_form_dataset(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
constants_patch,
|
||||
@@ -206,8 +197,7 @@ def test_rt_dsf_01_empty_processed_no_crash(
|
||||
set(),
|
||||
)
|
||||
train.form_dataset()
|
||||
base = Path(train.today_dataset)
|
||||
assert base.is_dir()
|
||||
assert Path(today_ds).is_dir()
|
||||
|
||||
|
||||
@pytest.mark.resource_limit
|
||||
@@ -225,7 +215,7 @@ def test_rl_dsf_02_no_filename_duplication_across_splits(
|
||||
fixture_images_dir,
|
||||
fixture_labels_dir,
|
||||
):
|
||||
train = _prepare_form_dataset(
|
||||
train, today_ds = _prepare_form_dataset(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
constants_patch,
|
||||
@@ -235,7 +225,7 @@ def test_rl_dsf_02_no_filename_duplication_across_splits(
|
||||
set(),
|
||||
)
|
||||
train.form_dataset()
|
||||
base = Path(train.today_dataset)
|
||||
base = Path(today_ds)
|
||||
names = []
|
||||
for split in ("train", "valid", "test"):
|
||||
for f in (base / split / "images").glob("*.jpg"):
|
||||
|
||||
@@ -54,6 +54,6 @@ def test_empty_label_file(empty_label):
|
||||
|
||||
def test_constants_patch_uses_tmp(constants_patch, tmp_path):
|
||||
constants_patch(tmp_path)
|
||||
assert c.azaion.startswith(str(tmp_path))
|
||||
assert c.data_dir.startswith(str(tmp_path))
|
||||
assert c.CURRENT_ONNX_MODEL.startswith(str(tmp_path))
|
||||
assert c.config.azaion.startswith(str(tmp_path))
|
||||
assert c.config.data_dir.startswith(str(tmp_path))
|
||||
assert c.config.current_onnx_model.startswith(str(tmp_path))
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
import sys
|
||||
import types
|
||||
import importlib
|
||||
import shutil
|
||||
from os import path as osp
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
for _n in ("boto3", "netron", "requests"):
|
||||
if _n not in sys.modules:
|
||||
sys.modules[_n] = types.ModuleType(_n)
|
||||
|
||||
for _k in [k for k in sys.modules if k == "ultralytics" or k.startswith("ultralytics.")]:
|
||||
del sys.modules[_k]
|
||||
from ultralytics import YOLO
|
||||
|
||||
for _m in ("exports", "train"):
|
||||
if _m in sys.modules:
|
||||
importlib.reload(sys.modules[_m])
|
||||
|
||||
import constants as c
|
||||
import train as train_mod
|
||||
import exports as exports_mod
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
_DATASET_IMAGES = _PROJECT_ROOT / "_docs/00_problem/input_data/dataset/images"
|
||||
_DATASET_LABELS = _PROJECT_ROOT / "_docs/00_problem/input_data/dataset/labels"
|
||||
_CONFIG_TEST = _PROJECT_ROOT / "config.test.yaml"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def e2e_result(tmp_path_factory):
|
||||
base = tmp_path_factory.mktemp("e2e")
|
||||
|
||||
old_config = c.config
|
||||
c.config = c.Config.from_yaml(str(_CONFIG_TEST), root=str(base / "azaion"))
|
||||
|
||||
data_img = Path(c.config.data_images_dir)
|
||||
data_lbl = Path(c.config.data_labels_dir)
|
||||
data_img.mkdir(parents=True)
|
||||
data_lbl.mkdir(parents=True)
|
||||
Path(c.config.models_dir).mkdir(parents=True)
|
||||
|
||||
for img in sorted(_DATASET_IMAGES.glob("*.jpg")):
|
||||
shutil.copy2(img, data_img / img.name)
|
||||
lbl = _DATASET_LABELS / f"{img.stem}.txt"
|
||||
if lbl.exists():
|
||||
shutil.copy2(lbl, data_lbl / lbl.name)
|
||||
|
||||
from augmentation import Augmentator
|
||||
Augmentator().augment_annotations()
|
||||
|
||||
train_mod.train_dataset()
|
||||
|
||||
exports_mod.export_onnx(c.config.current_pt_model)
|
||||
exports_mod.export_coreml(c.config.current_pt_model)
|
||||
|
||||
today_ds = osp.join(c.config.datasets_dir, train_mod.today_folder)
|
||||
|
||||
yield {
|
||||
"today_dataset": today_ds,
|
||||
}
|
||||
|
||||
c.config = old_config
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
class TestTrainingPipeline:
|
||||
def test_augmentation_produced_output(self, e2e_result):
|
||||
proc = Path(c.config.processed_images_dir)
|
||||
assert len(list(proc.glob("*.jpg"))) == 800
|
||||
|
||||
def test_dataset_formed(self, e2e_result):
|
||||
base = Path(e2e_result["today_dataset"])
|
||||
for split in ("train", "valid", "test"):
|
||||
assert (base / split / "images").is_dir()
|
||||
assert (base / split / "labels").is_dir()
|
||||
total = sum(
|
||||
len(list((base / s / "images").glob("*.jpg")))
|
||||
for s in ("train", "valid", "test")
|
||||
)
|
||||
assert total == 800
|
||||
|
||||
def test_data_yaml_created(self, e2e_result):
|
||||
yaml_path = Path(e2e_result["today_dataset"]) / "data.yaml"
|
||||
assert yaml_path.exists()
|
||||
content = yaml_path.read_text()
|
||||
assert "nc: 80" in content
|
||||
assert "train:" in content
|
||||
assert "val:" in content
|
||||
|
||||
def test_training_produces_pt(self, e2e_result):
|
||||
pt = Path(c.config.current_pt_model)
|
||||
assert pt.exists()
|
||||
assert pt.stat().st_size > 0
|
||||
|
||||
def test_export_onnx(self, e2e_result):
|
||||
p = Path(c.config.current_onnx_model)
|
||||
assert p.exists()
|
||||
assert p.suffix == ".onnx"
|
||||
assert p.stat().st_size > 0
|
||||
|
||||
def test_export_coreml(self, e2e_result):
|
||||
pkgs = list(Path(c.config.models_dir).glob("*.mlpackage"))
|
||||
assert len(pkgs) >= 1
|
||||
|
||||
def test_onnx_inference(self, e2e_result):
|
||||
onnx_model = YOLO(c.config.current_onnx_model)
|
||||
img = sorted(_DATASET_IMAGES.glob("*.jpg"))[0]
|
||||
results = onnx_model.predict(source=str(img), imgsz=c.config.export.onnx_imgsz, verbose=False)
|
||||
assert len(results) == 1
|
||||
assert results[0].boxes is not None
|
||||
Reference in New Issue
Block a user