mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 08:56:35 +00:00
142c6c4de8
- Replaced module-level path variables in constants.py with a structured Pydantic Config class. - Updated all relevant modules (train.py, augmentation.py, exports.py, dataset-visualiser.py, manual_run.py) to access paths through the new config structure. - Fixed bugs related to image processing and model saving. - Enhanced test infrastructure to accommodate the new configuration approach. This refactor improves code maintainability and clarity by centralizing configuration management.
98 lines
2.2 KiB
Python
98 lines
2.2 KiB
Python
import shutil
|
|
import sys
|
|
import time
|
|
import types
|
|
from os import path as osp
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
import constants as c_mod
|
|
|
|
|
|
def _stub_train_dependencies():
|
|
if getattr(_stub_train_dependencies, "_done", False):
|
|
return
|
|
|
|
def add_mod(name):
|
|
if name in sys.modules:
|
|
return sys.modules[name]
|
|
m = types.ModuleType(name)
|
|
sys.modules[name] = m
|
|
return m
|
|
|
|
ultra = add_mod("ultralytics")
|
|
|
|
class YOLO:
|
|
pass
|
|
|
|
ultra.YOLO = YOLO
|
|
|
|
def fake_client(*_a, **_k):
|
|
return types.SimpleNamespace(
|
|
upload_fileobj=lambda *_a, **_k: None,
|
|
download_file=lambda *_a, **_k: None,
|
|
)
|
|
|
|
boto = add_mod("boto3")
|
|
boto.client = fake_client
|
|
add_mod("netron")
|
|
add_mod("requests")
|
|
_stub_train_dependencies._done = True
|
|
|
|
|
|
_stub_train_dependencies()
|
|
|
|
|
|
def _prepare_form_dataset(
|
|
monkeypatch,
|
|
tmp_path,
|
|
constants_patch,
|
|
fixture_images_dir,
|
|
fixture_labels_dir,
|
|
count,
|
|
corrupt_stems,
|
|
):
|
|
constants_patch(tmp_path)
|
|
import train
|
|
|
|
proc_img = Path(c_mod.config.processed_images_dir)
|
|
proc_lbl = Path(c_mod.config.processed_labels_dir)
|
|
proc_img.mkdir(parents=True, exist_ok=True)
|
|
proc_lbl.mkdir(parents=True, exist_ok=True)
|
|
|
|
imgs = sorted(fixture_images_dir.glob("*.jpg"))[:count]
|
|
for p in imgs:
|
|
stem = p.stem
|
|
shutil.copy2(fixture_images_dir / f"{stem}.jpg", proc_img / f"{stem}.jpg")
|
|
dst = proc_lbl / f"{stem}.txt"
|
|
shutil.copy2(fixture_labels_dir / f"{stem}.txt", dst)
|
|
if stem in corrupt_stems:
|
|
dst.write_text("0 1.5 0.5 0.1 0.1\n", encoding="utf-8")
|
|
|
|
today_ds = osp.join(c_mod.config.datasets_dir, train.today_folder)
|
|
return train, today_ds
|
|
|
|
|
|
@pytest.mark.performance
|
|
def test_pt_dsf_01_dataset_formation_under_thirty_seconds(
|
|
monkeypatch,
|
|
tmp_path,
|
|
constants_patch,
|
|
fixture_images_dir,
|
|
fixture_labels_dir,
|
|
):
|
|
train, today_ds = _prepare_form_dataset(
|
|
monkeypatch,
|
|
tmp_path,
|
|
constants_patch,
|
|
fixture_images_dir,
|
|
fixture_labels_dir,
|
|
100,
|
|
set(),
|
|
)
|
|
t0 = time.perf_counter()
|
|
train.form_dataset()
|
|
elapsed = time.perf_counter() - t0
|
|
assert elapsed <= 30.0
|