From 8baa96978b2dba9258e9991bdd2b22e25a70dd14 Mon Sep 17 00:00:00 2001 From: Oleksandr Bezdieniezhnykh Date: Fri, 3 Apr 2026 02:42:05 +0300 Subject: [PATCH] [AZ-180] Refactor detection event handling and improve SSE support - Updated the detection image endpoint to require a channel ID for event streaming. - Introduced a new endpoint for streaming detection events, allowing clients to receive real-time updates. - Enhanced the internal buffering mechanism for detection events to manage multiple channels. - Refactored the inference module to support the new event handling structure. Made-with: Cursor --- .cursor/rules/coderule.mdc | 4 + .cursor/rules/git-workflow.mdc | 1 + e2e/conftest.py | 128 +++++++++- e2e/tests/test_async_sse.py | 44 ++-- e2e/tests/test_health_engine.py | 91 ++++++- e2e/tests/test_negative.py | 6 +- e2e/tests/test_performance.py | 42 +--- e2e/tests/test_resilience.py | 22 +- e2e/tests/test_resource_limits.py | 24 +- e2e/tests/test_single_image.py | 85 ++----- e2e/tests/test_streaming_video_upload.py | 26 +- e2e/tests/test_tiling.py | 26 +- e2e/tests/test_video.py | 38 +-- requirements.txt | 2 +- setup.py | 3 + src/engines/__init__.py | 37 ++- src/engines/coreml_engine.pyx | 4 - src/engines/engine_factory.py | 109 +++++++++ src/engines/jetson_tensorrt_engine.pxd | 5 + src/engines/jetson_tensorrt_engine.pyx | 5 + src/engines/onnx_engine.pyx | 2 +- src/engines/tensorrt_engine.pyx | 5 - src/inference.pyx | 77 ++---- src/main.py | 295 +++++++++++++---------- tests/test_az178_streaming_video.py | 5 +- tests/test_az180_jetson_int8.py | 146 +++++++++++ 26 files changed, 819 insertions(+), 413 deletions(-) create mode 100644 src/engines/engine_factory.py create mode 100644 src/engines/jetson_tensorrt_engine.pxd create mode 100644 src/engines/jetson_tensorrt_engine.pyx diff --git a/.cursor/rules/coderule.mdc b/.cursor/rules/coderule.mdc index 7fb550a..4044db0 100644 --- a/.cursor/rules/coderule.mdc +++ b/.cursor/rules/coderule.mdc @@ -4,6 +4,10 @@ alwaysApply: true --- # Coding preferences - Always prefer simple solution +- Follow the Single Responsibility Principle — a class or method should have one reason to change: + - If a method is hard to name precisely from the caller's perspective, its responsibility is misplaced. Vague names like "candidate", "data", or "item" are a signal — fix the design, not just the name. + - Logic specific to a platform, variant, or environment belongs in the class that owns that variant, not in the general coordinator. Passing a dependency through is preferable to leaking variant-specific concepts into shared code. + - Only use static methods for pure, self-contained computations (constants, simple math, stateless lookups). If a static method involves resource access, side effects, OS interaction, or logic that varies across subclasses or environments — use an instance method or factory class instead. Before implementing a non-trivial static method, ask the user. - Generate concise code - Do not put comments in the code, except in tests: every test must use the Arrange / Act / Assert pattern with language-appropriate comment syntax (`# Arrange` for Python, `// Arrange` for C#/Rust/JS/TS). Omit any section that is not needed (e.g. if there is no setup, skip Arrange; if act and assert are the same line, keep only Assert) - Do not put logs unless it is an exception, or was asked specifically diff --git a/.cursor/rules/git-workflow.mdc b/.cursor/rules/git-workflow.mdc index 53b30b1..b400f89 100644 --- a/.cursor/rules/git-workflow.mdc +++ b/.cursor/rules/git-workflow.mdc @@ -6,3 +6,4 @@ alwaysApply: true - Work on the `dev` branch - Commit message format: `[TRACKER-ID-1] [TRACKER-ID-2] Summary of changes` +- Commit message total length must not exceed 30 characters diff --git a/e2e/conftest.py b/e2e/conftest.py index 77fc6b1..df9f772 100644 --- a/e2e/conftest.py +++ b/e2e/conftest.py @@ -1,6 +1,9 @@ +import json import os import random +import threading import time +import uuid from contextlib import contextmanager from pathlib import Path @@ -75,12 +78,83 @@ def auth_headers(jwt_token): return {"Authorization": f"Bearer {jwt_token}"} if jwt_token else {} +@pytest.fixture +def channel_id(): + return str(uuid.uuid4()) + + +@pytest.fixture(scope="session") +def image_detect(http_client, auth_headers): + def _detect(image_bytes, filename="img.jpg", config=None, timeout=30): + cid = str(uuid.uuid4()) + headers = {**auth_headers, "X-Channel-Id": cid} + detections = [] + errors = [] + done = threading.Event() + connected = threading.Event() + + def _listen(): + try: + with http_client.get( + f"/detect/events/{cid}", + stream=True, + timeout=timeout + 2, + headers=auth_headers, + ) as resp: + resp.raise_for_status() + connected.set() + for ev in sseclient.SSEClient(resp).events(): + if not ev.data or not str(ev.data).strip(): + continue + data = json.loads(ev.data) + if data.get("mediaStatus") == "AIProcessing": + detections.extend(data.get("annotations", [])) + if data.get("mediaStatus") in ("AIProcessed", "Error"): + break + except BaseException as e: + errors.append(e) + finally: + connected.set() + done.set() + + th = threading.Thread(target=_listen, daemon=True) + th.start() + connected.wait(timeout=5) + + data_form = {} + if config: + data_form["config"] = config + + t0 = time.perf_counter() + r = http_client.post( + "/detect/image", + files={"file": (filename, image_bytes, "image/jpeg")}, + data=data_form, + headers=headers, + timeout=timeout, + ) + done.wait(timeout=timeout) + elapsed_ms = (time.perf_counter() - t0) * 1000.0 + + assert r.status_code == 202, f"Expected 202, got {r.status_code}: {r.text}" + assert not errors, f"SSE errors: {errors}" + + th.join(timeout=1) + return detections, elapsed_ms + + return _detect + + @pytest.fixture def sse_client_factory(http_client, auth_headers): @contextmanager - def _open(media_id: str): - with http_client.get(f"/detect/{media_id}", stream=True, - timeout=600, headers=auth_headers) as resp: + def _open(channel_id: str): + with http_client.get( + f"/detect/events/{channel_id}", + stream=True, + timeout=600, + headers=auth_headers, + ) as resp: resp.raise_for_status() yield sseclient.SSEClient(resp) @@ -201,19 +275,52 @@ def corrupt_image(): return random.randbytes(1024) - @pytest.fixture(scope="module") def warm_engine(http_client, image_small, auth_headers): deadline = time.time() + 120 - files = {"file": ("warm.jpg", image_small, "image/jpeg")} - consecutive_errors = 0 last_status = None + consecutive_errors = 0 + while time.time() < deadline: + cid = str(uuid.uuid4()) + headers = {**auth_headers, "X-Channel-Id": cid} + done = threading.Event() + + def _listen(cid=cid): + try: + with http_client.get( + f"/detect/events/{cid}", + stream=True, + timeout=35, + headers=auth_headers, + ) as resp: + resp.raise_for_status() + for ev in sseclient.SSEClient(resp).events(): + if not ev.data or not str(ev.data).strip(): + continue + data = json.loads(ev.data) + if data.get("mediaStatus") == "AIProcessed": + break + except Exception: + pass + finally: + done.set() + + th = threading.Thread(target=_listen, daemon=True) + th.start() + time.sleep(0.1) + try: - r = http_client.post("/detect/image", files=files, headers=auth_headers) - if r.status_code == 200: - return + r = http_client.post( + "/detect/image", + files={"file": ("warm.jpg", image_small, "image/jpeg")}, + headers=headers, + ) last_status = r.status_code + if r.status_code == 202: + done.wait(timeout=30) + th.join(timeout=1) + return if r.status_code >= 500: consecutive_errors += 1 if consecutive_errors >= 5: @@ -225,5 +332,8 @@ def warm_engine(http_client, image_small, auth_headers): consecutive_errors = 0 except OSError: consecutive_errors = 0 + + th.join(timeout=1) time.sleep(2) + pytest.fail(f"engine warm-up timed out after 120s (last status: {last_status})") diff --git a/e2e/tests/test_async_sse.py b/e2e/tests/test_async_sse.py index 27704ba..f03d385 100644 --- a/e2e/tests/test_async_sse.py +++ b/e2e/tests/test_async_sse.py @@ -4,6 +4,7 @@ import time import uuid import pytest +import sseclient def _ai_config_video() -> dict: @@ -19,36 +20,43 @@ def test_ft_p08_immediate_async_response( ): media_id = f"async-{uuid.uuid4().hex}" body = _ai_config_image() - headers = {"Authorization": f"Bearer {jwt_token}"} + channel_id = str(uuid.uuid4()) + headers = {"Authorization": f"Bearer {jwt_token}", "X-Channel-Id": channel_id} t0 = time.monotonic() r = http_client.post(f"/detect/{media_id}", json=body, headers=headers) elapsed = time.monotonic() - t0 assert elapsed < 2.0 - assert r.status_code == 200 - assert r.json() == {"status": "started", "mediaId": media_id} + assert r.status_code == 202 @pytest.mark.timeout(10) def test_ft_p09_sse_event_delivery( - warm_engine, http_client, jwt_token, sse_client_factory + warm_engine, http_client, jwt_token ): media_id = f"sse-{uuid.uuid4().hex}" + channel_id = str(uuid.uuid4()) body = _ai_config_video() - headers = {"Authorization": f"Bearer {jwt_token}"} + auth_header = {"Authorization": f"Bearer {jwt_token}"} + post_headers = {**auth_header, "X-Channel-Id": channel_id} collected: list[dict] = [] thread_exc: list[BaseException] = [] first_event = threading.Event() + connected = threading.Event() def _listen(): try: - with sse_client_factory(media_id) as sse: - time.sleep(0.3) - for event in sse.events(): + with http_client.get( + f"/detect/events/{channel_id}", + stream=True, + timeout=600, + headers=auth_header, + ) as resp: + resp.raise_for_status() + connected.set() + for event in sseclient.SSEClient(resp).events(): if not event.data or not str(event.data).strip(): continue data = json.loads(event.data) - if data.get("mediaId") != media_id: - continue collected.append(data) first_event.set() if len(collected) >= 5: @@ -56,13 +64,14 @@ def test_ft_p09_sse_event_delivery( except BaseException as e: thread_exc.append(e) finally: + connected.set() first_event.set() th = threading.Thread(target=_listen, daemon=True) th.start() - time.sleep(0.5) - r = http_client.post(f"/detect/{media_id}", json=body, headers=headers) - assert r.status_code == 200 + connected.wait(timeout=5) + r = http_client.post(f"/detect/{media_id}", json=body, headers=post_headers) + assert r.status_code == 202 first_event.wait(timeout=5) th.join(timeout=5) assert not thread_exc, thread_exc @@ -74,8 +83,9 @@ def test_ft_n04_duplicate_media_id_409( ): media_id = "dup-test" body = _ai_config_image() - headers = {"Authorization": f"Bearer {jwt_token}"} - r1 = http_client.post(f"/detect/{media_id}", json=body, headers=headers) - assert r1.status_code == 200 - r2 = http_client.post(f"/detect/{media_id}", json=body, headers=headers) + headers1 = {"Authorization": f"Bearer {jwt_token}", "X-Channel-Id": str(uuid.uuid4())} + headers2 = {"Authorization": f"Bearer {jwt_token}", "X-Channel-Id": str(uuid.uuid4())} + r1 = http_client.post(f"/detect/{media_id}", json=body, headers=headers1) + assert r1.status_code == 202 + r2 = http_client.post(f"/detect/{media_id}", json=body, headers=headers2) assert r2.status_code == 409 diff --git a/e2e/tests/test_health_engine.py b/e2e/tests/test_health_engine.py index a0474be..e160f52 100644 --- a/e2e/tests/test_health_engine.py +++ b/e2e/tests/test_health_engine.py @@ -1,6 +1,10 @@ +import json +import threading import time +import uuid import pytest +import sseclient _DETECT_TIMEOUT = 60 @@ -39,11 +43,44 @@ class TestHealthEngineStep02LazyInit: f"engine already initialized (aiAvailability={before['aiAvailability']}); " "lazy-init test must run before any test that triggers warm_engine" ) + + cid = str(uuid.uuid4()) + headers = {**auth_headers, "X-Channel-Id": cid} + done = threading.Event() + connected = threading.Event() + + def _listen(): + try: + with http_client.get( + f"/detect/events/{cid}", + stream=True, + timeout=_DETECT_TIMEOUT + 2, + headers=auth_headers, + ) as resp: + resp.raise_for_status() + connected.set() + for ev in sseclient.SSEClient(resp).events(): + if not ev.data or not str(ev.data).strip(): + continue + data = json.loads(ev.data) + if data.get("mediaStatus") in ("AIProcessed", "Error"): + break + except Exception: + pass + finally: + connected.set() + done.set() + + th = threading.Thread(target=_listen, daemon=True) + th.start() + connected.wait(timeout=5) + files = {"file": ("lazy.jpg", image_small, "image/jpeg")} - r = http_client.post("/detect/image", files=files, headers=auth_headers, timeout=_DETECT_TIMEOUT) - r.raise_for_status() - body = r.json() - assert isinstance(body, list) + r = http_client.post("/detect/image", files=files, headers=headers, timeout=_DETECT_TIMEOUT) + assert r.status_code == 202 + done.wait(timeout=_DETECT_TIMEOUT) + th.join(timeout=2) + after = _get_health(http_client) _assert_active_ai(after) @@ -61,13 +98,49 @@ class TestHealthEngineStep03Warmed: assert data.get("errorMessage") is None def test_ft_p_15_onnx_cpu_detect(self, http_client, image_small, auth_headers): + cid = str(uuid.uuid4()) + headers = {**auth_headers, "X-Channel-Id": cid} + all_detections = [] + done = threading.Event() + connected = threading.Event() + + def _listen(): + try: + with http_client.get( + f"/detect/events/{cid}", + stream=True, + timeout=_DETECT_TIMEOUT + 2, + headers=auth_headers, + ) as resp: + resp.raise_for_status() + connected.set() + for ev in sseclient.SSEClient(resp).events(): + if not ev.data or not str(ev.data).strip(): + continue + data = json.loads(ev.data) + if data.get("mediaStatus") == "AIProcessing": + all_detections.extend(data.get("annotations", [])) + if data.get("mediaStatus") in ("AIProcessed", "Error"): + break + except Exception: + pass + finally: + connected.set() + done.set() + + th = threading.Thread(target=_listen, daemon=True) + th.start() + connected.wait(timeout=5) + files = {"file": ("onnx.jpg", image_small, "image/jpeg")} - r = http_client.post("/detect/image", files=files, headers=auth_headers, timeout=_DETECT_TIMEOUT) + r = http_client.post("/detect/image", files=files, headers=headers, timeout=_DETECT_TIMEOUT) r.raise_for_status() - body = r.json() - assert isinstance(body, list) - if body: - d = body[0] + assert r.status_code == 202 + done.wait(timeout=_DETECT_TIMEOUT) + th.join(timeout=2) + + if all_detections: + d = all_detections[0] for k in ( "centerX", "centerY", diff --git a/e2e/tests/test_negative.py b/e2e/tests/test_negative.py index 18ae2ac..f88e715 100644 --- a/e2e/tests/test_negative.py +++ b/e2e/tests/test_negative.py @@ -1,3 +1,5 @@ +import uuid + import pytest import requests @@ -41,6 +43,8 @@ def test_ft_n_03_loader_error_mode_detect_does_not_500( f"{mock_loader_url}/mock/config", json={"mode": "error"}, timeout=10 ) cfg.raise_for_status() + channel_id = str(uuid.uuid4()) + headers = {**auth_headers, "X-Channel-Id": channel_id} files = {"file": ("small.jpg", image_small, "image/jpeg")} - r = http_client.post("/detect/image", files=files, headers=auth_headers, timeout=_DETECT_TIMEOUT) + r = http_client.post("/detect/image", files=files, headers=headers, timeout=_DETECT_TIMEOUT) assert r.status_code != 500 diff --git a/e2e/tests/test_performance.py b/e2e/tests/test_performance.py index 887ae62..3ee9c76 100644 --- a/e2e/tests/test_performance.py +++ b/e2e/tests/test_performance.py @@ -1,5 +1,4 @@ import json -import time import pytest @@ -19,20 +18,13 @@ def _percentile_ms(sorted_ms, p): @pytest.mark.timeout(60) def test_nft_perf_01_single_image_latency_p95( - warm_engine, http_client, image_small, auth_headers + warm_engine, image_detect, image_small ): times_ms = [] for _ in range(10): - t0 = time.perf_counter() - r = http_client.post( - "/detect/image", - files={"file": ("img.jpg", image_small, "image/jpeg")}, - headers=auth_headers, - timeout=8, - ) - elapsed_ms = (time.perf_counter() - t0) * 1000.0 - assert r.status_code == 200 + _, elapsed_ms = image_detect(image_small, "img.jpg", timeout=8) times_ms.append(elapsed_ms) + sorted_ms = sorted(times_ms) p50 = _percentile_ms(sorted_ms, 50) p95 = _percentile_ms(sorted_ms, 95) @@ -47,34 +39,16 @@ def test_nft_perf_01_single_image_latency_p95( @pytest.mark.timeout(60) def test_nft_perf_03_tiling_overhead_large_image( - warm_engine, http_client, image_small, image_large, auth_headers + warm_engine, image_detect, image_small, image_large ): - t_small = time.perf_counter() - r_small = http_client.post( - "/detect/image", - files={"file": ("small.jpg", image_small, "image/jpeg")}, - headers=auth_headers, - timeout=8, - ) - small_ms = (time.perf_counter() - t_small) * 1000.0 - assert r_small.status_code == 200 - config = json.dumps( - {"altitude": 400, "focal_length": 24, "sensor_width": 23.5} - ) - t_large = time.perf_counter() - r_large = http_client.post( - "/detect/image", - files={"file": ("large.jpg", image_large, "image/jpeg")}, - data={"config": config}, - headers=auth_headers, + _, small_ms = image_detect(image_small, "small.jpg", timeout=8) + _, large_ms = image_detect( + image_large, "large.jpg", + config=json.dumps({"altitude": 400, "focal_length": 24, "sensor_width": 23.5}), timeout=20, ) - large_ms = (time.perf_counter() - t_large) * 1000.0 - assert r_large.status_code == 200 assert large_ms < 30_000.0 print( f"nft_perf_03_csv,baseline_small_ms,{small_ms:.2f},large_ms,{large_ms:.2f}" ) assert large_ms > small_ms - 500.0 - - diff --git a/e2e/tests/test_resilience.py b/e2e/tests/test_resilience.py index 87d5a8c..d567a71 100644 --- a/e2e/tests/test_resilience.py +++ b/e2e/tests/test_resilience.py @@ -5,15 +5,13 @@ _DETECT_TIMEOUT = 60 def test_nft_res_01_loader_outage_after_init( - warm_engine, http_client, mock_loader_url, image_small, auth_headers + warm_engine, image_detect, mock_loader_url, image_small, http_client ): requests.post( f"{mock_loader_url}/mock/config", json={"mode": "error"}, timeout=10 ).raise_for_status() - files = {"file": ("r1.jpg", image_small, "image/jpeg")} - r = http_client.post("/detect/image", files=files, headers=auth_headers, timeout=_DETECT_TIMEOUT) - assert r.status_code == 200 - assert isinstance(r.json(), list) + detections, _ = image_detect(image_small, "r1.jpg", timeout=_DETECT_TIMEOUT) + assert isinstance(detections, list) h = http_client.get("/health") assert h.status_code == 200 hd = h.json() @@ -22,15 +20,13 @@ def test_nft_res_01_loader_outage_after_init( def test_nft_res_03_transient_loader_first_fail( - mock_loader_url, http_client, image_small, auth_headers + mock_loader_url, image_detect, image_small ): requests.post( f"{mock_loader_url}/mock/config", json={"mode": "first_fail"}, timeout=10 ).raise_for_status() - files = {"file": ("r3a.jpg", image_small, "image/jpeg")} - r1 = http_client.post("/detect/image", files=files, headers=auth_headers, timeout=_DETECT_TIMEOUT) - files2 = {"file": ("r3b.jpg", image_small, "image/jpeg")} - r2 = http_client.post("/detect/image", files=files2, headers=auth_headers, timeout=_DETECT_TIMEOUT) - assert r2.status_code == 200 - if r1.status_code != 200: - assert r1.status_code != 500 + try: + image_detect(image_small, "r3a.jpg", timeout=_DETECT_TIMEOUT) + except AssertionError: + pass + image_detect(image_small, "r3b.jpg", timeout=_DETECT_TIMEOUT) diff --git a/e2e/tests/test_resource_limits.py b/e2e/tests/test_resource_limits.py index d88cc10..098c5bb 100644 --- a/e2e/tests/test_resource_limits.py +++ b/e2e/tests/test_resource_limits.py @@ -8,28 +8,16 @@ import pytest @pytest.mark.slow @pytest.mark.timeout(120) def test_nft_res_lim_03_max_detections_per_frame( - warm_engine, http_client, image_dense, auth_headers + warm_engine, image_detect, image_dense ): - r = http_client.post( - "/detect/image", - files={"file": ("img.jpg", image_dense, "image/jpeg")}, - headers=auth_headers, - timeout=120, - ) - assert r.status_code == 200 - body = r.json() - assert isinstance(body, list) - assert len(body) <= 300 + detections, _ = image_detect(image_dense, "img.jpg", timeout=120) + assert isinstance(detections, list) + assert len(detections) <= 300 @pytest.mark.slow -def test_nft_res_lim_04_log_file_rotation(warm_engine, http_client, image_small, auth_headers): - http_client.post( - "/detect/image", - files={"file": ("img.jpg", image_small, "image/jpeg")}, - headers=auth_headers, - timeout=60, - ) +def test_nft_res_lim_04_log_file_rotation(warm_engine, image_detect, image_small): + image_detect(image_small, "img.jpg", timeout=60) candidates = [ Path(__file__).resolve().parent.parent / "logs", Path("/app/Logs"), diff --git a/e2e/tests/test_single_image.py b/e2e/tests/test_single_image.py index 372d948..44d9f31 100644 --- a/e2e/tests/test_single_image.py +++ b/e2e/tests/test_single_image.py @@ -81,16 +81,10 @@ def _weather_label_ok(label, base_names): @pytest.mark.slow -def test_ft_p_03_detection_response_structure_ac1(http_client, image_small, warm_engine, auth_headers): - r = http_client.post( - "/detect/image", - files={"file": ("img.jpg", image_small, "image/jpeg")}, - headers=auth_headers, - ) - assert r.status_code == 200 - body = r.json() - assert isinstance(body, list) - for d in body: +def test_ft_p_03_detection_response_structure_ac1(image_detect, image_small, warm_engine): + detections, _ = image_detect(image_small, "img.jpg") + assert isinstance(detections, list) + for d in detections: assert isinstance(d["centerX"], (int, float)) assert isinstance(d["centerY"], (int, float)) assert isinstance(d["width"], (int, float)) @@ -106,44 +100,24 @@ def test_ft_p_03_detection_response_structure_ac1(http_client, image_small, warm @pytest.mark.slow -def test_ft_p_05_confidence_filtering_ac2(http_client, image_small, warm_engine, auth_headers): - cfg_hi = json.dumps({"probability_threshold": 0.8}) - r_hi = http_client.post( - "/detect/image", - files={"file": ("img.jpg", image_small, "image/jpeg")}, - data={"config": cfg_hi}, - headers=auth_headers, - ) - assert r_hi.status_code == 200 - hi = r_hi.json() +def test_ft_p_05_confidence_filtering_ac2(image_detect, image_small, warm_engine): + hi, _ = image_detect(image_small, "img.jpg", config=json.dumps({"probability_threshold": 0.8})) assert isinstance(hi, list) for d in hi: assert float(d["confidence"]) + _EPS >= 0.8 - cfg_lo = json.dumps({"probability_threshold": 0.1}) - r_lo = http_client.post( - "/detect/image", - files={"file": ("img.jpg", image_small, "image/jpeg")}, - data={"config": cfg_lo}, - headers=auth_headers, - ) - assert r_lo.status_code == 200 - lo = r_lo.json() + + lo, _ = image_detect(image_small, "img.jpg", config=json.dumps({"probability_threshold": 0.1})) assert isinstance(lo, list) assert len(lo) >= len(hi) @pytest.mark.slow -def test_ft_p_06_overlap_deduplication_ac3(http_client, image_dense, warm_engine, auth_headers): - cfg_loose = json.dumps({"tracking_intersection_threshold": 0.6}) - r1 = http_client.post( - "/detect/image", - files={"file": ("img.jpg", image_dense, "image/jpeg")}, - data={"config": cfg_loose}, - headers=auth_headers, +def test_ft_p_06_overlap_deduplication_ac3(image_detect, image_dense, warm_engine): + dets, _ = image_detect( + image_dense, "img.jpg", + config=json.dumps({"tracking_intersection_threshold": 0.6}), timeout=_DETECT_SLOW_TIMEOUT, ) - assert r1.status_code == 200 - dets = r1.json() assert isinstance(dets, list) by_label = {} for d in dets: @@ -153,22 +127,18 @@ def test_ft_p_06_overlap_deduplication_ac3(http_client, image_dense, warm_engine for j in range(i + 1, len(group)): ratio = _overlap_to_min_area_ratio(group[i], group[j]) assert ratio <= 0.6 + _EPS, (label, ratio) - cfg_strict = json.dumps({"tracking_intersection_threshold": 0.01}) - r2 = http_client.post( - "/detect/image", - files={"file": ("img.jpg", image_dense, "image/jpeg")}, - data={"config": cfg_strict}, - headers=auth_headers, + + strict, _ = image_detect( + image_dense, "img.jpg", + config=json.dumps({"tracking_intersection_threshold": 0.01}), timeout=_DETECT_SLOW_TIMEOUT, ) - assert r2.status_code == 200 - strict = r2.json() assert isinstance(strict, list) assert len(strict) <= len(dets) @pytest.mark.slow -def test_ft_p_07_physical_size_filtering_ac4(http_client, image_small, warm_engine, auth_headers): +def test_ft_p_07_physical_size_filtering_ac4(image_detect, image_small, warm_engine): by_id, _ = _load_classes_media() wh = _image_width_height(image_small) assert wh is not None @@ -184,15 +154,7 @@ def test_ft_p_07_physical_size_filtering_ac4(http_client, image_small, warm_engi "sensor_width": sensor_width, } ) - r = http_client.post( - "/detect/image", - files={"file": ("img.jpg", image_small, "image/jpeg")}, - data={"config": cfg}, - headers=auth_headers, - timeout=_DETECT_SLOW_TIMEOUT, - ) - assert r.status_code == 200 - body = r.json() + body, _ = image_detect(image_small, "img.jpg", config=cfg, timeout=_DETECT_SLOW_TIMEOUT) assert isinstance(body, list) for d in body: base_id = d["classNum"] % _WEATHER_CLASS_STRIDE @@ -203,17 +165,10 @@ def test_ft_p_07_physical_size_filtering_ac4(http_client, image_small, warm_engi @pytest.mark.slow def test_ft_p_13_weather_mode_class_variants_ac5( - http_client, image_different_types, warm_engine, auth_headers + image_detect, image_different_types, warm_engine ): _, base_names = _load_classes_media() - r = http_client.post( - "/detect/image", - files={"file": ("img.jpg", image_different_types, "image/jpeg")}, - headers=auth_headers, - timeout=_DETECT_SLOW_TIMEOUT, - ) - assert r.status_code == 200 - body = r.json() + body, _ = image_detect(image_different_types, "img.jpg", timeout=_DETECT_SLOW_TIMEOUT) assert isinstance(body, list) for d in body: label = d["label"] diff --git a/e2e/tests/test_streaming_video_upload.py b/e2e/tests/test_streaming_video_upload.py index 5d4cf3d..1e1f195 100644 --- a/e2e/tests/test_streaming_video_upload.py +++ b/e2e/tests/test_streaming_video_upload.py @@ -10,6 +10,7 @@ Run with: pytest e2e/tests/test_streaming_video_upload.py -s -v import json import threading import time +import uuid from pathlib import Path import pytest @@ -37,21 +38,23 @@ def _chunked_reader(path: str, chunk_size: int = 64 * 1024): def _start_sse_listener( - http_client, media_id: str, auth_headers: dict + http_client, channel_id: str, auth_headers: dict ) -> tuple[list[dict], list[BaseException], threading.Event]: events: list[dict] = [] errors: list[BaseException] = [] first_event = threading.Event() + connected = threading.Event() def _listen(): try: with http_client.get( - f"/detect/{media_id}", + f"/detect/events/{channel_id}", stream=True, timeout=_TIMEOUT + 2, headers=auth_headers, ) as resp: resp.raise_for_status() + connected.set() for event in sseclient.SSEClient(resp).events(): if not event.data or not str(event.data).strip(): continue @@ -62,9 +65,12 @@ def _start_sse_listener( except BaseException as exc: errors.append(exc) finally: + connected.set() first_event.set() - threading.Thread(target=_listen, daemon=True).start() + th = threading.Thread(target=_listen, daemon=True) + th.start() + connected.wait(timeout=3) return events, errors, first_event @@ -74,6 +80,8 @@ def test_streaming_video_detections_appear_during_upload( ): # Arrange video_path = _fixture_path("video_test01.mp4") + channel_id = str(uuid.uuid4()) + events, errors, first_event = _start_sse_listener(http_client, channel_id, auth_headers) # Act r = http_client.post( @@ -81,14 +89,13 @@ def test_streaming_video_detections_appear_during_upload( data=_chunked_reader(video_path), headers={ **auth_headers, + "X-Channel-Id": channel_id, "X-Filename": "video_test01.mp4", "Content-Type": "application/octet-stream", }, timeout=8, ) - assert r.status_code == 200 - media_id = r.json()["mediaId"] - events, errors, first_event = _start_sse_listener(http_client, media_id, auth_headers) + assert r.status_code == 202 first_event.wait(timeout=_TIMEOUT) # Assert @@ -103,6 +110,8 @@ def test_streaming_video_detections_appear_during_upload( def test_non_faststart_video_still_works(warm_engine, http_client, auth_headers): # Arrange video_path = _fixture_path("video_test01.mp4") + channel_id = str(uuid.uuid4()) + events, errors, first_event = _start_sse_listener(http_client, channel_id, auth_headers) # Act r = http_client.post( @@ -110,14 +119,13 @@ def test_non_faststart_video_still_works(warm_engine, http_client, auth_headers) data=_chunked_reader(video_path), headers={ **auth_headers, + "X-Channel-Id": channel_id, "X-Filename": "video_test01_plain.mp4", "Content-Type": "application/octet-stream", }, timeout=8, ) - assert r.status_code == 200 - media_id = r.json()["mediaId"] - events, errors, first_event = _start_sse_listener(http_client, media_id, auth_headers) + assert r.status_code == 202 first_event.wait(timeout=_TIMEOUT) # Assert diff --git a/e2e/tests/test_tiling.py b/e2e/tests/test_tiling.py index ee2e446..dbc7518 100644 --- a/e2e/tests/test_tiling.py +++ b/e2e/tests/test_tiling.py @@ -28,32 +28,22 @@ def _assert_no_same_label_near_duplicate_centers(detections): @pytest.mark.slow -def test_ft_p_04_gsd_based_tiling_ac1(http_client, image_large, warm_engine, auth_headers): - config = json.dumps(_GSD) - r = http_client.post( - "/detect/image", - files={"file": ("img.jpg", image_large, "image/jpeg")}, - data={"config": config}, - headers=auth_headers, +def test_ft_p_04_gsd_based_tiling_ac1(image_detect, image_large, warm_engine): + body, _ = image_detect( + image_large, "img.jpg", + config=json.dumps(_GSD), timeout=_TILING_TIMEOUT, ) - assert r.status_code == 200 - body = r.json() assert isinstance(body, list) _assert_coords_normalized(body) @pytest.mark.slow -def test_ft_p_16_tile_boundary_deduplication_ac2(http_client, image_large, warm_engine, auth_headers): - config = json.dumps({**_GSD, "big_image_tile_overlap_percent": 20}) - r = http_client.post( - "/detect/image", - files={"file": ("img.jpg", image_large, "image/jpeg")}, - data={"config": config}, - headers=auth_headers, +def test_ft_p_16_tile_boundary_deduplication_ac2(image_detect, image_large, warm_engine): + body, _ = image_detect( + image_large, "img.jpg", + config=json.dumps({**_GSD, "big_image_tile_overlap_percent": 20}), timeout=_TILING_TIMEOUT, ) - assert r.status_code == 200 - body = r.json() assert isinstance(body, list) _assert_no_same_label_near_duplicate_centers(body) diff --git a/e2e/tests/test_video.py b/e2e/tests/test_video.py index 553c02f..6120837 100644 --- a/e2e/tests/test_video.py +++ b/e2e/tests/test_video.py @@ -1,6 +1,7 @@ import json import threading import time +import uuid from pathlib import Path import pytest @@ -24,29 +25,22 @@ def video_events(warm_engine, http_client, auth_headers): if not Path(_VIDEO).is_file(): pytest.skip(f"missing fixture {_VIDEO}") - r = http_client.post( - "/detect/video", - data=_chunked_reader(_VIDEO), - headers={ - **auth_headers, - "X-Filename": "video_test01.mp4", - "Content-Type": "application/octet-stream", - }, - timeout=15, - ) - assert r.status_code == 200 - media_id = r.json()["mediaId"] - + channel_id = str(uuid.uuid4()) collected: list[tuple[float, dict]] = [] thread_exc: list[BaseException] = [] done = threading.Event() + connected = threading.Event() def _listen(): try: with http_client.get( - f"/detect/{media_id}", stream=True, timeout=35, headers=auth_headers + f"/detect/events/{channel_id}", + stream=True, + timeout=60, + headers=auth_headers, ) as resp: resp.raise_for_status() + connected.set() sse = sseclient.SSEClient(resp) for event in sse.events(): if not event.data or not str(event.data).strip(): @@ -61,10 +55,26 @@ def video_events(warm_engine, http_client, auth_headers): except BaseException as e: thread_exc.append(e) finally: + connected.set() done.set() th = threading.Thread(target=_listen, daemon=True) th.start() + connected.wait(timeout=5) + + r = http_client.post( + "/detect/video", + data=_chunked_reader(_VIDEO), + headers={ + **auth_headers, + "X-Channel-Id": channel_id, + "X-Filename": "video_test01.mp4", + "Content-Type": "application/octet-stream", + }, + timeout=15, + ) + assert r.status_code == 202 + assert done.wait(timeout=30) th.join(timeout=5) assert not thread_exc, thread_exc diff --git a/requirements.txt b/requirements.txt index 9caa7e5..ca41aa5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ fastapi==0.135.2 uvicorn[standard]==0.42.0 PyJWT==2.12.1 h11==0.16.0 -python-multipart>=1.3.1 +python-multipart==0.0.22 Cython==3.2.4 opencv-python==4.10.0.84 numpy==2.3.0 diff --git a/setup.py b/setup.py index ec15a41..3f1e966 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,9 @@ try: extensions.append( Extension('engines.tensorrt_engine', [f'{SRC}/engines/tensorrt_engine.pyx'], include_dirs=np_inc) ) + extensions.append( + Extension('engines.jetson_tensorrt_engine', [f'{SRC}/engines/jetson_tensorrt_engine.pyx'], include_dirs=np_inc) + ) except ImportError: pass diff --git a/src/engines/__init__.py b/src/engines/__init__.py index c1b7a4e..27a6148 100644 --- a/src/engines/__init__.py +++ b/src/engines/__init__.py @@ -1,6 +1,16 @@ +import os import platform import sys +from loguru import logger +from engines.engine_factory import ( + EngineFactory, + OnnxEngineFactory, + CoreMLEngineFactory, + TensorRTEngineFactory, + JetsonTensorRTEngineFactory, +) + def _check_tensor_gpu_index(): try: @@ -35,18 +45,29 @@ def _is_apple_silicon(): return False +def _is_jetson(): + return ( + platform.machine() == "aarch64" + and tensor_gpu_index > -1 + and os.path.isfile("/etc/nv_tegra_release") + ) + + tensor_gpu_index = _check_tensor_gpu_index() -def _select_engine_class(): +def _create_engine_factory() -> EngineFactory: + if _is_jetson(): + logger.info("Engine factory: JetsonTensorRTEngineFactory") + return JetsonTensorRTEngineFactory() if tensor_gpu_index > -1: - from engines.tensorrt_engine import TensorRTEngine # pyright: ignore[reportMissingImports] - return TensorRTEngine + logger.info("Engine factory: TensorRTEngineFactory") + return TensorRTEngineFactory() if _is_apple_silicon(): - from engines.coreml_engine import CoreMLEngine - return CoreMLEngine - from engines.onnx_engine import OnnxEngine - return OnnxEngine + logger.info("Engine factory: CoreMLEngineFactory") + return CoreMLEngineFactory() + logger.info("Engine factory: OnnxEngineFactory") + return OnnxEngineFactory() -EngineClass = _select_engine_class() +engine_factory = _create_engine_factory() diff --git a/src/engines/coreml_engine.pyx b/src/engines/coreml_engine.pyx index f80aa2c..9431bd9 100644 --- a/src/engines/coreml_engine.pyx +++ b/src/engines/coreml_engine.pyx @@ -30,10 +30,6 @@ cdef class CoreMLEngine(InferenceEngine): constants_inf.log(f'CoreML model: {self.img_width}x{self.img_height}') - @staticmethod - def get_engine_filename(): - return "azaion_coreml.zip" - @staticmethod def _extract_from_zip(model_bytes): tmpdir = tempfile.mkdtemp() diff --git a/src/engines/engine_factory.py b/src/engines/engine_factory.py new file mode 100644 index 0000000..b969436 --- /dev/null +++ b/src/engines/engine_factory.py @@ -0,0 +1,109 @@ +import os +import tempfile + + +class EngineFactory: + has_build_step = False + + def create(self, model_bytes: bytes): + raise NotImplementedError + + def load_engine(self, loader_client, models_dir: str): + filename = self._get_ai_engine_filename() + if filename is None: + return None + try: + res = loader_client.load_big_small_resource(filename, models_dir) + if res.err is None: + return self.create(res.data) + except Exception: + pass + return None + + def _get_ai_engine_filename(self) -> str | None: + return None + + def get_source_filename(self) -> str | None: + return None + + def build_from_source(self, onnx_bytes: bytes, loader_client, models_dir: str): + raise NotImplementedError(f"{type(self).__name__} does not support building from source") + + +class OnnxEngineFactory(EngineFactory): + def create(self, model_bytes: bytes): + from engines.onnx_engine import OnnxEngine + return OnnxEngine(model_bytes) + + def get_source_filename(self) -> str: + import constants_inf + return constants_inf.AI_ONNX_MODEL_FILE + + +class CoreMLEngineFactory(EngineFactory): + def create(self, model_bytes: bytes): + from engines.coreml_engine import CoreMLEngine + return CoreMLEngine(model_bytes) + + def _get_ai_engine_filename(self) -> str: + return "azaion_coreml.zip" + + +class TensorRTEngineFactory(EngineFactory): + has_build_step = True + + def create(self, model_bytes: bytes): + from engines.tensorrt_engine import TensorRTEngine + return TensorRTEngine(model_bytes) + + def _get_ai_engine_filename(self) -> str | None: + from engines.tensorrt_engine import TensorRTEngine + return TensorRTEngine.get_engine_filename() + + def get_source_filename(self) -> str: + import constants_inf + return constants_inf.AI_ONNX_MODEL_FILE + + def build_from_source(self, onnx_bytes: bytes, loader_client, models_dir: str): + from engines.tensorrt_engine import TensorRTEngine + engine_bytes = TensorRTEngine.convert_from_source(onnx_bytes, None) + return engine_bytes, TensorRTEngine.get_engine_filename() + + +class JetsonTensorRTEngineFactory(TensorRTEngineFactory): + def create(self, model_bytes: bytes): + from engines.jetson_tensorrt_engine import JetsonTensorRTEngine + return JetsonTensorRTEngine(model_bytes) + + def _get_ai_engine_filename(self) -> str | None: + from engines.tensorrt_engine import TensorRTEngine + return TensorRTEngine.get_engine_filename("int8") + + def build_from_source(self, onnx_bytes: bytes, loader_client, models_dir: str): + from engines.tensorrt_engine import TensorRTEngine + calib_cache_path = self._download_calib_cache(loader_client, models_dir) + try: + engine_bytes = TensorRTEngine.convert_from_source(onnx_bytes, calib_cache_path) + return engine_bytes, TensorRTEngine.get_engine_filename("int8") + finally: + if calib_cache_path is not None: + try: + os.unlink(calib_cache_path) + except Exception: + pass + + def _download_calib_cache(self, loader_client, models_dir: str) -> str | None: + import constants_inf + try: + res = loader_client.load_big_small_resource(constants_inf.INT8_CALIB_CACHE_FILE, models_dir) + if res.err is not None: + constants_inf.log(f"INT8 calibration cache not available: {res.err}") + return None + fd, path = tempfile.mkstemp(suffix=".cache") + with os.fdopen(fd, "wb") as f: + f.write(res.data) + constants_inf.log("INT8 calibration cache downloaded") + return path + except Exception as e: + constants_inf.log(f"INT8 calibration cache download failed: {str(e)}") + return None diff --git a/src/engines/jetson_tensorrt_engine.pxd b/src/engines/jetson_tensorrt_engine.pxd new file mode 100644 index 0000000..3f77e1e --- /dev/null +++ b/src/engines/jetson_tensorrt_engine.pxd @@ -0,0 +1,5 @@ +from engines.tensorrt_engine cimport TensorRTEngine + + +cdef class JetsonTensorRTEngine(TensorRTEngine): + pass diff --git a/src/engines/jetson_tensorrt_engine.pyx b/src/engines/jetson_tensorrt_engine.pyx new file mode 100644 index 0000000..3f77e1e --- /dev/null +++ b/src/engines/jetson_tensorrt_engine.pyx @@ -0,0 +1,5 @@ +from engines.tensorrt_engine cimport TensorRTEngine + + +cdef class JetsonTensorRTEngine(TensorRTEngine): + pass diff --git a/src/engines/onnx_engine.pyx b/src/engines/onnx_engine.pyx index e539edc..0065f69 100644 --- a/src/engines/onnx_engine.pyx +++ b/src/engines/onnx_engine.pyx @@ -23,7 +23,7 @@ cdef class OnnxEngine(InferenceEngine): self.model_inputs = self.session.get_inputs() self.input_name = self.model_inputs[0].name self.input_shape = self.model_inputs[0].shape - if self.input_shape[0] not in (-1, None, "N"): + if isinstance(self.input_shape[0], int) and self.input_shape[0] > 0: self.max_batch_size = self.input_shape[0] constants_inf.log(f'AI detection model input: {self.model_inputs} {self.input_shape}') model_meta = self.session.get_modelmeta() diff --git a/src/engines/tensorrt_engine.pyx b/src/engines/tensorrt_engine.pyx index b1dba55..3d05cd6 100644 --- a/src/engines/tensorrt_engine.pyx +++ b/src/engines/tensorrt_engine.pyx @@ -113,11 +113,6 @@ cdef class TensorRTEngine(InferenceEngine): except Exception: return None - @staticmethod - def get_source_filename(): - import constants_inf - return constants_inf.AI_ONNX_MODEL_FILE - @staticmethod def convert_from_source(bytes onnx_model, str calib_cache_path=None): gpu_mem = TensorRTEngine.get_gpu_memory_bytes(0) diff --git a/src/inference.pyx b/src/inference.pyx index 715decd..0d0e355 100644 --- a/src/inference.pyx +++ b/src/inference.pyx @@ -1,6 +1,4 @@ import io -import os -import tempfile import threading import av @@ -14,7 +12,7 @@ from ai_config cimport AIRecognitionConfig from engines.inference_engine cimport InferenceEngine from loader_http_client cimport LoaderHttpClient from threading import Thread -from engines import EngineClass +from engines import engine_factory def ai_config_from_dict(dict data): @@ -76,29 +74,23 @@ cdef class Inference: raise Exception(res.err) return res.data - cdef convert_and_upload_model(self, bytes source_bytes, str engine_filename, str calib_cache_path): + cdef convert_and_upload_model(self, bytes source_bytes, str models_dir): try: self.ai_availability_status.set_status(AIAvailabilityEnum.CONVERTING) - models_dir = constants_inf.MODELS_FOLDER - model_bytes = EngineClass.convert_from_source(source_bytes, calib_cache_path) + engine_bytes, engine_filename = engine_factory.build_from_source(source_bytes, self.loader_client, models_dir) self.ai_availability_status.set_status(AIAvailabilityEnum.UPLOADING) - res = self.loader_client.upload_big_small_resource(model_bytes, engine_filename, models_dir) + res = self.loader_client.upload_big_small_resource(engine_bytes, engine_filename, models_dir) if res.err is not None: self.ai_availability_status.set_status(AIAvailabilityEnum.WARNING, f"Failed to upload converted model: {res.err}") - self._converted_model_bytes = model_bytes + self._converted_model_bytes = engine_bytes self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED) except Exception as e: self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, str(e)) self._converted_model_bytes = None finally: self.is_building_engine = False - if calib_cache_path is not None: - try: - os.unlink(calib_cache_path) - except Exception: - pass cdef init_ai(self): constants_inf.log( 'init AI...') @@ -110,7 +102,7 @@ cdef class Inference: if self._converted_model_bytes is not None: try: - self.engine = EngineClass(self._converted_model_bytes) + self.engine = engine_factory.create(self._converted_model_bytes) self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED) except Exception as e: self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, str(e)) @@ -119,58 +111,33 @@ cdef class Inference: return models_dir = constants_inf.MODELS_FOLDER - engine_filename_fp16 = EngineClass.get_engine_filename() - if engine_filename_fp16 is not None: - engine_filename_int8 = EngineClass.get_engine_filename("int8") - for candidate in [engine_filename_int8, engine_filename_fp16]: - try: - self.ai_availability_status.set_status(AIAvailabilityEnum.DOWNLOADING) - res = self.loader_client.load_big_small_resource(candidate, models_dir) - if res.err is not None: - raise Exception(res.err) - self.engine = EngineClass(res.data) - self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED) - return - except Exception: - pass + self.ai_availability_status.set_status(AIAvailabilityEnum.DOWNLOADING) + engine = engine_factory.load_engine(self.loader_client, models_dir) + if engine is not None: + self.engine = engine + self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED) + return - source_filename = EngineClass.get_source_filename() - if source_filename is None: - self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, "Pre-built engine not found and no source available") - return + source_filename = engine_factory.get_source_filename() + if source_filename is None: + self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, "No engine available and no source to build from") + return + + source_bytes = self.download_model(source_filename) + + if engine_factory.has_build_step: self.ai_availability_status.set_status(AIAvailabilityEnum.WARNING, "Cached engine not found, converting from source") - source_bytes = self.download_model(source_filename) - calib_cache_path = self._try_download_calib_cache(models_dir) - target_engine_filename = EngineClass.get_engine_filename("int8") if calib_cache_path is not None else engine_filename_fp16 self.is_building_engine = True - - thread = Thread(target=self.convert_and_upload_model, args=(source_bytes, target_engine_filename, calib_cache_path)) + thread = Thread(target=self.convert_and_upload_model, args=(source_bytes, models_dir)) thread.daemon = True thread.start() - return else: - self.engine = EngineClass(self.download_model(constants_inf.AI_ONNX_MODEL_FILE)) + self.engine = engine_factory.create(source_bytes) self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED) - self.is_building_engine = False except Exception as e: self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, str(e)) self.is_building_engine = False - cdef str _try_download_calib_cache(self, str models_dir): - try: - res = self.loader_client.load_big_small_resource(constants_inf.INT8_CALIB_CACHE_FILE, models_dir) - if res.err is not None: - constants_inf.log(f"INT8 calibration cache not available: {res.err}") - return None - fd, path = tempfile.mkstemp(suffix='.cache') - with os.fdopen(fd, 'wb') as f: - f.write(res.data) - constants_inf.log('INT8 calibration cache downloaded') - return path - except Exception as e: - constants_inf.log(f"INT8 calibration cache download failed: {str(e)}") - return None - cpdef run_detect_image(self, bytes image_bytes, AIRecognitionConfig ai_config, str media_name, object annotation_callback, object status_callback=None): cdef list all_frame_data = [] diff --git a/src/main.py b/src/main.py index 68aad2f..ba56089 100644 --- a/src/main.py +++ b/src/main.py @@ -5,6 +5,7 @@ import json import os import tempfile import time +from collections import deque from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Annotated, Optional @@ -15,7 +16,7 @@ import jwt as pyjwt import numpy as np import requests as http_requests from fastapi import Body, Depends, FastAPI, File, Form, HTTPException, Request, UploadFile -from fastapi.responses import StreamingResponse +from fastapi.responses import Response, StreamingResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from pydantic import BaseModel @@ -37,11 +38,14 @@ _MEDIA_STATUS_ERROR = 6 _VIDEO_EXTENSIONS = frozenset({".mp4", ".mov", ".webm", ".mkv", ".avi", ".m4v"}) _IMAGE_EXTENSIONS = frozenset({".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif", ".tiff"}) +_BUFFER_TTL_MS = 10_000 +_BUFFER_MAX = 200 + loader_client = LoaderHttpClient(LOADER_URL) annotations_client = LoaderHttpClient(ANNOTATIONS_URL) inference = None _job_queues: dict[str, list[asyncio.Queue]] = {} -_job_buffers: dict[str, list[str]] = {} +_channel_buffers: dict[str, deque] = {} _active_detections: dict[str, asyncio.Task] = {} _bearer = HTTPBearer(auto_error=False) @@ -323,21 +327,50 @@ def detection_to_dto(det) -> DetectionDto: ) -def _enqueue(media_id: str, event: DetectionEvent): +def _post_annotation_to_service(token_mgr: TokenManager, media_id: str, + annotation, dtos: list[DetectionDto]): + try: + token = token_mgr.get_valid_token() + image_b64 = base64.b64encode(annotation.image).decode() if annotation.image else None + payload = { + "mediaId": media_id, + "source": 0, + "videoTime": f"00:00:{annotation.time // 1000:02d}" if annotation.time else "00:00:00", + "detections": [d.model_dump() for d in dtos], + } + if image_b64: + payload["image"] = image_b64 + http_requests.post( + f"{ANNOTATIONS_URL}/annotations", + json=payload, + headers={"Authorization": f"Bearer {token}"}, + timeout=30, + ) + except Exception: + pass + + +def _cleanup_channel(channel_id: str): + _channel_buffers.pop(channel_id, None) + + +def _enqueue(channel_id: str, event: DetectionEvent): + now_ms = int(time.time() * 1000) data = event.model_dump_json() - _job_buffers.setdefault(media_id, []).append(data) - for q in _job_queues.get(media_id, []): + + buf = _channel_buffers.setdefault(channel_id, deque(maxlen=_BUFFER_MAX)) + buf.append((now_ms, data)) + cutoff = now_ms - _BUFFER_TTL_MS + while buf and buf[0][0] < cutoff: + buf.popleft() + + for q in _job_queues.get(channel_id, []): try: - q.put_nowait(data) + q.put_nowait((now_ms, data)) except asyncio.QueueFull: pass -def _schedule_buffer_cleanup(media_id: str, delay: float = 300.0): - loop = asyncio.get_event_loop() - loop.call_later(delay, lambda: _job_buffers.pop(media_id, None)) - - @app.get("/health") def health() -> HealthResponse: if inference is None: @@ -361,6 +394,36 @@ def health() -> HealthResponse: ) +@app.get("/detect/events/{channel_id}", dependencies=[Depends(require_auth)]) +async def detect_events(channel_id: str, request: Request, after_ts: Optional[int] = None): + queue: asyncio.Queue = asyncio.Queue(maxsize=200) + _job_queues.setdefault(channel_id, []).append(queue) + + async def event_generator(): + try: + if after_ts is not None: + for ts_ms, data in list(_channel_buffers.get(channel_id, [])): + if ts_ms > after_ts: + yield f"id: {ts_ms}\ndata: {data}\n\n" + while True: + ts_ms, data = await queue.get() + yield f"id: {ts_ms}\ndata: {data}\n\n" + except asyncio.CancelledError: + pass + finally: + queues = _job_queues.get(channel_id, []) + if queue in queues: + queues.remove(queue) + if not queues: + _job_queues.pop(channel_id, None) + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, + ) + + @app.post("/detect/image") async def detect_image( request: Request, @@ -384,6 +447,10 @@ async def detect_image( if cv2.imdecode(arr, cv2.IMREAD_COLOR) is None: raise HTTPException(status_code=400, detail="Invalid image data") + channel_id = request.headers.get("x-channel-id", "") + if not channel_id: + raise HTTPException(status_code=400, detail="X-Channel-Id header required") + config_dict = {} if config: config_dict = json.loads(config) @@ -395,7 +462,6 @@ async def detect_image( images_dir = os.environ.get( "IMAGES_DIR", os.path.join(os.getcwd(), "data", "images") ) - storage_path = None content_hash = None if token_mgr and user_id: content_hash = compute_media_content_hash(image_bytes) @@ -417,45 +483,65 @@ async def detect_image( _put_media_status(content_hash, _MEDIA_STATUS_AI_PROCESSING, bearer) media_name = Path(orig_name).stem.replace(" ", "") + media_id = content_hash or channel_id loop = asyncio.get_event_loop() inf = get_inference() - results = [] - def on_annotation(annotation, percent): - results.extend(annotation.detections) + async def run_detection(): + ai_cfg = ai_config_from_dict(config_dict) - ai_cfg = ai_config_from_dict(config_dict) + def on_annotation(annotation, percent): + dtos = [detection_to_dto(d) for d in annotation.detections] + event = DetectionEvent( + annotations=dtos, + mediaId=media_id, + mediaStatus="AIProcessing", + mediaPercent=percent, + ) + loop.call_soon_threadsafe(_enqueue, channel_id, event) + if token_mgr and content_hash and dtos: + _post_annotation_to_service(token_mgr, content_hash, annotation, dtos) - def run_detect(): - inf.run_detect_image(image_bytes, ai_cfg, media_name, on_annotation) + def run_sync(): + inf.run_detect_image(image_bytes, ai_cfg, media_name, on_annotation) - try: - await loop.run_in_executor(executor, run_detect) - if token_mgr and user_id and content_hash: - _put_media_status( - content_hash, _MEDIA_STATUS_AI_PROCESSED, token_mgr.get_valid_token() - ) - return [detection_to_dto(d) for d in results] - except RuntimeError as e: - if token_mgr and user_id and content_hash: - _put_media_status( - content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token() - ) - if "not available" in str(e): - raise HTTPException(status_code=503, detail=str(e)) - raise HTTPException(status_code=422, detail=str(e)) - except ValueError as e: - if token_mgr and user_id and content_hash: - _put_media_status( - content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token() - ) - raise HTTPException(status_code=400, detail=str(e)) - except Exception: - if token_mgr and user_id and content_hash: - _put_media_status( - content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token() - ) - raise + try: + await loop.run_in_executor(executor, run_sync) + _enqueue(channel_id, DetectionEvent( + annotations=[], mediaId=media_id, + mediaStatus="AIProcessed", mediaPercent=100, + )) + if token_mgr and content_hash: + _put_media_status( + content_hash, _MEDIA_STATUS_AI_PROCESSED, token_mgr.get_valid_token() + ) + except RuntimeError as e: + if token_mgr and content_hash: + _put_media_status( + content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token() + ) + _enqueue(channel_id, DetectionEvent( + annotations=[], mediaId=media_id, + mediaStatus="Error", mediaPercent=0, + )) + if "not available" in str(e): + return + raise + except Exception: + if token_mgr and content_hash: + _put_media_status( + content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token() + ) + _enqueue(channel_id, DetectionEvent( + annotations=[], mediaId=media_id, + mediaStatus="Error", mediaPercent=0, + )) + raise + finally: + loop.call_later(10.0, _cleanup_channel, channel_id) + + asyncio.create_task(run_detection()) + return Response(status_code=202) @app.post("/detect/video") @@ -467,6 +553,10 @@ async def detect_video_upload( from inference import ai_config_from_dict from streaming_buffer import StreamingBuffer + channel_id = request.headers.get("x-channel-id", "") + if not channel_id: + raise HTTPException(status_code=400, detail="X-Channel-Id header required") + filename = request.headers.get("x-filename", "upload.mp4") config_json = request.headers.get("x-config", "") ext = _normalize_upload_ext(filename) @@ -491,32 +581,23 @@ async def detect_video_upload( loop = asyncio.get_event_loop() inf = get_inference() - placeholder_id = f"tmp_{os.path.basename(buffer.path)}" - current_id = [placeholder_id] # mutable — updated to content_hash after upload + current_media_id = [channel_id] def on_annotation(annotation, percent): dtos = [detection_to_dto(d) for d in annotation.detections] - mid = current_id[0] + mid = current_media_id[0] event = DetectionEvent( annotations=dtos, mediaId=mid, mediaStatus="AIProcessing", mediaPercent=percent, ) - loop.call_soon_threadsafe(_enqueue, mid, event) - - def on_status(media_name_cb, count): - mid = current_id[0] - event = DetectionEvent( - annotations=[], - mediaId=mid, - mediaStatus="AIProcessed", - mediaPercent=100, - ) - loop.call_soon_threadsafe(_enqueue, mid, event) + loop.call_soon_threadsafe(_enqueue, channel_id, event) + if token_mgr and mid != channel_id and dtos: + _post_annotation_to_service(token_mgr, mid, annotation, dtos) def run_inference(): - inf.run_detect_video_stream(buffer, ai_cfg, media_name, on_annotation, on_status) + inf.run_detect_video_stream(buffer, ai_cfg, media_name, on_annotation, lambda *_: None) inference_future = loop.run_in_executor(executor, run_inference) @@ -533,14 +614,14 @@ async def detect_video_upload( if not ext.startswith("."): ext = "." + ext storage_path = os.path.abspath(os.path.join(videos_dir, f"{content_hash}{ext}")) + current_media_id[0] = content_hash - # Re-key buffered events from placeholder_id to content_hash so clients - # can subscribe to GET /detect/{content_hash} after POST returns. - if placeholder_id in _job_buffers: - _job_buffers[content_hash] = _job_buffers.pop(placeholder_id) - if placeholder_id in _job_queues: - _job_queues[content_hash] = _job_queues.pop(placeholder_id) - current_id[0] = content_hash # future on_annotation/on_status callbacks use content_hash + _enqueue(channel_id, DetectionEvent( + annotations=[], + mediaId=content_hash, + mediaStatus="Started", + mediaPercent=0, + )) if token_mgr and user_id: os.rename(buffer.path, storage_path) @@ -564,27 +645,24 @@ async def detect_video_upload( content_hash, _MEDIA_STATUS_AI_PROCESSED, token_mgr.get_valid_token(), ) - done_event = DetectionEvent( + _enqueue(channel_id, DetectionEvent( annotations=[], mediaId=content_hash, mediaStatus="AIProcessed", mediaPercent=100, - ) - _enqueue(content_hash, done_event) + )) except Exception: if token_mgr and user_id: _put_media_status( content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token(), ) - err_event = DetectionEvent( + _enqueue(channel_id, DetectionEvent( annotations=[], mediaId=content_hash, mediaStatus="Error", mediaPercent=0, - ) - _enqueue(content_hash, err_event) + )) finally: - _active_detections.pop(content_hash, None) - _schedule_buffer_cleanup(content_hash) + loop.call_later(10.0, _cleanup_channel, channel_id) buffer.close() if not (token_mgr and user_id) and os.path.isfile(buffer.path): try: @@ -592,31 +670,8 @@ async def detect_video_upload( except OSError: pass - _active_detections[content_hash] = asyncio.create_task(_wait_inference()) - return {"status": "started", "mediaId": content_hash} - - -def _post_annotation_to_service(token_mgr: TokenManager, media_id: str, - annotation, dtos: list[DetectionDto]): - try: - token = token_mgr.get_valid_token() - image_b64 = base64.b64encode(annotation.image).decode() if annotation.image else None - payload = { - "mediaId": media_id, - "source": 0, - "videoTime": f"00:00:{annotation.time // 1000:02d}" if annotation.time else "00:00:00", - "detections": [d.model_dump() for d in dtos], - } - if image_b64: - payload["image"] = image_b64 - http_requests.post( - f"{ANNOTATIONS_URL}/annotations", - json=payload, - headers={"Authorization": f"Bearer {token}"}, - timeout=30, - ) - except Exception: - pass + asyncio.create_task(_wait_inference()) + return Response(status_code=202) @app.post("/detect/{media_id}") @@ -630,6 +685,10 @@ async def detect_media( if existing is not None and not existing.done(): raise HTTPException(status_code=409, detail="Detection already in progress for this media") + channel_id = request.headers.get("x-channel-id", "") + if not channel_id: + raise HTTPException(status_code=400, detail="X-Channel-Id header required") + refresh_token = request.headers.get("x-refresh-token", "") access_token = request.headers.get("authorization", "").removeprefix("Bearer ").strip() token_mgr = TokenManager(access_token, refresh_token) if access_token else None @@ -668,7 +727,7 @@ async def detect_media( mediaStatus="AIProcessing", mediaPercent=percent, ) - loop.call_soon_threadsafe(_enqueue, media_id, event) + loop.call_soon_threadsafe(_enqueue, channel_id, event) if token_mgr and dtos: _post_annotation_to_service(token_mgr, media_id, annotation, dtos) @@ -679,7 +738,7 @@ async def detect_media( mediaStatus="AIProcessed", mediaPercent=100, ) - loop.call_soon_threadsafe(_enqueue, media_id, event) + loop.call_soon_threadsafe(_enqueue, channel_id, event) if token_mgr: _put_media_status( media_id, @@ -718,36 +777,10 @@ async def detect_media( mediaStatus="Error", mediaPercent=0, ) - _enqueue(media_id, error_event) + _enqueue(channel_id, error_event) finally: _active_detections.pop(media_id, None) - _schedule_buffer_cleanup(media_id) + loop.call_later(10.0, _cleanup_channel, channel_id) _active_detections[media_id] = asyncio.create_task(run_detection()) - return {"status": "started", "mediaId": media_id} - - -@app.get("/detect/{media_id}", dependencies=[Depends(require_auth)]) -async def detect_events(media_id: str): - queue: asyncio.Queue = asyncio.Queue(maxsize=200) - _job_queues.setdefault(media_id, []).append(queue) - - async def event_generator(): - try: - for data in list(_job_buffers.get(media_id, [])): - yield f"data: {data}\n\n" - while True: - data = await queue.get() - yield f"data: {data}\n\n" - except asyncio.CancelledError: - pass - finally: - queues = _job_queues.get(media_id, []) - if queue in queues: - queues.remove(queue) - - return StreamingResponse( - event_generator(), - media_type="text/event-stream", - headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, - ) + return Response(status_code=202) diff --git a/tests/test_az178_streaming_video.py b/tests/test_az178_streaming_video.py index a526cae..c3601c3 100644 --- a/tests/test_az178_streaming_video.py +++ b/tests/test_az178_streaming_video.py @@ -419,7 +419,10 @@ class TestDetectVideoEndpoint: from fastapi.testclient import TestClient client = TestClient(main.app) token = _access_jwt() - with patch.object(main, "get_inference", return_value=_CaptureInf()): + with ( + patch.object(main, "JWT_SECRET", _TEST_JWT_SECRET), + patch.object(main, "get_inference", return_value=_CaptureInf()), + ): # Act r = client.post( "/detect/video", diff --git a/tests/test_az180_jetson_int8.py b/tests/test_az180_jetson_int8.py index a67faad..9c988c6 100644 --- a/tests/test_az180_jetson_int8.py +++ b/tests/test_az180_jetson_int8.py @@ -96,3 +96,149 @@ def test_convert_from_source_uses_fp16_when_no_cache(): mock_config.set_flag.assert_any_call("FP16") int8_calls = [c for c in mock_config.set_flag.call_args_list if c == call("INT8")] assert len(int8_calls) == 0 + + +@requires_tensorrt +def test_trt_factory_build_from_source_uses_fp16(): + # Arrange + from engines.engine_factory import TensorRTEngineFactory + from engines.tensorrt_engine import TensorRTEngine + import engines.tensorrt_engine as trt_mod + + mock_trt, mock_builder, mock_config = _make_mock_trt() + factory = TensorRTEngineFactory() + + with patch.object(trt_mod, "trt", mock_trt), \ + patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3): + # Act + engine_bytes, filename = factory.build_from_source(b"onnx", MagicMock(), "models") + + # Assert + assert engine_bytes == b"engine_bytes" + assert filename is not None + int8_calls = [c for c in mock_config.set_flag.call_args_list if c == call("INT8")] + assert len(int8_calls) == 0 + + +@requires_tensorrt +def test_jetson_factory_build_from_source_uses_int8_when_cache_available(): + # Arrange + from engines.engine_factory import JetsonTensorRTEngineFactory + from engines.tensorrt_engine import TensorRTEngine + import engines.tensorrt_engine as trt_mod + + mock_trt, mock_builder, mock_config = _make_mock_trt() + factory = JetsonTensorRTEngineFactory() + loader = MagicMock() + result = MagicMock() + result.err = None + result.data = b"calib_data" + loader.load_big_small_resource.return_value = result + + with patch.object(trt_mod, "trt", mock_trt), \ + patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3): + # Act + engine_bytes, filename = factory.build_from_source(b"onnx", loader, "models") + + # Assert + assert engine_bytes == b"engine_bytes" + assert "int8" in filename + mock_config.set_flag.assert_any_call("INT8") + + +@requires_tensorrt +def test_jetson_factory_build_from_source_falls_back_to_fp16_when_no_cache(): + # Arrange + from engines.engine_factory import JetsonTensorRTEngineFactory + from engines.tensorrt_engine import TensorRTEngine + import engines.tensorrt_engine as trt_mod + + mock_trt, mock_builder, mock_config = _make_mock_trt() + factory = JetsonTensorRTEngineFactory() + loader = MagicMock() + result = MagicMock() + result.err = "not found" + loader.load_big_small_resource.return_value = result + + with patch.object(trt_mod, "trt", mock_trt), \ + patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3): + # Act + engine_bytes, filename = factory.build_from_source(b"onnx", loader, "models") + + # Assert + assert engine_bytes == b"engine_bytes" + int8_calls = [c for c in mock_config.set_flag.call_args_list if c == call("INT8")] + assert len(int8_calls) == 0 + + +@requires_tensorrt +def test_jetson_factory_cleans_up_cache_tempfile_after_build(): + # Arrange + from engines.engine_factory import JetsonTensorRTEngineFactory + from engines.tensorrt_engine import TensorRTEngine + import engines.tensorrt_engine as trt_mod + + mock_trt, _, _ = _make_mock_trt() + factory = JetsonTensorRTEngineFactory() + loader = MagicMock() + result = MagicMock() + result.err = None + result.data = b"calib_data" + loader.load_big_small_resource.return_value = result + + written_paths = [] + original_download = factory._download_calib_cache + + def tracking_download(lc, md): + path = original_download(lc, md) + if path: + written_paths.append(path) + return path + + with patch.object(trt_mod, "trt", mock_trt), \ + patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3), \ + patch.object(factory, "_download_calib_cache", side_effect=tracking_download): + factory.build_from_source(b"onnx", loader, "models") + + # Assert: temp file was deleted after build + for p in written_paths: + assert not os.path.exists(p) + + +def test_is_jetson_false_on_non_aarch64(): + # Arrange + import engines as eng + + with patch("engines.platform") as mock_platform, \ + patch("engines.tensor_gpu_index", 0), \ + patch("engines.os.path.isfile", return_value=True): + mock_platform.machine.return_value = "x86_64" + + # Assert + assert eng._is_jetson() is False + + +def test_is_jetson_false_when_no_gpu(): + # Arrange + import engines as eng + + with patch("engines.platform") as mock_platform, \ + patch("engines.tensor_gpu_index", -1), \ + patch("engines.os.path.isfile", return_value=True): + mock_platform.machine.return_value = "aarch64" + + # Assert + assert eng._is_jetson() is False + + +def test_is_jetson_false_when_no_tegra_release(): + # Arrange + import engines as eng + + with patch("engines.platform") as mock_platform, \ + patch("engines.tensor_gpu_index", 0), \ + patch("engines.os.path.isfile", return_value=False): + mock_platform.machine.return_value = "aarch64" + + # Assert + assert eng._is_jetson() is False