mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-23 03:06:35 +00:00
[AZ-171] Add TensorRT tests, AC coverage gate in implement skill, optimize test infrastructure
- Add TensorRT export tests with graceful skip when no GPU available - Add AC test coverage verification step (Step 8) to implement skill - Add test coverage gap analysis to new-task skill - Move exported_models fixture to conftest.py as session-scoped (shared across modules) - Reorder tests: e2e training runs first so images/labels are available for all tests - Consolidate teardown into single session-level cleanup in conftest.py - Fix infrastructure tests to count files dynamically instead of hardcoded 20 Made-with: Cursor
This commit is contained in:
@@ -94,6 +94,7 @@ For each task in the batch, launch an `implementer` subagent with:
|
|||||||
- List of files OWNED (exclusive write access)
|
- List of files OWNED (exclusive write access)
|
||||||
- List of files READ-ONLY
|
- List of files READ-ONLY
|
||||||
- List of files FORBIDDEN
|
- List of files FORBIDDEN
|
||||||
|
- **Explicit instruction**: the implementer must write or update tests that validate each acceptance criterion in the task spec. If a test cannot run in the current environment (e.g., TensorRT requires GPU), the test must still be written and skip with a clear reason.
|
||||||
|
|
||||||
Launch all subagents immediately — no user confirmation.
|
Launch all subagents immediately — no user confirmation.
|
||||||
|
|
||||||
@@ -108,46 +109,64 @@ Launch all subagents immediately — no user confirmation.
|
|||||||
- Subagent has not produced new output for an extended period → flag as potentially hung
|
- Subagent has not produced new output for an extended period → flag as potentially hung
|
||||||
- If a subagent is flagged as stuck, do NOT let it continue looping — stop it and record the blocker in the batch report
|
- If a subagent is flagged as stuck, do NOT let it continue looping — stop it and record the blocker in the batch report
|
||||||
|
|
||||||
### 8. Code Review
|
### 8. AC Test Coverage Verification
|
||||||
|
|
||||||
|
Before code review, verify that every acceptance criterion in each task spec has at least one test that validates it. For each task in the batch:
|
||||||
|
|
||||||
|
1. Read the task spec's **Acceptance Criteria** section
|
||||||
|
2. Search the test files (new and existing) for tests that cover each AC
|
||||||
|
3. Classify each AC as:
|
||||||
|
- **Covered**: a test directly validates this AC (running or skipped-with-reason)
|
||||||
|
- **Not covered**: no test exists for this AC
|
||||||
|
|
||||||
|
If any AC is **Not covered**:
|
||||||
|
- This is a **BLOCKING** failure — the implementer must write the missing test before proceeding
|
||||||
|
- Re-launch the implementer with the specific ACs that need tests
|
||||||
|
- If the test cannot run in the current environment (GPU required, platform-specific, external service), the test must still exist and skip with `pytest.mark.skipif` or `pytest.skip()` explaining the prerequisite
|
||||||
|
- A skipped test counts as **Covered** — the test exists and will run when the environment allows
|
||||||
|
|
||||||
|
Only proceed to Step 9 when every AC has a corresponding test.
|
||||||
|
|
||||||
|
### 9. Code Review
|
||||||
|
|
||||||
- Run `/code-review` skill on the batch's changed files + corresponding task specs
|
- Run `/code-review` skill on the batch's changed files + corresponding task specs
|
||||||
- The code-review skill produces a verdict: PASS, PASS_WITH_WARNINGS, or FAIL
|
- The code-review skill produces a verdict: PASS, PASS_WITH_WARNINGS, or FAIL
|
||||||
|
|
||||||
### 9. Auto-Fix Gate
|
### 10. Auto-Fix Gate
|
||||||
|
|
||||||
Auto-fix loop with bounded retries (max 2 attempts) before escalating to user:
|
Auto-fix loop with bounded retries (max 2 attempts) before escalating to user:
|
||||||
|
|
||||||
1. If verdict is **PASS** or **PASS_WITH_WARNINGS**: show findings as info, continue automatically to step 10
|
1. If verdict is **PASS** or **PASS_WITH_WARNINGS**: show findings as info, continue automatically to step 11
|
||||||
2. If verdict is **FAIL** (attempt 1 or 2):
|
2. If verdict is **FAIL** (attempt 1 or 2):
|
||||||
- Parse the code review findings (Critical and High severity items)
|
- Parse the code review findings (Critical and High severity items)
|
||||||
- For each finding, attempt an automated fix using the finding's location, description, and suggestion
|
- For each finding, attempt an automated fix using the finding's location, description, and suggestion
|
||||||
- Re-run `/code-review` on the modified files
|
- Re-run `/code-review` on the modified files
|
||||||
- If now PASS or PASS_WITH_WARNINGS → continue to step 10
|
- If now PASS or PASS_WITH_WARNINGS → continue to step 11
|
||||||
- If still FAIL → increment retry counter, repeat from (2) up to max 2 attempts
|
- If still FAIL → increment retry counter, repeat from (2) up to max 2 attempts
|
||||||
3. If still **FAIL** after 2 auto-fix attempts: present all findings to user (**BLOCKING**). User must confirm fixes or accept before proceeding.
|
3. If still **FAIL** after 2 auto-fix attempts: present all findings to user (**BLOCKING**). User must confirm fixes or accept before proceeding.
|
||||||
|
|
||||||
Track `auto_fix_attempts` count in the batch report for retrospective analysis.
|
Track `auto_fix_attempts` count in the batch report for retrospective analysis.
|
||||||
|
|
||||||
### 10. Commit and Push
|
### 11. Commit and Push
|
||||||
|
|
||||||
- After user confirms the batch (explicitly for FAIL, implicitly for PASS/PASS_WITH_WARNINGS):
|
- After user confirms the batch (explicitly for FAIL, implicitly for PASS/PASS_WITH_WARNINGS):
|
||||||
- `git add` all changed files from the batch
|
- `git add` all changed files from the batch
|
||||||
- `git commit` with a message that includes ALL task IDs (tracker IDs or numeric prefixes) of tasks implemented in the batch, followed by a summary of what was implemented. Format: `[TASK-ID-1] [TASK-ID-2] ... Summary of changes`
|
- `git commit` with a message that includes ALL task IDs (tracker IDs or numeric prefixes) of tasks implemented in the batch, followed by a summary of what was implemented. Format: `[TASK-ID-1] [TASK-ID-2] ... Summary of changes`
|
||||||
- `git push` to the remote branch
|
- `git push` to the remote branch
|
||||||
|
|
||||||
### 11. Update Tracker Status → In Testing
|
### 12. Update Tracker Status → In Testing
|
||||||
|
|
||||||
After the batch is committed and pushed, transition the ticket status of each task in the batch to **In Testing** via the configured work item tracker. If `tracker: local`, skip this step.
|
After the batch is committed and pushed, transition the ticket status of each task in the batch to **In Testing** via the configured work item tracker. If `tracker: local`, skip this step.
|
||||||
|
|
||||||
### 12. Archive Completed Tasks
|
### 13. Archive Completed Tasks
|
||||||
|
|
||||||
Move each completed task file from `TASKS_DIR/todo/` to `TASKS_DIR/done/`.
|
Move each completed task file from `TASKS_DIR/todo/` to `TASKS_DIR/done/`.
|
||||||
|
|
||||||
### 13. Loop
|
### 14. Loop
|
||||||
|
|
||||||
- Go back to step 2 until all tasks in `todo/` are done
|
- Go back to step 2 until all tasks in `todo/` are done
|
||||||
|
|
||||||
### 14. Final Test Run
|
### 15. Final Test Run
|
||||||
|
|
||||||
- After all batches are complete, run the full test suite once
|
- After all batches are complete, run the full test suite once
|
||||||
- Read and execute `.cursor/skills/test-run/SKILL.md` (detect runner, run suite, diagnose failures, present blocking choices)
|
- Read and execute `.cursor/skills/test-run/SKILL.md` (detect runner, run suite, diagnose failures, present blocking choices)
|
||||||
@@ -177,10 +196,11 @@ After each batch, produce a structured report:
|
|||||||
|
|
||||||
## Task Results
|
## Task Results
|
||||||
|
|
||||||
| Task | Status | Files Modified | Tests | Issues |
|
| Task | Status | Files Modified | Tests | AC Coverage | Issues |
|
||||||
|------|--------|---------------|-------|--------|
|
|------|--------|---------------|-------|-------------|--------|
|
||||||
| [TRACKER-ID]_[name] | Done | [count] files | [pass/fail] | [count or None] |
|
| [TRACKER-ID]_[name] | Done | [count] files | [pass/fail] | [N/N ACs covered] | [count or None] |
|
||||||
|
|
||||||
|
## AC Test Coverage: [All covered / X of Y covered]
|
||||||
## Code Review Verdict: [PASS/FAIL/PASS_WITH_WARNINGS]
|
## Code Review Verdict: [PASS/FAIL/PASS_WITH_WARNINGS]
|
||||||
## Auto-Fix Attempts: [0/1/2]
|
## Auto-Fix Attempts: [0/1/2]
|
||||||
## Stuck Agents: [count or None]
|
## Stuck Agents: [count or None]
|
||||||
@@ -212,4 +232,4 @@ Each batch commit serves as a rollback checkpoint. If recovery is needed:
|
|||||||
- Never launch tasks whose dependencies are not yet completed
|
- Never launch tasks whose dependencies are not yet completed
|
||||||
- Never allow two parallel agents to write to the same file
|
- Never allow two parallel agents to write to the same file
|
||||||
- If a subagent fails or is flagged as stuck, stop it and report — do not let it loop indefinitely
|
- If a subagent fails or is flagged as stuck, stop it and report — do not let it loop indefinitely
|
||||||
- Always run the full test suite after all batches complete (step 14)
|
- Always run the full test suite after all batches complete (step 15)
|
||||||
|
|||||||
@@ -129,7 +129,7 @@ The `<task_slug>` is a short kebab-case name derived from the feature descriptio
|
|||||||
### Step 4: Codebase Analysis
|
### Step 4: Codebase Analysis
|
||||||
|
|
||||||
**Role**: Software architect
|
**Role**: Software architect
|
||||||
**Goal**: Determine where and how to insert the new functionality.
|
**Goal**: Determine where and how to insert the new functionality, and whether existing tests cover the new requirements.
|
||||||
|
|
||||||
1. Read the codebase documentation from DOCUMENT_DIR:
|
1. Read the codebase documentation from DOCUMENT_DIR:
|
||||||
- `architecture.md` — overall structure
|
- `architecture.md` — overall structure
|
||||||
@@ -144,6 +144,10 @@ The `<task_slug>` is a short kebab-case name derived from the feature descriptio
|
|||||||
- What new interfaces or models are needed
|
- What new interfaces or models are needed
|
||||||
- How data flows through the change
|
- How data flows through the change
|
||||||
4. If the change is complex enough, read the actual source files (not just docs) to verify insertion points
|
4. If the change is complex enough, read the actual source files (not just docs) to verify insertion points
|
||||||
|
5. **Test coverage gap analysis**: Read existing test files that cover the affected components. For each acceptance criterion from Step 1, determine whether an existing test already validates it. Classify each AC as:
|
||||||
|
- **Covered**: an existing test directly validates this behavior
|
||||||
|
- **Partially covered**: an existing test exercises the code path but doesn't assert the new requirement
|
||||||
|
- **Not covered**: no existing test validates this behavior — a new test is required
|
||||||
|
|
||||||
Present the analysis:
|
Present the analysis:
|
||||||
|
|
||||||
@@ -156,9 +160,22 @@ Present the analysis:
|
|||||||
Interface changes: [list or "None"]
|
Interface changes: [list or "None"]
|
||||||
New interfaces: [list or "None"]
|
New interfaces: [list or "None"]
|
||||||
Data flow impact: [summary]
|
Data flow impact: [summary]
|
||||||
|
─────────────────────────────────────
|
||||||
|
TEST COVERAGE GAP ANALYSIS
|
||||||
|
─────────────────────────────────────
|
||||||
|
AC-1: [Covered / Partially covered / Not covered]
|
||||||
|
[existing test name or "needs new test"]
|
||||||
|
AC-2: [Covered / Partially covered / Not covered]
|
||||||
|
[existing test name or "needs new test"]
|
||||||
|
...
|
||||||
|
─────────────────────────────────────
|
||||||
|
New tests needed: [count]
|
||||||
|
Existing tests to update: [count or "None"]
|
||||||
══════════════════════════════════════
|
══════════════════════════════════════
|
||||||
```
|
```
|
||||||
|
|
||||||
|
When gaps are found, the task spec (Step 6) MUST include the missing tests in the Scope (Included) section and the Unit/Blackbox Tests tables. Tests are not optional — if an AC is not covered by an existing test, the task must deliver a test for it.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
### Step 5: Validate Assumptions
|
### Step 5: Validate Assumptions
|
||||||
|
|||||||
@@ -0,0 +1,17 @@
|
|||||||
|
# Batch Report
|
||||||
|
|
||||||
|
**Batch**: 3
|
||||||
|
**Tasks**: AZ-171_dynamic_batch_export
|
||||||
|
**Date**: 2026-03-28
|
||||||
|
|
||||||
|
## Task Results
|
||||||
|
|
||||||
|
| Task | Status | Files Modified | Tests | Issues |
|
||||||
|
|------|--------|---------------|-------|--------|
|
||||||
|
| AZ-171_dynamic_batch_export | Done | 2 files (src/exports.py, _docs/02_document/architecture.md) | 48 passed, 14 skipped, 6 errors (pre-existing) | None |
|
||||||
|
|
||||||
|
## Code Review Verdict: PASS
|
||||||
|
## Auto-Fix Attempts: 0
|
||||||
|
## Stuck Agents: None
|
||||||
|
|
||||||
|
## Next Batch: All tasks complete
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
# Implementation Report — Dynamic Batch Export
|
||||||
|
|
||||||
|
**Date**: 2026-03-28
|
||||||
|
**Epic**: AZ-164 (Code Improvements)
|
||||||
|
**Total Tasks**: 1
|
||||||
|
**Total Batches**: 1
|
||||||
|
**Commit**: 433e080
|
||||||
|
|
||||||
|
## Summary
|
||||||
|
|
||||||
|
Enabled dynamic batch size for all three model export formats (ONNX, TensorRT, CoreML) by adding `dynamic=True` to the ultralytics export calls. TensorRT max batch size set to 8.
|
||||||
|
|
||||||
|
## Tasks Implemented
|
||||||
|
|
||||||
|
| Task | Name | Complexity | Status |
|
||||||
|
|------|------|-----------|--------|
|
||||||
|
| AZ-171 | dynamic_batch_export | 2 | Done |
|
||||||
|
|
||||||
|
## Changes
|
||||||
|
|
||||||
|
| File | Change |
|
||||||
|
|------|--------|
|
||||||
|
| src/exports.py | Added `dynamic=True` to export_onnx, export_tensorrt, export_coreml; changed TensorRT batch from 4 to 8 |
|
||||||
|
| _docs/02_document/architecture.md | Updated Model Artifacts table to reflect dynamic batch support |
|
||||||
|
|
||||||
|
## Test Results
|
||||||
|
|
||||||
|
- 48 passed, 14 skipped, 6 errors (pre-existing ModuleNotFoundError for onnx package in e2e tests — environment dependency, not introduced by this change)
|
||||||
+47
-2
@@ -1,5 +1,6 @@
|
|||||||
import csv
|
import csv
|
||||||
import shutil
|
import shutil
|
||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -12,12 +13,21 @@ _DATASET_LABELS = _TEST_ROOT / "data" / "labels"
|
|||||||
_ONNX_MODEL = _PROJECT_ROOT / "_docs/00_problem/input_data/azaion.onnx"
|
_ONNX_MODEL = _PROJECT_ROOT / "_docs/00_problem/input_data/azaion.onnx"
|
||||||
_CLASSES_JSON = _PROJECT_ROOT / "src" / "classes.json"
|
_CLASSES_JSON = _PROJECT_ROOT / "src" / "classes.json"
|
||||||
_CONFIG_TEST = _PROJECT_ROOT / "config.test.yaml"
|
_CONFIG_TEST = _PROJECT_ROOT / "config.test.yaml"
|
||||||
|
_MODELS_DIR = _TEST_ROOT / "models"
|
||||||
|
|
||||||
collect_ignore = ["security_test.py", "imagelabel_visualize_test.py"]
|
collect_ignore = ["security_test.py", "imagelabel_visualize_test.py"]
|
||||||
|
|
||||||
|
_E2E_MODULE = "test_training_e2e"
|
||||||
|
|
||||||
_test_results = []
|
_test_results = []
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_collection_modifyitems(items):
|
||||||
|
e2e = [i for i in items if _E2E_MODULE in i.nodeid]
|
||||||
|
rest = [i for i in items if _E2E_MODULE not in i.nodeid]
|
||||||
|
items[:] = e2e + rest
|
||||||
|
|
||||||
|
|
||||||
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
|
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
|
||||||
def pytest_runtest_makereport(item, call):
|
def pytest_runtest_makereport(item, call):
|
||||||
outcome = yield
|
outcome = yield
|
||||||
@@ -32,8 +42,7 @@ def pytest_runtest_makereport(item, call):
|
|||||||
|
|
||||||
|
|
||||||
def pytest_sessionfinish(session, exitstatus):
|
def pytest_sessionfinish(session, exitstatus):
|
||||||
if not _test_results:
|
if _test_results:
|
||||||
return
|
|
||||||
results_dir = Path(__file__).resolve().parent / "test-results"
|
results_dir = Path(__file__).resolve().parent / "test-results"
|
||||||
results_dir.mkdir(exist_ok=True)
|
results_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
@@ -43,6 +52,12 @@ def pytest_sessionfinish(session, exitstatus):
|
|||||||
for r in _test_results:
|
for r in _test_results:
|
||||||
writer.writerow([r["module"], r["name"], r["result"], f"{r['duration']:.3f}"])
|
writer.writerow([r["module"], r["name"], r["result"], f"{r['duration']:.3f}"])
|
||||||
|
|
||||||
|
import constants as c
|
||||||
|
test_config = c.Config.from_yaml(str(_CONFIG_TEST), root=str(_TEST_ROOT))
|
||||||
|
for d in (_DATASET_IMAGES, _DATASET_LABELS, test_config.datasets_dir,
|
||||||
|
test_config.corrupted_dir, str(_MODELS_DIR)):
|
||||||
|
shutil.rmtree(str(d), ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
def apply_constants_patch(monkeypatch, base: Path):
|
def apply_constants_patch(monkeypatch, base: Path):
|
||||||
import constants as c
|
import constants as c
|
||||||
@@ -157,3 +172,33 @@ def empty_label(tmp_path):
|
|||||||
p.parent.mkdir(parents=True, exist_ok=True)
|
p.parent.mkdir(parents=True, exist_ok=True)
|
||||||
p.write_text("", encoding="utf-8")
|
p.write_text("", encoding="utf-8")
|
||||||
return p
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def exported_models():
|
||||||
|
from ultralytics import YOLO
|
||||||
|
import constants as c
|
||||||
|
import exports as exports_mod
|
||||||
|
|
||||||
|
_MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
pt_path = str(_MODELS_DIR / "test.pt")
|
||||||
|
YOLO("yolo11n.pt").save(pt_path)
|
||||||
|
|
||||||
|
old_config = c.config
|
||||||
|
c.config = c.Config.from_yaml(str(_CONFIG_TEST), root=str(_TEST_ROOT))
|
||||||
|
imgsz = c.config.export.onnx_imgsz
|
||||||
|
|
||||||
|
exports_mod.export_onnx(pt_path)
|
||||||
|
if sys.platform == "darwin":
|
||||||
|
exports_mod.export_coreml(pt_path)
|
||||||
|
|
||||||
|
c.config = old_config
|
||||||
|
|
||||||
|
onnx_files = list(_MODELS_DIR.glob("test*.onnx"))
|
||||||
|
return {
|
||||||
|
"onnx": str(onnx_files[0]) if onnx_files else None,
|
||||||
|
"model_dir": _MODELS_DIR,
|
||||||
|
"pt_path": pt_path,
|
||||||
|
"imgsz": imgsz,
|
||||||
|
}
|
||||||
|
|||||||
+86
-28
@@ -5,41 +5,23 @@ import cv2
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
|
|
||||||
import constants as c
|
import constants as c
|
||||||
import exports as exports_mod
|
import exports as exports_mod
|
||||||
|
|
||||||
|
_HAS_TENSORRT = torch.cuda.is_available()
|
||||||
|
try:
|
||||||
|
import tensorrt
|
||||||
|
except ImportError:
|
||||||
|
_HAS_TENSORRT = False
|
||||||
|
|
||||||
_TESTS_DIR = Path(__file__).resolve().parent
|
_TESTS_DIR = Path(__file__).resolve().parent
|
||||||
_CONFIG_TEST = _TESTS_DIR.parent / "config.test.yaml"
|
_CONFIG_TEST = _TESTS_DIR.parent / "config.test.yaml"
|
||||||
_DATASET_IMAGES = _TESTS_DIR / "root" / "data" / "images"
|
_DATASET_IMAGES = _TESTS_DIR / "root" / "data" / "images"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def exported_models(tmp_path_factory):
|
|
||||||
# Arrange
|
|
||||||
tmp = tmp_path_factory.mktemp("export")
|
|
||||||
model_dir = tmp / "models"
|
|
||||||
model_dir.mkdir()
|
|
||||||
|
|
||||||
pt_path = str(model_dir / "test.pt")
|
|
||||||
YOLO("yolo11n.pt").save(pt_path)
|
|
||||||
|
|
||||||
old_config = c.config
|
|
||||||
c.config = c.Config.from_yaml(str(_CONFIG_TEST), root=str(tmp))
|
|
||||||
|
|
||||||
# Act
|
|
||||||
exports_mod.export_onnx(pt_path)
|
|
||||||
exports_mod.export_coreml(pt_path)
|
|
||||||
|
|
||||||
yield {
|
|
||||||
"onnx": str(next(model_dir.glob("*.onnx"))),
|
|
||||||
"model_dir": model_dir,
|
|
||||||
}
|
|
||||||
|
|
||||||
c.config = old_config
|
|
||||||
|
|
||||||
|
|
||||||
class TestOnnxExport:
|
class TestOnnxExport:
|
||||||
def test_onnx_file_created(self, exported_models):
|
def test_onnx_file_created(self, exported_models):
|
||||||
# Assert
|
# Assert
|
||||||
@@ -59,7 +41,7 @@ class TestOnnxExport:
|
|||||||
# Arrange
|
# Arrange
|
||||||
session = ort.InferenceSession(exported_models["onnx"], providers=["CPUExecutionProvider"])
|
session = ort.InferenceSession(exported_models["onnx"], providers=["CPUExecutionProvider"])
|
||||||
meta = session.get_inputs()[0]
|
meta = session.get_inputs()[0]
|
||||||
imgsz = c.config.export.onnx_imgsz
|
imgsz = exported_models["imgsz"]
|
||||||
imgs = sorted(_DATASET_IMAGES.glob("*.jpg"))
|
imgs = sorted(_DATASET_IMAGES.glob("*.jpg"))
|
||||||
if not imgs:
|
if not imgs:
|
||||||
pytest.skip("no test images")
|
pytest.skip("no test images")
|
||||||
@@ -77,7 +59,7 @@ class TestOnnxExport:
|
|||||||
# Arrange
|
# Arrange
|
||||||
session = ort.InferenceSession(exported_models["onnx"], providers=["CPUExecutionProvider"])
|
session = ort.InferenceSession(exported_models["onnx"], providers=["CPUExecutionProvider"])
|
||||||
meta = session.get_inputs()[0]
|
meta = session.get_inputs()[0]
|
||||||
imgsz = c.config.export.onnx_imgsz
|
imgsz = exported_models["imgsz"]
|
||||||
imgs = sorted(_DATASET_IMAGES.glob("*.jpg"))
|
imgs = sorted(_DATASET_IMAGES.glob("*.jpg"))
|
||||||
if not imgs:
|
if not imgs:
|
||||||
pytest.skip("no test images")
|
pytest.skip("no test images")
|
||||||
@@ -93,6 +75,82 @@ class TestOnnxExport:
|
|||||||
assert out[0].shape[0] == 4
|
assert out[0].shape[0] == 4
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not _HAS_TENSORRT, reason="TensorRT requires NVIDIA GPU and tensorrt package")
|
||||||
|
class TestTensorrtExport:
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def tensorrt_model(self, exported_models):
|
||||||
|
# Arrange
|
||||||
|
model_dir = exported_models["model_dir"]
|
||||||
|
pt_path = exported_models["pt_path"]
|
||||||
|
old_config = c.config
|
||||||
|
c.config = c.Config.from_yaml(str(_CONFIG_TEST), root=str(model_dir.parent))
|
||||||
|
|
||||||
|
# Act
|
||||||
|
exports_mod.export_tensorrt(pt_path)
|
||||||
|
|
||||||
|
c.config = old_config
|
||||||
|
engines = list(model_dir.glob("*.engine"))
|
||||||
|
yield {
|
||||||
|
"engine": str(engines[0]) if engines else None,
|
||||||
|
"model_dir": model_dir,
|
||||||
|
"imgsz": exported_models["imgsz"],
|
||||||
|
}
|
||||||
|
|
||||||
|
for e in model_dir.glob("*.engine"):
|
||||||
|
e.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
def test_tensorrt_engine_created(self, tensorrt_model):
|
||||||
|
# Assert
|
||||||
|
assert tensorrt_model["engine"] is not None
|
||||||
|
p = Path(tensorrt_model["engine"])
|
||||||
|
assert p.exists()
|
||||||
|
assert p.stat().st_size > 0
|
||||||
|
|
||||||
|
def test_tensorrt_inference_batch_1(self, tensorrt_model):
|
||||||
|
# Arrange
|
||||||
|
assert tensorrt_model["engine"] is not None
|
||||||
|
imgs = sorted(_DATASET_IMAGES.glob("*.jpg"))
|
||||||
|
if not imgs:
|
||||||
|
pytest.skip("no test images")
|
||||||
|
model = YOLO(tensorrt_model["engine"])
|
||||||
|
|
||||||
|
# Act
|
||||||
|
results = model.predict(source=str(imgs[0]), imgsz=tensorrt_model["imgsz"], verbose=False)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].boxes is not None
|
||||||
|
|
||||||
|
def test_tensorrt_inference_batch_multiple(self, tensorrt_model):
|
||||||
|
# Arrange
|
||||||
|
assert tensorrt_model["engine"] is not None
|
||||||
|
imgs = sorted(_DATASET_IMAGES.glob("*.jpg"))
|
||||||
|
if len(imgs) < 4:
|
||||||
|
pytest.skip("need at least 4 test images")
|
||||||
|
model = YOLO(tensorrt_model["engine"])
|
||||||
|
|
||||||
|
# Act
|
||||||
|
results = model.predict(source=[str(p) for p in imgs[:4]], imgsz=tensorrt_model["imgsz"], verbose=False)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(results) == 4
|
||||||
|
|
||||||
|
def test_tensorrt_inference_batch_max(self, tensorrt_model):
|
||||||
|
# Arrange
|
||||||
|
assert tensorrt_model["engine"] is not None
|
||||||
|
imgs = sorted(_DATASET_IMAGES.glob("*.jpg"))
|
||||||
|
if not imgs:
|
||||||
|
pytest.skip("no test images")
|
||||||
|
model = YOLO(tensorrt_model["engine"])
|
||||||
|
sources = [str(imgs[0])] * 8
|
||||||
|
|
||||||
|
# Act
|
||||||
|
results = model.predict(source=sources, imgsz=tensorrt_model["imgsz"], verbose=False)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(results) == 8
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(sys.platform != "darwin", reason="CoreML requires macOS")
|
@pytest.mark.skipif(sys.platform != "darwin", reason="CoreML requires macOS")
|
||||||
class TestCoremlExport:
|
class TestCoremlExport:
|
||||||
def test_coreml_package_created(self, exported_models):
|
def test_coreml_package_created(self, exported_models):
|
||||||
@@ -117,7 +175,7 @@ class TestCoremlExport:
|
|||||||
model = YOLO(str(pkgs[0]))
|
model = YOLO(str(pkgs[0]))
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
results = model.predict(source=str(imgs[0]), imgsz=c.config.export.onnx_imgsz, verbose=False)
|
results = model.predict(source=str(imgs[0]), imgsz=exported_models["imgsz"], verbose=False)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
|
|||||||
@@ -3,12 +3,14 @@ import constants as c
|
|||||||
|
|
||||||
def test_fixture_images_dir_has_jpegs(fixture_images_dir):
|
def test_fixture_images_dir_has_jpegs(fixture_images_dir):
|
||||||
jpgs = list(fixture_images_dir.glob("*.jpg"))
|
jpgs = list(fixture_images_dir.glob("*.jpg"))
|
||||||
assert len(jpgs) == 20
|
assert len(jpgs) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_fixture_labels_dir_has_yolo_labels(fixture_labels_dir):
|
def test_fixture_labels_dir_has_yolo_labels(fixture_labels_dir, fixture_images_dir):
|
||||||
txts = list(fixture_labels_dir.glob("*.txt"))
|
txts = list(fixture_labels_dir.glob("*.txt"))
|
||||||
assert len(txts) == 20
|
jpgs = list(fixture_images_dir.glob("*.jpg"))
|
||||||
|
assert len(txts) > 0
|
||||||
|
assert len(txts) == len(jpgs)
|
||||||
|
|
||||||
|
|
||||||
def test_fixture_onnx_model_bytes(fixture_onnx_model):
|
def test_fixture_onnx_model_bytes(fixture_onnx_model):
|
||||||
|
|||||||
@@ -62,11 +62,6 @@ def e2e_result():
|
|||||||
"linked_count": linked_count,
|
"linked_count": linked_count,
|
||||||
}
|
}
|
||||||
|
|
||||||
shutil.rmtree(str(dst_images), ignore_errors=True)
|
|
||||||
shutil.rmtree(str(dst_labels), ignore_errors=True)
|
|
||||||
shutil.rmtree(c.config.datasets_dir, ignore_errors=True)
|
|
||||||
shutil.rmtree(c.config.models_dir, ignore_errors=True)
|
|
||||||
shutil.rmtree(c.config.corrupted_dir, ignore_errors=True)
|
|
||||||
c.config = old_config
|
c.config = old_config
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user