mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 08:46:36 +00:00
142c6c4de8
- Replaced module-level path variables in constants.py with a structured Pydantic Config class. - Updated all relevant modules (train.py, augmentation.py, exports.py, dataset-visualiser.py, manual_run.py) to access paths through the new config structure. - Fixed bugs related to image processing and model saving. - Enhanced test infrastructure to accommodate the new configuration approach. This refactor improves code maintainability and clarity by centralizing configuration management.
84 lines
2.4 KiB
Python
84 lines
2.4 KiB
Python
import re
|
|
import sys
|
|
import types
|
|
|
|
import pytest
|
|
|
|
from dto.annotationClass import AnnotationClass
|
|
|
|
|
|
def _stub_train_imports():
|
|
if getattr(_stub_train_imports, "_done", False):
|
|
return
|
|
for _name in ("ultralytics", "boto3", "netron", "requests"):
|
|
if _name not in sys.modules:
|
|
sys.modules[_name] = types.ModuleType(_name)
|
|
sys.modules["ultralytics"].YOLO = type("YOLO", (), {})
|
|
sys.modules["boto3"].client = lambda *a, **k: None
|
|
_stub_train_imports._done = True
|
|
|
|
|
|
def _name_lines_under_names(text):
|
|
lines = text.splitlines()
|
|
out = []
|
|
in_block = False
|
|
for line in lines:
|
|
s = line.strip()
|
|
if s == "names:":
|
|
in_block = True
|
|
continue
|
|
if s.startswith("nc:"):
|
|
break
|
|
if in_block and s.startswith("-"):
|
|
out.append(s)
|
|
return out
|
|
|
|
|
|
_PLACEHOLDER_RE = re.compile(r"^-\s+Class-\d+\s*$")
|
|
|
|
|
|
@pytest.fixture
|
|
def data_yaml_text(monkeypatch, tmp_path, fixture_classes_json):
|
|
_stub_train_imports()
|
|
import train
|
|
|
|
import constants as c
|
|
monkeypatch.setattr(c, "config", c.Config(dirs=c.DirsConfig(root=str(tmp_path))))
|
|
monkeypatch.setattr(train, "today_folder", "")
|
|
from pathlib import Path
|
|
Path(c.config.datasets_dir).mkdir(parents=True, exist_ok=True)
|
|
train.create_yaml()
|
|
return (Path(c.config.datasets_dir) / "data.yaml").read_text(encoding="utf-8")
|
|
|
|
|
|
def test_bt_cls_01_base_classes(fixture_classes_json):
|
|
d = AnnotationClass.read_json()
|
|
norm = {k: d[k] for k in range(17)}
|
|
assert len(norm) == 17
|
|
assert len({v.id for v in norm.values()}) == 17
|
|
|
|
|
|
def test_bt_cls_02_weather_expansion(fixture_classes_json):
|
|
d = AnnotationClass.read_json()
|
|
assert d[0].name == "ArmorVehicle"
|
|
assert d[20].name == "ArmorVehicle(Wint)"
|
|
assert d[40].name == "ArmorVehicle(Night)"
|
|
|
|
|
|
@pytest.mark.resource_limit
|
|
def test_bt_cls_03_yaml_generation(data_yaml_text):
|
|
text = data_yaml_text
|
|
assert "nc: 80" in text
|
|
names = _name_lines_under_names(text)
|
|
placeholders = [ln for ln in names if _PLACEHOLDER_RE.match(ln)]
|
|
named = [ln for ln in names if not _PLACEHOLDER_RE.match(ln)]
|
|
assert len(names) == 80
|
|
assert len(placeholders) == 29
|
|
assert len(named) == 51
|
|
|
|
|
|
@pytest.mark.resource_limit
|
|
def test_rl_cls_01_total_class_count(data_yaml_text):
|
|
names = _name_lines_under_names(data_yaml_text)
|
|
assert len(names) == 80
|