[AZ-180] Refactor detection event handling and improve SSE support

- Updated the detection image endpoint to require a channel ID for event streaming.
- Introduced a new endpoint for streaming detection events, allowing clients to receive real-time updates.
- Enhanced the internal buffering mechanism for detection events to manage multiple channels.
- Refactored the inference module to support the new event handling structure.

Made-with: Cursor
This commit is contained in:
Oleksandr Bezdieniezhnykh
2026-04-03 02:42:05 +03:00
parent 2c35e59a77
commit 8baa96978b
26 changed files with 819 additions and 413 deletions
+4
View File
@@ -4,6 +4,10 @@ alwaysApply: true
--- ---
# Coding preferences # Coding preferences
- Always prefer simple solution - Always prefer simple solution
- Follow the Single Responsibility Principle — a class or method should have one reason to change:
- If a method is hard to name precisely from the caller's perspective, its responsibility is misplaced. Vague names like "candidate", "data", or "item" are a signal — fix the design, not just the name.
- Logic specific to a platform, variant, or environment belongs in the class that owns that variant, not in the general coordinator. Passing a dependency through is preferable to leaking variant-specific concepts into shared code.
- Only use static methods for pure, self-contained computations (constants, simple math, stateless lookups). If a static method involves resource access, side effects, OS interaction, or logic that varies across subclasses or environments — use an instance method or factory class instead. Before implementing a non-trivial static method, ask the user.
- Generate concise code - Generate concise code
- Do not put comments in the code, except in tests: every test must use the Arrange / Act / Assert pattern with language-appropriate comment syntax (`# Arrange` for Python, `// Arrange` for C#/Rust/JS/TS). Omit any section that is not needed (e.g. if there is no setup, skip Arrange; if act and assert are the same line, keep only Assert) - Do not put comments in the code, except in tests: every test must use the Arrange / Act / Assert pattern with language-appropriate comment syntax (`# Arrange` for Python, `// Arrange` for C#/Rust/JS/TS). Omit any section that is not needed (e.g. if there is no setup, skip Arrange; if act and assert are the same line, keep only Assert)
- Do not put logs unless it is an exception, or was asked specifically - Do not put logs unless it is an exception, or was asked specifically
+1
View File
@@ -6,3 +6,4 @@ alwaysApply: true
- Work on the `dev` branch - Work on the `dev` branch
- Commit message format: `[TRACKER-ID-1] [TRACKER-ID-2] Summary of changes` - Commit message format: `[TRACKER-ID-1] [TRACKER-ID-2] Summary of changes`
- Commit message total length must not exceed 30 characters
+119 -9
View File
@@ -1,6 +1,9 @@
import json
import os import os
import random import random
import threading
import time import time
import uuid
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
@@ -75,12 +78,83 @@ def auth_headers(jwt_token):
return {"Authorization": f"Bearer {jwt_token}"} if jwt_token else {} return {"Authorization": f"Bearer {jwt_token}"} if jwt_token else {}
@pytest.fixture
def channel_id():
return str(uuid.uuid4())
@pytest.fixture(scope="session")
def image_detect(http_client, auth_headers):
def _detect(image_bytes, filename="img.jpg", config=None, timeout=30):
cid = str(uuid.uuid4())
headers = {**auth_headers, "X-Channel-Id": cid}
detections = []
errors = []
done = threading.Event()
connected = threading.Event()
def _listen():
try:
with http_client.get(
f"/detect/events/{cid}",
stream=True,
timeout=timeout + 2,
headers=auth_headers,
) as resp:
resp.raise_for_status()
connected.set()
for ev in sseclient.SSEClient(resp).events():
if not ev.data or not str(ev.data).strip():
continue
data = json.loads(ev.data)
if data.get("mediaStatus") == "AIProcessing":
detections.extend(data.get("annotations", []))
if data.get("mediaStatus") in ("AIProcessed", "Error"):
break
except BaseException as e:
errors.append(e)
finally:
connected.set()
done.set()
th = threading.Thread(target=_listen, daemon=True)
th.start()
connected.wait(timeout=5)
data_form = {}
if config:
data_form["config"] = config
t0 = time.perf_counter()
r = http_client.post(
"/detect/image",
files={"file": (filename, image_bytes, "image/jpeg")},
data=data_form,
headers=headers,
timeout=timeout,
)
done.wait(timeout=timeout)
elapsed_ms = (time.perf_counter() - t0) * 1000.0
assert r.status_code == 202, f"Expected 202, got {r.status_code}: {r.text}"
assert not errors, f"SSE errors: {errors}"
th.join(timeout=1)
return detections, elapsed_ms
return _detect
@pytest.fixture @pytest.fixture
def sse_client_factory(http_client, auth_headers): def sse_client_factory(http_client, auth_headers):
@contextmanager @contextmanager
def _open(media_id: str): def _open(channel_id: str):
with http_client.get(f"/detect/{media_id}", stream=True, with http_client.get(
timeout=600, headers=auth_headers) as resp: f"/detect/events/{channel_id}",
stream=True,
timeout=600,
headers=auth_headers,
) as resp:
resp.raise_for_status() resp.raise_for_status()
yield sseclient.SSEClient(resp) yield sseclient.SSEClient(resp)
@@ -201,19 +275,52 @@ def corrupt_image():
return random.randbytes(1024) return random.randbytes(1024)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def warm_engine(http_client, image_small, auth_headers): def warm_engine(http_client, image_small, auth_headers):
deadline = time.time() + 120 deadline = time.time() + 120
files = {"file": ("warm.jpg", image_small, "image/jpeg")}
consecutive_errors = 0
last_status = None last_status = None
consecutive_errors = 0
while time.time() < deadline: while time.time() < deadline:
cid = str(uuid.uuid4())
headers = {**auth_headers, "X-Channel-Id": cid}
done = threading.Event()
def _listen(cid=cid):
try: try:
r = http_client.post("/detect/image", files=files, headers=auth_headers) with http_client.get(
if r.status_code == 200: f"/detect/events/{cid}",
return stream=True,
timeout=35,
headers=auth_headers,
) as resp:
resp.raise_for_status()
for ev in sseclient.SSEClient(resp).events():
if not ev.data or not str(ev.data).strip():
continue
data = json.loads(ev.data)
if data.get("mediaStatus") == "AIProcessed":
break
except Exception:
pass
finally:
done.set()
th = threading.Thread(target=_listen, daemon=True)
th.start()
time.sleep(0.1)
try:
r = http_client.post(
"/detect/image",
files={"file": ("warm.jpg", image_small, "image/jpeg")},
headers=headers,
)
last_status = r.status_code last_status = r.status_code
if r.status_code == 202:
done.wait(timeout=30)
th.join(timeout=1)
return
if r.status_code >= 500: if r.status_code >= 500:
consecutive_errors += 1 consecutive_errors += 1
if consecutive_errors >= 5: if consecutive_errors >= 5:
@@ -225,5 +332,8 @@ def warm_engine(http_client, image_small, auth_headers):
consecutive_errors = 0 consecutive_errors = 0
except OSError: except OSError:
consecutive_errors = 0 consecutive_errors = 0
th.join(timeout=1)
time.sleep(2) time.sleep(2)
pytest.fail(f"engine warm-up timed out after 120s (last status: {last_status})") pytest.fail(f"engine warm-up timed out after 120s (last status: {last_status})")
+27 -17
View File
@@ -4,6 +4,7 @@ import time
import uuid import uuid
import pytest import pytest
import sseclient
def _ai_config_video() -> dict: def _ai_config_video() -> dict:
@@ -19,36 +20,43 @@ def test_ft_p08_immediate_async_response(
): ):
media_id = f"async-{uuid.uuid4().hex}" media_id = f"async-{uuid.uuid4().hex}"
body = _ai_config_image() body = _ai_config_image()
headers = {"Authorization": f"Bearer {jwt_token}"} channel_id = str(uuid.uuid4())
headers = {"Authorization": f"Bearer {jwt_token}", "X-Channel-Id": channel_id}
t0 = time.monotonic() t0 = time.monotonic()
r = http_client.post(f"/detect/{media_id}", json=body, headers=headers) r = http_client.post(f"/detect/{media_id}", json=body, headers=headers)
elapsed = time.monotonic() - t0 elapsed = time.monotonic() - t0
assert elapsed < 2.0 assert elapsed < 2.0
assert r.status_code == 200 assert r.status_code == 202
assert r.json() == {"status": "started", "mediaId": media_id}
@pytest.mark.timeout(10) @pytest.mark.timeout(10)
def test_ft_p09_sse_event_delivery( def test_ft_p09_sse_event_delivery(
warm_engine, http_client, jwt_token, sse_client_factory warm_engine, http_client, jwt_token
): ):
media_id = f"sse-{uuid.uuid4().hex}" media_id = f"sse-{uuid.uuid4().hex}"
channel_id = str(uuid.uuid4())
body = _ai_config_video() body = _ai_config_video()
headers = {"Authorization": f"Bearer {jwt_token}"} auth_header = {"Authorization": f"Bearer {jwt_token}"}
post_headers = {**auth_header, "X-Channel-Id": channel_id}
collected: list[dict] = [] collected: list[dict] = []
thread_exc: list[BaseException] = [] thread_exc: list[BaseException] = []
first_event = threading.Event() first_event = threading.Event()
connected = threading.Event()
def _listen(): def _listen():
try: try:
with sse_client_factory(media_id) as sse: with http_client.get(
time.sleep(0.3) f"/detect/events/{channel_id}",
for event in sse.events(): stream=True,
timeout=600,
headers=auth_header,
) as resp:
resp.raise_for_status()
connected.set()
for event in sseclient.SSEClient(resp).events():
if not event.data or not str(event.data).strip(): if not event.data or not str(event.data).strip():
continue continue
data = json.loads(event.data) data = json.loads(event.data)
if data.get("mediaId") != media_id:
continue
collected.append(data) collected.append(data)
first_event.set() first_event.set()
if len(collected) >= 5: if len(collected) >= 5:
@@ -56,13 +64,14 @@ def test_ft_p09_sse_event_delivery(
except BaseException as e: except BaseException as e:
thread_exc.append(e) thread_exc.append(e)
finally: finally:
connected.set()
first_event.set() first_event.set()
th = threading.Thread(target=_listen, daemon=True) th = threading.Thread(target=_listen, daemon=True)
th.start() th.start()
time.sleep(0.5) connected.wait(timeout=5)
r = http_client.post(f"/detect/{media_id}", json=body, headers=headers) r = http_client.post(f"/detect/{media_id}", json=body, headers=post_headers)
assert r.status_code == 200 assert r.status_code == 202
first_event.wait(timeout=5) first_event.wait(timeout=5)
th.join(timeout=5) th.join(timeout=5)
assert not thread_exc, thread_exc assert not thread_exc, thread_exc
@@ -74,8 +83,9 @@ def test_ft_n04_duplicate_media_id_409(
): ):
media_id = "dup-test" media_id = "dup-test"
body = _ai_config_image() body = _ai_config_image()
headers = {"Authorization": f"Bearer {jwt_token}"} headers1 = {"Authorization": f"Bearer {jwt_token}", "X-Channel-Id": str(uuid.uuid4())}
r1 = http_client.post(f"/detect/{media_id}", json=body, headers=headers) headers2 = {"Authorization": f"Bearer {jwt_token}", "X-Channel-Id": str(uuid.uuid4())}
assert r1.status_code == 200 r1 = http_client.post(f"/detect/{media_id}", json=body, headers=headers1)
r2 = http_client.post(f"/detect/{media_id}", json=body, headers=headers) assert r1.status_code == 202
r2 = http_client.post(f"/detect/{media_id}", json=body, headers=headers2)
assert r2.status_code == 409 assert r2.status_code == 409
+82 -9
View File
@@ -1,6 +1,10 @@
import json
import threading
import time import time
import uuid
import pytest import pytest
import sseclient
_DETECT_TIMEOUT = 60 _DETECT_TIMEOUT = 60
@@ -39,11 +43,44 @@ class TestHealthEngineStep02LazyInit:
f"engine already initialized (aiAvailability={before['aiAvailability']}); " f"engine already initialized (aiAvailability={before['aiAvailability']}); "
"lazy-init test must run before any test that triggers warm_engine" "lazy-init test must run before any test that triggers warm_engine"
) )
cid = str(uuid.uuid4())
headers = {**auth_headers, "X-Channel-Id": cid}
done = threading.Event()
connected = threading.Event()
def _listen():
try:
with http_client.get(
f"/detect/events/{cid}",
stream=True,
timeout=_DETECT_TIMEOUT + 2,
headers=auth_headers,
) as resp:
resp.raise_for_status()
connected.set()
for ev in sseclient.SSEClient(resp).events():
if not ev.data or not str(ev.data).strip():
continue
data = json.loads(ev.data)
if data.get("mediaStatus") in ("AIProcessed", "Error"):
break
except Exception:
pass
finally:
connected.set()
done.set()
th = threading.Thread(target=_listen, daemon=True)
th.start()
connected.wait(timeout=5)
files = {"file": ("lazy.jpg", image_small, "image/jpeg")} files = {"file": ("lazy.jpg", image_small, "image/jpeg")}
r = http_client.post("/detect/image", files=files, headers=auth_headers, timeout=_DETECT_TIMEOUT) r = http_client.post("/detect/image", files=files, headers=headers, timeout=_DETECT_TIMEOUT)
r.raise_for_status() assert r.status_code == 202
body = r.json() done.wait(timeout=_DETECT_TIMEOUT)
assert isinstance(body, list) th.join(timeout=2)
after = _get_health(http_client) after = _get_health(http_client)
_assert_active_ai(after) _assert_active_ai(after)
@@ -61,13 +98,49 @@ class TestHealthEngineStep03Warmed:
assert data.get("errorMessage") is None assert data.get("errorMessage") is None
def test_ft_p_15_onnx_cpu_detect(self, http_client, image_small, auth_headers): def test_ft_p_15_onnx_cpu_detect(self, http_client, image_small, auth_headers):
cid = str(uuid.uuid4())
headers = {**auth_headers, "X-Channel-Id": cid}
all_detections = []
done = threading.Event()
connected = threading.Event()
def _listen():
try:
with http_client.get(
f"/detect/events/{cid}",
stream=True,
timeout=_DETECT_TIMEOUT + 2,
headers=auth_headers,
) as resp:
resp.raise_for_status()
connected.set()
for ev in sseclient.SSEClient(resp).events():
if not ev.data or not str(ev.data).strip():
continue
data = json.loads(ev.data)
if data.get("mediaStatus") == "AIProcessing":
all_detections.extend(data.get("annotations", []))
if data.get("mediaStatus") in ("AIProcessed", "Error"):
break
except Exception:
pass
finally:
connected.set()
done.set()
th = threading.Thread(target=_listen, daemon=True)
th.start()
connected.wait(timeout=5)
files = {"file": ("onnx.jpg", image_small, "image/jpeg")} files = {"file": ("onnx.jpg", image_small, "image/jpeg")}
r = http_client.post("/detect/image", files=files, headers=auth_headers, timeout=_DETECT_TIMEOUT) r = http_client.post("/detect/image", files=files, headers=headers, timeout=_DETECT_TIMEOUT)
r.raise_for_status() r.raise_for_status()
body = r.json() assert r.status_code == 202
assert isinstance(body, list) done.wait(timeout=_DETECT_TIMEOUT)
if body: th.join(timeout=2)
d = body[0]
if all_detections:
d = all_detections[0]
for k in ( for k in (
"centerX", "centerX",
"centerY", "centerY",
+5 -1
View File
@@ -1,3 +1,5 @@
import uuid
import pytest import pytest
import requests import requests
@@ -41,6 +43,8 @@ def test_ft_n_03_loader_error_mode_detect_does_not_500(
f"{mock_loader_url}/mock/config", json={"mode": "error"}, timeout=10 f"{mock_loader_url}/mock/config", json={"mode": "error"}, timeout=10
) )
cfg.raise_for_status() cfg.raise_for_status()
channel_id = str(uuid.uuid4())
headers = {**auth_headers, "X-Channel-Id": channel_id}
files = {"file": ("small.jpg", image_small, "image/jpeg")} files = {"file": ("small.jpg", image_small, "image/jpeg")}
r = http_client.post("/detect/image", files=files, headers=auth_headers, timeout=_DETECT_TIMEOUT) r = http_client.post("/detect/image", files=files, headers=headers, timeout=_DETECT_TIMEOUT)
assert r.status_code != 500 assert r.status_code != 500
+8 -34
View File
@@ -1,5 +1,4 @@
import json import json
import time
import pytest import pytest
@@ -19,20 +18,13 @@ def _percentile_ms(sorted_ms, p):
@pytest.mark.timeout(60) @pytest.mark.timeout(60)
def test_nft_perf_01_single_image_latency_p95( def test_nft_perf_01_single_image_latency_p95(
warm_engine, http_client, image_small, auth_headers warm_engine, image_detect, image_small
): ):
times_ms = [] times_ms = []
for _ in range(10): for _ in range(10):
t0 = time.perf_counter() _, elapsed_ms = image_detect(image_small, "img.jpg", timeout=8)
r = http_client.post(
"/detect/image",
files={"file": ("img.jpg", image_small, "image/jpeg")},
headers=auth_headers,
timeout=8,
)
elapsed_ms = (time.perf_counter() - t0) * 1000.0
assert r.status_code == 200
times_ms.append(elapsed_ms) times_ms.append(elapsed_ms)
sorted_ms = sorted(times_ms) sorted_ms = sorted(times_ms)
p50 = _percentile_ms(sorted_ms, 50) p50 = _percentile_ms(sorted_ms, 50)
p95 = _percentile_ms(sorted_ms, 95) p95 = _percentile_ms(sorted_ms, 95)
@@ -47,34 +39,16 @@ def test_nft_perf_01_single_image_latency_p95(
@pytest.mark.timeout(60) @pytest.mark.timeout(60)
def test_nft_perf_03_tiling_overhead_large_image( def test_nft_perf_03_tiling_overhead_large_image(
warm_engine, http_client, image_small, image_large, auth_headers warm_engine, image_detect, image_small, image_large
): ):
t_small = time.perf_counter() _, small_ms = image_detect(image_small, "small.jpg", timeout=8)
r_small = http_client.post( _, large_ms = image_detect(
"/detect/image", image_large, "large.jpg",
files={"file": ("small.jpg", image_small, "image/jpeg")}, config=json.dumps({"altitude": 400, "focal_length": 24, "sensor_width": 23.5}),
headers=auth_headers,
timeout=8,
)
small_ms = (time.perf_counter() - t_small) * 1000.0
assert r_small.status_code == 200
config = json.dumps(
{"altitude": 400, "focal_length": 24, "sensor_width": 23.5}
)
t_large = time.perf_counter()
r_large = http_client.post(
"/detect/image",
files={"file": ("large.jpg", image_large, "image/jpeg")},
data={"config": config},
headers=auth_headers,
timeout=20, timeout=20,
) )
large_ms = (time.perf_counter() - t_large) * 1000.0
assert r_large.status_code == 200
assert large_ms < 30_000.0 assert large_ms < 30_000.0
print( print(
f"nft_perf_03_csv,baseline_small_ms,{small_ms:.2f},large_ms,{large_ms:.2f}" f"nft_perf_03_csv,baseline_small_ms,{small_ms:.2f},large_ms,{large_ms:.2f}"
) )
assert large_ms > small_ms - 500.0 assert large_ms > small_ms - 500.0
+9 -13
View File
@@ -5,15 +5,13 @@ _DETECT_TIMEOUT = 60
def test_nft_res_01_loader_outage_after_init( def test_nft_res_01_loader_outage_after_init(
warm_engine, http_client, mock_loader_url, image_small, auth_headers warm_engine, image_detect, mock_loader_url, image_small, http_client
): ):
requests.post( requests.post(
f"{mock_loader_url}/mock/config", json={"mode": "error"}, timeout=10 f"{mock_loader_url}/mock/config", json={"mode": "error"}, timeout=10
).raise_for_status() ).raise_for_status()
files = {"file": ("r1.jpg", image_small, "image/jpeg")} detections, _ = image_detect(image_small, "r1.jpg", timeout=_DETECT_TIMEOUT)
r = http_client.post("/detect/image", files=files, headers=auth_headers, timeout=_DETECT_TIMEOUT) assert isinstance(detections, list)
assert r.status_code == 200
assert isinstance(r.json(), list)
h = http_client.get("/health") h = http_client.get("/health")
assert h.status_code == 200 assert h.status_code == 200
hd = h.json() hd = h.json()
@@ -22,15 +20,13 @@ def test_nft_res_01_loader_outage_after_init(
def test_nft_res_03_transient_loader_first_fail( def test_nft_res_03_transient_loader_first_fail(
mock_loader_url, http_client, image_small, auth_headers mock_loader_url, image_detect, image_small
): ):
requests.post( requests.post(
f"{mock_loader_url}/mock/config", json={"mode": "first_fail"}, timeout=10 f"{mock_loader_url}/mock/config", json={"mode": "first_fail"}, timeout=10
).raise_for_status() ).raise_for_status()
files = {"file": ("r3a.jpg", image_small, "image/jpeg")} try:
r1 = http_client.post("/detect/image", files=files, headers=auth_headers, timeout=_DETECT_TIMEOUT) image_detect(image_small, "r3a.jpg", timeout=_DETECT_TIMEOUT)
files2 = {"file": ("r3b.jpg", image_small, "image/jpeg")} except AssertionError:
r2 = http_client.post("/detect/image", files=files2, headers=auth_headers, timeout=_DETECT_TIMEOUT) pass
assert r2.status_code == 200 image_detect(image_small, "r3b.jpg", timeout=_DETECT_TIMEOUT)
if r1.status_code != 200:
assert r1.status_code != 500
+6 -18
View File
@@ -8,28 +8,16 @@ import pytest
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
def test_nft_res_lim_03_max_detections_per_frame( def test_nft_res_lim_03_max_detections_per_frame(
warm_engine, http_client, image_dense, auth_headers warm_engine, image_detect, image_dense
): ):
r = http_client.post( detections, _ = image_detect(image_dense, "img.jpg", timeout=120)
"/detect/image", assert isinstance(detections, list)
files={"file": ("img.jpg", image_dense, "image/jpeg")}, assert len(detections) <= 300
headers=auth_headers,
timeout=120,
)
assert r.status_code == 200
body = r.json()
assert isinstance(body, list)
assert len(body) <= 300
@pytest.mark.slow @pytest.mark.slow
def test_nft_res_lim_04_log_file_rotation(warm_engine, http_client, image_small, auth_headers): def test_nft_res_lim_04_log_file_rotation(warm_engine, image_detect, image_small):
http_client.post( image_detect(image_small, "img.jpg", timeout=60)
"/detect/image",
files={"file": ("img.jpg", image_small, "image/jpeg")},
headers=auth_headers,
timeout=60,
)
candidates = [ candidates = [
Path(__file__).resolve().parent.parent / "logs", Path(__file__).resolve().parent.parent / "logs",
Path("/app/Logs"), Path("/app/Logs"),
+20 -65
View File
@@ -81,16 +81,10 @@ def _weather_label_ok(label, base_names):
@pytest.mark.slow @pytest.mark.slow
def test_ft_p_03_detection_response_structure_ac1(http_client, image_small, warm_engine, auth_headers): def test_ft_p_03_detection_response_structure_ac1(image_detect, image_small, warm_engine):
r = http_client.post( detections, _ = image_detect(image_small, "img.jpg")
"/detect/image", assert isinstance(detections, list)
files={"file": ("img.jpg", image_small, "image/jpeg")}, for d in detections:
headers=auth_headers,
)
assert r.status_code == 200
body = r.json()
assert isinstance(body, list)
for d in body:
assert isinstance(d["centerX"], (int, float)) assert isinstance(d["centerX"], (int, float))
assert isinstance(d["centerY"], (int, float)) assert isinstance(d["centerY"], (int, float))
assert isinstance(d["width"], (int, float)) assert isinstance(d["width"], (int, float))
@@ -106,44 +100,24 @@ def test_ft_p_03_detection_response_structure_ac1(http_client, image_small, warm
@pytest.mark.slow @pytest.mark.slow
def test_ft_p_05_confidence_filtering_ac2(http_client, image_small, warm_engine, auth_headers): def test_ft_p_05_confidence_filtering_ac2(image_detect, image_small, warm_engine):
cfg_hi = json.dumps({"probability_threshold": 0.8}) hi, _ = image_detect(image_small, "img.jpg", config=json.dumps({"probability_threshold": 0.8}))
r_hi = http_client.post(
"/detect/image",
files={"file": ("img.jpg", image_small, "image/jpeg")},
data={"config": cfg_hi},
headers=auth_headers,
)
assert r_hi.status_code == 200
hi = r_hi.json()
assert isinstance(hi, list) assert isinstance(hi, list)
for d in hi: for d in hi:
assert float(d["confidence"]) + _EPS >= 0.8 assert float(d["confidence"]) + _EPS >= 0.8
cfg_lo = json.dumps({"probability_threshold": 0.1})
r_lo = http_client.post( lo, _ = image_detect(image_small, "img.jpg", config=json.dumps({"probability_threshold": 0.1}))
"/detect/image",
files={"file": ("img.jpg", image_small, "image/jpeg")},
data={"config": cfg_lo},
headers=auth_headers,
)
assert r_lo.status_code == 200
lo = r_lo.json()
assert isinstance(lo, list) assert isinstance(lo, list)
assert len(lo) >= len(hi) assert len(lo) >= len(hi)
@pytest.mark.slow @pytest.mark.slow
def test_ft_p_06_overlap_deduplication_ac3(http_client, image_dense, warm_engine, auth_headers): def test_ft_p_06_overlap_deduplication_ac3(image_detect, image_dense, warm_engine):
cfg_loose = json.dumps({"tracking_intersection_threshold": 0.6}) dets, _ = image_detect(
r1 = http_client.post( image_dense, "img.jpg",
"/detect/image", config=json.dumps({"tracking_intersection_threshold": 0.6}),
files={"file": ("img.jpg", image_dense, "image/jpeg")},
data={"config": cfg_loose},
headers=auth_headers,
timeout=_DETECT_SLOW_TIMEOUT, timeout=_DETECT_SLOW_TIMEOUT,
) )
assert r1.status_code == 200
dets = r1.json()
assert isinstance(dets, list) assert isinstance(dets, list)
by_label = {} by_label = {}
for d in dets: for d in dets:
@@ -153,22 +127,18 @@ def test_ft_p_06_overlap_deduplication_ac3(http_client, image_dense, warm_engine
for j in range(i + 1, len(group)): for j in range(i + 1, len(group)):
ratio = _overlap_to_min_area_ratio(group[i], group[j]) ratio = _overlap_to_min_area_ratio(group[i], group[j])
assert ratio <= 0.6 + _EPS, (label, ratio) assert ratio <= 0.6 + _EPS, (label, ratio)
cfg_strict = json.dumps({"tracking_intersection_threshold": 0.01})
r2 = http_client.post( strict, _ = image_detect(
"/detect/image", image_dense, "img.jpg",
files={"file": ("img.jpg", image_dense, "image/jpeg")}, config=json.dumps({"tracking_intersection_threshold": 0.01}),
data={"config": cfg_strict},
headers=auth_headers,
timeout=_DETECT_SLOW_TIMEOUT, timeout=_DETECT_SLOW_TIMEOUT,
) )
assert r2.status_code == 200
strict = r2.json()
assert isinstance(strict, list) assert isinstance(strict, list)
assert len(strict) <= len(dets) assert len(strict) <= len(dets)
@pytest.mark.slow @pytest.mark.slow
def test_ft_p_07_physical_size_filtering_ac4(http_client, image_small, warm_engine, auth_headers): def test_ft_p_07_physical_size_filtering_ac4(image_detect, image_small, warm_engine):
by_id, _ = _load_classes_media() by_id, _ = _load_classes_media()
wh = _image_width_height(image_small) wh = _image_width_height(image_small)
assert wh is not None assert wh is not None
@@ -184,15 +154,7 @@ def test_ft_p_07_physical_size_filtering_ac4(http_client, image_small, warm_engi
"sensor_width": sensor_width, "sensor_width": sensor_width,
} }
) )
r = http_client.post( body, _ = image_detect(image_small, "img.jpg", config=cfg, timeout=_DETECT_SLOW_TIMEOUT)
"/detect/image",
files={"file": ("img.jpg", image_small, "image/jpeg")},
data={"config": cfg},
headers=auth_headers,
timeout=_DETECT_SLOW_TIMEOUT,
)
assert r.status_code == 200
body = r.json()
assert isinstance(body, list) assert isinstance(body, list)
for d in body: for d in body:
base_id = d["classNum"] % _WEATHER_CLASS_STRIDE base_id = d["classNum"] % _WEATHER_CLASS_STRIDE
@@ -203,17 +165,10 @@ def test_ft_p_07_physical_size_filtering_ac4(http_client, image_small, warm_engi
@pytest.mark.slow @pytest.mark.slow
def test_ft_p_13_weather_mode_class_variants_ac5( def test_ft_p_13_weather_mode_class_variants_ac5(
http_client, image_different_types, warm_engine, auth_headers image_detect, image_different_types, warm_engine
): ):
_, base_names = _load_classes_media() _, base_names = _load_classes_media()
r = http_client.post( body, _ = image_detect(image_different_types, "img.jpg", timeout=_DETECT_SLOW_TIMEOUT)
"/detect/image",
files={"file": ("img.jpg", image_different_types, "image/jpeg")},
headers=auth_headers,
timeout=_DETECT_SLOW_TIMEOUT,
)
assert r.status_code == 200
body = r.json()
assert isinstance(body, list) assert isinstance(body, list)
for d in body: for d in body:
label = d["label"] label = d["label"]
+17 -9
View File
@@ -10,6 +10,7 @@ Run with: pytest e2e/tests/test_streaming_video_upload.py -s -v
import json import json
import threading import threading
import time import time
import uuid
from pathlib import Path from pathlib import Path
import pytest import pytest
@@ -37,21 +38,23 @@ def _chunked_reader(path: str, chunk_size: int = 64 * 1024):
def _start_sse_listener( def _start_sse_listener(
http_client, media_id: str, auth_headers: dict http_client, channel_id: str, auth_headers: dict
) -> tuple[list[dict], list[BaseException], threading.Event]: ) -> tuple[list[dict], list[BaseException], threading.Event]:
events: list[dict] = [] events: list[dict] = []
errors: list[BaseException] = [] errors: list[BaseException] = []
first_event = threading.Event() first_event = threading.Event()
connected = threading.Event()
def _listen(): def _listen():
try: try:
with http_client.get( with http_client.get(
f"/detect/{media_id}", f"/detect/events/{channel_id}",
stream=True, stream=True,
timeout=_TIMEOUT + 2, timeout=_TIMEOUT + 2,
headers=auth_headers, headers=auth_headers,
) as resp: ) as resp:
resp.raise_for_status() resp.raise_for_status()
connected.set()
for event in sseclient.SSEClient(resp).events(): for event in sseclient.SSEClient(resp).events():
if not event.data or not str(event.data).strip(): if not event.data or not str(event.data).strip():
continue continue
@@ -62,9 +65,12 @@ def _start_sse_listener(
except BaseException as exc: except BaseException as exc:
errors.append(exc) errors.append(exc)
finally: finally:
connected.set()
first_event.set() first_event.set()
threading.Thread(target=_listen, daemon=True).start() th = threading.Thread(target=_listen, daemon=True)
th.start()
connected.wait(timeout=3)
return events, errors, first_event return events, errors, first_event
@@ -74,6 +80,8 @@ def test_streaming_video_detections_appear_during_upload(
): ):
# Arrange # Arrange
video_path = _fixture_path("video_test01.mp4") video_path = _fixture_path("video_test01.mp4")
channel_id = str(uuid.uuid4())
events, errors, first_event = _start_sse_listener(http_client, channel_id, auth_headers)
# Act # Act
r = http_client.post( r = http_client.post(
@@ -81,14 +89,13 @@ def test_streaming_video_detections_appear_during_upload(
data=_chunked_reader(video_path), data=_chunked_reader(video_path),
headers={ headers={
**auth_headers, **auth_headers,
"X-Channel-Id": channel_id,
"X-Filename": "video_test01.mp4", "X-Filename": "video_test01.mp4",
"Content-Type": "application/octet-stream", "Content-Type": "application/octet-stream",
}, },
timeout=8, timeout=8,
) )
assert r.status_code == 200 assert r.status_code == 202
media_id = r.json()["mediaId"]
events, errors, first_event = _start_sse_listener(http_client, media_id, auth_headers)
first_event.wait(timeout=_TIMEOUT) first_event.wait(timeout=_TIMEOUT)
# Assert # Assert
@@ -103,6 +110,8 @@ def test_streaming_video_detections_appear_during_upload(
def test_non_faststart_video_still_works(warm_engine, http_client, auth_headers): def test_non_faststart_video_still_works(warm_engine, http_client, auth_headers):
# Arrange # Arrange
video_path = _fixture_path("video_test01.mp4") video_path = _fixture_path("video_test01.mp4")
channel_id = str(uuid.uuid4())
events, errors, first_event = _start_sse_listener(http_client, channel_id, auth_headers)
# Act # Act
r = http_client.post( r = http_client.post(
@@ -110,14 +119,13 @@ def test_non_faststart_video_still_works(warm_engine, http_client, auth_headers)
data=_chunked_reader(video_path), data=_chunked_reader(video_path),
headers={ headers={
**auth_headers, **auth_headers,
"X-Channel-Id": channel_id,
"X-Filename": "video_test01_plain.mp4", "X-Filename": "video_test01_plain.mp4",
"Content-Type": "application/octet-stream", "Content-Type": "application/octet-stream",
}, },
timeout=8, timeout=8,
) )
assert r.status_code == 200 assert r.status_code == 202
media_id = r.json()["mediaId"]
events, errors, first_event = _start_sse_listener(http_client, media_id, auth_headers)
first_event.wait(timeout=_TIMEOUT) first_event.wait(timeout=_TIMEOUT)
# Assert # Assert
+8 -18
View File
@@ -28,32 +28,22 @@ def _assert_no_same_label_near_duplicate_centers(detections):
@pytest.mark.slow @pytest.mark.slow
def test_ft_p_04_gsd_based_tiling_ac1(http_client, image_large, warm_engine, auth_headers): def test_ft_p_04_gsd_based_tiling_ac1(image_detect, image_large, warm_engine):
config = json.dumps(_GSD) body, _ = image_detect(
r = http_client.post( image_large, "img.jpg",
"/detect/image", config=json.dumps(_GSD),
files={"file": ("img.jpg", image_large, "image/jpeg")},
data={"config": config},
headers=auth_headers,
timeout=_TILING_TIMEOUT, timeout=_TILING_TIMEOUT,
) )
assert r.status_code == 200
body = r.json()
assert isinstance(body, list) assert isinstance(body, list)
_assert_coords_normalized(body) _assert_coords_normalized(body)
@pytest.mark.slow @pytest.mark.slow
def test_ft_p_16_tile_boundary_deduplication_ac2(http_client, image_large, warm_engine, auth_headers): def test_ft_p_16_tile_boundary_deduplication_ac2(image_detect, image_large, warm_engine):
config = json.dumps({**_GSD, "big_image_tile_overlap_percent": 20}) body, _ = image_detect(
r = http_client.post( image_large, "img.jpg",
"/detect/image", config=json.dumps({**_GSD, "big_image_tile_overlap_percent": 20}),
files={"file": ("img.jpg", image_large, "image/jpeg")},
data={"config": config},
headers=auth_headers,
timeout=_TILING_TIMEOUT, timeout=_TILING_TIMEOUT,
) )
assert r.status_code == 200
body = r.json()
assert isinstance(body, list) assert isinstance(body, list)
_assert_no_same_label_near_duplicate_centers(body) _assert_no_same_label_near_duplicate_centers(body)
+24 -14
View File
@@ -1,6 +1,7 @@
import json import json
import threading import threading
import time import time
import uuid
from pathlib import Path from pathlib import Path
import pytest import pytest
@@ -24,29 +25,22 @@ def video_events(warm_engine, http_client, auth_headers):
if not Path(_VIDEO).is_file(): if not Path(_VIDEO).is_file():
pytest.skip(f"missing fixture {_VIDEO}") pytest.skip(f"missing fixture {_VIDEO}")
r = http_client.post( channel_id = str(uuid.uuid4())
"/detect/video",
data=_chunked_reader(_VIDEO),
headers={
**auth_headers,
"X-Filename": "video_test01.mp4",
"Content-Type": "application/octet-stream",
},
timeout=15,
)
assert r.status_code == 200
media_id = r.json()["mediaId"]
collected: list[tuple[float, dict]] = [] collected: list[tuple[float, dict]] = []
thread_exc: list[BaseException] = [] thread_exc: list[BaseException] = []
done = threading.Event() done = threading.Event()
connected = threading.Event()
def _listen(): def _listen():
try: try:
with http_client.get( with http_client.get(
f"/detect/{media_id}", stream=True, timeout=35, headers=auth_headers f"/detect/events/{channel_id}",
stream=True,
timeout=60,
headers=auth_headers,
) as resp: ) as resp:
resp.raise_for_status() resp.raise_for_status()
connected.set()
sse = sseclient.SSEClient(resp) sse = sseclient.SSEClient(resp)
for event in sse.events(): for event in sse.events():
if not event.data or not str(event.data).strip(): if not event.data or not str(event.data).strip():
@@ -61,10 +55,26 @@ def video_events(warm_engine, http_client, auth_headers):
except BaseException as e: except BaseException as e:
thread_exc.append(e) thread_exc.append(e)
finally: finally:
connected.set()
done.set() done.set()
th = threading.Thread(target=_listen, daemon=True) th = threading.Thread(target=_listen, daemon=True)
th.start() th.start()
connected.wait(timeout=5)
r = http_client.post(
"/detect/video",
data=_chunked_reader(_VIDEO),
headers={
**auth_headers,
"X-Channel-Id": channel_id,
"X-Filename": "video_test01.mp4",
"Content-Type": "application/octet-stream",
},
timeout=15,
)
assert r.status_code == 202
assert done.wait(timeout=30) assert done.wait(timeout=30)
th.join(timeout=5) th.join(timeout=5)
assert not thread_exc, thread_exc assert not thread_exc, thread_exc
+1 -1
View File
@@ -2,7 +2,7 @@ fastapi==0.135.2
uvicorn[standard]==0.42.0 uvicorn[standard]==0.42.0
PyJWT==2.12.1 PyJWT==2.12.1
h11==0.16.0 h11==0.16.0
python-multipart>=1.3.1 python-multipart==0.0.22
Cython==3.2.4 Cython==3.2.4
opencv-python==4.10.0.84 opencv-python==4.10.0.84
numpy==2.3.0 numpy==2.3.0
+3
View File
@@ -22,6 +22,9 @@ try:
extensions.append( extensions.append(
Extension('engines.tensorrt_engine', [f'{SRC}/engines/tensorrt_engine.pyx'], include_dirs=np_inc) Extension('engines.tensorrt_engine', [f'{SRC}/engines/tensorrt_engine.pyx'], include_dirs=np_inc)
) )
extensions.append(
Extension('engines.jetson_tensorrt_engine', [f'{SRC}/engines/jetson_tensorrt_engine.pyx'], include_dirs=np_inc)
)
except ImportError: except ImportError:
pass pass
+29 -8
View File
@@ -1,6 +1,16 @@
import os
import platform import platform
import sys import sys
from loguru import logger
from engines.engine_factory import (
EngineFactory,
OnnxEngineFactory,
CoreMLEngineFactory,
TensorRTEngineFactory,
JetsonTensorRTEngineFactory,
)
def _check_tensor_gpu_index(): def _check_tensor_gpu_index():
try: try:
@@ -35,18 +45,29 @@ def _is_apple_silicon():
return False return False
def _is_jetson():
return (
platform.machine() == "aarch64"
and tensor_gpu_index > -1
and os.path.isfile("/etc/nv_tegra_release")
)
tensor_gpu_index = _check_tensor_gpu_index() tensor_gpu_index = _check_tensor_gpu_index()
def _select_engine_class(): def _create_engine_factory() -> EngineFactory:
if _is_jetson():
logger.info("Engine factory: JetsonTensorRTEngineFactory")
return JetsonTensorRTEngineFactory()
if tensor_gpu_index > -1: if tensor_gpu_index > -1:
from engines.tensorrt_engine import TensorRTEngine # pyright: ignore[reportMissingImports] logger.info("Engine factory: TensorRTEngineFactory")
return TensorRTEngine return TensorRTEngineFactory()
if _is_apple_silicon(): if _is_apple_silicon():
from engines.coreml_engine import CoreMLEngine logger.info("Engine factory: CoreMLEngineFactory")
return CoreMLEngine return CoreMLEngineFactory()
from engines.onnx_engine import OnnxEngine logger.info("Engine factory: OnnxEngineFactory")
return OnnxEngine return OnnxEngineFactory()
EngineClass = _select_engine_class() engine_factory = _create_engine_factory()
-4
View File
@@ -30,10 +30,6 @@ cdef class CoreMLEngine(InferenceEngine):
constants_inf.log(<str>f'CoreML model: {self.img_width}x{self.img_height}') constants_inf.log(<str>f'CoreML model: {self.img_width}x{self.img_height}')
@staticmethod
def get_engine_filename():
return "azaion_coreml.zip"
@staticmethod @staticmethod
def _extract_from_zip(model_bytes): def _extract_from_zip(model_bytes):
tmpdir = tempfile.mkdtemp() tmpdir = tempfile.mkdtemp()
+109
View File
@@ -0,0 +1,109 @@
import os
import tempfile
class EngineFactory:
has_build_step = False
def create(self, model_bytes: bytes):
raise NotImplementedError
def load_engine(self, loader_client, models_dir: str):
filename = self._get_ai_engine_filename()
if filename is None:
return None
try:
res = loader_client.load_big_small_resource(filename, models_dir)
if res.err is None:
return self.create(res.data)
except Exception:
pass
return None
def _get_ai_engine_filename(self) -> str | None:
return None
def get_source_filename(self) -> str | None:
return None
def build_from_source(self, onnx_bytes: bytes, loader_client, models_dir: str):
raise NotImplementedError(f"{type(self).__name__} does not support building from source")
class OnnxEngineFactory(EngineFactory):
def create(self, model_bytes: bytes):
from engines.onnx_engine import OnnxEngine
return OnnxEngine(model_bytes)
def get_source_filename(self) -> str:
import constants_inf
return constants_inf.AI_ONNX_MODEL_FILE
class CoreMLEngineFactory(EngineFactory):
def create(self, model_bytes: bytes):
from engines.coreml_engine import CoreMLEngine
return CoreMLEngine(model_bytes)
def _get_ai_engine_filename(self) -> str:
return "azaion_coreml.zip"
class TensorRTEngineFactory(EngineFactory):
has_build_step = True
def create(self, model_bytes: bytes):
from engines.tensorrt_engine import TensorRTEngine
return TensorRTEngine(model_bytes)
def _get_ai_engine_filename(self) -> str | None:
from engines.tensorrt_engine import TensorRTEngine
return TensorRTEngine.get_engine_filename()
def get_source_filename(self) -> str:
import constants_inf
return constants_inf.AI_ONNX_MODEL_FILE
def build_from_source(self, onnx_bytes: bytes, loader_client, models_dir: str):
from engines.tensorrt_engine import TensorRTEngine
engine_bytes = TensorRTEngine.convert_from_source(onnx_bytes, None)
return engine_bytes, TensorRTEngine.get_engine_filename()
class JetsonTensorRTEngineFactory(TensorRTEngineFactory):
def create(self, model_bytes: bytes):
from engines.jetson_tensorrt_engine import JetsonTensorRTEngine
return JetsonTensorRTEngine(model_bytes)
def _get_ai_engine_filename(self) -> str | None:
from engines.tensorrt_engine import TensorRTEngine
return TensorRTEngine.get_engine_filename("int8")
def build_from_source(self, onnx_bytes: bytes, loader_client, models_dir: str):
from engines.tensorrt_engine import TensorRTEngine
calib_cache_path = self._download_calib_cache(loader_client, models_dir)
try:
engine_bytes = TensorRTEngine.convert_from_source(onnx_bytes, calib_cache_path)
return engine_bytes, TensorRTEngine.get_engine_filename("int8")
finally:
if calib_cache_path is not None:
try:
os.unlink(calib_cache_path)
except Exception:
pass
def _download_calib_cache(self, loader_client, models_dir: str) -> str | None:
import constants_inf
try:
res = loader_client.load_big_small_resource(constants_inf.INT8_CALIB_CACHE_FILE, models_dir)
if res.err is not None:
constants_inf.log(f"INT8 calibration cache not available: {res.err}")
return None
fd, path = tempfile.mkstemp(suffix=".cache")
with os.fdopen(fd, "wb") as f:
f.write(res.data)
constants_inf.log("INT8 calibration cache downloaded")
return path
except Exception as e:
constants_inf.log(f"INT8 calibration cache download failed: {str(e)}")
return None
+5
View File
@@ -0,0 +1,5 @@
from engines.tensorrt_engine cimport TensorRTEngine
cdef class JetsonTensorRTEngine(TensorRTEngine):
pass
+5
View File
@@ -0,0 +1,5 @@
from engines.tensorrt_engine cimport TensorRTEngine
cdef class JetsonTensorRTEngine(TensorRTEngine):
pass
+1 -1
View File
@@ -23,7 +23,7 @@ cdef class OnnxEngine(InferenceEngine):
self.model_inputs = self.session.get_inputs() self.model_inputs = self.session.get_inputs()
self.input_name = self.model_inputs[0].name self.input_name = self.model_inputs[0].name
self.input_shape = self.model_inputs[0].shape self.input_shape = self.model_inputs[0].shape
if self.input_shape[0] not in (-1, None, "N"): if isinstance(self.input_shape[0], int) and self.input_shape[0] > 0:
self.max_batch_size = self.input_shape[0] self.max_batch_size = self.input_shape[0]
constants_inf.log(f'AI detection model input: {self.model_inputs} {self.input_shape}') constants_inf.log(f'AI detection model input: {self.model_inputs} {self.input_shape}')
model_meta = self.session.get_modelmeta() model_meta = self.session.get_modelmeta()
-5
View File
@@ -113,11 +113,6 @@ cdef class TensorRTEngine(InferenceEngine):
except Exception: except Exception:
return None return None
@staticmethod
def get_source_filename():
import constants_inf
return constants_inf.AI_ONNX_MODEL_FILE
@staticmethod @staticmethod
def convert_from_source(bytes onnx_model, str calib_cache_path=None): def convert_from_source(bytes onnx_model, str calib_cache_path=None):
gpu_mem = TensorRTEngine.get_gpu_memory_bytes(0) gpu_mem = TensorRTEngine.get_gpu_memory_bytes(0)
+18 -51
View File
@@ -1,6 +1,4 @@
import io import io
import os
import tempfile
import threading import threading
import av import av
@@ -14,7 +12,7 @@ from ai_config cimport AIRecognitionConfig
from engines.inference_engine cimport InferenceEngine from engines.inference_engine cimport InferenceEngine
from loader_http_client cimport LoaderHttpClient from loader_http_client cimport LoaderHttpClient
from threading import Thread from threading import Thread
from engines import EngineClass from engines import engine_factory
def ai_config_from_dict(dict data): def ai_config_from_dict(dict data):
@@ -76,29 +74,23 @@ cdef class Inference:
raise Exception(res.err) raise Exception(res.err)
return <bytes>res.data return <bytes>res.data
cdef convert_and_upload_model(self, bytes source_bytes, str engine_filename, str calib_cache_path): cdef convert_and_upload_model(self, bytes source_bytes, str models_dir):
try: try:
self.ai_availability_status.set_status(AIAvailabilityEnum.CONVERTING) self.ai_availability_status.set_status(AIAvailabilityEnum.CONVERTING)
models_dir = constants_inf.MODELS_FOLDER engine_bytes, engine_filename = engine_factory.build_from_source(source_bytes, self.loader_client, models_dir)
model_bytes = EngineClass.convert_from_source(source_bytes, calib_cache_path)
self.ai_availability_status.set_status(AIAvailabilityEnum.UPLOADING) self.ai_availability_status.set_status(AIAvailabilityEnum.UPLOADING)
res = self.loader_client.upload_big_small_resource(model_bytes, engine_filename, models_dir) res = self.loader_client.upload_big_small_resource(engine_bytes, engine_filename, models_dir)
if res.err is not None: if res.err is not None:
self.ai_availability_status.set_status(AIAvailabilityEnum.WARNING, <str>f"Failed to upload converted model: {res.err}") self.ai_availability_status.set_status(AIAvailabilityEnum.WARNING, <str>f"Failed to upload converted model: {res.err}")
self._converted_model_bytes = model_bytes self._converted_model_bytes = engine_bytes
self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED) self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED)
except Exception as e: except Exception as e:
self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str> str(e)) self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str> str(e))
self._converted_model_bytes = <bytes>None self._converted_model_bytes = <bytes>None
finally: finally:
self.is_building_engine = <bint>False self.is_building_engine = <bint>False
if calib_cache_path is not None:
try:
os.unlink(calib_cache_path)
except Exception:
pass
cdef init_ai(self): cdef init_ai(self):
constants_inf.log(<str> 'init AI...') constants_inf.log(<str> 'init AI...')
@@ -110,7 +102,7 @@ cdef class Inference:
if self._converted_model_bytes is not None: if self._converted_model_bytes is not None:
try: try:
self.engine = EngineClass(self._converted_model_bytes) self.engine = engine_factory.create(self._converted_model_bytes)
self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED) self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED)
except Exception as e: except Exception as e:
self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str> str(e)) self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str> str(e))
@@ -119,58 +111,33 @@ cdef class Inference:
return return
models_dir = constants_inf.MODELS_FOLDER models_dir = constants_inf.MODELS_FOLDER
engine_filename_fp16 = EngineClass.get_engine_filename()
if engine_filename_fp16 is not None:
engine_filename_int8 = EngineClass.get_engine_filename(<str>"int8")
for candidate in [engine_filename_int8, engine_filename_fp16]:
try:
self.ai_availability_status.set_status(AIAvailabilityEnum.DOWNLOADING) self.ai_availability_status.set_status(AIAvailabilityEnum.DOWNLOADING)
res = self.loader_client.load_big_small_resource(candidate, models_dir) engine = engine_factory.load_engine(self.loader_client, models_dir)
if res.err is not None: if engine is not None:
raise Exception(res.err) self.engine = engine
self.engine = EngineClass(res.data)
self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED) self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED)
return return
except Exception:
pass
source_filename = EngineClass.get_source_filename() source_filename = engine_factory.get_source_filename()
if source_filename is None: if source_filename is None:
self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str>"Pre-built engine not found and no source available") self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str>"No engine available and no source to build from")
return return
self.ai_availability_status.set_status(AIAvailabilityEnum.WARNING, <str>"Cached engine not found, converting from source")
source_bytes = self.download_model(source_filename)
calib_cache_path = self._try_download_calib_cache(models_dir)
target_engine_filename = EngineClass.get_engine_filename(<str>"int8") if calib_cache_path is not None else engine_filename_fp16
self.is_building_engine = <bint>True
thread = Thread(target=self.convert_and_upload_model, args=(source_bytes, target_engine_filename, calib_cache_path)) source_bytes = self.download_model(source_filename)
if engine_factory.has_build_step:
self.ai_availability_status.set_status(AIAvailabilityEnum.WARNING, <str>"Cached engine not found, converting from source")
self.is_building_engine = <bint>True
thread = Thread(target=self.convert_and_upload_model, args=(source_bytes, models_dir))
thread.daemon = True thread.daemon = True
thread.start() thread.start()
return
else: else:
self.engine = EngineClass(<bytes>self.download_model(constants_inf.AI_ONNX_MODEL_FILE)) self.engine = engine_factory.create(source_bytes)
self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED) self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED)
self.is_building_engine = <bint>False
except Exception as e: except Exception as e:
self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str>str(e)) self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str>str(e))
self.is_building_engine = <bint>False self.is_building_engine = <bint>False
cdef str _try_download_calib_cache(self, str models_dir):
try:
res = self.loader_client.load_big_small_resource(constants_inf.INT8_CALIB_CACHE_FILE, models_dir)
if res.err is not None:
constants_inf.log(<str>f"INT8 calibration cache not available: {res.err}")
return <str>None
fd, path = tempfile.mkstemp(suffix='.cache')
with os.fdopen(fd, 'wb') as f:
f.write(res.data)
constants_inf.log(<str>'INT8 calibration cache downloaded')
return <str>path
except Exception as e:
constants_inf.log(<str>f"INT8 calibration cache download failed: {str(e)}")
return <str>None
cpdef run_detect_image(self, bytes image_bytes, AIRecognitionConfig ai_config, str media_name, cpdef run_detect_image(self, bytes image_bytes, AIRecognitionConfig ai_config, str media_name,
object annotation_callback, object status_callback=None): object annotation_callback, object status_callback=None):
cdef list all_frame_data = [] cdef list all_frame_data = []
+153 -120
View File
@@ -5,6 +5,7 @@ import json
import os import os
import tempfile import tempfile
import time import time
from collections import deque
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from pathlib import Path from pathlib import Path
from typing import Annotated, Optional from typing import Annotated, Optional
@@ -15,7 +16,7 @@ import jwt as pyjwt
import numpy as np import numpy as np
import requests as http_requests import requests as http_requests
from fastapi import Body, Depends, FastAPI, File, Form, HTTPException, Request, UploadFile from fastapi import Body, Depends, FastAPI, File, Form, HTTPException, Request, UploadFile
from fastapi.responses import StreamingResponse from fastapi.responses import Response, StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel from pydantic import BaseModel
@@ -37,11 +38,14 @@ _MEDIA_STATUS_ERROR = 6
_VIDEO_EXTENSIONS = frozenset({".mp4", ".mov", ".webm", ".mkv", ".avi", ".m4v"}) _VIDEO_EXTENSIONS = frozenset({".mp4", ".mov", ".webm", ".mkv", ".avi", ".m4v"})
_IMAGE_EXTENSIONS = frozenset({".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif", ".tiff"}) _IMAGE_EXTENSIONS = frozenset({".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif", ".tiff"})
_BUFFER_TTL_MS = 10_000
_BUFFER_MAX = 200
loader_client = LoaderHttpClient(LOADER_URL) loader_client = LoaderHttpClient(LOADER_URL)
annotations_client = LoaderHttpClient(ANNOTATIONS_URL) annotations_client = LoaderHttpClient(ANNOTATIONS_URL)
inference = None inference = None
_job_queues: dict[str, list[asyncio.Queue]] = {} _job_queues: dict[str, list[asyncio.Queue]] = {}
_job_buffers: dict[str, list[str]] = {} _channel_buffers: dict[str, deque] = {}
_active_detections: dict[str, asyncio.Task] = {} _active_detections: dict[str, asyncio.Task] = {}
_bearer = HTTPBearer(auto_error=False) _bearer = HTTPBearer(auto_error=False)
@@ -323,19 +327,48 @@ def detection_to_dto(det) -> DetectionDto:
) )
def _enqueue(media_id: str, event: DetectionEvent): def _post_annotation_to_service(token_mgr: TokenManager, media_id: str,
data = event.model_dump_json() annotation, dtos: list[DetectionDto]):
_job_buffers.setdefault(media_id, []).append(data)
for q in _job_queues.get(media_id, []):
try: try:
q.put_nowait(data) token = token_mgr.get_valid_token()
except asyncio.QueueFull: image_b64 = base64.b64encode(annotation.image).decode() if annotation.image else None
payload = {
"mediaId": media_id,
"source": 0,
"videoTime": f"00:00:{annotation.time // 1000:02d}" if annotation.time else "00:00:00",
"detections": [d.model_dump() for d in dtos],
}
if image_b64:
payload["image"] = image_b64
http_requests.post(
f"{ANNOTATIONS_URL}/annotations",
json=payload,
headers={"Authorization": f"Bearer {token}"},
timeout=30,
)
except Exception:
pass pass
def _schedule_buffer_cleanup(media_id: str, delay: float = 300.0): def _cleanup_channel(channel_id: str):
loop = asyncio.get_event_loop() _channel_buffers.pop(channel_id, None)
loop.call_later(delay, lambda: _job_buffers.pop(media_id, None))
def _enqueue(channel_id: str, event: DetectionEvent):
now_ms = int(time.time() * 1000)
data = event.model_dump_json()
buf = _channel_buffers.setdefault(channel_id, deque(maxlen=_BUFFER_MAX))
buf.append((now_ms, data))
cutoff = now_ms - _BUFFER_TTL_MS
while buf and buf[0][0] < cutoff:
buf.popleft()
for q in _job_queues.get(channel_id, []):
try:
q.put_nowait((now_ms, data))
except asyncio.QueueFull:
pass
@app.get("/health") @app.get("/health")
@@ -361,6 +394,36 @@ def health() -> HealthResponse:
) )
@app.get("/detect/events/{channel_id}", dependencies=[Depends(require_auth)])
async def detect_events(channel_id: str, request: Request, after_ts: Optional[int] = None):
queue: asyncio.Queue = asyncio.Queue(maxsize=200)
_job_queues.setdefault(channel_id, []).append(queue)
async def event_generator():
try:
if after_ts is not None:
for ts_ms, data in list(_channel_buffers.get(channel_id, [])):
if ts_ms > after_ts:
yield f"id: {ts_ms}\ndata: {data}\n\n"
while True:
ts_ms, data = await queue.get()
yield f"id: {ts_ms}\ndata: {data}\n\n"
except asyncio.CancelledError:
pass
finally:
queues = _job_queues.get(channel_id, [])
if queue in queues:
queues.remove(queue)
if not queues:
_job_queues.pop(channel_id, None)
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
@app.post("/detect/image") @app.post("/detect/image")
async def detect_image( async def detect_image(
request: Request, request: Request,
@@ -384,6 +447,10 @@ async def detect_image(
if cv2.imdecode(arr, cv2.IMREAD_COLOR) is None: if cv2.imdecode(arr, cv2.IMREAD_COLOR) is None:
raise HTTPException(status_code=400, detail="Invalid image data") raise HTTPException(status_code=400, detail="Invalid image data")
channel_id = request.headers.get("x-channel-id", "")
if not channel_id:
raise HTTPException(status_code=400, detail="X-Channel-Id header required")
config_dict = {} config_dict = {}
if config: if config:
config_dict = json.loads(config) config_dict = json.loads(config)
@@ -395,7 +462,6 @@ async def detect_image(
images_dir = os.environ.get( images_dir = os.environ.get(
"IMAGES_DIR", os.path.join(os.getcwd(), "data", "images") "IMAGES_DIR", os.path.join(os.getcwd(), "data", "images")
) )
storage_path = None
content_hash = None content_hash = None
if token_mgr and user_id: if token_mgr and user_id:
content_hash = compute_media_content_hash(image_bytes) content_hash = compute_media_content_hash(image_bytes)
@@ -417,45 +483,65 @@ async def detect_image(
_put_media_status(content_hash, _MEDIA_STATUS_AI_PROCESSING, bearer) _put_media_status(content_hash, _MEDIA_STATUS_AI_PROCESSING, bearer)
media_name = Path(orig_name).stem.replace(" ", "") media_name = Path(orig_name).stem.replace(" ", "")
media_id = content_hash or channel_id
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
inf = get_inference() inf = get_inference()
results = []
def on_annotation(annotation, percent):
results.extend(annotation.detections)
async def run_detection():
ai_cfg = ai_config_from_dict(config_dict) ai_cfg = ai_config_from_dict(config_dict)
def run_detect(): def on_annotation(annotation, percent):
dtos = [detection_to_dto(d) for d in annotation.detections]
event = DetectionEvent(
annotations=dtos,
mediaId=media_id,
mediaStatus="AIProcessing",
mediaPercent=percent,
)
loop.call_soon_threadsafe(_enqueue, channel_id, event)
if token_mgr and content_hash and dtos:
_post_annotation_to_service(token_mgr, content_hash, annotation, dtos)
def run_sync():
inf.run_detect_image(image_bytes, ai_cfg, media_name, on_annotation) inf.run_detect_image(image_bytes, ai_cfg, media_name, on_annotation)
try: try:
await loop.run_in_executor(executor, run_detect) await loop.run_in_executor(executor, run_sync)
if token_mgr and user_id and content_hash: _enqueue(channel_id, DetectionEvent(
annotations=[], mediaId=media_id,
mediaStatus="AIProcessed", mediaPercent=100,
))
if token_mgr and content_hash:
_put_media_status( _put_media_status(
content_hash, _MEDIA_STATUS_AI_PROCESSED, token_mgr.get_valid_token() content_hash, _MEDIA_STATUS_AI_PROCESSED, token_mgr.get_valid_token()
) )
return [detection_to_dto(d) for d in results]
except RuntimeError as e: except RuntimeError as e:
if token_mgr and user_id and content_hash: if token_mgr and content_hash:
_put_media_status( _put_media_status(
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token() content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
) )
_enqueue(channel_id, DetectionEvent(
annotations=[], mediaId=media_id,
mediaStatus="Error", mediaPercent=0,
))
if "not available" in str(e): if "not available" in str(e):
raise HTTPException(status_code=503, detail=str(e)) return
raise HTTPException(status_code=422, detail=str(e))
except ValueError as e:
if token_mgr and user_id and content_hash:
_put_media_status(
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
)
raise HTTPException(status_code=400, detail=str(e))
except Exception:
if token_mgr and user_id and content_hash:
_put_media_status(
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
)
raise raise
except Exception:
if token_mgr and content_hash:
_put_media_status(
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
)
_enqueue(channel_id, DetectionEvent(
annotations=[], mediaId=media_id,
mediaStatus="Error", mediaPercent=0,
))
raise
finally:
loop.call_later(10.0, _cleanup_channel, channel_id)
asyncio.create_task(run_detection())
return Response(status_code=202)
@app.post("/detect/video") @app.post("/detect/video")
@@ -467,6 +553,10 @@ async def detect_video_upload(
from inference import ai_config_from_dict from inference import ai_config_from_dict
from streaming_buffer import StreamingBuffer from streaming_buffer import StreamingBuffer
channel_id = request.headers.get("x-channel-id", "")
if not channel_id:
raise HTTPException(status_code=400, detail="X-Channel-Id header required")
filename = request.headers.get("x-filename", "upload.mp4") filename = request.headers.get("x-filename", "upload.mp4")
config_json = request.headers.get("x-config", "") config_json = request.headers.get("x-config", "")
ext = _normalize_upload_ext(filename) ext = _normalize_upload_ext(filename)
@@ -491,32 +581,23 @@ async def detect_video_upload(
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
inf = get_inference() inf = get_inference()
placeholder_id = f"tmp_{os.path.basename(buffer.path)}" current_media_id = [channel_id]
current_id = [placeholder_id] # mutable — updated to content_hash after upload
def on_annotation(annotation, percent): def on_annotation(annotation, percent):
dtos = [detection_to_dto(d) for d in annotation.detections] dtos = [detection_to_dto(d) for d in annotation.detections]
mid = current_id[0] mid = current_media_id[0]
event = DetectionEvent( event = DetectionEvent(
annotations=dtos, annotations=dtos,
mediaId=mid, mediaId=mid,
mediaStatus="AIProcessing", mediaStatus="AIProcessing",
mediaPercent=percent, mediaPercent=percent,
) )
loop.call_soon_threadsafe(_enqueue, mid, event) loop.call_soon_threadsafe(_enqueue, channel_id, event)
if token_mgr and mid != channel_id and dtos:
def on_status(media_name_cb, count): _post_annotation_to_service(token_mgr, mid, annotation, dtos)
mid = current_id[0]
event = DetectionEvent(
annotations=[],
mediaId=mid,
mediaStatus="AIProcessed",
mediaPercent=100,
)
loop.call_soon_threadsafe(_enqueue, mid, event)
def run_inference(): def run_inference():
inf.run_detect_video_stream(buffer, ai_cfg, media_name, on_annotation, on_status) inf.run_detect_video_stream(buffer, ai_cfg, media_name, on_annotation, lambda *_: None)
inference_future = loop.run_in_executor(executor, run_inference) inference_future = loop.run_in_executor(executor, run_inference)
@@ -533,14 +614,14 @@ async def detect_video_upload(
if not ext.startswith("."): if not ext.startswith("."):
ext = "." + ext ext = "." + ext
storage_path = os.path.abspath(os.path.join(videos_dir, f"{content_hash}{ext}")) storage_path = os.path.abspath(os.path.join(videos_dir, f"{content_hash}{ext}"))
current_media_id[0] = content_hash
# Re-key buffered events from placeholder_id to content_hash so clients _enqueue(channel_id, DetectionEvent(
# can subscribe to GET /detect/{content_hash} after POST returns. annotations=[],
if placeholder_id in _job_buffers: mediaId=content_hash,
_job_buffers[content_hash] = _job_buffers.pop(placeholder_id) mediaStatus="Started",
if placeholder_id in _job_queues: mediaPercent=0,
_job_queues[content_hash] = _job_queues.pop(placeholder_id) ))
current_id[0] = content_hash # future on_annotation/on_status callbacks use content_hash
if token_mgr and user_id: if token_mgr and user_id:
os.rename(buffer.path, storage_path) os.rename(buffer.path, storage_path)
@@ -564,27 +645,24 @@ async def detect_video_upload(
content_hash, _MEDIA_STATUS_AI_PROCESSED, content_hash, _MEDIA_STATUS_AI_PROCESSED,
token_mgr.get_valid_token(), token_mgr.get_valid_token(),
) )
done_event = DetectionEvent( _enqueue(channel_id, DetectionEvent(
annotations=[], annotations=[],
mediaId=content_hash, mediaId=content_hash,
mediaStatus="AIProcessed", mediaStatus="AIProcessed",
mediaPercent=100, mediaPercent=100,
) ))
_enqueue(content_hash, done_event)
except Exception: except Exception:
if token_mgr and user_id: if token_mgr and user_id:
_put_media_status( _put_media_status(
content_hash, _MEDIA_STATUS_ERROR, content_hash, _MEDIA_STATUS_ERROR,
token_mgr.get_valid_token(), token_mgr.get_valid_token(),
) )
err_event = DetectionEvent( _enqueue(channel_id, DetectionEvent(
annotations=[], mediaId=content_hash, annotations=[], mediaId=content_hash,
mediaStatus="Error", mediaPercent=0, mediaStatus="Error", mediaPercent=0,
) ))
_enqueue(content_hash, err_event)
finally: finally:
_active_detections.pop(content_hash, None) loop.call_later(10.0, _cleanup_channel, channel_id)
_schedule_buffer_cleanup(content_hash)
buffer.close() buffer.close()
if not (token_mgr and user_id) and os.path.isfile(buffer.path): if not (token_mgr and user_id) and os.path.isfile(buffer.path):
try: try:
@@ -592,31 +670,8 @@ async def detect_video_upload(
except OSError: except OSError:
pass pass
_active_detections[content_hash] = asyncio.create_task(_wait_inference()) asyncio.create_task(_wait_inference())
return {"status": "started", "mediaId": content_hash} return Response(status_code=202)
def _post_annotation_to_service(token_mgr: TokenManager, media_id: str,
annotation, dtos: list[DetectionDto]):
try:
token = token_mgr.get_valid_token()
image_b64 = base64.b64encode(annotation.image).decode() if annotation.image else None
payload = {
"mediaId": media_id,
"source": 0,
"videoTime": f"00:00:{annotation.time // 1000:02d}" if annotation.time else "00:00:00",
"detections": [d.model_dump() for d in dtos],
}
if image_b64:
payload["image"] = image_b64
http_requests.post(
f"{ANNOTATIONS_URL}/annotations",
json=payload,
headers={"Authorization": f"Bearer {token}"},
timeout=30,
)
except Exception:
pass
@app.post("/detect/{media_id}") @app.post("/detect/{media_id}")
@@ -630,6 +685,10 @@ async def detect_media(
if existing is not None and not existing.done(): if existing is not None and not existing.done():
raise HTTPException(status_code=409, detail="Detection already in progress for this media") raise HTTPException(status_code=409, detail="Detection already in progress for this media")
channel_id = request.headers.get("x-channel-id", "")
if not channel_id:
raise HTTPException(status_code=400, detail="X-Channel-Id header required")
refresh_token = request.headers.get("x-refresh-token", "") refresh_token = request.headers.get("x-refresh-token", "")
access_token = request.headers.get("authorization", "").removeprefix("Bearer ").strip() access_token = request.headers.get("authorization", "").removeprefix("Bearer ").strip()
token_mgr = TokenManager(access_token, refresh_token) if access_token else None token_mgr = TokenManager(access_token, refresh_token) if access_token else None
@@ -668,7 +727,7 @@ async def detect_media(
mediaStatus="AIProcessing", mediaStatus="AIProcessing",
mediaPercent=percent, mediaPercent=percent,
) )
loop.call_soon_threadsafe(_enqueue, media_id, event) loop.call_soon_threadsafe(_enqueue, channel_id, event)
if token_mgr and dtos: if token_mgr and dtos:
_post_annotation_to_service(token_mgr, media_id, annotation, dtos) _post_annotation_to_service(token_mgr, media_id, annotation, dtos)
@@ -679,7 +738,7 @@ async def detect_media(
mediaStatus="AIProcessed", mediaStatus="AIProcessed",
mediaPercent=100, mediaPercent=100,
) )
loop.call_soon_threadsafe(_enqueue, media_id, event) loop.call_soon_threadsafe(_enqueue, channel_id, event)
if token_mgr: if token_mgr:
_put_media_status( _put_media_status(
media_id, media_id,
@@ -718,36 +777,10 @@ async def detect_media(
mediaStatus="Error", mediaStatus="Error",
mediaPercent=0, mediaPercent=0,
) )
_enqueue(media_id, error_event) _enqueue(channel_id, error_event)
finally: finally:
_active_detections.pop(media_id, None) _active_detections.pop(media_id, None)
_schedule_buffer_cleanup(media_id) loop.call_later(10.0, _cleanup_channel, channel_id)
_active_detections[media_id] = asyncio.create_task(run_detection()) _active_detections[media_id] = asyncio.create_task(run_detection())
return {"status": "started", "mediaId": media_id} return Response(status_code=202)
@app.get("/detect/{media_id}", dependencies=[Depends(require_auth)])
async def detect_events(media_id: str):
queue: asyncio.Queue = asyncio.Queue(maxsize=200)
_job_queues.setdefault(media_id, []).append(queue)
async def event_generator():
try:
for data in list(_job_buffers.get(media_id, [])):
yield f"data: {data}\n\n"
while True:
data = await queue.get()
yield f"data: {data}\n\n"
except asyncio.CancelledError:
pass
finally:
queues = _job_queues.get(media_id, [])
if queue in queues:
queues.remove(queue)
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
+4 -1
View File
@@ -419,7 +419,10 @@ class TestDetectVideoEndpoint:
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
client = TestClient(main.app) client = TestClient(main.app)
token = _access_jwt() token = _access_jwt()
with patch.object(main, "get_inference", return_value=_CaptureInf()): with (
patch.object(main, "JWT_SECRET", _TEST_JWT_SECRET),
patch.object(main, "get_inference", return_value=_CaptureInf()),
):
# Act # Act
r = client.post( r = client.post(
"/detect/video", "/detect/video",
+146
View File
@@ -96,3 +96,149 @@ def test_convert_from_source_uses_fp16_when_no_cache():
mock_config.set_flag.assert_any_call("FP16") mock_config.set_flag.assert_any_call("FP16")
int8_calls = [c for c in mock_config.set_flag.call_args_list if c == call("INT8")] int8_calls = [c for c in mock_config.set_flag.call_args_list if c == call("INT8")]
assert len(int8_calls) == 0 assert len(int8_calls) == 0
@requires_tensorrt
def test_trt_factory_build_from_source_uses_fp16():
# Arrange
from engines.engine_factory import TensorRTEngineFactory
from engines.tensorrt_engine import TensorRTEngine
import engines.tensorrt_engine as trt_mod
mock_trt, mock_builder, mock_config = _make_mock_trt()
factory = TensorRTEngineFactory()
with patch.object(trt_mod, "trt", mock_trt), \
patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3):
# Act
engine_bytes, filename = factory.build_from_source(b"onnx", MagicMock(), "models")
# Assert
assert engine_bytes == b"engine_bytes"
assert filename is not None
int8_calls = [c for c in mock_config.set_flag.call_args_list if c == call("INT8")]
assert len(int8_calls) == 0
@requires_tensorrt
def test_jetson_factory_build_from_source_uses_int8_when_cache_available():
# Arrange
from engines.engine_factory import JetsonTensorRTEngineFactory
from engines.tensorrt_engine import TensorRTEngine
import engines.tensorrt_engine as trt_mod
mock_trt, mock_builder, mock_config = _make_mock_trt()
factory = JetsonTensorRTEngineFactory()
loader = MagicMock()
result = MagicMock()
result.err = None
result.data = b"calib_data"
loader.load_big_small_resource.return_value = result
with patch.object(trt_mod, "trt", mock_trt), \
patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3):
# Act
engine_bytes, filename = factory.build_from_source(b"onnx", loader, "models")
# Assert
assert engine_bytes == b"engine_bytes"
assert "int8" in filename
mock_config.set_flag.assert_any_call("INT8")
@requires_tensorrt
def test_jetson_factory_build_from_source_falls_back_to_fp16_when_no_cache():
# Arrange
from engines.engine_factory import JetsonTensorRTEngineFactory
from engines.tensorrt_engine import TensorRTEngine
import engines.tensorrt_engine as trt_mod
mock_trt, mock_builder, mock_config = _make_mock_trt()
factory = JetsonTensorRTEngineFactory()
loader = MagicMock()
result = MagicMock()
result.err = "not found"
loader.load_big_small_resource.return_value = result
with patch.object(trt_mod, "trt", mock_trt), \
patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3):
# Act
engine_bytes, filename = factory.build_from_source(b"onnx", loader, "models")
# Assert
assert engine_bytes == b"engine_bytes"
int8_calls = [c for c in mock_config.set_flag.call_args_list if c == call("INT8")]
assert len(int8_calls) == 0
@requires_tensorrt
def test_jetson_factory_cleans_up_cache_tempfile_after_build():
# Arrange
from engines.engine_factory import JetsonTensorRTEngineFactory
from engines.tensorrt_engine import TensorRTEngine
import engines.tensorrt_engine as trt_mod
mock_trt, _, _ = _make_mock_trt()
factory = JetsonTensorRTEngineFactory()
loader = MagicMock()
result = MagicMock()
result.err = None
result.data = b"calib_data"
loader.load_big_small_resource.return_value = result
written_paths = []
original_download = factory._download_calib_cache
def tracking_download(lc, md):
path = original_download(lc, md)
if path:
written_paths.append(path)
return path
with patch.object(trt_mod, "trt", mock_trt), \
patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3), \
patch.object(factory, "_download_calib_cache", side_effect=tracking_download):
factory.build_from_source(b"onnx", loader, "models")
# Assert: temp file was deleted after build
for p in written_paths:
assert not os.path.exists(p)
def test_is_jetson_false_on_non_aarch64():
# Arrange
import engines as eng
with patch("engines.platform") as mock_platform, \
patch("engines.tensor_gpu_index", 0), \
patch("engines.os.path.isfile", return_value=True):
mock_platform.machine.return_value = "x86_64"
# Assert
assert eng._is_jetson() is False
def test_is_jetson_false_when_no_gpu():
# Arrange
import engines as eng
with patch("engines.platform") as mock_platform, \
patch("engines.tensor_gpu_index", -1), \
patch("engines.os.path.isfile", return_value=True):
mock_platform.machine.return_value = "aarch64"
# Assert
assert eng._is_jetson() is False
def test_is_jetson_false_when_no_tegra_release():
# Arrange
import engines as eng
with patch("engines.platform") as mock_platform, \
patch("engines.tensor_gpu_index", 0), \
patch("engines.os.path.isfile", return_value=False):
mock_platform.machine.return_value = "aarch64"
# Assert
assert eng._is_jetson() is False