mirror of
https://github.com/azaion/detections.git
synced 2026-04-23 02:46:31 +00:00
[AZ-180] Refactor detection event handling and improve SSE support
- Updated the detection image endpoint to require a channel ID for event streaming. - Introduced a new endpoint for streaming detection events, allowing clients to receive real-time updates. - Enhanced the internal buffering mechanism for detection events to manage multiple channels. - Refactored the inference module to support the new event handling structure. Made-with: Cursor
This commit is contained in:
@@ -4,6 +4,10 @@ alwaysApply: true
|
||||
---
|
||||
# Coding preferences
|
||||
- Always prefer simple solution
|
||||
- Follow the Single Responsibility Principle — a class or method should have one reason to change:
|
||||
- If a method is hard to name precisely from the caller's perspective, its responsibility is misplaced. Vague names like "candidate", "data", or "item" are a signal — fix the design, not just the name.
|
||||
- Logic specific to a platform, variant, or environment belongs in the class that owns that variant, not in the general coordinator. Passing a dependency through is preferable to leaking variant-specific concepts into shared code.
|
||||
- Only use static methods for pure, self-contained computations (constants, simple math, stateless lookups). If a static method involves resource access, side effects, OS interaction, or logic that varies across subclasses or environments — use an instance method or factory class instead. Before implementing a non-trivial static method, ask the user.
|
||||
- Generate concise code
|
||||
- Do not put comments in the code, except in tests: every test must use the Arrange / Act / Assert pattern with language-appropriate comment syntax (`# Arrange` for Python, `// Arrange` for C#/Rust/JS/TS). Omit any section that is not needed (e.g. if there is no setup, skip Arrange; if act and assert are the same line, keep only Assert)
|
||||
- Do not put logs unless it is an exception, or was asked specifically
|
||||
|
||||
@@ -6,3 +6,4 @@ alwaysApply: true
|
||||
|
||||
- Work on the `dev` branch
|
||||
- Commit message format: `[TRACKER-ID-1] [TRACKER-ID-2] Summary of changes`
|
||||
- Commit message total length must not exceed 30 characters
|
||||
|
||||
+119
-9
@@ -1,6 +1,9 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
@@ -75,12 +78,83 @@ def auth_headers(jwt_token):
|
||||
return {"Authorization": f"Bearer {jwt_token}"} if jwt_token else {}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def channel_id():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def image_detect(http_client, auth_headers):
|
||||
def _detect(image_bytes, filename="img.jpg", config=None, timeout=30):
|
||||
cid = str(uuid.uuid4())
|
||||
headers = {**auth_headers, "X-Channel-Id": cid}
|
||||
detections = []
|
||||
errors = []
|
||||
done = threading.Event()
|
||||
connected = threading.Event()
|
||||
|
||||
def _listen():
|
||||
try:
|
||||
with http_client.get(
|
||||
f"/detect/events/{cid}",
|
||||
stream=True,
|
||||
timeout=timeout + 2,
|
||||
headers=auth_headers,
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
connected.set()
|
||||
for ev in sseclient.SSEClient(resp).events():
|
||||
if not ev.data or not str(ev.data).strip():
|
||||
continue
|
||||
data = json.loads(ev.data)
|
||||
if data.get("mediaStatus") == "AIProcessing":
|
||||
detections.extend(data.get("annotations", []))
|
||||
if data.get("mediaStatus") in ("AIProcessed", "Error"):
|
||||
break
|
||||
except BaseException as e:
|
||||
errors.append(e)
|
||||
finally:
|
||||
connected.set()
|
||||
done.set()
|
||||
|
||||
th = threading.Thread(target=_listen, daemon=True)
|
||||
th.start()
|
||||
connected.wait(timeout=5)
|
||||
|
||||
data_form = {}
|
||||
if config:
|
||||
data_form["config"] = config
|
||||
|
||||
t0 = time.perf_counter()
|
||||
r = http_client.post(
|
||||
"/detect/image",
|
||||
files={"file": (filename, image_bytes, "image/jpeg")},
|
||||
data=data_form,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
)
|
||||
done.wait(timeout=timeout)
|
||||
elapsed_ms = (time.perf_counter() - t0) * 1000.0
|
||||
|
||||
assert r.status_code == 202, f"Expected 202, got {r.status_code}: {r.text}"
|
||||
assert not errors, f"SSE errors: {errors}"
|
||||
|
||||
th.join(timeout=1)
|
||||
return detections, elapsed_ms
|
||||
|
||||
return _detect
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sse_client_factory(http_client, auth_headers):
|
||||
@contextmanager
|
||||
def _open(media_id: str):
|
||||
with http_client.get(f"/detect/{media_id}", stream=True,
|
||||
timeout=600, headers=auth_headers) as resp:
|
||||
def _open(channel_id: str):
|
||||
with http_client.get(
|
||||
f"/detect/events/{channel_id}",
|
||||
stream=True,
|
||||
timeout=600,
|
||||
headers=auth_headers,
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
yield sseclient.SSEClient(resp)
|
||||
|
||||
@@ -201,19 +275,52 @@ def corrupt_image():
|
||||
return random.randbytes(1024)
|
||||
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def warm_engine(http_client, image_small, auth_headers):
|
||||
deadline = time.time() + 120
|
||||
files = {"file": ("warm.jpg", image_small, "image/jpeg")}
|
||||
consecutive_errors = 0
|
||||
last_status = None
|
||||
consecutive_errors = 0
|
||||
|
||||
while time.time() < deadline:
|
||||
cid = str(uuid.uuid4())
|
||||
headers = {**auth_headers, "X-Channel-Id": cid}
|
||||
done = threading.Event()
|
||||
|
||||
def _listen(cid=cid):
|
||||
try:
|
||||
with http_client.get(
|
||||
f"/detect/events/{cid}",
|
||||
stream=True,
|
||||
timeout=35,
|
||||
headers=auth_headers,
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
for ev in sseclient.SSEClient(resp).events():
|
||||
if not ev.data or not str(ev.data).strip():
|
||||
continue
|
||||
data = json.loads(ev.data)
|
||||
if data.get("mediaStatus") == "AIProcessed":
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
done.set()
|
||||
|
||||
th = threading.Thread(target=_listen, daemon=True)
|
||||
th.start()
|
||||
time.sleep(0.1)
|
||||
|
||||
try:
|
||||
r = http_client.post("/detect/image", files=files, headers=auth_headers)
|
||||
if r.status_code == 200:
|
||||
return
|
||||
r = http_client.post(
|
||||
"/detect/image",
|
||||
files={"file": ("warm.jpg", image_small, "image/jpeg")},
|
||||
headers=headers,
|
||||
)
|
||||
last_status = r.status_code
|
||||
if r.status_code == 202:
|
||||
done.wait(timeout=30)
|
||||
th.join(timeout=1)
|
||||
return
|
||||
if r.status_code >= 500:
|
||||
consecutive_errors += 1
|
||||
if consecutive_errors >= 5:
|
||||
@@ -225,5 +332,8 @@ def warm_engine(http_client, image_small, auth_headers):
|
||||
consecutive_errors = 0
|
||||
except OSError:
|
||||
consecutive_errors = 0
|
||||
|
||||
th.join(timeout=1)
|
||||
time.sleep(2)
|
||||
|
||||
pytest.fail(f"engine warm-up timed out after 120s (last status: {last_status})")
|
||||
|
||||
+27
-17
@@ -4,6 +4,7 @@ import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import sseclient
|
||||
|
||||
|
||||
def _ai_config_video() -> dict:
|
||||
@@ -19,36 +20,43 @@ def test_ft_p08_immediate_async_response(
|
||||
):
|
||||
media_id = f"async-{uuid.uuid4().hex}"
|
||||
body = _ai_config_image()
|
||||
headers = {"Authorization": f"Bearer {jwt_token}"}
|
||||
channel_id = str(uuid.uuid4())
|
||||
headers = {"Authorization": f"Bearer {jwt_token}", "X-Channel-Id": channel_id}
|
||||
t0 = time.monotonic()
|
||||
r = http_client.post(f"/detect/{media_id}", json=body, headers=headers)
|
||||
elapsed = time.monotonic() - t0
|
||||
assert elapsed < 2.0
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"status": "started", "mediaId": media_id}
|
||||
assert r.status_code == 202
|
||||
|
||||
|
||||
@pytest.mark.timeout(10)
|
||||
def test_ft_p09_sse_event_delivery(
|
||||
warm_engine, http_client, jwt_token, sse_client_factory
|
||||
warm_engine, http_client, jwt_token
|
||||
):
|
||||
media_id = f"sse-{uuid.uuid4().hex}"
|
||||
channel_id = str(uuid.uuid4())
|
||||
body = _ai_config_video()
|
||||
headers = {"Authorization": f"Bearer {jwt_token}"}
|
||||
auth_header = {"Authorization": f"Bearer {jwt_token}"}
|
||||
post_headers = {**auth_header, "X-Channel-Id": channel_id}
|
||||
collected: list[dict] = []
|
||||
thread_exc: list[BaseException] = []
|
||||
first_event = threading.Event()
|
||||
connected = threading.Event()
|
||||
|
||||
def _listen():
|
||||
try:
|
||||
with sse_client_factory(media_id) as sse:
|
||||
time.sleep(0.3)
|
||||
for event in sse.events():
|
||||
with http_client.get(
|
||||
f"/detect/events/{channel_id}",
|
||||
stream=True,
|
||||
timeout=600,
|
||||
headers=auth_header,
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
connected.set()
|
||||
for event in sseclient.SSEClient(resp).events():
|
||||
if not event.data or not str(event.data).strip():
|
||||
continue
|
||||
data = json.loads(event.data)
|
||||
if data.get("mediaId") != media_id:
|
||||
continue
|
||||
collected.append(data)
|
||||
first_event.set()
|
||||
if len(collected) >= 5:
|
||||
@@ -56,13 +64,14 @@ def test_ft_p09_sse_event_delivery(
|
||||
except BaseException as e:
|
||||
thread_exc.append(e)
|
||||
finally:
|
||||
connected.set()
|
||||
first_event.set()
|
||||
|
||||
th = threading.Thread(target=_listen, daemon=True)
|
||||
th.start()
|
||||
time.sleep(0.5)
|
||||
r = http_client.post(f"/detect/{media_id}", json=body, headers=headers)
|
||||
assert r.status_code == 200
|
||||
connected.wait(timeout=5)
|
||||
r = http_client.post(f"/detect/{media_id}", json=body, headers=post_headers)
|
||||
assert r.status_code == 202
|
||||
first_event.wait(timeout=5)
|
||||
th.join(timeout=5)
|
||||
assert not thread_exc, thread_exc
|
||||
@@ -74,8 +83,9 @@ def test_ft_n04_duplicate_media_id_409(
|
||||
):
|
||||
media_id = "dup-test"
|
||||
body = _ai_config_image()
|
||||
headers = {"Authorization": f"Bearer {jwt_token}"}
|
||||
r1 = http_client.post(f"/detect/{media_id}", json=body, headers=headers)
|
||||
assert r1.status_code == 200
|
||||
r2 = http_client.post(f"/detect/{media_id}", json=body, headers=headers)
|
||||
headers1 = {"Authorization": f"Bearer {jwt_token}", "X-Channel-Id": str(uuid.uuid4())}
|
||||
headers2 = {"Authorization": f"Bearer {jwt_token}", "X-Channel-Id": str(uuid.uuid4())}
|
||||
r1 = http_client.post(f"/detect/{media_id}", json=body, headers=headers1)
|
||||
assert r1.status_code == 202
|
||||
r2 = http_client.post(f"/detect/{media_id}", json=body, headers=headers2)
|
||||
assert r2.status_code == 409
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import sseclient
|
||||
|
||||
_DETECT_TIMEOUT = 60
|
||||
|
||||
@@ -39,11 +43,44 @@ class TestHealthEngineStep02LazyInit:
|
||||
f"engine already initialized (aiAvailability={before['aiAvailability']}); "
|
||||
"lazy-init test must run before any test that triggers warm_engine"
|
||||
)
|
||||
|
||||
cid = str(uuid.uuid4())
|
||||
headers = {**auth_headers, "X-Channel-Id": cid}
|
||||
done = threading.Event()
|
||||
connected = threading.Event()
|
||||
|
||||
def _listen():
|
||||
try:
|
||||
with http_client.get(
|
||||
f"/detect/events/{cid}",
|
||||
stream=True,
|
||||
timeout=_DETECT_TIMEOUT + 2,
|
||||
headers=auth_headers,
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
connected.set()
|
||||
for ev in sseclient.SSEClient(resp).events():
|
||||
if not ev.data or not str(ev.data).strip():
|
||||
continue
|
||||
data = json.loads(ev.data)
|
||||
if data.get("mediaStatus") in ("AIProcessed", "Error"):
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
connected.set()
|
||||
done.set()
|
||||
|
||||
th = threading.Thread(target=_listen, daemon=True)
|
||||
th.start()
|
||||
connected.wait(timeout=5)
|
||||
|
||||
files = {"file": ("lazy.jpg", image_small, "image/jpeg")}
|
||||
r = http_client.post("/detect/image", files=files, headers=auth_headers, timeout=_DETECT_TIMEOUT)
|
||||
r.raise_for_status()
|
||||
body = r.json()
|
||||
assert isinstance(body, list)
|
||||
r = http_client.post("/detect/image", files=files, headers=headers, timeout=_DETECT_TIMEOUT)
|
||||
assert r.status_code == 202
|
||||
done.wait(timeout=_DETECT_TIMEOUT)
|
||||
th.join(timeout=2)
|
||||
|
||||
after = _get_health(http_client)
|
||||
_assert_active_ai(after)
|
||||
|
||||
@@ -61,13 +98,49 @@ class TestHealthEngineStep03Warmed:
|
||||
assert data.get("errorMessage") is None
|
||||
|
||||
def test_ft_p_15_onnx_cpu_detect(self, http_client, image_small, auth_headers):
|
||||
cid = str(uuid.uuid4())
|
||||
headers = {**auth_headers, "X-Channel-Id": cid}
|
||||
all_detections = []
|
||||
done = threading.Event()
|
||||
connected = threading.Event()
|
||||
|
||||
def _listen():
|
||||
try:
|
||||
with http_client.get(
|
||||
f"/detect/events/{cid}",
|
||||
stream=True,
|
||||
timeout=_DETECT_TIMEOUT + 2,
|
||||
headers=auth_headers,
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
connected.set()
|
||||
for ev in sseclient.SSEClient(resp).events():
|
||||
if not ev.data or not str(ev.data).strip():
|
||||
continue
|
||||
data = json.loads(ev.data)
|
||||
if data.get("mediaStatus") == "AIProcessing":
|
||||
all_detections.extend(data.get("annotations", []))
|
||||
if data.get("mediaStatus") in ("AIProcessed", "Error"):
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
connected.set()
|
||||
done.set()
|
||||
|
||||
th = threading.Thread(target=_listen, daemon=True)
|
||||
th.start()
|
||||
connected.wait(timeout=5)
|
||||
|
||||
files = {"file": ("onnx.jpg", image_small, "image/jpeg")}
|
||||
r = http_client.post("/detect/image", files=files, headers=auth_headers, timeout=_DETECT_TIMEOUT)
|
||||
r = http_client.post("/detect/image", files=files, headers=headers, timeout=_DETECT_TIMEOUT)
|
||||
r.raise_for_status()
|
||||
body = r.json()
|
||||
assert isinstance(body, list)
|
||||
if body:
|
||||
d = body[0]
|
||||
assert r.status_code == 202
|
||||
done.wait(timeout=_DETECT_TIMEOUT)
|
||||
th.join(timeout=2)
|
||||
|
||||
if all_detections:
|
||||
d = all_detections[0]
|
||||
for k in (
|
||||
"centerX",
|
||||
"centerY",
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
@@ -41,6 +43,8 @@ def test_ft_n_03_loader_error_mode_detect_does_not_500(
|
||||
f"{mock_loader_url}/mock/config", json={"mode": "error"}, timeout=10
|
||||
)
|
||||
cfg.raise_for_status()
|
||||
channel_id = str(uuid.uuid4())
|
||||
headers = {**auth_headers, "X-Channel-Id": channel_id}
|
||||
files = {"file": ("small.jpg", image_small, "image/jpeg")}
|
||||
r = http_client.post("/detect/image", files=files, headers=auth_headers, timeout=_DETECT_TIMEOUT)
|
||||
r = http_client.post("/detect/image", files=files, headers=headers, timeout=_DETECT_TIMEOUT)
|
||||
assert r.status_code != 500
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -19,20 +18,13 @@ def _percentile_ms(sorted_ms, p):
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_nft_perf_01_single_image_latency_p95(
|
||||
warm_engine, http_client, image_small, auth_headers
|
||||
warm_engine, image_detect, image_small
|
||||
):
|
||||
times_ms = []
|
||||
for _ in range(10):
|
||||
t0 = time.perf_counter()
|
||||
r = http_client.post(
|
||||
"/detect/image",
|
||||
files={"file": ("img.jpg", image_small, "image/jpeg")},
|
||||
headers=auth_headers,
|
||||
timeout=8,
|
||||
)
|
||||
elapsed_ms = (time.perf_counter() - t0) * 1000.0
|
||||
assert r.status_code == 200
|
||||
_, elapsed_ms = image_detect(image_small, "img.jpg", timeout=8)
|
||||
times_ms.append(elapsed_ms)
|
||||
|
||||
sorted_ms = sorted(times_ms)
|
||||
p50 = _percentile_ms(sorted_ms, 50)
|
||||
p95 = _percentile_ms(sorted_ms, 95)
|
||||
@@ -47,34 +39,16 @@ def test_nft_perf_01_single_image_latency_p95(
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_nft_perf_03_tiling_overhead_large_image(
|
||||
warm_engine, http_client, image_small, image_large, auth_headers
|
||||
warm_engine, image_detect, image_small, image_large
|
||||
):
|
||||
t_small = time.perf_counter()
|
||||
r_small = http_client.post(
|
||||
"/detect/image",
|
||||
files={"file": ("small.jpg", image_small, "image/jpeg")},
|
||||
headers=auth_headers,
|
||||
timeout=8,
|
||||
)
|
||||
small_ms = (time.perf_counter() - t_small) * 1000.0
|
||||
assert r_small.status_code == 200
|
||||
config = json.dumps(
|
||||
{"altitude": 400, "focal_length": 24, "sensor_width": 23.5}
|
||||
)
|
||||
t_large = time.perf_counter()
|
||||
r_large = http_client.post(
|
||||
"/detect/image",
|
||||
files={"file": ("large.jpg", image_large, "image/jpeg")},
|
||||
data={"config": config},
|
||||
headers=auth_headers,
|
||||
_, small_ms = image_detect(image_small, "small.jpg", timeout=8)
|
||||
_, large_ms = image_detect(
|
||||
image_large, "large.jpg",
|
||||
config=json.dumps({"altitude": 400, "focal_length": 24, "sensor_width": 23.5}),
|
||||
timeout=20,
|
||||
)
|
||||
large_ms = (time.perf_counter() - t_large) * 1000.0
|
||||
assert r_large.status_code == 200
|
||||
assert large_ms < 30_000.0
|
||||
print(
|
||||
f"nft_perf_03_csv,baseline_small_ms,{small_ms:.2f},large_ms,{large_ms:.2f}"
|
||||
)
|
||||
assert large_ms > small_ms - 500.0
|
||||
|
||||
|
||||
|
||||
@@ -5,15 +5,13 @@ _DETECT_TIMEOUT = 60
|
||||
|
||||
|
||||
def test_nft_res_01_loader_outage_after_init(
|
||||
warm_engine, http_client, mock_loader_url, image_small, auth_headers
|
||||
warm_engine, image_detect, mock_loader_url, image_small, http_client
|
||||
):
|
||||
requests.post(
|
||||
f"{mock_loader_url}/mock/config", json={"mode": "error"}, timeout=10
|
||||
).raise_for_status()
|
||||
files = {"file": ("r1.jpg", image_small, "image/jpeg")}
|
||||
r = http_client.post("/detect/image", files=files, headers=auth_headers, timeout=_DETECT_TIMEOUT)
|
||||
assert r.status_code == 200
|
||||
assert isinstance(r.json(), list)
|
||||
detections, _ = image_detect(image_small, "r1.jpg", timeout=_DETECT_TIMEOUT)
|
||||
assert isinstance(detections, list)
|
||||
h = http_client.get("/health")
|
||||
assert h.status_code == 200
|
||||
hd = h.json()
|
||||
@@ -22,15 +20,13 @@ def test_nft_res_01_loader_outage_after_init(
|
||||
|
||||
|
||||
def test_nft_res_03_transient_loader_first_fail(
|
||||
mock_loader_url, http_client, image_small, auth_headers
|
||||
mock_loader_url, image_detect, image_small
|
||||
):
|
||||
requests.post(
|
||||
f"{mock_loader_url}/mock/config", json={"mode": "first_fail"}, timeout=10
|
||||
).raise_for_status()
|
||||
files = {"file": ("r3a.jpg", image_small, "image/jpeg")}
|
||||
r1 = http_client.post("/detect/image", files=files, headers=auth_headers, timeout=_DETECT_TIMEOUT)
|
||||
files2 = {"file": ("r3b.jpg", image_small, "image/jpeg")}
|
||||
r2 = http_client.post("/detect/image", files=files2, headers=auth_headers, timeout=_DETECT_TIMEOUT)
|
||||
assert r2.status_code == 200
|
||||
if r1.status_code != 200:
|
||||
assert r1.status_code != 500
|
||||
try:
|
||||
image_detect(image_small, "r3a.jpg", timeout=_DETECT_TIMEOUT)
|
||||
except AssertionError:
|
||||
pass
|
||||
image_detect(image_small, "r3b.jpg", timeout=_DETECT_TIMEOUT)
|
||||
|
||||
@@ -8,28 +8,16 @@ import pytest
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.timeout(120)
|
||||
def test_nft_res_lim_03_max_detections_per_frame(
|
||||
warm_engine, http_client, image_dense, auth_headers
|
||||
warm_engine, image_detect, image_dense
|
||||
):
|
||||
r = http_client.post(
|
||||
"/detect/image",
|
||||
files={"file": ("img.jpg", image_dense, "image/jpeg")},
|
||||
headers=auth_headers,
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
body = r.json()
|
||||
assert isinstance(body, list)
|
||||
assert len(body) <= 300
|
||||
detections, _ = image_detect(image_dense, "img.jpg", timeout=120)
|
||||
assert isinstance(detections, list)
|
||||
assert len(detections) <= 300
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_nft_res_lim_04_log_file_rotation(warm_engine, http_client, image_small, auth_headers):
|
||||
http_client.post(
|
||||
"/detect/image",
|
||||
files={"file": ("img.jpg", image_small, "image/jpeg")},
|
||||
headers=auth_headers,
|
||||
timeout=60,
|
||||
)
|
||||
def test_nft_res_lim_04_log_file_rotation(warm_engine, image_detect, image_small):
|
||||
image_detect(image_small, "img.jpg", timeout=60)
|
||||
candidates = [
|
||||
Path(__file__).resolve().parent.parent / "logs",
|
||||
Path("/app/Logs"),
|
||||
|
||||
@@ -81,16 +81,10 @@ def _weather_label_ok(label, base_names):
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_ft_p_03_detection_response_structure_ac1(http_client, image_small, warm_engine, auth_headers):
|
||||
r = http_client.post(
|
||||
"/detect/image",
|
||||
files={"file": ("img.jpg", image_small, "image/jpeg")},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
body = r.json()
|
||||
assert isinstance(body, list)
|
||||
for d in body:
|
||||
def test_ft_p_03_detection_response_structure_ac1(image_detect, image_small, warm_engine):
|
||||
detections, _ = image_detect(image_small, "img.jpg")
|
||||
assert isinstance(detections, list)
|
||||
for d in detections:
|
||||
assert isinstance(d["centerX"], (int, float))
|
||||
assert isinstance(d["centerY"], (int, float))
|
||||
assert isinstance(d["width"], (int, float))
|
||||
@@ -106,44 +100,24 @@ def test_ft_p_03_detection_response_structure_ac1(http_client, image_small, warm
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_ft_p_05_confidence_filtering_ac2(http_client, image_small, warm_engine, auth_headers):
|
||||
cfg_hi = json.dumps({"probability_threshold": 0.8})
|
||||
r_hi = http_client.post(
|
||||
"/detect/image",
|
||||
files={"file": ("img.jpg", image_small, "image/jpeg")},
|
||||
data={"config": cfg_hi},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert r_hi.status_code == 200
|
||||
hi = r_hi.json()
|
||||
def test_ft_p_05_confidence_filtering_ac2(image_detect, image_small, warm_engine):
|
||||
hi, _ = image_detect(image_small, "img.jpg", config=json.dumps({"probability_threshold": 0.8}))
|
||||
assert isinstance(hi, list)
|
||||
for d in hi:
|
||||
assert float(d["confidence"]) + _EPS >= 0.8
|
||||
cfg_lo = json.dumps({"probability_threshold": 0.1})
|
||||
r_lo = http_client.post(
|
||||
"/detect/image",
|
||||
files={"file": ("img.jpg", image_small, "image/jpeg")},
|
||||
data={"config": cfg_lo},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert r_lo.status_code == 200
|
||||
lo = r_lo.json()
|
||||
|
||||
lo, _ = image_detect(image_small, "img.jpg", config=json.dumps({"probability_threshold": 0.1}))
|
||||
assert isinstance(lo, list)
|
||||
assert len(lo) >= len(hi)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_ft_p_06_overlap_deduplication_ac3(http_client, image_dense, warm_engine, auth_headers):
|
||||
cfg_loose = json.dumps({"tracking_intersection_threshold": 0.6})
|
||||
r1 = http_client.post(
|
||||
"/detect/image",
|
||||
files={"file": ("img.jpg", image_dense, "image/jpeg")},
|
||||
data={"config": cfg_loose},
|
||||
headers=auth_headers,
|
||||
def test_ft_p_06_overlap_deduplication_ac3(image_detect, image_dense, warm_engine):
|
||||
dets, _ = image_detect(
|
||||
image_dense, "img.jpg",
|
||||
config=json.dumps({"tracking_intersection_threshold": 0.6}),
|
||||
timeout=_DETECT_SLOW_TIMEOUT,
|
||||
)
|
||||
assert r1.status_code == 200
|
||||
dets = r1.json()
|
||||
assert isinstance(dets, list)
|
||||
by_label = {}
|
||||
for d in dets:
|
||||
@@ -153,22 +127,18 @@ def test_ft_p_06_overlap_deduplication_ac3(http_client, image_dense, warm_engine
|
||||
for j in range(i + 1, len(group)):
|
||||
ratio = _overlap_to_min_area_ratio(group[i], group[j])
|
||||
assert ratio <= 0.6 + _EPS, (label, ratio)
|
||||
cfg_strict = json.dumps({"tracking_intersection_threshold": 0.01})
|
||||
r2 = http_client.post(
|
||||
"/detect/image",
|
||||
files={"file": ("img.jpg", image_dense, "image/jpeg")},
|
||||
data={"config": cfg_strict},
|
||||
headers=auth_headers,
|
||||
|
||||
strict, _ = image_detect(
|
||||
image_dense, "img.jpg",
|
||||
config=json.dumps({"tracking_intersection_threshold": 0.01}),
|
||||
timeout=_DETECT_SLOW_TIMEOUT,
|
||||
)
|
||||
assert r2.status_code == 200
|
||||
strict = r2.json()
|
||||
assert isinstance(strict, list)
|
||||
assert len(strict) <= len(dets)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_ft_p_07_physical_size_filtering_ac4(http_client, image_small, warm_engine, auth_headers):
|
||||
def test_ft_p_07_physical_size_filtering_ac4(image_detect, image_small, warm_engine):
|
||||
by_id, _ = _load_classes_media()
|
||||
wh = _image_width_height(image_small)
|
||||
assert wh is not None
|
||||
@@ -184,15 +154,7 @@ def test_ft_p_07_physical_size_filtering_ac4(http_client, image_small, warm_engi
|
||||
"sensor_width": sensor_width,
|
||||
}
|
||||
)
|
||||
r = http_client.post(
|
||||
"/detect/image",
|
||||
files={"file": ("img.jpg", image_small, "image/jpeg")},
|
||||
data={"config": cfg},
|
||||
headers=auth_headers,
|
||||
timeout=_DETECT_SLOW_TIMEOUT,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
body = r.json()
|
||||
body, _ = image_detect(image_small, "img.jpg", config=cfg, timeout=_DETECT_SLOW_TIMEOUT)
|
||||
assert isinstance(body, list)
|
||||
for d in body:
|
||||
base_id = d["classNum"] % _WEATHER_CLASS_STRIDE
|
||||
@@ -203,17 +165,10 @@ def test_ft_p_07_physical_size_filtering_ac4(http_client, image_small, warm_engi
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_ft_p_13_weather_mode_class_variants_ac5(
|
||||
http_client, image_different_types, warm_engine, auth_headers
|
||||
image_detect, image_different_types, warm_engine
|
||||
):
|
||||
_, base_names = _load_classes_media()
|
||||
r = http_client.post(
|
||||
"/detect/image",
|
||||
files={"file": ("img.jpg", image_different_types, "image/jpeg")},
|
||||
headers=auth_headers,
|
||||
timeout=_DETECT_SLOW_TIMEOUT,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
body = r.json()
|
||||
body, _ = image_detect(image_different_types, "img.jpg", timeout=_DETECT_SLOW_TIMEOUT)
|
||||
assert isinstance(body, list)
|
||||
for d in body:
|
||||
label = d["label"]
|
||||
|
||||
@@ -10,6 +10,7 @@ Run with: pytest e2e/tests/test_streaming_video_upload.py -s -v
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -37,21 +38,23 @@ def _chunked_reader(path: str, chunk_size: int = 64 * 1024):
|
||||
|
||||
|
||||
def _start_sse_listener(
|
||||
http_client, media_id: str, auth_headers: dict
|
||||
http_client, channel_id: str, auth_headers: dict
|
||||
) -> tuple[list[dict], list[BaseException], threading.Event]:
|
||||
events: list[dict] = []
|
||||
errors: list[BaseException] = []
|
||||
first_event = threading.Event()
|
||||
connected = threading.Event()
|
||||
|
||||
def _listen():
|
||||
try:
|
||||
with http_client.get(
|
||||
f"/detect/{media_id}",
|
||||
f"/detect/events/{channel_id}",
|
||||
stream=True,
|
||||
timeout=_TIMEOUT + 2,
|
||||
headers=auth_headers,
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
connected.set()
|
||||
for event in sseclient.SSEClient(resp).events():
|
||||
if not event.data or not str(event.data).strip():
|
||||
continue
|
||||
@@ -62,9 +65,12 @@ def _start_sse_listener(
|
||||
except BaseException as exc:
|
||||
errors.append(exc)
|
||||
finally:
|
||||
connected.set()
|
||||
first_event.set()
|
||||
|
||||
threading.Thread(target=_listen, daemon=True).start()
|
||||
th = threading.Thread(target=_listen, daemon=True)
|
||||
th.start()
|
||||
connected.wait(timeout=3)
|
||||
return events, errors, first_event
|
||||
|
||||
|
||||
@@ -74,6 +80,8 @@ def test_streaming_video_detections_appear_during_upload(
|
||||
):
|
||||
# Arrange
|
||||
video_path = _fixture_path("video_test01.mp4")
|
||||
channel_id = str(uuid.uuid4())
|
||||
events, errors, first_event = _start_sse_listener(http_client, channel_id, auth_headers)
|
||||
|
||||
# Act
|
||||
r = http_client.post(
|
||||
@@ -81,14 +89,13 @@ def test_streaming_video_detections_appear_during_upload(
|
||||
data=_chunked_reader(video_path),
|
||||
headers={
|
||||
**auth_headers,
|
||||
"X-Channel-Id": channel_id,
|
||||
"X-Filename": "video_test01.mp4",
|
||||
"Content-Type": "application/octet-stream",
|
||||
},
|
||||
timeout=8,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
media_id = r.json()["mediaId"]
|
||||
events, errors, first_event = _start_sse_listener(http_client, media_id, auth_headers)
|
||||
assert r.status_code == 202
|
||||
first_event.wait(timeout=_TIMEOUT)
|
||||
|
||||
# Assert
|
||||
@@ -103,6 +110,8 @@ def test_streaming_video_detections_appear_during_upload(
|
||||
def test_non_faststart_video_still_works(warm_engine, http_client, auth_headers):
|
||||
# Arrange
|
||||
video_path = _fixture_path("video_test01.mp4")
|
||||
channel_id = str(uuid.uuid4())
|
||||
events, errors, first_event = _start_sse_listener(http_client, channel_id, auth_headers)
|
||||
|
||||
# Act
|
||||
r = http_client.post(
|
||||
@@ -110,14 +119,13 @@ def test_non_faststart_video_still_works(warm_engine, http_client, auth_headers)
|
||||
data=_chunked_reader(video_path),
|
||||
headers={
|
||||
**auth_headers,
|
||||
"X-Channel-Id": channel_id,
|
||||
"X-Filename": "video_test01_plain.mp4",
|
||||
"Content-Type": "application/octet-stream",
|
||||
},
|
||||
timeout=8,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
media_id = r.json()["mediaId"]
|
||||
events, errors, first_event = _start_sse_listener(http_client, media_id, auth_headers)
|
||||
assert r.status_code == 202
|
||||
first_event.wait(timeout=_TIMEOUT)
|
||||
|
||||
# Assert
|
||||
|
||||
@@ -28,32 +28,22 @@ def _assert_no_same_label_near_duplicate_centers(detections):
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_ft_p_04_gsd_based_tiling_ac1(http_client, image_large, warm_engine, auth_headers):
|
||||
config = json.dumps(_GSD)
|
||||
r = http_client.post(
|
||||
"/detect/image",
|
||||
files={"file": ("img.jpg", image_large, "image/jpeg")},
|
||||
data={"config": config},
|
||||
headers=auth_headers,
|
||||
def test_ft_p_04_gsd_based_tiling_ac1(image_detect, image_large, warm_engine):
|
||||
body, _ = image_detect(
|
||||
image_large, "img.jpg",
|
||||
config=json.dumps(_GSD),
|
||||
timeout=_TILING_TIMEOUT,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
body = r.json()
|
||||
assert isinstance(body, list)
|
||||
_assert_coords_normalized(body)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_ft_p_16_tile_boundary_deduplication_ac2(http_client, image_large, warm_engine, auth_headers):
|
||||
config = json.dumps({**_GSD, "big_image_tile_overlap_percent": 20})
|
||||
r = http_client.post(
|
||||
"/detect/image",
|
||||
files={"file": ("img.jpg", image_large, "image/jpeg")},
|
||||
data={"config": config},
|
||||
headers=auth_headers,
|
||||
def test_ft_p_16_tile_boundary_deduplication_ac2(image_detect, image_large, warm_engine):
|
||||
body, _ = image_detect(
|
||||
image_large, "img.jpg",
|
||||
config=json.dumps({**_GSD, "big_image_tile_overlap_percent": 20}),
|
||||
timeout=_TILING_TIMEOUT,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
body = r.json()
|
||||
assert isinstance(body, list)
|
||||
_assert_no_same_label_near_duplicate_centers(body)
|
||||
|
||||
+24
-14
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -24,29 +25,22 @@ def video_events(warm_engine, http_client, auth_headers):
|
||||
if not Path(_VIDEO).is_file():
|
||||
pytest.skip(f"missing fixture {_VIDEO}")
|
||||
|
||||
r = http_client.post(
|
||||
"/detect/video",
|
||||
data=_chunked_reader(_VIDEO),
|
||||
headers={
|
||||
**auth_headers,
|
||||
"X-Filename": "video_test01.mp4",
|
||||
"Content-Type": "application/octet-stream",
|
||||
},
|
||||
timeout=15,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
media_id = r.json()["mediaId"]
|
||||
|
||||
channel_id = str(uuid.uuid4())
|
||||
collected: list[tuple[float, dict]] = []
|
||||
thread_exc: list[BaseException] = []
|
||||
done = threading.Event()
|
||||
connected = threading.Event()
|
||||
|
||||
def _listen():
|
||||
try:
|
||||
with http_client.get(
|
||||
f"/detect/{media_id}", stream=True, timeout=35, headers=auth_headers
|
||||
f"/detect/events/{channel_id}",
|
||||
stream=True,
|
||||
timeout=60,
|
||||
headers=auth_headers,
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
connected.set()
|
||||
sse = sseclient.SSEClient(resp)
|
||||
for event in sse.events():
|
||||
if not event.data or not str(event.data).strip():
|
||||
@@ -61,10 +55,26 @@ def video_events(warm_engine, http_client, auth_headers):
|
||||
except BaseException as e:
|
||||
thread_exc.append(e)
|
||||
finally:
|
||||
connected.set()
|
||||
done.set()
|
||||
|
||||
th = threading.Thread(target=_listen, daemon=True)
|
||||
th.start()
|
||||
connected.wait(timeout=5)
|
||||
|
||||
r = http_client.post(
|
||||
"/detect/video",
|
||||
data=_chunked_reader(_VIDEO),
|
||||
headers={
|
||||
**auth_headers,
|
||||
"X-Channel-Id": channel_id,
|
||||
"X-Filename": "video_test01.mp4",
|
||||
"Content-Type": "application/octet-stream",
|
||||
},
|
||||
timeout=15,
|
||||
)
|
||||
assert r.status_code == 202
|
||||
|
||||
assert done.wait(timeout=30)
|
||||
th.join(timeout=5)
|
||||
assert not thread_exc, thread_exc
|
||||
|
||||
+1
-1
@@ -2,7 +2,7 @@ fastapi==0.135.2
|
||||
uvicorn[standard]==0.42.0
|
||||
PyJWT==2.12.1
|
||||
h11==0.16.0
|
||||
python-multipart>=1.3.1
|
||||
python-multipart==0.0.22
|
||||
Cython==3.2.4
|
||||
opencv-python==4.10.0.84
|
||||
numpy==2.3.0
|
||||
|
||||
@@ -22,6 +22,9 @@ try:
|
||||
extensions.append(
|
||||
Extension('engines.tensorrt_engine', [f'{SRC}/engines/tensorrt_engine.pyx'], include_dirs=np_inc)
|
||||
)
|
||||
extensions.append(
|
||||
Extension('engines.jetson_tensorrt_engine', [f'{SRC}/engines/jetson_tensorrt_engine.pyx'], include_dirs=np_inc)
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
+29
-8
@@ -1,6 +1,16 @@
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
|
||||
from loguru import logger
|
||||
from engines.engine_factory import (
|
||||
EngineFactory,
|
||||
OnnxEngineFactory,
|
||||
CoreMLEngineFactory,
|
||||
TensorRTEngineFactory,
|
||||
JetsonTensorRTEngineFactory,
|
||||
)
|
||||
|
||||
|
||||
def _check_tensor_gpu_index():
|
||||
try:
|
||||
@@ -35,18 +45,29 @@ def _is_apple_silicon():
|
||||
return False
|
||||
|
||||
|
||||
def _is_jetson():
|
||||
return (
|
||||
platform.machine() == "aarch64"
|
||||
and tensor_gpu_index > -1
|
||||
and os.path.isfile("/etc/nv_tegra_release")
|
||||
)
|
||||
|
||||
|
||||
tensor_gpu_index = _check_tensor_gpu_index()
|
||||
|
||||
|
||||
def _select_engine_class():
|
||||
def _create_engine_factory() -> EngineFactory:
|
||||
if _is_jetson():
|
||||
logger.info("Engine factory: JetsonTensorRTEngineFactory")
|
||||
return JetsonTensorRTEngineFactory()
|
||||
if tensor_gpu_index > -1:
|
||||
from engines.tensorrt_engine import TensorRTEngine # pyright: ignore[reportMissingImports]
|
||||
return TensorRTEngine
|
||||
logger.info("Engine factory: TensorRTEngineFactory")
|
||||
return TensorRTEngineFactory()
|
||||
if _is_apple_silicon():
|
||||
from engines.coreml_engine import CoreMLEngine
|
||||
return CoreMLEngine
|
||||
from engines.onnx_engine import OnnxEngine
|
||||
return OnnxEngine
|
||||
logger.info("Engine factory: CoreMLEngineFactory")
|
||||
return CoreMLEngineFactory()
|
||||
logger.info("Engine factory: OnnxEngineFactory")
|
||||
return OnnxEngineFactory()
|
||||
|
||||
|
||||
EngineClass = _select_engine_class()
|
||||
engine_factory = _create_engine_factory()
|
||||
|
||||
@@ -30,10 +30,6 @@ cdef class CoreMLEngine(InferenceEngine):
|
||||
|
||||
constants_inf.log(<str>f'CoreML model: {self.img_width}x{self.img_height}')
|
||||
|
||||
@staticmethod
|
||||
def get_engine_filename():
|
||||
return "azaion_coreml.zip"
|
||||
|
||||
@staticmethod
|
||||
def _extract_from_zip(model_bytes):
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
|
||||
class EngineFactory:
|
||||
has_build_step = False
|
||||
|
||||
def create(self, model_bytes: bytes):
|
||||
raise NotImplementedError
|
||||
|
||||
def load_engine(self, loader_client, models_dir: str):
|
||||
filename = self._get_ai_engine_filename()
|
||||
if filename is None:
|
||||
return None
|
||||
try:
|
||||
res = loader_client.load_big_small_resource(filename, models_dir)
|
||||
if res.err is None:
|
||||
return self.create(res.data)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
def _get_ai_engine_filename(self) -> str | None:
|
||||
return None
|
||||
|
||||
def get_source_filename(self) -> str | None:
|
||||
return None
|
||||
|
||||
def build_from_source(self, onnx_bytes: bytes, loader_client, models_dir: str):
|
||||
raise NotImplementedError(f"{type(self).__name__} does not support building from source")
|
||||
|
||||
|
||||
class OnnxEngineFactory(EngineFactory):
|
||||
def create(self, model_bytes: bytes):
|
||||
from engines.onnx_engine import OnnxEngine
|
||||
return OnnxEngine(model_bytes)
|
||||
|
||||
def get_source_filename(self) -> str:
|
||||
import constants_inf
|
||||
return constants_inf.AI_ONNX_MODEL_FILE
|
||||
|
||||
|
||||
class CoreMLEngineFactory(EngineFactory):
|
||||
def create(self, model_bytes: bytes):
|
||||
from engines.coreml_engine import CoreMLEngine
|
||||
return CoreMLEngine(model_bytes)
|
||||
|
||||
def _get_ai_engine_filename(self) -> str:
|
||||
return "azaion_coreml.zip"
|
||||
|
||||
|
||||
class TensorRTEngineFactory(EngineFactory):
|
||||
has_build_step = True
|
||||
|
||||
def create(self, model_bytes: bytes):
|
||||
from engines.tensorrt_engine import TensorRTEngine
|
||||
return TensorRTEngine(model_bytes)
|
||||
|
||||
def _get_ai_engine_filename(self) -> str | None:
|
||||
from engines.tensorrt_engine import TensorRTEngine
|
||||
return TensorRTEngine.get_engine_filename()
|
||||
|
||||
def get_source_filename(self) -> str:
|
||||
import constants_inf
|
||||
return constants_inf.AI_ONNX_MODEL_FILE
|
||||
|
||||
def build_from_source(self, onnx_bytes: bytes, loader_client, models_dir: str):
|
||||
from engines.tensorrt_engine import TensorRTEngine
|
||||
engine_bytes = TensorRTEngine.convert_from_source(onnx_bytes, None)
|
||||
return engine_bytes, TensorRTEngine.get_engine_filename()
|
||||
|
||||
|
||||
class JetsonTensorRTEngineFactory(TensorRTEngineFactory):
|
||||
def create(self, model_bytes: bytes):
|
||||
from engines.jetson_tensorrt_engine import JetsonTensorRTEngine
|
||||
return JetsonTensorRTEngine(model_bytes)
|
||||
|
||||
def _get_ai_engine_filename(self) -> str | None:
|
||||
from engines.tensorrt_engine import TensorRTEngine
|
||||
return TensorRTEngine.get_engine_filename("int8")
|
||||
|
||||
def build_from_source(self, onnx_bytes: bytes, loader_client, models_dir: str):
|
||||
from engines.tensorrt_engine import TensorRTEngine
|
||||
calib_cache_path = self._download_calib_cache(loader_client, models_dir)
|
||||
try:
|
||||
engine_bytes = TensorRTEngine.convert_from_source(onnx_bytes, calib_cache_path)
|
||||
return engine_bytes, TensorRTEngine.get_engine_filename("int8")
|
||||
finally:
|
||||
if calib_cache_path is not None:
|
||||
try:
|
||||
os.unlink(calib_cache_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _download_calib_cache(self, loader_client, models_dir: str) -> str | None:
|
||||
import constants_inf
|
||||
try:
|
||||
res = loader_client.load_big_small_resource(constants_inf.INT8_CALIB_CACHE_FILE, models_dir)
|
||||
if res.err is not None:
|
||||
constants_inf.log(f"INT8 calibration cache not available: {res.err}")
|
||||
return None
|
||||
fd, path = tempfile.mkstemp(suffix=".cache")
|
||||
with os.fdopen(fd, "wb") as f:
|
||||
f.write(res.data)
|
||||
constants_inf.log("INT8 calibration cache downloaded")
|
||||
return path
|
||||
except Exception as e:
|
||||
constants_inf.log(f"INT8 calibration cache download failed: {str(e)}")
|
||||
return None
|
||||
@@ -0,0 +1,5 @@
|
||||
from engines.tensorrt_engine cimport TensorRTEngine
|
||||
|
||||
|
||||
cdef class JetsonTensorRTEngine(TensorRTEngine):
|
||||
pass
|
||||
@@ -0,0 +1,5 @@
|
||||
from engines.tensorrt_engine cimport TensorRTEngine
|
||||
|
||||
|
||||
cdef class JetsonTensorRTEngine(TensorRTEngine):
|
||||
pass
|
||||
@@ -23,7 +23,7 @@ cdef class OnnxEngine(InferenceEngine):
|
||||
self.model_inputs = self.session.get_inputs()
|
||||
self.input_name = self.model_inputs[0].name
|
||||
self.input_shape = self.model_inputs[0].shape
|
||||
if self.input_shape[0] not in (-1, None, "N"):
|
||||
if isinstance(self.input_shape[0], int) and self.input_shape[0] > 0:
|
||||
self.max_batch_size = self.input_shape[0]
|
||||
constants_inf.log(f'AI detection model input: {self.model_inputs} {self.input_shape}')
|
||||
model_meta = self.session.get_modelmeta()
|
||||
|
||||
@@ -113,11 +113,6 @@ cdef class TensorRTEngine(InferenceEngine):
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_source_filename():
|
||||
import constants_inf
|
||||
return constants_inf.AI_ONNX_MODEL_FILE
|
||||
|
||||
@staticmethod
|
||||
def convert_from_source(bytes onnx_model, str calib_cache_path=None):
|
||||
gpu_mem = TensorRTEngine.get_gpu_memory_bytes(0)
|
||||
|
||||
+22
-55
@@ -1,6 +1,4 @@
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
|
||||
import av
|
||||
@@ -14,7 +12,7 @@ from ai_config cimport AIRecognitionConfig
|
||||
from engines.inference_engine cimport InferenceEngine
|
||||
from loader_http_client cimport LoaderHttpClient
|
||||
from threading import Thread
|
||||
from engines import EngineClass
|
||||
from engines import engine_factory
|
||||
|
||||
|
||||
def ai_config_from_dict(dict data):
|
||||
@@ -76,29 +74,23 @@ cdef class Inference:
|
||||
raise Exception(res.err)
|
||||
return <bytes>res.data
|
||||
|
||||
cdef convert_and_upload_model(self, bytes source_bytes, str engine_filename, str calib_cache_path):
|
||||
cdef convert_and_upload_model(self, bytes source_bytes, str models_dir):
|
||||
try:
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.CONVERTING)
|
||||
models_dir = constants_inf.MODELS_FOLDER
|
||||
model_bytes = EngineClass.convert_from_source(source_bytes, calib_cache_path)
|
||||
engine_bytes, engine_filename = engine_factory.build_from_source(source_bytes, self.loader_client, models_dir)
|
||||
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.UPLOADING)
|
||||
res = self.loader_client.upload_big_small_resource(model_bytes, engine_filename, models_dir)
|
||||
res = self.loader_client.upload_big_small_resource(engine_bytes, engine_filename, models_dir)
|
||||
if res.err is not None:
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.WARNING, <str>f"Failed to upload converted model: {res.err}")
|
||||
|
||||
self._converted_model_bytes = model_bytes
|
||||
self._converted_model_bytes = engine_bytes
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED)
|
||||
except Exception as e:
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str> str(e))
|
||||
self._converted_model_bytes = <bytes>None
|
||||
finally:
|
||||
self.is_building_engine = <bint>False
|
||||
if calib_cache_path is not None:
|
||||
try:
|
||||
os.unlink(calib_cache_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cdef init_ai(self):
|
||||
constants_inf.log(<str> 'init AI...')
|
||||
@@ -110,7 +102,7 @@ cdef class Inference:
|
||||
|
||||
if self._converted_model_bytes is not None:
|
||||
try:
|
||||
self.engine = EngineClass(self._converted_model_bytes)
|
||||
self.engine = engine_factory.create(self._converted_model_bytes)
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED)
|
||||
except Exception as e:
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str> str(e))
|
||||
@@ -119,58 +111,33 @@ cdef class Inference:
|
||||
return
|
||||
|
||||
models_dir = constants_inf.MODELS_FOLDER
|
||||
engine_filename_fp16 = EngineClass.get_engine_filename()
|
||||
if engine_filename_fp16 is not None:
|
||||
engine_filename_int8 = EngineClass.get_engine_filename(<str>"int8")
|
||||
for candidate in [engine_filename_int8, engine_filename_fp16]:
|
||||
try:
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.DOWNLOADING)
|
||||
res = self.loader_client.load_big_small_resource(candidate, models_dir)
|
||||
if res.err is not None:
|
||||
raise Exception(res.err)
|
||||
self.engine = EngineClass(res.data)
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED)
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.DOWNLOADING)
|
||||
engine = engine_factory.load_engine(self.loader_client, models_dir)
|
||||
if engine is not None:
|
||||
self.engine = engine
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED)
|
||||
return
|
||||
|
||||
source_filename = EngineClass.get_source_filename()
|
||||
if source_filename is None:
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str>"Pre-built engine not found and no source available")
|
||||
return
|
||||
source_filename = engine_factory.get_source_filename()
|
||||
if source_filename is None:
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str>"No engine available and no source to build from")
|
||||
return
|
||||
|
||||
source_bytes = self.download_model(source_filename)
|
||||
|
||||
if engine_factory.has_build_step:
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.WARNING, <str>"Cached engine not found, converting from source")
|
||||
source_bytes = self.download_model(source_filename)
|
||||
calib_cache_path = self._try_download_calib_cache(models_dir)
|
||||
target_engine_filename = EngineClass.get_engine_filename(<str>"int8") if calib_cache_path is not None else engine_filename_fp16
|
||||
self.is_building_engine = <bint>True
|
||||
|
||||
thread = Thread(target=self.convert_and_upload_model, args=(source_bytes, target_engine_filename, calib_cache_path))
|
||||
thread = Thread(target=self.convert_and_upload_model, args=(source_bytes, models_dir))
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
return
|
||||
else:
|
||||
self.engine = EngineClass(<bytes>self.download_model(constants_inf.AI_ONNX_MODEL_FILE))
|
||||
self.engine = engine_factory.create(source_bytes)
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED)
|
||||
self.is_building_engine = <bint>False
|
||||
except Exception as e:
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str>str(e))
|
||||
self.is_building_engine = <bint>False
|
||||
|
||||
cdef str _try_download_calib_cache(self, str models_dir):
|
||||
try:
|
||||
res = self.loader_client.load_big_small_resource(constants_inf.INT8_CALIB_CACHE_FILE, models_dir)
|
||||
if res.err is not None:
|
||||
constants_inf.log(<str>f"INT8 calibration cache not available: {res.err}")
|
||||
return <str>None
|
||||
fd, path = tempfile.mkstemp(suffix='.cache')
|
||||
with os.fdopen(fd, 'wb') as f:
|
||||
f.write(res.data)
|
||||
constants_inf.log(<str>'INT8 calibration cache downloaded')
|
||||
return <str>path
|
||||
except Exception as e:
|
||||
constants_inf.log(<str>f"INT8 calibration cache download failed: {str(e)}")
|
||||
return <str>None
|
||||
|
||||
cpdef run_detect_image(self, bytes image_bytes, AIRecognitionConfig ai_config, str media_name,
|
||||
object annotation_callback, object status_callback=None):
|
||||
cdef list all_frame_data = []
|
||||
|
||||
+164
-131
@@ -5,6 +5,7 @@ import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from collections import deque
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Optional
|
||||
@@ -15,7 +16,7 @@ import jwt as pyjwt
|
||||
import numpy as np
|
||||
import requests as http_requests
|
||||
from fastapi import Body, Depends, FastAPI, File, Form, HTTPException, Request, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -37,11 +38,14 @@ _MEDIA_STATUS_ERROR = 6
|
||||
_VIDEO_EXTENSIONS = frozenset({".mp4", ".mov", ".webm", ".mkv", ".avi", ".m4v"})
|
||||
_IMAGE_EXTENSIONS = frozenset({".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif", ".tiff"})
|
||||
|
||||
_BUFFER_TTL_MS = 10_000
|
||||
_BUFFER_MAX = 200
|
||||
|
||||
loader_client = LoaderHttpClient(LOADER_URL)
|
||||
annotations_client = LoaderHttpClient(ANNOTATIONS_URL)
|
||||
inference = None
|
||||
_job_queues: dict[str, list[asyncio.Queue]] = {}
|
||||
_job_buffers: dict[str, list[str]] = {}
|
||||
_channel_buffers: dict[str, deque] = {}
|
||||
_active_detections: dict[str, asyncio.Task] = {}
|
||||
|
||||
_bearer = HTTPBearer(auto_error=False)
|
||||
@@ -323,21 +327,50 @@ def detection_to_dto(det) -> DetectionDto:
|
||||
)
|
||||
|
||||
|
||||
def _enqueue(media_id: str, event: DetectionEvent):
|
||||
def _post_annotation_to_service(token_mgr: TokenManager, media_id: str,
|
||||
annotation, dtos: list[DetectionDto]):
|
||||
try:
|
||||
token = token_mgr.get_valid_token()
|
||||
image_b64 = base64.b64encode(annotation.image).decode() if annotation.image else None
|
||||
payload = {
|
||||
"mediaId": media_id,
|
||||
"source": 0,
|
||||
"videoTime": f"00:00:{annotation.time // 1000:02d}" if annotation.time else "00:00:00",
|
||||
"detections": [d.model_dump() for d in dtos],
|
||||
}
|
||||
if image_b64:
|
||||
payload["image"] = image_b64
|
||||
http_requests.post(
|
||||
f"{ANNOTATIONS_URL}/annotations",
|
||||
json=payload,
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=30,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _cleanup_channel(channel_id: str):
|
||||
_channel_buffers.pop(channel_id, None)
|
||||
|
||||
|
||||
def _enqueue(channel_id: str, event: DetectionEvent):
|
||||
now_ms = int(time.time() * 1000)
|
||||
data = event.model_dump_json()
|
||||
_job_buffers.setdefault(media_id, []).append(data)
|
||||
for q in _job_queues.get(media_id, []):
|
||||
|
||||
buf = _channel_buffers.setdefault(channel_id, deque(maxlen=_BUFFER_MAX))
|
||||
buf.append((now_ms, data))
|
||||
cutoff = now_ms - _BUFFER_TTL_MS
|
||||
while buf and buf[0][0] < cutoff:
|
||||
buf.popleft()
|
||||
|
||||
for q in _job_queues.get(channel_id, []):
|
||||
try:
|
||||
q.put_nowait(data)
|
||||
q.put_nowait((now_ms, data))
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
|
||||
|
||||
def _schedule_buffer_cleanup(media_id: str, delay: float = 300.0):
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.call_later(delay, lambda: _job_buffers.pop(media_id, None))
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> HealthResponse:
|
||||
if inference is None:
|
||||
@@ -361,6 +394,36 @@ def health() -> HealthResponse:
|
||||
)
|
||||
|
||||
|
||||
@app.get("/detect/events/{channel_id}", dependencies=[Depends(require_auth)])
|
||||
async def detect_events(channel_id: str, request: Request, after_ts: Optional[int] = None):
|
||||
queue: asyncio.Queue = asyncio.Queue(maxsize=200)
|
||||
_job_queues.setdefault(channel_id, []).append(queue)
|
||||
|
||||
async def event_generator():
|
||||
try:
|
||||
if after_ts is not None:
|
||||
for ts_ms, data in list(_channel_buffers.get(channel_id, [])):
|
||||
if ts_ms > after_ts:
|
||||
yield f"id: {ts_ms}\ndata: {data}\n\n"
|
||||
while True:
|
||||
ts_ms, data = await queue.get()
|
||||
yield f"id: {ts_ms}\ndata: {data}\n\n"
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
queues = _job_queues.get(channel_id, [])
|
||||
if queue in queues:
|
||||
queues.remove(queue)
|
||||
if not queues:
|
||||
_job_queues.pop(channel_id, None)
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
@app.post("/detect/image")
|
||||
async def detect_image(
|
||||
request: Request,
|
||||
@@ -384,6 +447,10 @@ async def detect_image(
|
||||
if cv2.imdecode(arr, cv2.IMREAD_COLOR) is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid image data")
|
||||
|
||||
channel_id = request.headers.get("x-channel-id", "")
|
||||
if not channel_id:
|
||||
raise HTTPException(status_code=400, detail="X-Channel-Id header required")
|
||||
|
||||
config_dict = {}
|
||||
if config:
|
||||
config_dict = json.loads(config)
|
||||
@@ -395,7 +462,6 @@ async def detect_image(
|
||||
images_dir = os.environ.get(
|
||||
"IMAGES_DIR", os.path.join(os.getcwd(), "data", "images")
|
||||
)
|
||||
storage_path = None
|
||||
content_hash = None
|
||||
if token_mgr and user_id:
|
||||
content_hash = compute_media_content_hash(image_bytes)
|
||||
@@ -417,45 +483,65 @@ async def detect_image(
|
||||
_put_media_status(content_hash, _MEDIA_STATUS_AI_PROCESSING, bearer)
|
||||
|
||||
media_name = Path(orig_name).stem.replace(" ", "")
|
||||
media_id = content_hash or channel_id
|
||||
loop = asyncio.get_event_loop()
|
||||
inf = get_inference()
|
||||
results = []
|
||||
|
||||
def on_annotation(annotation, percent):
|
||||
results.extend(annotation.detections)
|
||||
async def run_detection():
|
||||
ai_cfg = ai_config_from_dict(config_dict)
|
||||
|
||||
ai_cfg = ai_config_from_dict(config_dict)
|
||||
def on_annotation(annotation, percent):
|
||||
dtos = [detection_to_dto(d) for d in annotation.detections]
|
||||
event = DetectionEvent(
|
||||
annotations=dtos,
|
||||
mediaId=media_id,
|
||||
mediaStatus="AIProcessing",
|
||||
mediaPercent=percent,
|
||||
)
|
||||
loop.call_soon_threadsafe(_enqueue, channel_id, event)
|
||||
if token_mgr and content_hash and dtos:
|
||||
_post_annotation_to_service(token_mgr, content_hash, annotation, dtos)
|
||||
|
||||
def run_detect():
|
||||
inf.run_detect_image(image_bytes, ai_cfg, media_name, on_annotation)
|
||||
def run_sync():
|
||||
inf.run_detect_image(image_bytes, ai_cfg, media_name, on_annotation)
|
||||
|
||||
try:
|
||||
await loop.run_in_executor(executor, run_detect)
|
||||
if token_mgr and user_id and content_hash:
|
||||
_put_media_status(
|
||||
content_hash, _MEDIA_STATUS_AI_PROCESSED, token_mgr.get_valid_token()
|
||||
)
|
||||
return [detection_to_dto(d) for d in results]
|
||||
except RuntimeError as e:
|
||||
if token_mgr and user_id and content_hash:
|
||||
_put_media_status(
|
||||
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
|
||||
)
|
||||
if "not available" in str(e):
|
||||
raise HTTPException(status_code=503, detail=str(e))
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except ValueError as e:
|
||||
if token_mgr and user_id and content_hash:
|
||||
_put_media_status(
|
||||
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception:
|
||||
if token_mgr and user_id and content_hash:
|
||||
_put_media_status(
|
||||
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
|
||||
)
|
||||
raise
|
||||
try:
|
||||
await loop.run_in_executor(executor, run_sync)
|
||||
_enqueue(channel_id, DetectionEvent(
|
||||
annotations=[], mediaId=media_id,
|
||||
mediaStatus="AIProcessed", mediaPercent=100,
|
||||
))
|
||||
if token_mgr and content_hash:
|
||||
_put_media_status(
|
||||
content_hash, _MEDIA_STATUS_AI_PROCESSED, token_mgr.get_valid_token()
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if token_mgr and content_hash:
|
||||
_put_media_status(
|
||||
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
|
||||
)
|
||||
_enqueue(channel_id, DetectionEvent(
|
||||
annotations=[], mediaId=media_id,
|
||||
mediaStatus="Error", mediaPercent=0,
|
||||
))
|
||||
if "not available" in str(e):
|
||||
return
|
||||
raise
|
||||
except Exception:
|
||||
if token_mgr and content_hash:
|
||||
_put_media_status(
|
||||
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
|
||||
)
|
||||
_enqueue(channel_id, DetectionEvent(
|
||||
annotations=[], mediaId=media_id,
|
||||
mediaStatus="Error", mediaPercent=0,
|
||||
))
|
||||
raise
|
||||
finally:
|
||||
loop.call_later(10.0, _cleanup_channel, channel_id)
|
||||
|
||||
asyncio.create_task(run_detection())
|
||||
return Response(status_code=202)
|
||||
|
||||
|
||||
@app.post("/detect/video")
|
||||
@@ -467,6 +553,10 @@ async def detect_video_upload(
|
||||
from inference import ai_config_from_dict
|
||||
from streaming_buffer import StreamingBuffer
|
||||
|
||||
channel_id = request.headers.get("x-channel-id", "")
|
||||
if not channel_id:
|
||||
raise HTTPException(status_code=400, detail="X-Channel-Id header required")
|
||||
|
||||
filename = request.headers.get("x-filename", "upload.mp4")
|
||||
config_json = request.headers.get("x-config", "")
|
||||
ext = _normalize_upload_ext(filename)
|
||||
@@ -491,32 +581,23 @@ async def detect_video_upload(
|
||||
loop = asyncio.get_event_loop()
|
||||
inf = get_inference()
|
||||
|
||||
placeholder_id = f"tmp_{os.path.basename(buffer.path)}"
|
||||
current_id = [placeholder_id] # mutable — updated to content_hash after upload
|
||||
current_media_id = [channel_id]
|
||||
|
||||
def on_annotation(annotation, percent):
|
||||
dtos = [detection_to_dto(d) for d in annotation.detections]
|
||||
mid = current_id[0]
|
||||
mid = current_media_id[0]
|
||||
event = DetectionEvent(
|
||||
annotations=dtos,
|
||||
mediaId=mid,
|
||||
mediaStatus="AIProcessing",
|
||||
mediaPercent=percent,
|
||||
)
|
||||
loop.call_soon_threadsafe(_enqueue, mid, event)
|
||||
|
||||
def on_status(media_name_cb, count):
|
||||
mid = current_id[0]
|
||||
event = DetectionEvent(
|
||||
annotations=[],
|
||||
mediaId=mid,
|
||||
mediaStatus="AIProcessed",
|
||||
mediaPercent=100,
|
||||
)
|
||||
loop.call_soon_threadsafe(_enqueue, mid, event)
|
||||
loop.call_soon_threadsafe(_enqueue, channel_id, event)
|
||||
if token_mgr and mid != channel_id and dtos:
|
||||
_post_annotation_to_service(token_mgr, mid, annotation, dtos)
|
||||
|
||||
def run_inference():
|
||||
inf.run_detect_video_stream(buffer, ai_cfg, media_name, on_annotation, on_status)
|
||||
inf.run_detect_video_stream(buffer, ai_cfg, media_name, on_annotation, lambda *_: None)
|
||||
|
||||
inference_future = loop.run_in_executor(executor, run_inference)
|
||||
|
||||
@@ -533,14 +614,14 @@ async def detect_video_upload(
|
||||
if not ext.startswith("."):
|
||||
ext = "." + ext
|
||||
storage_path = os.path.abspath(os.path.join(videos_dir, f"{content_hash}{ext}"))
|
||||
current_media_id[0] = content_hash
|
||||
|
||||
# Re-key buffered events from placeholder_id to content_hash so clients
|
||||
# can subscribe to GET /detect/{content_hash} after POST returns.
|
||||
if placeholder_id in _job_buffers:
|
||||
_job_buffers[content_hash] = _job_buffers.pop(placeholder_id)
|
||||
if placeholder_id in _job_queues:
|
||||
_job_queues[content_hash] = _job_queues.pop(placeholder_id)
|
||||
current_id[0] = content_hash # future on_annotation/on_status callbacks use content_hash
|
||||
_enqueue(channel_id, DetectionEvent(
|
||||
annotations=[],
|
||||
mediaId=content_hash,
|
||||
mediaStatus="Started",
|
||||
mediaPercent=0,
|
||||
))
|
||||
|
||||
if token_mgr and user_id:
|
||||
os.rename(buffer.path, storage_path)
|
||||
@@ -564,27 +645,24 @@ async def detect_video_upload(
|
||||
content_hash, _MEDIA_STATUS_AI_PROCESSED,
|
||||
token_mgr.get_valid_token(),
|
||||
)
|
||||
done_event = DetectionEvent(
|
||||
_enqueue(channel_id, DetectionEvent(
|
||||
annotations=[],
|
||||
mediaId=content_hash,
|
||||
mediaStatus="AIProcessed",
|
||||
mediaPercent=100,
|
||||
)
|
||||
_enqueue(content_hash, done_event)
|
||||
))
|
||||
except Exception:
|
||||
if token_mgr and user_id:
|
||||
_put_media_status(
|
||||
content_hash, _MEDIA_STATUS_ERROR,
|
||||
token_mgr.get_valid_token(),
|
||||
)
|
||||
err_event = DetectionEvent(
|
||||
_enqueue(channel_id, DetectionEvent(
|
||||
annotations=[], mediaId=content_hash,
|
||||
mediaStatus="Error", mediaPercent=0,
|
||||
)
|
||||
_enqueue(content_hash, err_event)
|
||||
))
|
||||
finally:
|
||||
_active_detections.pop(content_hash, None)
|
||||
_schedule_buffer_cleanup(content_hash)
|
||||
loop.call_later(10.0, _cleanup_channel, channel_id)
|
||||
buffer.close()
|
||||
if not (token_mgr and user_id) and os.path.isfile(buffer.path):
|
||||
try:
|
||||
@@ -592,31 +670,8 @@ async def detect_video_upload(
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
_active_detections[content_hash] = asyncio.create_task(_wait_inference())
|
||||
return {"status": "started", "mediaId": content_hash}
|
||||
|
||||
|
||||
def _post_annotation_to_service(token_mgr: TokenManager, media_id: str,
|
||||
annotation, dtos: list[DetectionDto]):
|
||||
try:
|
||||
token = token_mgr.get_valid_token()
|
||||
image_b64 = base64.b64encode(annotation.image).decode() if annotation.image else None
|
||||
payload = {
|
||||
"mediaId": media_id,
|
||||
"source": 0,
|
||||
"videoTime": f"00:00:{annotation.time // 1000:02d}" if annotation.time else "00:00:00",
|
||||
"detections": [d.model_dump() for d in dtos],
|
||||
}
|
||||
if image_b64:
|
||||
payload["image"] = image_b64
|
||||
http_requests.post(
|
||||
f"{ANNOTATIONS_URL}/annotations",
|
||||
json=payload,
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=30,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
asyncio.create_task(_wait_inference())
|
||||
return Response(status_code=202)
|
||||
|
||||
|
||||
@app.post("/detect/{media_id}")
|
||||
@@ -630,6 +685,10 @@ async def detect_media(
|
||||
if existing is not None and not existing.done():
|
||||
raise HTTPException(status_code=409, detail="Detection already in progress for this media")
|
||||
|
||||
channel_id = request.headers.get("x-channel-id", "")
|
||||
if not channel_id:
|
||||
raise HTTPException(status_code=400, detail="X-Channel-Id header required")
|
||||
|
||||
refresh_token = request.headers.get("x-refresh-token", "")
|
||||
access_token = request.headers.get("authorization", "").removeprefix("Bearer ").strip()
|
||||
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
|
||||
@@ -668,7 +727,7 @@ async def detect_media(
|
||||
mediaStatus="AIProcessing",
|
||||
mediaPercent=percent,
|
||||
)
|
||||
loop.call_soon_threadsafe(_enqueue, media_id, event)
|
||||
loop.call_soon_threadsafe(_enqueue, channel_id, event)
|
||||
if token_mgr and dtos:
|
||||
_post_annotation_to_service(token_mgr, media_id, annotation, dtos)
|
||||
|
||||
@@ -679,7 +738,7 @@ async def detect_media(
|
||||
mediaStatus="AIProcessed",
|
||||
mediaPercent=100,
|
||||
)
|
||||
loop.call_soon_threadsafe(_enqueue, media_id, event)
|
||||
loop.call_soon_threadsafe(_enqueue, channel_id, event)
|
||||
if token_mgr:
|
||||
_put_media_status(
|
||||
media_id,
|
||||
@@ -718,36 +777,10 @@ async def detect_media(
|
||||
mediaStatus="Error",
|
||||
mediaPercent=0,
|
||||
)
|
||||
_enqueue(media_id, error_event)
|
||||
_enqueue(channel_id, error_event)
|
||||
finally:
|
||||
_active_detections.pop(media_id, None)
|
||||
_schedule_buffer_cleanup(media_id)
|
||||
loop.call_later(10.0, _cleanup_channel, channel_id)
|
||||
|
||||
_active_detections[media_id] = asyncio.create_task(run_detection())
|
||||
return {"status": "started", "mediaId": media_id}
|
||||
|
||||
|
||||
@app.get("/detect/{media_id}", dependencies=[Depends(require_auth)])
|
||||
async def detect_events(media_id: str):
|
||||
queue: asyncio.Queue = asyncio.Queue(maxsize=200)
|
||||
_job_queues.setdefault(media_id, []).append(queue)
|
||||
|
||||
async def event_generator():
|
||||
try:
|
||||
for data in list(_job_buffers.get(media_id, [])):
|
||||
yield f"data: {data}\n\n"
|
||||
while True:
|
||||
data = await queue.get()
|
||||
yield f"data: {data}\n\n"
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
queues = _job_queues.get(media_id, [])
|
||||
if queue in queues:
|
||||
queues.remove(queue)
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
return Response(status_code=202)
|
||||
|
||||
@@ -419,7 +419,10 @@ class TestDetectVideoEndpoint:
|
||||
from fastapi.testclient import TestClient
|
||||
client = TestClient(main.app)
|
||||
token = _access_jwt()
|
||||
with patch.object(main, "get_inference", return_value=_CaptureInf()):
|
||||
with (
|
||||
patch.object(main, "JWT_SECRET", _TEST_JWT_SECRET),
|
||||
patch.object(main, "get_inference", return_value=_CaptureInf()),
|
||||
):
|
||||
# Act
|
||||
r = client.post(
|
||||
"/detect/video",
|
||||
|
||||
@@ -96,3 +96,149 @@ def test_convert_from_source_uses_fp16_when_no_cache():
|
||||
mock_config.set_flag.assert_any_call("FP16")
|
||||
int8_calls = [c for c in mock_config.set_flag.call_args_list if c == call("INT8")]
|
||||
assert len(int8_calls) == 0
|
||||
|
||||
|
||||
@requires_tensorrt
|
||||
def test_trt_factory_build_from_source_uses_fp16():
|
||||
# Arrange
|
||||
from engines.engine_factory import TensorRTEngineFactory
|
||||
from engines.tensorrt_engine import TensorRTEngine
|
||||
import engines.tensorrt_engine as trt_mod
|
||||
|
||||
mock_trt, mock_builder, mock_config = _make_mock_trt()
|
||||
factory = TensorRTEngineFactory()
|
||||
|
||||
with patch.object(trt_mod, "trt", mock_trt), \
|
||||
patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3):
|
||||
# Act
|
||||
engine_bytes, filename = factory.build_from_source(b"onnx", MagicMock(), "models")
|
||||
|
||||
# Assert
|
||||
assert engine_bytes == b"engine_bytes"
|
||||
assert filename is not None
|
||||
int8_calls = [c for c in mock_config.set_flag.call_args_list if c == call("INT8")]
|
||||
assert len(int8_calls) == 0
|
||||
|
||||
|
||||
@requires_tensorrt
|
||||
def test_jetson_factory_build_from_source_uses_int8_when_cache_available():
|
||||
# Arrange
|
||||
from engines.engine_factory import JetsonTensorRTEngineFactory
|
||||
from engines.tensorrt_engine import TensorRTEngine
|
||||
import engines.tensorrt_engine as trt_mod
|
||||
|
||||
mock_trt, mock_builder, mock_config = _make_mock_trt()
|
||||
factory = JetsonTensorRTEngineFactory()
|
||||
loader = MagicMock()
|
||||
result = MagicMock()
|
||||
result.err = None
|
||||
result.data = b"calib_data"
|
||||
loader.load_big_small_resource.return_value = result
|
||||
|
||||
with patch.object(trt_mod, "trt", mock_trt), \
|
||||
patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3):
|
||||
# Act
|
||||
engine_bytes, filename = factory.build_from_source(b"onnx", loader, "models")
|
||||
|
||||
# Assert
|
||||
assert engine_bytes == b"engine_bytes"
|
||||
assert "int8" in filename
|
||||
mock_config.set_flag.assert_any_call("INT8")
|
||||
|
||||
|
||||
@requires_tensorrt
|
||||
def test_jetson_factory_build_from_source_falls_back_to_fp16_when_no_cache():
|
||||
# Arrange
|
||||
from engines.engine_factory import JetsonTensorRTEngineFactory
|
||||
from engines.tensorrt_engine import TensorRTEngine
|
||||
import engines.tensorrt_engine as trt_mod
|
||||
|
||||
mock_trt, mock_builder, mock_config = _make_mock_trt()
|
||||
factory = JetsonTensorRTEngineFactory()
|
||||
loader = MagicMock()
|
||||
result = MagicMock()
|
||||
result.err = "not found"
|
||||
loader.load_big_small_resource.return_value = result
|
||||
|
||||
with patch.object(trt_mod, "trt", mock_trt), \
|
||||
patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3):
|
||||
# Act
|
||||
engine_bytes, filename = factory.build_from_source(b"onnx", loader, "models")
|
||||
|
||||
# Assert
|
||||
assert engine_bytes == b"engine_bytes"
|
||||
int8_calls = [c for c in mock_config.set_flag.call_args_list if c == call("INT8")]
|
||||
assert len(int8_calls) == 0
|
||||
|
||||
|
||||
@requires_tensorrt
|
||||
def test_jetson_factory_cleans_up_cache_tempfile_after_build():
|
||||
# Arrange
|
||||
from engines.engine_factory import JetsonTensorRTEngineFactory
|
||||
from engines.tensorrt_engine import TensorRTEngine
|
||||
import engines.tensorrt_engine as trt_mod
|
||||
|
||||
mock_trt, _, _ = _make_mock_trt()
|
||||
factory = JetsonTensorRTEngineFactory()
|
||||
loader = MagicMock()
|
||||
result = MagicMock()
|
||||
result.err = None
|
||||
result.data = b"calib_data"
|
||||
loader.load_big_small_resource.return_value = result
|
||||
|
||||
written_paths = []
|
||||
original_download = factory._download_calib_cache
|
||||
|
||||
def tracking_download(lc, md):
|
||||
path = original_download(lc, md)
|
||||
if path:
|
||||
written_paths.append(path)
|
||||
return path
|
||||
|
||||
with patch.object(trt_mod, "trt", mock_trt), \
|
||||
patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3), \
|
||||
patch.object(factory, "_download_calib_cache", side_effect=tracking_download):
|
||||
factory.build_from_source(b"onnx", loader, "models")
|
||||
|
||||
# Assert: temp file was deleted after build
|
||||
for p in written_paths:
|
||||
assert not os.path.exists(p)
|
||||
|
||||
|
||||
def test_is_jetson_false_on_non_aarch64():
|
||||
# Arrange
|
||||
import engines as eng
|
||||
|
||||
with patch("engines.platform") as mock_platform, \
|
||||
patch("engines.tensor_gpu_index", 0), \
|
||||
patch("engines.os.path.isfile", return_value=True):
|
||||
mock_platform.machine.return_value = "x86_64"
|
||||
|
||||
# Assert
|
||||
assert eng._is_jetson() is False
|
||||
|
||||
|
||||
def test_is_jetson_false_when_no_gpu():
|
||||
# Arrange
|
||||
import engines as eng
|
||||
|
||||
with patch("engines.platform") as mock_platform, \
|
||||
patch("engines.tensor_gpu_index", -1), \
|
||||
patch("engines.os.path.isfile", return_value=True):
|
||||
mock_platform.machine.return_value = "aarch64"
|
||||
|
||||
# Assert
|
||||
assert eng._is_jetson() is False
|
||||
|
||||
|
||||
def test_is_jetson_false_when_no_tegra_release():
|
||||
# Arrange
|
||||
import engines as eng
|
||||
|
||||
with patch("engines.platform") as mock_platform, \
|
||||
patch("engines.tensor_gpu_index", 0), \
|
||||
patch("engines.os.path.isfile", return_value=False):
|
||||
mock_platform.machine.return_value = "aarch64"
|
||||
|
||||
# Assert
|
||||
assert eng._is_jetson() is False
|
||||
|
||||
Reference in New Issue
Block a user