mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 08:46:32 +00:00
27f4aceb52
- Updated the `Inference` class to replace the `get_onnx_engine_bytes` method with `download_model`, allowing for dynamic model loading based on a specified filename. - Modified the `convert_and_upload_model` method to accept `source_bytes` instead of `onnx_engine_bytes`, enhancing flexibility in model conversion. - Introduced a new property `engine_name` to the `Inference` class for better access to engine details. - Adjusted the `AIRecognitionConfig` structure to include a new method pointer `from_dict`, improving configuration handling. - Updated various test cases to reflect changes in model paths and timeout settings, ensuring consistency and reliability in testing.
227 lines
6.6 KiB
Python
227 lines
6.6 KiB
Python
import csv
|
|
import json
|
|
import os
|
|
import threading
|
|
import time
|
|
import uuid
|
|
|
|
import pytest
|
|
|
|
RESULTS_DIR = os.environ.get("RESULTS_DIR", "/results")
|
|
|
|
|
|
def _base_ai_body(video_path: str) -> dict:
|
|
return {
|
|
"probability_threshold": 0.25,
|
|
"frame_period_recognition": 4,
|
|
"frame_recognition_seconds": 2,
|
|
"tracking_distance_confidence": 0.0,
|
|
"tracking_probability_increase": 0.0,
|
|
"tracking_intersection_threshold": 0.6,
|
|
"altitude": 400.0,
|
|
"focal_length": 24.0,
|
|
"sensor_width": 23.5,
|
|
"paths": [video_path],
|
|
}
|
|
|
|
|
|
def _save_events_csv(video_path: str, events: list[dict]):
|
|
stem = os.path.splitext(os.path.basename(video_path))[0]
|
|
path = os.path.join(RESULTS_DIR, f"{stem}_detections.csv")
|
|
rows = []
|
|
for ev in events:
|
|
base = {
|
|
"mediaId": ev.get("mediaId", ""),
|
|
"mediaStatus": ev.get("mediaStatus", ""),
|
|
"mediaPercent": ev.get("mediaPercent", ""),
|
|
}
|
|
anns = ev.get("annotations") or []
|
|
if anns:
|
|
for det in anns:
|
|
rows.append({**base, **det})
|
|
else:
|
|
rows.append(base)
|
|
if not rows:
|
|
return
|
|
fieldnames = list(rows[0].keys())
|
|
for r in rows[1:]:
|
|
for k in r:
|
|
if k not in fieldnames:
|
|
fieldnames.append(k)
|
|
with open(path, "w", newline="") as f:
|
|
writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore")
|
|
writer.writeheader()
|
|
writer.writerows(rows)
|
|
|
|
|
|
def _run_async_video_sse(
|
|
http_client,
|
|
jwt_token,
|
|
sse_client_factory,
|
|
media_id: str,
|
|
body: dict,
|
|
*,
|
|
timed: bool = False,
|
|
wait_s: float = 900.0,
|
|
):
|
|
video_path = (body.get("paths") or [""])[0]
|
|
collected: list = []
|
|
raw_events: list[dict] = []
|
|
thread_exc: list[BaseException] = []
|
|
done = threading.Event()
|
|
|
|
def _listen():
|
|
try:
|
|
with sse_client_factory() as sse:
|
|
time.sleep(0.3)
|
|
for event in sse.events():
|
|
if not event.data or not str(event.data).strip():
|
|
continue
|
|
data = json.loads(event.data)
|
|
if data.get("mediaId") != media_id:
|
|
continue
|
|
raw_events.append(data)
|
|
if timed:
|
|
collected.append((time.monotonic(), data))
|
|
else:
|
|
collected.append(data)
|
|
if (
|
|
data.get("mediaStatus") == "AIProcessed"
|
|
and data.get("mediaPercent") == 100
|
|
):
|
|
break
|
|
except BaseException as e:
|
|
thread_exc.append(e)
|
|
finally:
|
|
if video_path and raw_events:
|
|
try:
|
|
_save_events_csv(video_path, raw_events)
|
|
except Exception:
|
|
pass
|
|
done.set()
|
|
|
|
th = threading.Thread(target=_listen, daemon=True)
|
|
th.start()
|
|
time.sleep(0.5)
|
|
r = http_client.post(
|
|
f"/detect/{media_id}",
|
|
json=body,
|
|
headers={"Authorization": f"Bearer {jwt_token}"},
|
|
)
|
|
assert r.status_code == 200
|
|
assert r.json() == {"status": "started", "mediaId": media_id}
|
|
assert done.wait(timeout=wait_s)
|
|
th.join(timeout=5)
|
|
assert not thread_exc, thread_exc
|
|
return collected
|
|
|
|
|
|
def _assert_detection_dto(d: dict) -> None:
|
|
assert isinstance(d["centerX"], (int, float))
|
|
assert isinstance(d["centerY"], (int, float))
|
|
assert isinstance(d["width"], (int, float))
|
|
assert isinstance(d["height"], (int, float))
|
|
assert 0.0 <= float(d["centerX"]) <= 1.0
|
|
assert 0.0 <= float(d["centerY"]) <= 1.0
|
|
assert 0.0 <= float(d["width"]) <= 1.0
|
|
assert 0.0 <= float(d["height"]) <= 1.0
|
|
assert isinstance(d["classNum"], int)
|
|
assert isinstance(d["label"], str)
|
|
assert isinstance(d["confidence"], (int, float))
|
|
assert 0.0 <= float(d["confidence"]) <= 1.0
|
|
|
|
|
|
@pytest.mark.skip(reason="Single video run — covered by test_ft_p09_sse_event_delivery")
|
|
@pytest.mark.slow
|
|
@pytest.mark.timeout(900)
|
|
def test_ft_p_10_frame_sampling_ac1(
|
|
warm_engine,
|
|
http_client,
|
|
jwt_token,
|
|
video_short_path,
|
|
sse_client_factory,
|
|
):
|
|
media_id = f"video-{uuid.uuid4().hex}"
|
|
body = _base_ai_body(video_short_path)
|
|
body["frame_period_recognition"] = 4
|
|
collected = _run_async_video_sse(
|
|
http_client,
|
|
jwt_token,
|
|
sse_client_factory,
|
|
media_id,
|
|
body,
|
|
)
|
|
processing = [e for e in collected if e.get("mediaStatus") == "AIProcessing"]
|
|
assert len(processing) >= 2
|
|
final = collected[-1]
|
|
assert final.get("mediaStatus") == "AIProcessed"
|
|
assert final.get("mediaPercent") == 100
|
|
|
|
|
|
@pytest.mark.skip(reason="Single video run — covered by test_ft_p09_sse_event_delivery")
|
|
@pytest.mark.slow
|
|
@pytest.mark.timeout(900)
|
|
def test_ft_p_11_annotation_interval_ac2(
|
|
warm_engine,
|
|
http_client,
|
|
jwt_token,
|
|
video_short_path,
|
|
sse_client_factory,
|
|
):
|
|
media_id = f"video-{uuid.uuid4().hex}"
|
|
body = _base_ai_body(video_short_path)
|
|
body["frame_recognition_seconds"] = 2
|
|
collected = _run_async_video_sse(
|
|
http_client,
|
|
jwt_token,
|
|
sse_client_factory,
|
|
media_id,
|
|
body,
|
|
timed=True,
|
|
)
|
|
processing = [
|
|
(t, d) for t, d in collected if d.get("mediaStatus") == "AIProcessing"
|
|
]
|
|
assert len(processing) >= 2
|
|
gaps = [
|
|
processing[i][0] - processing[i - 1][0]
|
|
for i in range(1, len(processing))
|
|
]
|
|
assert all(g >= 0.0 for g in gaps)
|
|
final = collected[-1][1]
|
|
assert final.get("mediaStatus") == "AIProcessed"
|
|
assert final.get("mediaPercent") == 100
|
|
|
|
|
|
@pytest.mark.skip(reason="Single video run — covered by test_ft_p09_sse_event_delivery")
|
|
@pytest.mark.slow
|
|
@pytest.mark.timeout(900)
|
|
def test_ft_p_12_movement_tracking_ac3(
|
|
warm_engine,
|
|
http_client,
|
|
jwt_token,
|
|
video_short_path,
|
|
sse_client_factory,
|
|
):
|
|
media_id = f"video-{uuid.uuid4().hex}"
|
|
body = _base_ai_body(video_short_path)
|
|
body["tracking_distance_confidence"] = 0.1
|
|
body["tracking_probability_increase"] = 0.1
|
|
collected = _run_async_video_sse(
|
|
http_client,
|
|
jwt_token,
|
|
sse_client_factory,
|
|
media_id,
|
|
body,
|
|
)
|
|
for e in collected:
|
|
anns = e.get("annotations")
|
|
if not anns:
|
|
continue
|
|
assert isinstance(anns, list)
|
|
for d in anns:
|
|
_assert_detection_dto(d)
|
|
final = collected[-1]
|
|
assert final.get("mediaStatus") == "AIProcessed"
|
|
assert final.get("mediaPercent") == 100
|