mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 22:16:31 +00:00
8baa96978b
- Updated the detection image endpoint to require a channel ID for event streaming. - Introduced a new endpoint for streaming detection events, allowing clients to receive real-time updates. - Enhanced the internal buffering mechanism for detection events to manage multiple channels. - Refactored the inference module to support the new event handling structure. Made-with: Cursor
340 lines
9.1 KiB
Python
340 lines
9.1 KiB
Python
import json
|
|
import os
|
|
import random
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from contextlib import contextmanager
|
|
from pathlib import Path
|
|
|
|
import jwt as pyjwt
|
|
import pytest
|
|
import requests
|
|
import sseclient
|
|
from pytest import ExitCode
|
|
|
|
|
|
def pytest_collection_modifyitems(items):
|
|
early = []
|
|
rest = []
|
|
for item in items:
|
|
if "Step01PreInit" in item.nodeid or "Step02LazyInit" in item.nodeid:
|
|
early.append(item)
|
|
else:
|
|
rest.append(item)
|
|
items[:] = early + rest
|
|
|
|
|
|
@pytest.hookimpl(trylast=True)
|
|
def pytest_sessionfinish(session, exitstatus):
|
|
if exitstatus in (ExitCode.NO_TESTS_COLLECTED, 5):
|
|
session.exitstatus = ExitCode.OK
|
|
|
|
|
|
class _SessionWithBase(requests.Session):
|
|
def __init__(self, base: str, default_timeout: float = 30):
|
|
super().__init__()
|
|
self._base = base.rstrip("/")
|
|
self._default_timeout = default_timeout
|
|
|
|
def request(self, method, url, *args, **kwargs):
|
|
if url.startswith("http://") or url.startswith("https://"):
|
|
full = url
|
|
else:
|
|
path = url if url.startswith("/") else f"/{url}"
|
|
full = f"{self._base}{path}"
|
|
kwargs.setdefault("timeout", self._default_timeout)
|
|
return super().request(method, full, *args, **kwargs)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def base_url():
|
|
return os.environ.get("BASE_URL", "http://detections:8080")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def http_client(base_url):
|
|
return _SessionWithBase(base_url, 30)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def jwt_secret():
|
|
return os.environ.get("JWT_SECRET", "")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def jwt_token(jwt_secret):
|
|
if not jwt_secret:
|
|
return ""
|
|
return pyjwt.encode(
|
|
{"sub": "test-user", "exp": int(time.time()) + 3600},
|
|
jwt_secret,
|
|
algorithm="HS256",
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def auth_headers(jwt_token):
|
|
return {"Authorization": f"Bearer {jwt_token}"} if jwt_token else {}
|
|
|
|
|
|
@pytest.fixture
|
|
def channel_id():
|
|
return str(uuid.uuid4())
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def image_detect(http_client, auth_headers):
|
|
def _detect(image_bytes, filename="img.jpg", config=None, timeout=30):
|
|
cid = str(uuid.uuid4())
|
|
headers = {**auth_headers, "X-Channel-Id": cid}
|
|
detections = []
|
|
errors = []
|
|
done = threading.Event()
|
|
connected = threading.Event()
|
|
|
|
def _listen():
|
|
try:
|
|
with http_client.get(
|
|
f"/detect/events/{cid}",
|
|
stream=True,
|
|
timeout=timeout + 2,
|
|
headers=auth_headers,
|
|
) as resp:
|
|
resp.raise_for_status()
|
|
connected.set()
|
|
for ev in sseclient.SSEClient(resp).events():
|
|
if not ev.data or not str(ev.data).strip():
|
|
continue
|
|
data = json.loads(ev.data)
|
|
if data.get("mediaStatus") == "AIProcessing":
|
|
detections.extend(data.get("annotations", []))
|
|
if data.get("mediaStatus") in ("AIProcessed", "Error"):
|
|
break
|
|
except BaseException as e:
|
|
errors.append(e)
|
|
finally:
|
|
connected.set()
|
|
done.set()
|
|
|
|
th = threading.Thread(target=_listen, daemon=True)
|
|
th.start()
|
|
connected.wait(timeout=5)
|
|
|
|
data_form = {}
|
|
if config:
|
|
data_form["config"] = config
|
|
|
|
t0 = time.perf_counter()
|
|
r = http_client.post(
|
|
"/detect/image",
|
|
files={"file": (filename, image_bytes, "image/jpeg")},
|
|
data=data_form,
|
|
headers=headers,
|
|
timeout=timeout,
|
|
)
|
|
done.wait(timeout=timeout)
|
|
elapsed_ms = (time.perf_counter() - t0) * 1000.0
|
|
|
|
assert r.status_code == 202, f"Expected 202, got {r.status_code}: {r.text}"
|
|
assert not errors, f"SSE errors: {errors}"
|
|
|
|
th.join(timeout=1)
|
|
return detections, elapsed_ms
|
|
|
|
return _detect
|
|
|
|
|
|
@pytest.fixture
|
|
def sse_client_factory(http_client, auth_headers):
|
|
@contextmanager
|
|
def _open(channel_id: str):
|
|
with http_client.get(
|
|
f"/detect/events/{channel_id}",
|
|
stream=True,
|
|
timeout=600,
|
|
headers=auth_headers,
|
|
) as resp:
|
|
resp.raise_for_status()
|
|
yield sseclient.SSEClient(resp)
|
|
|
|
return _open
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def mock_loader_url():
|
|
return os.environ.get("MOCK_LOADER_URL", "http://mock-loader:8080")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def mock_annotations_url():
|
|
return os.environ.get("MOCK_ANNOTATIONS_URL", "http://mock-annotations:8081")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def wait_for_services(base_url, mock_loader_url, mock_annotations_url):
|
|
urls = [
|
|
f"{base_url}/health",
|
|
f"{mock_loader_url}/mock/status",
|
|
f"{mock_annotations_url}/mock/status",
|
|
]
|
|
deadline = time.time() + 120
|
|
while time.time() < deadline:
|
|
ok = True
|
|
for u in urls:
|
|
try:
|
|
r = requests.get(u, timeout=5)
|
|
if r.status_code != 200:
|
|
ok = False
|
|
break
|
|
except OSError:
|
|
ok = False
|
|
break
|
|
if ok:
|
|
return
|
|
time.sleep(2)
|
|
pytest.fail("services not ready within 120s")
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def reset_mocks(mock_loader_url, mock_annotations_url):
|
|
requests.post(f"{mock_loader_url}/mock/reset", timeout=10)
|
|
requests.post(f"{mock_annotations_url}/mock/reset", timeout=10)
|
|
yield
|
|
|
|
|
|
def _media_dir() -> Path:
|
|
return Path(os.environ.get("MEDIA_DIR", "/media"))
|
|
|
|
|
|
def _read_media(name: str) -> bytes:
|
|
p = _media_dir() / name
|
|
if not p.is_file():
|
|
pytest.skip(f"missing {p}")
|
|
return p.read_bytes()
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def media_dir():
|
|
return str(_media_dir())
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def image_small():
|
|
return _read_media("image_small.jpg")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def image_large():
|
|
return _read_media("image_large.JPG")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def image_dense():
|
|
return _read_media("image_dense01.jpg")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def image_dense_02():
|
|
return _read_media("image_dense02.jpg")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def image_different_types():
|
|
return _read_media("image_different_types.jpg")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def image_empty_scene():
|
|
return _read_media("image_empty_scene.jpg")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def video_short_path():
|
|
return str(_media_dir() / "video_test01.mp4")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def video_short_02_path():
|
|
return str(_media_dir() / "video_short02.mp4")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def video_long_path():
|
|
return str(_media_dir() / "video_long03.mp4")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def empty_image():
|
|
return b""
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def corrupt_image():
|
|
random.seed(42)
|
|
return random.randbytes(1024)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def warm_engine(http_client, image_small, auth_headers):
|
|
deadline = time.time() + 120
|
|
last_status = None
|
|
consecutive_errors = 0
|
|
|
|
while time.time() < deadline:
|
|
cid = str(uuid.uuid4())
|
|
headers = {**auth_headers, "X-Channel-Id": cid}
|
|
done = threading.Event()
|
|
|
|
def _listen(cid=cid):
|
|
try:
|
|
with http_client.get(
|
|
f"/detect/events/{cid}",
|
|
stream=True,
|
|
timeout=35,
|
|
headers=auth_headers,
|
|
) as resp:
|
|
resp.raise_for_status()
|
|
for ev in sseclient.SSEClient(resp).events():
|
|
if not ev.data or not str(ev.data).strip():
|
|
continue
|
|
data = json.loads(ev.data)
|
|
if data.get("mediaStatus") == "AIProcessed":
|
|
break
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
done.set()
|
|
|
|
th = threading.Thread(target=_listen, daemon=True)
|
|
th.start()
|
|
time.sleep(0.1)
|
|
|
|
try:
|
|
r = http_client.post(
|
|
"/detect/image",
|
|
files={"file": ("warm.jpg", image_small, "image/jpeg")},
|
|
headers=headers,
|
|
)
|
|
last_status = r.status_code
|
|
if r.status_code == 202:
|
|
done.wait(timeout=30)
|
|
th.join(timeout=1)
|
|
return
|
|
if r.status_code >= 500:
|
|
consecutive_errors += 1
|
|
if consecutive_errors >= 5:
|
|
pytest.fail(
|
|
f"engine warm-up aborted: {consecutive_errors} consecutive "
|
|
f"HTTP {last_status} errors — server is broken, not starting up"
|
|
)
|
|
else:
|
|
consecutive_errors = 0
|
|
except OSError:
|
|
consecutive_errors = 0
|
|
|
|
th.join(timeout=1)
|
|
time.sleep(2)
|
|
|
|
pytest.fail(f"engine warm-up timed out after 120s (last status: {last_status})")
|