mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 05:26:32 +00:00
[AZ-178] Fix Critical/High security findings: auth, CVEs, non-root containers, per-job SSE
- Pin all deps; h11==0.16.0 (CVE-2025-43859), python-multipart>=1.3.1 (CVE-2026-28356), PyJWT==2.12.1
- Add HMAC JWT verification (require_auth FastAPI dependency, JWT_SECRET-gated)
- Fix TokenManager._refresh() to use ADMIN_API_URL instead of ANNOTATIONS_URL
- Rename POST /detect → POST /detect/image (image-only, rejects video files)
- Replace global SSE stream with per-job SSE: GET /detect/{media_id} with event replay buffer
- Apply require_auth to all 4 protected endpoints
- Fix on_annotation/on_status closure to use mutable current_id for correct post-upload event routing
- Add non-root appuser to Dockerfile and Dockerfile.gpu
- Add JWT_SECRET to e2e/docker-compose.test.yml and run-tests.sh
- Update all e2e tests and unit tests for new endpoints and HMAC token signing
- 64/64 tests pass
Made-with: Cursor
This commit is contained in:
@@ -0,0 +1,26 @@
|
|||||||
|
# Azaion.Detections — Environment Variables
|
||||||
|
# Copy to .env and fill in actual values
|
||||||
|
|
||||||
|
# External service URLs
|
||||||
|
LOADER_URL=http://loader:8080
|
||||||
|
ANNOTATIONS_URL=http://annotations:8080
|
||||||
|
|
||||||
|
# Authentication (HMAC secret shared with Admin API; if empty, auth is not enforced)
|
||||||
|
JWT_SECRET=
|
||||||
|
|
||||||
|
# Remote Admin API for token refresh (if empty, refresh is skipped — offline/field mode)
|
||||||
|
ADMIN_API_URL=
|
||||||
|
|
||||||
|
# File paths
|
||||||
|
CLASSES_JSON_PATH=classes.json
|
||||||
|
LOG_DIR=Logs
|
||||||
|
VIDEOS_DIR=./data/videos
|
||||||
|
IMAGES_DIR=./data/images
|
||||||
|
|
||||||
|
# Container registry (for deployment scripts)
|
||||||
|
REGISTRY=your-registry.example.com
|
||||||
|
IMAGE_TAG=latest
|
||||||
|
|
||||||
|
# Remote deployment target (for deploy.sh)
|
||||||
|
DEPLOY_HOST=
|
||||||
|
DEPLOY_USER=deploy
|
||||||
@@ -23,6 +23,7 @@ venv/
|
|||||||
env/
|
env/
|
||||||
.env
|
.env
|
||||||
.env.*
|
.env.*
|
||||||
|
!.env.example
|
||||||
*.env.local
|
*.env.local
|
||||||
pip-log.txt
|
pip-log.txt
|
||||||
pip-delete-this-directory.txt
|
pip-delete-this-directory.txt
|
||||||
|
|||||||
@@ -6,5 +6,8 @@ RUN pip install --no-cache-dir -r requirements.txt
|
|||||||
COPY . .
|
COPY . .
|
||||||
RUN python setup.py build_ext --inplace
|
RUN python setup.py build_ext --inplace
|
||||||
ENV PYTHONPATH=/app/src
|
ENV PYTHONPATH=/app/src
|
||||||
|
RUN adduser --disabled-password --no-create-home --gecos "" appuser \
|
||||||
|
&& chown -R appuser /app
|
||||||
|
USER appuser
|
||||||
EXPOSE 8080
|
EXPOSE 8080
|
||||||
CMD ["python", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8080"]
|
CMD ["python", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8080"]
|
||||||
|
|||||||
@@ -6,5 +6,8 @@ RUN pip3 install --no-cache-dir -r requirements-gpu.txt
|
|||||||
COPY . .
|
COPY . .
|
||||||
RUN python3 setup.py build_ext --inplace
|
RUN python3 setup.py build_ext --inplace
|
||||||
ENV PYTHONPATH=/app/src
|
ENV PYTHONPATH=/app/src
|
||||||
|
RUN adduser --disabled-password --no-create-home --gecos "" appuser \
|
||||||
|
&& chown -R appuser /app
|
||||||
|
USER appuser
|
||||||
EXPOSE 8080
|
EXPOSE 8080
|
||||||
CMD ["python3", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8080"]
|
CMD ["python3", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8080"]
|
||||||
|
|||||||
@@ -2,16 +2,16 @@
|
|||||||
|
|
||||||
**Date**: 2026-03-31
|
**Date**: 2026-03-31
|
||||||
**Scope**: Azaion.Detections (full codebase)
|
**Scope**: Azaion.Detections (full codebase)
|
||||||
**Verdict**: FAIL
|
**Verdict**: FAIL → **REMEDIATED** (Critical/High resolved 2026-04-01)
|
||||||
|
|
||||||
## Summary
|
## Summary
|
||||||
|
|
||||||
| Severity | Count |
|
| Severity | Count | Resolved |
|
||||||
|----------|-------|
|
|----------|-------|---------|
|
||||||
| Critical | 1 |
|
| Critical | 1 | 1 ✓ |
|
||||||
| High | 3 |
|
| High | 3 | 3 ✓ |
|
||||||
| Medium | 5 |
|
| Medium | 5 | 2 ✓ |
|
||||||
| Low | 5 |
|
| Low | 5 | — |
|
||||||
|
|
||||||
## OWASP Top 10 Assessment
|
## OWASP Top 10 Assessment
|
||||||
|
|
||||||
@@ -30,22 +30,22 @@
|
|||||||
|
|
||||||
## Findings
|
## Findings
|
||||||
|
|
||||||
| # | Severity | Category | Location | Title |
|
| # | Severity | Category | Location | Title | Status |
|
||||||
|---|----------|----------|----------|-------|
|
|---|----------|----------|----------|-------|--------|
|
||||||
| 1 | Critical | A03 Supply Chain | requirements.txt (uvicorn→h11) | HTTP request smuggling via h11 CVE-2025-43859 |
|
| 1 | Critical | A03 Supply Chain | requirements.txt (uvicorn→h11) | HTTP request smuggling via h11 CVE-2025-43859 | **FIXED** — pinned h11==0.16.0 |
|
||||||
| 2 | High | A04 Crypto | src/main.py:67-99 | JWT decoded without signature verification |
|
| 2 | High | A04 Crypto | src/main.py | JWT decoded without signature verification | **FIXED** — PyJWT HMAC verification |
|
||||||
| 3 | High | A01 Access Control | src/main.py (all routes) | No authentication required on any endpoint |
|
| 3 | High | A01 Access Control | src/main.py (all routes) | No authentication required on any endpoint | **FIXED** — require_auth dependency on all protected endpoints |
|
||||||
| 4 | High | A03 Supply Chain | requirements.txt (python-multipart) | ReDoS via python-multipart CVE-2026-28356 |
|
| 4 | High | A03 Supply Chain | requirements.txt (python-multipart) | ReDoS via python-multipart CVE-2026-28356 | **FIXED** — pinned python-multipart>=1.3.1 |
|
||||||
| 5 | Medium | A01 Access Control | src/main.py:608-627 | SSE stream broadcasts cross-user data |
|
| 5 | Medium | A01 Access Control | src/main.py | SSE stream broadcasts cross-user data | **FIXED** — per-job SSE (GET /detect/{media_id}), each client sees only their job |
|
||||||
| 6 | Medium | A06 Insecure Design | src/main.py:348-469 | No rate limiting on inference endpoints |
|
| 6 | Medium | A06 Insecure Design | src/main.py | No rate limiting on inference endpoints | Open — out of scope for this cycle |
|
||||||
| 7 | Medium | A02 Misconfig | Dockerfile, Dockerfile.gpu | Containers run as root |
|
| 7 | Medium | A02 Misconfig | Dockerfile, Dockerfile.gpu | Containers run as root | **FIXED** — non-root appuser added |
|
||||||
| 8 | Medium | A03 Supply Chain | requirements.txt | Unpinned critical dependencies |
|
| 8 | Medium | A03 Supply Chain | requirements.txt | Unpinned critical dependencies | **FIXED** — all deps pinned |
|
||||||
| 9 | Medium | A02 Misconfig | Dockerfile, Dockerfile.gpu | No TLS and no security headers |
|
| 9 | Medium | A02 Misconfig | Dockerfile, Dockerfile.gpu | No TLS and no security headers | Open — handled at infra/proxy level |
|
||||||
| 10 | Low | A06 Insecure Design | src/main.py:357 | No request body size limit |
|
| 10 | Low | A06 Insecure Design | src/main.py | No request body size limit | Open |
|
||||||
| 11 | Low | A10 Exceptions | src/main.py:63,490 | Silent exception swallowing |
|
| 11 | Low | A10 Exceptions | src/main.py | Silent exception swallowing | Open |
|
||||||
| 12 | Low | A09 Logging | src/main.py | Security events not logged |
|
| 12 | Low | A09 Logging | src/main.py | Security events not logged | Open |
|
||||||
| 13 | Low | A01 Access Control | src/main.py:449-450 | Exception details leaked in responses |
|
| 13 | Low | A01 Access Control | src/main.py | Exception details leaked in responses | Open |
|
||||||
| 14 | Low | A07 Auth | src/main.py:54-64 | Token refresh failure silently ignored |
|
| 14 | Low | A07 Auth | src/main.py | Token refresh failure silently ignored | Open (by design for offline mode) |
|
||||||
|
|
||||||
### Finding Details
|
### Finding Details
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
# Autopilot State
|
# Autopilot State
|
||||||
|
|
||||||
## Current Step
|
## Current Step
|
||||||
flow: existing-code
|
flow: existing-code
|
||||||
step: 14
|
step: 14
|
||||||
@@ -14,5 +13,5 @@ step: 8 (New Task) — DONE (AZ-178 defined)
|
|||||||
step: 9 (Implement) — DONE (implementation_report_streaming_video.md, 67/67 tests pass)
|
step: 9 (Implement) — DONE (implementation_report_streaming_video.md, 67/67 tests pass)
|
||||||
step: 10 (Run Tests) — DONE (67 passed, 0 failed)
|
step: 10 (Run Tests) — DONE (67 passed, 0 failed)
|
||||||
step: 11 (Update Docs) — DONE (docs updated during step 9 implementation)
|
step: 11 (Update Docs) — DONE (docs updated during step 9 implementation)
|
||||||
step: 12 (Security Audit) — SKIPPED (previous cycle audit complete; no new auth surface)
|
step: 12 (Security Audit) — DONE (Critical/High findings remediated 2026-04-01; 64/64 tests pass)
|
||||||
step: 13 (Performance Test) — SKIPPED (500ms latency validated by real-video integration test)
|
step: 13 (Performance Test) — SKIPPED (500ms latency validated by real-video integration test)
|
||||||
|
|||||||
+28
-22
@@ -1,11 +1,10 @@
|
|||||||
import base64
|
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import jwt as pyjwt
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
import sseclient
|
import sseclient
|
||||||
@@ -55,11 +54,33 @@ def http_client(base_url):
|
|||||||
return _SessionWithBase(base_url, 30)
|
return _SessionWithBase(base_url, 30)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def jwt_secret():
|
||||||
|
return os.environ.get("JWT_SECRET", "")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def jwt_token(jwt_secret):
|
||||||
|
if not jwt_secret:
|
||||||
|
return ""
|
||||||
|
return pyjwt.encode(
|
||||||
|
{"sub": "test-user", "exp": int(time.time()) + 3600},
|
||||||
|
jwt_secret,
|
||||||
|
algorithm="HS256",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def auth_headers(jwt_token):
|
||||||
|
return {"Authorization": f"Bearer {jwt_token}"} if jwt_token else {}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sse_client_factory(http_client):
|
def sse_client_factory(http_client, auth_headers):
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _open():
|
def _open(media_id: str):
|
||||||
with http_client.get("/detect/stream", stream=True, timeout=600) as resp:
|
with http_client.get(f"/detect/{media_id}", stream=True,
|
||||||
|
timeout=600, headers=auth_headers) as resp:
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
yield sseclient.SSEClient(resp)
|
yield sseclient.SSEClient(resp)
|
||||||
|
|
||||||
@@ -180,31 +201,16 @@ def corrupt_image():
|
|||||||
return random.randbytes(1024)
|
return random.randbytes(1024)
|
||||||
|
|
||||||
|
|
||||||
def _b64url_obj(obj: dict) -> str:
|
|
||||||
raw = json.dumps(obj, separators=(",", ":")).encode()
|
|
||||||
return base64.urlsafe_b64encode(raw).decode().rstrip("=")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def jwt_token():
|
|
||||||
header = (
|
|
||||||
base64.urlsafe_b64encode(json.dumps({"alg": "none", "typ": "JWT"}).encode())
|
|
||||||
.decode()
|
|
||||||
.rstrip("=")
|
|
||||||
)
|
|
||||||
payload = _b64url_obj({"exp": int(time.time()) + 3600, "sub": "test"})
|
|
||||||
return f"{header}.{payload}.signature"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def warm_engine(http_client, image_small):
|
def warm_engine(http_client, image_small, auth_headers):
|
||||||
deadline = time.time() + 120
|
deadline = time.time() + 120
|
||||||
files = {"file": ("warm.jpg", image_small, "image/jpeg")}
|
files = {"file": ("warm.jpg", image_small, "image/jpeg")}
|
||||||
consecutive_errors = 0
|
consecutive_errors = 0
|
||||||
last_status = None
|
last_status = None
|
||||||
while time.time() < deadline:
|
while time.time() < deadline:
|
||||||
try:
|
try:
|
||||||
r = http_client.post("/detect", files=files)
|
r = http_client.post("/detect/image", files=files, headers=auth_headers)
|
||||||
if r.status_code == 200:
|
if r.status_code == 200:
|
||||||
return
|
return
|
||||||
last_status = r.status_code
|
last_status = r.status_code
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
LOADER_URL: http://mock-loader:8080
|
LOADER_URL: http://mock-loader:8080
|
||||||
ANNOTATIONS_URL: http://mock-annotations:8081
|
ANNOTATIONS_URL: http://mock-annotations:8081
|
||||||
|
JWT_SECRET: test-secret-e2e-only
|
||||||
volumes:
|
volumes:
|
||||||
- ./fixtures/classes.json:/app/classes.json
|
- ./fixtures/classes.json:/app/classes.json
|
||||||
- ./fixtures:/media
|
- ./fixtures:/media
|
||||||
@@ -47,6 +48,7 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
LOADER_URL: http://mock-loader:8080
|
LOADER_URL: http://mock-loader:8080
|
||||||
ANNOTATIONS_URL: http://mock-annotations:8081
|
ANNOTATIONS_URL: http://mock-annotations:8081
|
||||||
|
JWT_SECRET: test-secret-e2e-only
|
||||||
volumes:
|
volumes:
|
||||||
- ./fixtures/classes.json:/app/classes.json
|
- ./fixtures/classes.json:/app/classes.json
|
||||||
- ./fixtures:/media
|
- ./fixtures:/media
|
||||||
@@ -64,6 +66,8 @@ services:
|
|||||||
depends_on:
|
depends_on:
|
||||||
- mock-loader
|
- mock-loader
|
||||||
- mock-annotations
|
- mock-annotations
|
||||||
|
environment:
|
||||||
|
JWT_SECRET: test-secret-e2e-only
|
||||||
volumes:
|
volumes:
|
||||||
- ./fixtures:/media
|
- ./fixtures:/media
|
||||||
- ./results:/results
|
- ./results:/results
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ def test_ft_p09_sse_event_delivery(
|
|||||||
|
|
||||||
def _listen():
|
def _listen():
|
||||||
try:
|
try:
|
||||||
with sse_client_factory() as sse:
|
with sse_client_factory(media_id) as sse:
|
||||||
time.sleep(0.3)
|
time.sleep(0.3)
|
||||||
for event in sse.events():
|
for event in sse.events():
|
||||||
if not event.data or not str(event.data).strip():
|
if not event.data or not str(event.data).strip():
|
||||||
|
|||||||
@@ -33,14 +33,14 @@ class TestHealthEngineStep01PreInit:
|
|||||||
@pytest.mark.cpu
|
@pytest.mark.cpu
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
class TestHealthEngineStep02LazyInit:
|
class TestHealthEngineStep02LazyInit:
|
||||||
def test_ft_p_14_lazy_initialization(self, http_client, image_small):
|
def test_ft_p_14_lazy_initialization(self, http_client, image_small, auth_headers):
|
||||||
before = _get_health(http_client)
|
before = _get_health(http_client)
|
||||||
assert before["aiAvailability"] == "None", (
|
assert before["aiAvailability"] == "None", (
|
||||||
f"engine already initialized (aiAvailability={before['aiAvailability']}); "
|
f"engine already initialized (aiAvailability={before['aiAvailability']}); "
|
||||||
"lazy-init test must run before any test that triggers warm_engine"
|
"lazy-init test must run before any test that triggers warm_engine"
|
||||||
)
|
)
|
||||||
files = {"file": ("lazy.jpg", image_small, "image/jpeg")}
|
files = {"file": ("lazy.jpg", image_small, "image/jpeg")}
|
||||||
r = http_client.post("/detect", files=files, timeout=_DETECT_TIMEOUT)
|
r = http_client.post("/detect/image", files=files, headers=auth_headers, timeout=_DETECT_TIMEOUT)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
body = r.json()
|
body = r.json()
|
||||||
assert isinstance(body, list)
|
assert isinstance(body, list)
|
||||||
@@ -60,9 +60,9 @@ class TestHealthEngineStep03Warmed:
|
|||||||
_assert_active_ai(data)
|
_assert_active_ai(data)
|
||||||
assert data.get("errorMessage") is None
|
assert data.get("errorMessage") is None
|
||||||
|
|
||||||
def test_ft_p_15_onnx_cpu_detect(self, http_client, image_small):
|
def test_ft_p_15_onnx_cpu_detect(self, http_client, image_small, auth_headers):
|
||||||
files = {"file": ("onnx.jpg", image_small, "image/jpeg")}
|
files = {"file": ("onnx.jpg", image_small, "image/jpeg")}
|
||||||
r = http_client.post("/detect", files=files, timeout=_DETECT_TIMEOUT)
|
r = http_client.post("/detect/image", files=files, headers=auth_headers, timeout=_DETECT_TIMEOUT)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
body = r.json()
|
body = r.json()
|
||||||
assert isinstance(body, list)
|
assert isinstance(body, list)
|
||||||
|
|||||||
@@ -13,9 +13,9 @@ def _assert_health_200(http_client):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.cpu
|
@pytest.mark.cpu
|
||||||
def test_ft_n_01_empty_image_returns_400(http_client, empty_image):
|
def test_ft_n_01_empty_image_returns_400(http_client, empty_image, auth_headers):
|
||||||
files = {"file": ("empty.jpg", empty_image, "image/jpeg")}
|
files = {"file": ("empty.jpg", empty_image, "image/jpeg")}
|
||||||
r = http_client.post("/detect", files=files, timeout=30)
|
r = http_client.post("/detect/image", files=files, headers=auth_headers, timeout=30)
|
||||||
assert r.status_code == 400
|
assert r.status_code == 400
|
||||||
body = r.json()
|
body = r.json()
|
||||||
assert "detail" in body
|
assert "detail" in body
|
||||||
@@ -24,9 +24,9 @@ def test_ft_n_01_empty_image_returns_400(http_client, empty_image):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.cpu
|
@pytest.mark.cpu
|
||||||
def test_ft_n_02_corrupt_image_returns_400_or_422(http_client, corrupt_image):
|
def test_ft_n_02_corrupt_image_returns_400_or_422(http_client, corrupt_image, auth_headers):
|
||||||
files = {"file": ("corrupt.jpg", corrupt_image, "image/jpeg")}
|
files = {"file": ("corrupt.jpg", corrupt_image, "image/jpeg")}
|
||||||
r = http_client.post("/detect", files=files, timeout=30)
|
r = http_client.post("/detect/image", files=files, headers=auth_headers, timeout=30)
|
||||||
assert r.status_code in (400, 422)
|
assert r.status_code in (400, 422)
|
||||||
body = r.json()
|
body = r.json()
|
||||||
assert "detail" in body
|
assert "detail" in body
|
||||||
@@ -35,14 +35,12 @@ def test_ft_n_02_corrupt_image_returns_400_or_422(http_client, corrupt_image):
|
|||||||
|
|
||||||
@pytest.mark.cpu
|
@pytest.mark.cpu
|
||||||
def test_ft_n_03_loader_error_mode_detect_does_not_500(
|
def test_ft_n_03_loader_error_mode_detect_does_not_500(
|
||||||
http_client, mock_loader_url, image_small
|
http_client, mock_loader_url, image_small, auth_headers
|
||||||
):
|
):
|
||||||
cfg = requests.post(
|
cfg = requests.post(
|
||||||
f"{mock_loader_url}/mock/config", json={"mode": "error"}, timeout=10
|
f"{mock_loader_url}/mock/config", json={"mode": "error"}, timeout=10
|
||||||
)
|
)
|
||||||
cfg.raise_for_status()
|
cfg.raise_for_status()
|
||||||
files = {"file": ("small.jpg", image_small, "image/jpeg")}
|
files = {"file": ("small.jpg", image_small, "image/jpeg")}
|
||||||
r = http_client.post("/detect", files=files, timeout=_DETECT_TIMEOUT)
|
r = http_client.post("/detect/image", files=files, headers=auth_headers, timeout=_DETECT_TIMEOUT)
|
||||||
assert r.status_code != 500
|
assert r.status_code != 500
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -19,14 +19,15 @@ def _percentile_ms(sorted_ms, p):
|
|||||||
|
|
||||||
@pytest.mark.timeout(60)
|
@pytest.mark.timeout(60)
|
||||||
def test_nft_perf_01_single_image_latency_p95(
|
def test_nft_perf_01_single_image_latency_p95(
|
||||||
warm_engine, http_client, image_small
|
warm_engine, http_client, image_small, auth_headers
|
||||||
):
|
):
|
||||||
times_ms = []
|
times_ms = []
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
r = http_client.post(
|
r = http_client.post(
|
||||||
"/detect",
|
"/detect/image",
|
||||||
files={"file": ("img.jpg", image_small, "image/jpeg")},
|
files={"file": ("img.jpg", image_small, "image/jpeg")},
|
||||||
|
headers=auth_headers,
|
||||||
timeout=8,
|
timeout=8,
|
||||||
)
|
)
|
||||||
elapsed_ms = (time.perf_counter() - t0) * 1000.0
|
elapsed_ms = (time.perf_counter() - t0) * 1000.0
|
||||||
@@ -46,12 +47,13 @@ def test_nft_perf_01_single_image_latency_p95(
|
|||||||
|
|
||||||
@pytest.mark.timeout(60)
|
@pytest.mark.timeout(60)
|
||||||
def test_nft_perf_03_tiling_overhead_large_image(
|
def test_nft_perf_03_tiling_overhead_large_image(
|
||||||
warm_engine, http_client, image_small, image_large
|
warm_engine, http_client, image_small, image_large, auth_headers
|
||||||
):
|
):
|
||||||
t_small = time.perf_counter()
|
t_small = time.perf_counter()
|
||||||
r_small = http_client.post(
|
r_small = http_client.post(
|
||||||
"/detect",
|
"/detect/image",
|
||||||
files={"file": ("small.jpg", image_small, "image/jpeg")},
|
files={"file": ("small.jpg", image_small, "image/jpeg")},
|
||||||
|
headers=auth_headers,
|
||||||
timeout=8,
|
timeout=8,
|
||||||
)
|
)
|
||||||
small_ms = (time.perf_counter() - t_small) * 1000.0
|
small_ms = (time.perf_counter() - t_small) * 1000.0
|
||||||
@@ -61,9 +63,10 @@ def test_nft_perf_03_tiling_overhead_large_image(
|
|||||||
)
|
)
|
||||||
t_large = time.perf_counter()
|
t_large = time.perf_counter()
|
||||||
r_large = http_client.post(
|
r_large = http_client.post(
|
||||||
"/detect",
|
"/detect/image",
|
||||||
files={"file": ("large.jpg", image_large, "image/jpeg")},
|
files={"file": ("large.jpg", image_large, "image/jpeg")},
|
||||||
data={"config": config},
|
data={"config": config},
|
||||||
|
headers=auth_headers,
|
||||||
timeout=20,
|
timeout=20,
|
||||||
)
|
)
|
||||||
large_ms = (time.perf_counter() - t_large) * 1000.0
|
large_ms = (time.perf_counter() - t_large) * 1000.0
|
||||||
|
|||||||
@@ -5,13 +5,13 @@ _DETECT_TIMEOUT = 60
|
|||||||
|
|
||||||
|
|
||||||
def test_nft_res_01_loader_outage_after_init(
|
def test_nft_res_01_loader_outage_after_init(
|
||||||
warm_engine, http_client, mock_loader_url, image_small
|
warm_engine, http_client, mock_loader_url, image_small, auth_headers
|
||||||
):
|
):
|
||||||
requests.post(
|
requests.post(
|
||||||
f"{mock_loader_url}/mock/config", json={"mode": "error"}, timeout=10
|
f"{mock_loader_url}/mock/config", json={"mode": "error"}, timeout=10
|
||||||
).raise_for_status()
|
).raise_for_status()
|
||||||
files = {"file": ("r1.jpg", image_small, "image/jpeg")}
|
files = {"file": ("r1.jpg", image_small, "image/jpeg")}
|
||||||
r = http_client.post("/detect", files=files, timeout=_DETECT_TIMEOUT)
|
r = http_client.post("/detect/image", files=files, headers=auth_headers, timeout=_DETECT_TIMEOUT)
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
assert isinstance(r.json(), list)
|
assert isinstance(r.json(), list)
|
||||||
h = http_client.get("/health")
|
h = http_client.get("/health")
|
||||||
@@ -22,17 +22,15 @@ def test_nft_res_01_loader_outage_after_init(
|
|||||||
|
|
||||||
|
|
||||||
def test_nft_res_03_transient_loader_first_fail(
|
def test_nft_res_03_transient_loader_first_fail(
|
||||||
mock_loader_url, http_client, image_small
|
mock_loader_url, http_client, image_small, auth_headers
|
||||||
):
|
):
|
||||||
requests.post(
|
requests.post(
|
||||||
f"{mock_loader_url}/mock/config", json={"mode": "first_fail"}, timeout=10
|
f"{mock_loader_url}/mock/config", json={"mode": "first_fail"}, timeout=10
|
||||||
).raise_for_status()
|
).raise_for_status()
|
||||||
files = {"file": ("r3a.jpg", image_small, "image/jpeg")}
|
files = {"file": ("r3a.jpg", image_small, "image/jpeg")}
|
||||||
r1 = http_client.post("/detect", files=files, timeout=_DETECT_TIMEOUT)
|
r1 = http_client.post("/detect/image", files=files, headers=auth_headers, timeout=_DETECT_TIMEOUT)
|
||||||
files2 = {"file": ("r3b.jpg", image_small, "image/jpeg")}
|
files2 = {"file": ("r3b.jpg", image_small, "image/jpeg")}
|
||||||
r2 = http_client.post("/detect", files=files2, timeout=_DETECT_TIMEOUT)
|
r2 = http_client.post("/detect/image", files=files2, headers=auth_headers, timeout=_DETECT_TIMEOUT)
|
||||||
assert r2.status_code == 200
|
assert r2.status_code == 200
|
||||||
if r1.status_code != 200:
|
if r1.status_code != 200:
|
||||||
assert r1.status_code != 500
|
assert r1.status_code != 500
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,11 +8,12 @@ import pytest
|
|||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.timeout(120)
|
@pytest.mark.timeout(120)
|
||||||
def test_nft_res_lim_03_max_detections_per_frame(
|
def test_nft_res_lim_03_max_detections_per_frame(
|
||||||
warm_engine, http_client, image_dense
|
warm_engine, http_client, image_dense, auth_headers
|
||||||
):
|
):
|
||||||
r = http_client.post(
|
r = http_client.post(
|
||||||
"/detect",
|
"/detect/image",
|
||||||
files={"file": ("img.jpg", image_dense, "image/jpeg")},
|
files={"file": ("img.jpg", image_dense, "image/jpeg")},
|
||||||
|
headers=auth_headers,
|
||||||
timeout=120,
|
timeout=120,
|
||||||
)
|
)
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
@@ -22,10 +23,11 @@ def test_nft_res_lim_03_max_detections_per_frame(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
def test_nft_res_lim_04_log_file_rotation(warm_engine, http_client, image_small):
|
def test_nft_res_lim_04_log_file_rotation(warm_engine, http_client, image_small, auth_headers):
|
||||||
http_client.post(
|
http_client.post(
|
||||||
"/detect",
|
"/detect/image",
|
||||||
files={"file": ("img.jpg", image_small, "image/jpeg")},
|
files={"file": ("img.jpg", image_small, "image/jpeg")},
|
||||||
|
headers=auth_headers,
|
||||||
timeout=60,
|
timeout=60,
|
||||||
)
|
)
|
||||||
candidates = [
|
candidates = [
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import requests
|
|||||||
|
|
||||||
|
|
||||||
def test_nft_sec_01_malformed_multipart(base_url, http_client):
|
def test_nft_sec_01_malformed_multipart(base_url, http_client):
|
||||||
url = f"{base_url.rstrip('/')}/detect"
|
url = f"{base_url.rstrip('/')}/detect/image"
|
||||||
r1 = requests.post(
|
r1 = requests.post(
|
||||||
url,
|
url,
|
||||||
data=b"not-multipart-body",
|
data=b"not-multipart-body",
|
||||||
@@ -25,18 +25,19 @@ def test_nft_sec_01_malformed_multipart(base_url, http_client):
|
|||||||
files={"file": ("", b"", "")},
|
files={"file": ("", b"", "")},
|
||||||
timeout=30,
|
timeout=30,
|
||||||
)
|
)
|
||||||
assert r3.status_code in (400, 422)
|
assert r3.status_code in (400, 401, 422)
|
||||||
assert http_client.get("/health").status_code == 200
|
assert http_client.get("/health").status_code == 200
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(30)
|
@pytest.mark.timeout(30)
|
||||||
def test_nft_sec_02_oversized_request(http_client):
|
def test_nft_sec_02_oversized_request(http_client, auth_headers):
|
||||||
large = os.urandom(50 * 1024 * 1024)
|
large = os.urandom(50 * 1024 * 1024)
|
||||||
try:
|
try:
|
||||||
r = http_client.post(
|
r = http_client.post(
|
||||||
"/detect",
|
"/detect/image",
|
||||||
files={"file": ("large.jpg", large, "image/jpeg")},
|
files={"file": ("large.jpg", large, "image/jpeg")},
|
||||||
timeout=15,
|
headers=auth_headers,
|
||||||
|
timeout=15,
|
||||||
)
|
)
|
||||||
except requests.RequestException:
|
except requests.RequestException:
|
||||||
pass
|
pass
|
||||||
@@ -44,5 +45,3 @@ def test_nft_sec_02_oversized_request(http_client):
|
|||||||
assert r.status_code != 500
|
assert r.status_code != 500
|
||||||
assert r.status_code in (413, 400, 422)
|
assert r.status_code in (413, 400, 422)
|
||||||
assert http_client.get("/health").status_code == 200
|
assert http_client.get("/health").status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -81,10 +81,11 @@ def _weather_label_ok(label, base_names):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
def test_ft_p_03_detection_response_structure_ac1(http_client, image_small, warm_engine):
|
def test_ft_p_03_detection_response_structure_ac1(http_client, image_small, warm_engine, auth_headers):
|
||||||
r = http_client.post(
|
r = http_client.post(
|
||||||
"/detect",
|
"/detect/image",
|
||||||
files={"file": ("img.jpg", image_small, "image/jpeg")},
|
files={"file": ("img.jpg", image_small, "image/jpeg")},
|
||||||
|
headers=auth_headers,
|
||||||
)
|
)
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
body = r.json()
|
body = r.json()
|
||||||
@@ -105,12 +106,13 @@ def test_ft_p_03_detection_response_structure_ac1(http_client, image_small, warm
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
def test_ft_p_05_confidence_filtering_ac2(http_client, image_small, warm_engine):
|
def test_ft_p_05_confidence_filtering_ac2(http_client, image_small, warm_engine, auth_headers):
|
||||||
cfg_hi = json.dumps({"probability_threshold": 0.8})
|
cfg_hi = json.dumps({"probability_threshold": 0.8})
|
||||||
r_hi = http_client.post(
|
r_hi = http_client.post(
|
||||||
"/detect",
|
"/detect/image",
|
||||||
files={"file": ("img.jpg", image_small, "image/jpeg")},
|
files={"file": ("img.jpg", image_small, "image/jpeg")},
|
||||||
data={"config": cfg_hi},
|
data={"config": cfg_hi},
|
||||||
|
headers=auth_headers,
|
||||||
)
|
)
|
||||||
assert r_hi.status_code == 200
|
assert r_hi.status_code == 200
|
||||||
hi = r_hi.json()
|
hi = r_hi.json()
|
||||||
@@ -119,9 +121,10 @@ def test_ft_p_05_confidence_filtering_ac2(http_client, image_small, warm_engine)
|
|||||||
assert float(d["confidence"]) + _EPS >= 0.8
|
assert float(d["confidence"]) + _EPS >= 0.8
|
||||||
cfg_lo = json.dumps({"probability_threshold": 0.1})
|
cfg_lo = json.dumps({"probability_threshold": 0.1})
|
||||||
r_lo = http_client.post(
|
r_lo = http_client.post(
|
||||||
"/detect",
|
"/detect/image",
|
||||||
files={"file": ("img.jpg", image_small, "image/jpeg")},
|
files={"file": ("img.jpg", image_small, "image/jpeg")},
|
||||||
data={"config": cfg_lo},
|
data={"config": cfg_lo},
|
||||||
|
headers=auth_headers,
|
||||||
)
|
)
|
||||||
assert r_lo.status_code == 200
|
assert r_lo.status_code == 200
|
||||||
lo = r_lo.json()
|
lo = r_lo.json()
|
||||||
@@ -130,12 +133,13 @@ def test_ft_p_05_confidence_filtering_ac2(http_client, image_small, warm_engine)
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
def test_ft_p_06_overlap_deduplication_ac3(http_client, image_dense, warm_engine):
|
def test_ft_p_06_overlap_deduplication_ac3(http_client, image_dense, warm_engine, auth_headers):
|
||||||
cfg_loose = json.dumps({"tracking_intersection_threshold": 0.6})
|
cfg_loose = json.dumps({"tracking_intersection_threshold": 0.6})
|
||||||
r1 = http_client.post(
|
r1 = http_client.post(
|
||||||
"/detect",
|
"/detect/image",
|
||||||
files={"file": ("img.jpg", image_dense, "image/jpeg")},
|
files={"file": ("img.jpg", image_dense, "image/jpeg")},
|
||||||
data={"config": cfg_loose},
|
data={"config": cfg_loose},
|
||||||
|
headers=auth_headers,
|
||||||
timeout=_DETECT_SLOW_TIMEOUT,
|
timeout=_DETECT_SLOW_TIMEOUT,
|
||||||
)
|
)
|
||||||
assert r1.status_code == 200
|
assert r1.status_code == 200
|
||||||
@@ -151,9 +155,10 @@ def test_ft_p_06_overlap_deduplication_ac3(http_client, image_dense, warm_engine
|
|||||||
assert ratio <= 0.6 + _EPS, (label, ratio)
|
assert ratio <= 0.6 + _EPS, (label, ratio)
|
||||||
cfg_strict = json.dumps({"tracking_intersection_threshold": 0.01})
|
cfg_strict = json.dumps({"tracking_intersection_threshold": 0.01})
|
||||||
r2 = http_client.post(
|
r2 = http_client.post(
|
||||||
"/detect",
|
"/detect/image",
|
||||||
files={"file": ("img.jpg", image_dense, "image/jpeg")},
|
files={"file": ("img.jpg", image_dense, "image/jpeg")},
|
||||||
data={"config": cfg_strict},
|
data={"config": cfg_strict},
|
||||||
|
headers=auth_headers,
|
||||||
timeout=_DETECT_SLOW_TIMEOUT,
|
timeout=_DETECT_SLOW_TIMEOUT,
|
||||||
)
|
)
|
||||||
assert r2.status_code == 200
|
assert r2.status_code == 200
|
||||||
@@ -163,7 +168,7 @@ def test_ft_p_06_overlap_deduplication_ac3(http_client, image_dense, warm_engine
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
def test_ft_p_07_physical_size_filtering_ac4(http_client, image_small, warm_engine):
|
def test_ft_p_07_physical_size_filtering_ac4(http_client, image_small, warm_engine, auth_headers):
|
||||||
by_id, _ = _load_classes_media()
|
by_id, _ = _load_classes_media()
|
||||||
wh = _image_width_height(image_small)
|
wh = _image_width_height(image_small)
|
||||||
assert wh is not None
|
assert wh is not None
|
||||||
@@ -180,9 +185,10 @@ def test_ft_p_07_physical_size_filtering_ac4(http_client, image_small, warm_engi
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
r = http_client.post(
|
r = http_client.post(
|
||||||
"/detect",
|
"/detect/image",
|
||||||
files={"file": ("img.jpg", image_small, "image/jpeg")},
|
files={"file": ("img.jpg", image_small, "image/jpeg")},
|
||||||
data={"config": cfg},
|
data={"config": cfg},
|
||||||
|
headers=auth_headers,
|
||||||
timeout=_DETECT_SLOW_TIMEOUT,
|
timeout=_DETECT_SLOW_TIMEOUT,
|
||||||
)
|
)
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
@@ -197,12 +203,13 @@ def test_ft_p_07_physical_size_filtering_ac4(http_client, image_small, warm_engi
|
|||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
def test_ft_p_13_weather_mode_class_variants_ac5(
|
def test_ft_p_13_weather_mode_class_variants_ac5(
|
||||||
http_client, image_different_types, warm_engine
|
http_client, image_different_types, warm_engine, auth_headers
|
||||||
):
|
):
|
||||||
_, base_names = _load_classes_media()
|
_, base_names = _load_classes_media()
|
||||||
r = http_client.post(
|
r = http_client.post(
|
||||||
"/detect",
|
"/detect/image",
|
||||||
files={"file": ("img.jpg", image_different_types, "image/jpeg")},
|
files={"file": ("img.jpg", image_different_types, "image/jpeg")},
|
||||||
|
headers=auth_headers,
|
||||||
timeout=_DETECT_SLOW_TIMEOUT,
|
timeout=_DETECT_SLOW_TIMEOUT,
|
||||||
)
|
)
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
|
|||||||
@@ -36,14 +36,21 @@ def _chunked_reader(path: str, chunk_size: int = 64 * 1024):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
def _start_sse_listener(http_client) -> tuple[list[dict], list[BaseException], threading.Event]:
|
def _start_sse_listener(
|
||||||
|
http_client, media_id: str, auth_headers: dict
|
||||||
|
) -> tuple[list[dict], list[BaseException], threading.Event]:
|
||||||
events: list[dict] = []
|
events: list[dict] = []
|
||||||
errors: list[BaseException] = []
|
errors: list[BaseException] = []
|
||||||
first_event = threading.Event()
|
first_event = threading.Event()
|
||||||
|
|
||||||
def _listen():
|
def _listen():
|
||||||
try:
|
try:
|
||||||
with http_client.get("/detect/stream", stream=True, timeout=_TIMEOUT + 2) as resp:
|
with http_client.get(
|
||||||
|
f"/detect/{media_id}",
|
||||||
|
stream=True,
|
||||||
|
timeout=_TIMEOUT + 2,
|
||||||
|
headers=auth_headers,
|
||||||
|
) as resp:
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
for event in sseclient.SSEClient(resp).events():
|
for event in sseclient.SSEClient(resp).events():
|
||||||
if not event.data or not str(event.data).strip():
|
if not event.data or not str(event.data).strip():
|
||||||
@@ -62,24 +69,30 @@ def _start_sse_listener(http_client) -> tuple[list[dict], list[BaseException], t
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(10)
|
@pytest.mark.timeout(10)
|
||||||
def test_streaming_video_detections_appear_during_upload(warm_engine, http_client):
|
def test_streaming_video_detections_appear_during_upload(
|
||||||
|
warm_engine, http_client, auth_headers
|
||||||
|
):
|
||||||
# Arrange
|
# Arrange
|
||||||
video_path = _fixture_path("video_test01.mp4")
|
video_path = _fixture_path("video_test01.mp4")
|
||||||
events, errors, first_event = _start_sse_listener(http_client)
|
|
||||||
time.sleep(0.3)
|
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
r = http_client.post(
|
r = http_client.post(
|
||||||
"/detect/video",
|
"/detect/video",
|
||||||
data=_chunked_reader(video_path),
|
data=_chunked_reader(video_path),
|
||||||
headers={"X-Filename": "video_test01.mp4", "Content-Type": "application/octet-stream"},
|
headers={
|
||||||
|
**auth_headers,
|
||||||
|
"X-Filename": "video_test01.mp4",
|
||||||
|
"Content-Type": "application/octet-stream",
|
||||||
|
},
|
||||||
timeout=8,
|
timeout=8,
|
||||||
)
|
)
|
||||||
|
assert r.status_code == 200
|
||||||
|
media_id = r.json()["mediaId"]
|
||||||
|
events, errors, first_event = _start_sse_listener(http_client, media_id, auth_headers)
|
||||||
first_event.wait(timeout=_TIMEOUT)
|
first_event.wait(timeout=_TIMEOUT)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert not errors, f"SSE thread error: {errors}"
|
assert not errors, f"SSE thread error: {errors}"
|
||||||
assert r.status_code == 200
|
|
||||||
assert len(events) >= 1, "Expected at least one SSE event within 5s"
|
assert len(events) >= 1, "Expected at least one SSE event within 5s"
|
||||||
print(f"\n First {len(events)} SSE events:")
|
print(f"\n First {len(events)} SSE events:")
|
||||||
for e in events:
|
for e in events:
|
||||||
@@ -87,24 +100,28 @@ def test_streaming_video_detections_appear_during_upload(warm_engine, http_clien
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(10)
|
@pytest.mark.timeout(10)
|
||||||
def test_non_faststart_video_still_works(warm_engine, http_client):
|
def test_non_faststart_video_still_works(warm_engine, http_client, auth_headers):
|
||||||
# Arrange
|
# Arrange
|
||||||
video_path = _fixture_path("video_test01.mp4")
|
video_path = _fixture_path("video_test01.mp4")
|
||||||
events, errors, first_event = _start_sse_listener(http_client)
|
|
||||||
time.sleep(0.3)
|
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
r = http_client.post(
|
r = http_client.post(
|
||||||
"/detect/video",
|
"/detect/video",
|
||||||
data=_chunked_reader(video_path),
|
data=_chunked_reader(video_path),
|
||||||
headers={"X-Filename": "video_test01_plain.mp4", "Content-Type": "application/octet-stream"},
|
headers={
|
||||||
|
**auth_headers,
|
||||||
|
"X-Filename": "video_test01_plain.mp4",
|
||||||
|
"Content-Type": "application/octet-stream",
|
||||||
|
},
|
||||||
timeout=8,
|
timeout=8,
|
||||||
)
|
)
|
||||||
|
assert r.status_code == 200
|
||||||
|
media_id = r.json()["mediaId"]
|
||||||
|
events, errors, first_event = _start_sse_listener(http_client, media_id, auth_headers)
|
||||||
first_event.wait(timeout=_TIMEOUT)
|
first_event.wait(timeout=_TIMEOUT)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert not errors, f"SSE thread error: {errors}"
|
assert not errors, f"SSE thread error: {errors}"
|
||||||
assert r.status_code == 200
|
|
||||||
assert len(events) >= 1, "Expected at least one SSE event within 5s"
|
assert len(events) >= 1, "Expected at least one SSE event within 5s"
|
||||||
print(f"\n First {len(events)} SSE events:")
|
print(f"\n First {len(events)} SSE events:")
|
||||||
for e in events:
|
for e in events:
|
||||||
|
|||||||
@@ -28,12 +28,13 @@ def _assert_no_same_label_near_duplicate_centers(detections):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
def test_ft_p_04_gsd_based_tiling_ac1(http_client, image_large, warm_engine):
|
def test_ft_p_04_gsd_based_tiling_ac1(http_client, image_large, warm_engine, auth_headers):
|
||||||
config = json.dumps(_GSD)
|
config = json.dumps(_GSD)
|
||||||
r = http_client.post(
|
r = http_client.post(
|
||||||
"/detect",
|
"/detect/image",
|
||||||
files={"file": ("img.jpg", image_large, "image/jpeg")},
|
files={"file": ("img.jpg", image_large, "image/jpeg")},
|
||||||
data={"config": config},
|
data={"config": config},
|
||||||
|
headers=auth_headers,
|
||||||
timeout=_TILING_TIMEOUT,
|
timeout=_TILING_TIMEOUT,
|
||||||
)
|
)
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
@@ -43,12 +44,13 @@ def test_ft_p_04_gsd_based_tiling_ac1(http_client, image_large, warm_engine):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
def test_ft_p_16_tile_boundary_deduplication_ac2(http_client, image_large, warm_engine):
|
def test_ft_p_16_tile_boundary_deduplication_ac2(http_client, image_large, warm_engine, auth_headers):
|
||||||
config = json.dumps({**_GSD, "big_image_tile_overlap_percent": 20})
|
config = json.dumps({**_GSD, "big_image_tile_overlap_percent": 20})
|
||||||
r = http_client.post(
|
r = http_client.post(
|
||||||
"/detect",
|
"/detect/image",
|
||||||
files={"file": ("img.jpg", image_large, "image/jpeg")},
|
files={"file": ("img.jpg", image_large, "image/jpeg")},
|
||||||
data={"config": config},
|
data={"config": config},
|
||||||
|
headers=auth_headers,
|
||||||
timeout=_TILING_TIMEOUT,
|
timeout=_TILING_TIMEOUT,
|
||||||
)
|
)
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
|
|||||||
+17
-12
@@ -20,17 +20,32 @@ def _chunked_reader(path: str, chunk_size: int = 64 * 1024):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def video_events(warm_engine, http_client):
|
def video_events(warm_engine, http_client, auth_headers):
|
||||||
if not Path(_VIDEO).is_file():
|
if not Path(_VIDEO).is_file():
|
||||||
pytest.skip(f"missing fixture {_VIDEO}")
|
pytest.skip(f"missing fixture {_VIDEO}")
|
||||||
|
|
||||||
|
r = http_client.post(
|
||||||
|
"/detect/video",
|
||||||
|
data=_chunked_reader(_VIDEO),
|
||||||
|
headers={
|
||||||
|
**auth_headers,
|
||||||
|
"X-Filename": "video_test01.mp4",
|
||||||
|
"Content-Type": "application/octet-stream",
|
||||||
|
},
|
||||||
|
timeout=15,
|
||||||
|
)
|
||||||
|
assert r.status_code == 200
|
||||||
|
media_id = r.json()["mediaId"]
|
||||||
|
|
||||||
collected: list[tuple[float, dict]] = []
|
collected: list[tuple[float, dict]] = []
|
||||||
thread_exc: list[BaseException] = []
|
thread_exc: list[BaseException] = []
|
||||||
done = threading.Event()
|
done = threading.Event()
|
||||||
|
|
||||||
def _listen():
|
def _listen():
|
||||||
try:
|
try:
|
||||||
with http_client.get("/detect/stream", stream=True, timeout=35) as resp:
|
with http_client.get(
|
||||||
|
f"/detect/{media_id}", stream=True, timeout=35, headers=auth_headers
|
||||||
|
) as resp:
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
sse = sseclient.SSEClient(resp)
|
sse = sseclient.SSEClient(resp)
|
||||||
for event in sse.events():
|
for event in sse.events():
|
||||||
@@ -50,16 +65,6 @@ def video_events(warm_engine, http_client):
|
|||||||
|
|
||||||
th = threading.Thread(target=_listen, daemon=True)
|
th = threading.Thread(target=_listen, daemon=True)
|
||||||
th.start()
|
th.start()
|
||||||
time.sleep(0.3)
|
|
||||||
|
|
||||||
r = http_client.post(
|
|
||||||
"/detect/video",
|
|
||||||
data=_chunked_reader(_VIDEO),
|
|
||||||
headers={"X-Filename": "video_test01.mp4", "Content-Type": "application/octet-stream"},
|
|
||||||
timeout=15,
|
|
||||||
)
|
|
||||||
assert r.status_code == 200
|
|
||||||
|
|
||||||
assert done.wait(timeout=30)
|
assert done.wait(timeout=30)
|
||||||
th.join(timeout=5)
|
th.join(timeout=5)
|
||||||
assert not thread_exc, thread_exc
|
assert not thread_exc, thread_exc
|
||||||
|
|||||||
+5
-3
@@ -1,5 +1,8 @@
|
|||||||
fastapi
|
fastapi==0.135.2
|
||||||
uvicorn[standard]
|
uvicorn[standard]==0.42.0
|
||||||
|
PyJWT==2.12.1
|
||||||
|
h11==0.16.0
|
||||||
|
python-multipart>=1.3.1
|
||||||
Cython==3.2.4
|
Cython==3.2.4
|
||||||
opencv-python==4.10.0.84
|
opencv-python==4.10.0.84
|
||||||
numpy==2.3.0
|
numpy==2.3.0
|
||||||
@@ -7,6 +10,5 @@ onnxruntime==1.22.0
|
|||||||
pynvml==12.0.0
|
pynvml==12.0.0
|
||||||
requests==2.32.4
|
requests==2.32.4
|
||||||
loguru==0.7.3
|
loguru==0.7.3
|
||||||
python-multipart
|
|
||||||
av==14.2.0
|
av==14.2.0
|
||||||
xxhash==3.5.0
|
xxhash==3.5.0
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ PIDS+=($!)
|
|||||||
echo "Starting detections service on :$DETECTIONS_PORT ..."
|
echo "Starting detections service on :$DETECTIONS_PORT ..."
|
||||||
LOADER_URL="http://localhost:$LOADER_PORT" \
|
LOADER_URL="http://localhost:$LOADER_PORT" \
|
||||||
ANNOTATIONS_URL="http://localhost:$ANNOTATIONS_PORT" \
|
ANNOTATIONS_URL="http://localhost:$ANNOTATIONS_PORT" \
|
||||||
|
JWT_SECRET="test-secret-local-only" \
|
||||||
PYTHONPATH="$ROOT/src" \
|
PYTHONPATH="$ROOT/src" \
|
||||||
"$PY" -m uvicorn main:app --host 0.0.0.0 --port "$DETECTIONS_PORT" \
|
"$PY" -m uvicorn main:app --host 0.0.0.0 --port "$DETECTIONS_PORT" \
|
||||||
--log-level warning >/dev/null 2>&1 &
|
--log-level warning >/dev/null 2>&1 &
|
||||||
@@ -73,5 +74,6 @@ BASE_URL="http://localhost:$DETECTIONS_PORT" \
|
|||||||
MOCK_LOADER_URL="http://localhost:$LOADER_PORT" \
|
MOCK_LOADER_URL="http://localhost:$LOADER_PORT" \
|
||||||
MOCK_ANNOTATIONS_URL="http://localhost:$ANNOTATIONS_PORT" \
|
MOCK_ANNOTATIONS_URL="http://localhost:$ANNOTATIONS_PORT" \
|
||||||
MEDIA_DIR="$FIXTURES" \
|
MEDIA_DIR="$FIXTURES" \
|
||||||
|
JWT_SECRET="test-secret-local-only" \
|
||||||
PYTHONPATH="$ROOT/src" \
|
PYTHONPATH="$ROOT/src" \
|
||||||
"$PY" -m pytest e2e/tests/ tests/ -v --tb=short --durations=0 "$@"
|
"$PY" -m pytest e2e/tests/ tests/ -v --tb=short --durations=0 "$@"
|
||||||
|
|||||||
+119
-124
@@ -11,10 +11,12 @@ from typing import Annotated, Optional
|
|||||||
|
|
||||||
import av
|
import av
|
||||||
import cv2
|
import cv2
|
||||||
|
import jwt as pyjwt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests as http_requests
|
import requests as http_requests
|
||||||
from fastapi import Body, FastAPI, UploadFile, File, Form, HTTPException, Request
|
from fastapi import Body, Depends, FastAPI, File, Form, HTTPException, Request, UploadFile
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from loader_http_client import LoaderHttpClient, LoadResult
|
from loader_http_client import LoaderHttpClient, LoadResult
|
||||||
@@ -24,6 +26,8 @@ executor = ThreadPoolExecutor(max_workers=2)
|
|||||||
|
|
||||||
LOADER_URL = os.environ.get("LOADER_URL", "http://loader:8080")
|
LOADER_URL = os.environ.get("LOADER_URL", "http://loader:8080")
|
||||||
ANNOTATIONS_URL = os.environ.get("ANNOTATIONS_URL", "http://annotations:8080")
|
ANNOTATIONS_URL = os.environ.get("ANNOTATIONS_URL", "http://annotations:8080")
|
||||||
|
JWT_SECRET = os.environ.get("JWT_SECRET", "")
|
||||||
|
ADMIN_API_URL = os.environ.get("ADMIN_API_URL", "")
|
||||||
|
|
||||||
_MEDIA_STATUS_NEW = 1
|
_MEDIA_STATUS_NEW = 1
|
||||||
_MEDIA_STATUS_AI_PROCESSING = 2
|
_MEDIA_STATUS_AI_PROCESSING = 2
|
||||||
@@ -36,9 +40,28 @@ _IMAGE_EXTENSIONS = frozenset({".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif",
|
|||||||
loader_client = LoaderHttpClient(LOADER_URL)
|
loader_client = LoaderHttpClient(LOADER_URL)
|
||||||
annotations_client = LoaderHttpClient(ANNOTATIONS_URL)
|
annotations_client = LoaderHttpClient(ANNOTATIONS_URL)
|
||||||
inference = None
|
inference = None
|
||||||
_event_queues: list[asyncio.Queue] = []
|
_job_queues: dict[str, list[asyncio.Queue]] = {}
|
||||||
|
_job_buffers: dict[str, list[str]] = {}
|
||||||
_active_detections: dict[str, asyncio.Task] = {}
|
_active_detections: dict[str, asyncio.Task] = {}
|
||||||
|
|
||||||
|
_bearer = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def require_auth(
|
||||||
|
credentials: Optional[HTTPAuthorizationCredentials] = Depends(_bearer),
|
||||||
|
) -> str:
|
||||||
|
if not JWT_SECRET:
|
||||||
|
return ""
|
||||||
|
if not credentials:
|
||||||
|
raise HTTPException(status_code=401, detail="Authentication required")
|
||||||
|
try:
|
||||||
|
payload = pyjwt.decode(credentials.credentials, JWT_SECRET, algorithms=["HS256"])
|
||||||
|
except pyjwt.ExpiredSignatureError:
|
||||||
|
raise HTTPException(status_code=401, detail="Token expired")
|
||||||
|
except pyjwt.InvalidTokenError:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid token")
|
||||||
|
return str(payload.get("sub") or payload.get("userId") or "")
|
||||||
|
|
||||||
|
|
||||||
class TokenManager:
|
class TokenManager:
|
||||||
def __init__(self, access_token: str, refresh_token: str):
|
def __init__(self, access_token: str, refresh_token: str):
|
||||||
@@ -46,15 +69,17 @@ class TokenManager:
|
|||||||
self.refresh_token = refresh_token
|
self.refresh_token = refresh_token
|
||||||
|
|
||||||
def get_valid_token(self) -> str:
|
def get_valid_token(self) -> str:
|
||||||
exp = self._decode_exp(self.access_token)
|
exp = self._decode_claims(self.access_token).get("exp")
|
||||||
if exp and exp - time.time() < 60:
|
if exp and float(exp) - time.time() < 60:
|
||||||
self._refresh()
|
self._refresh()
|
||||||
return self.access_token
|
return self.access_token
|
||||||
|
|
||||||
def _refresh(self):
|
def _refresh(self):
|
||||||
|
if not ADMIN_API_URL:
|
||||||
|
return
|
||||||
try:
|
try:
|
||||||
resp = http_requests.post(
|
resp = http_requests.post(
|
||||||
f"{ANNOTATIONS_URL}/auth/refresh",
|
f"{ADMIN_API_URL}/auth/refresh",
|
||||||
json={"refreshToken": self.refresh_token},
|
json={"refreshToken": self.refresh_token},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
@@ -64,39 +89,33 @@ class TokenManager:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _decode_exp(token: str) -> Optional[float]:
|
def _decode_claims(token: str) -> dict:
|
||||||
try:
|
try:
|
||||||
|
if JWT_SECRET:
|
||||||
|
return pyjwt.decode(token, JWT_SECRET, algorithms=["HS256"])
|
||||||
payload = token.split(".")[1]
|
payload = token.split(".")[1]
|
||||||
padding = 4 - len(payload) % 4
|
padding = 4 - len(payload) % 4
|
||||||
if padding != 4:
|
if padding != 4:
|
||||||
payload += "=" * padding
|
payload += "=" * padding
|
||||||
data = json.loads(base64.urlsafe_b64decode(payload))
|
return json.loads(base64.urlsafe_b64decode(payload))
|
||||||
return float(data.get("exp", 0))
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def decode_user_id(token: str) -> Optional[str]:
|
def decode_user_id(token: str) -> Optional[str]:
|
||||||
try:
|
data = TokenManager._decode_claims(token)
|
||||||
payload = token.split(".")[1]
|
uid = (
|
||||||
padding = 4 - len(payload) % 4
|
data.get("sub")
|
||||||
if padding != 4:
|
or data.get("userId")
|
||||||
payload += "=" * padding
|
or data.get("user_id")
|
||||||
data = json.loads(base64.urlsafe_b64decode(payload))
|
or data.get("nameid")
|
||||||
uid = (
|
or data.get(
|
||||||
data.get("sub")
|
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/nameidentifier"
|
||||||
or data.get("userId")
|
|
||||||
or data.get("user_id")
|
|
||||||
or data.get("nameid")
|
|
||||||
or data.get(
|
|
||||||
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/nameidentifier"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
if uid is None:
|
)
|
||||||
return None
|
if uid is None:
|
||||||
return str(uid)
|
|
||||||
except Exception:
|
|
||||||
return None
|
return None
|
||||||
|
return str(uid)
|
||||||
|
|
||||||
|
|
||||||
def get_inference():
|
def get_inference():
|
||||||
@@ -233,24 +252,6 @@ def _normalize_upload_ext(filename: str) -> str:
|
|||||||
return s if s else ""
|
return s if s else ""
|
||||||
|
|
||||||
|
|
||||||
def _detect_upload_kind(filename: str, data: bytes) -> tuple[str, str]:
|
|
||||||
ext = _normalize_upload_ext(filename)
|
|
||||||
if ext in _VIDEO_EXTENSIONS:
|
|
||||||
return "video", ext
|
|
||||||
if ext in _IMAGE_EXTENSIONS:
|
|
||||||
return "image", ext
|
|
||||||
arr = np.frombuffer(data, dtype=np.uint8)
|
|
||||||
if cv2.imdecode(arr, cv2.IMREAD_COLOR) is not None:
|
|
||||||
return "image", ext if ext else ".jpg"
|
|
||||||
try:
|
|
||||||
bio = io.BytesIO(data)
|
|
||||||
with av.open(bio):
|
|
||||||
pass
|
|
||||||
return "video", ext if ext else ".mp4"
|
|
||||||
except Exception:
|
|
||||||
raise HTTPException(status_code=400, detail="Invalid image or video data")
|
|
||||||
|
|
||||||
|
|
||||||
def _is_video_media_path(media_path: str) -> bool:
|
def _is_video_media_path(media_path: str) -> bool:
|
||||||
return Path(media_path).suffix.lower() in _VIDEO_EXTENSIONS
|
return Path(media_path).suffix.lower() in _VIDEO_EXTENSIONS
|
||||||
|
|
||||||
@@ -322,6 +323,21 @@ def detection_to_dto(det) -> DetectionDto:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _enqueue(media_id: str, event: DetectionEvent):
|
||||||
|
data = event.model_dump_json()
|
||||||
|
_job_buffers.setdefault(media_id, []).append(data)
|
||||||
|
for q in _job_queues.get(media_id, []):
|
||||||
|
try:
|
||||||
|
q.put_nowait(data)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _schedule_buffer_cleanup(media_id: str, delay: float = 300.0):
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
loop.call_later(delay, lambda: _job_buffers.pop(media_id, None))
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
def health() -> HealthResponse:
|
def health() -> HealthResponse:
|
||||||
if inference is None:
|
if inference is None:
|
||||||
@@ -345,11 +361,12 @@ def health() -> HealthResponse:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/detect")
|
@app.post("/detect/image")
|
||||||
async def detect_image(
|
async def detect_image(
|
||||||
request: Request,
|
request: Request,
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
config: Optional[str] = Form(None),
|
config: Optional[str] = Form(None),
|
||||||
|
user_id: str = Depends(require_auth),
|
||||||
):
|
):
|
||||||
from media_hash import compute_media_content_hash
|
from media_hash import compute_media_content_hash
|
||||||
from inference import ai_config_from_dict
|
from inference import ai_config_from_dict
|
||||||
@@ -359,26 +376,22 @@ async def detect_image(
|
|||||||
raise HTTPException(status_code=400, detail="Image is empty")
|
raise HTTPException(status_code=400, detail="Image is empty")
|
||||||
|
|
||||||
orig_name = file.filename or "upload"
|
orig_name = file.filename or "upload"
|
||||||
kind, ext = _detect_upload_kind(orig_name, image_bytes)
|
ext = _normalize_upload_ext(orig_name)
|
||||||
|
if ext and ext not in _IMAGE_EXTENSIONS:
|
||||||
|
raise HTTPException(status_code=400, detail="Expected an image file")
|
||||||
|
|
||||||
if kind == "image":
|
arr = np.frombuffer(image_bytes, dtype=np.uint8)
|
||||||
arr = np.frombuffer(image_bytes, dtype=np.uint8)
|
if cv2.imdecode(arr, cv2.IMREAD_COLOR) is None:
|
||||||
if cv2.imdecode(arr, cv2.IMREAD_COLOR) is None:
|
raise HTTPException(status_code=400, detail="Invalid image data")
|
||||||
raise HTTPException(status_code=400, detail="Invalid image data")
|
|
||||||
|
|
||||||
config_dict = {}
|
config_dict = {}
|
||||||
if config:
|
if config:
|
||||||
config_dict = json.loads(config)
|
config_dict = json.loads(config)
|
||||||
|
|
||||||
auth_header = request.headers.get("authorization", "")
|
|
||||||
access_token = auth_header.removeprefix("Bearer ").strip() if auth_header else ""
|
|
||||||
refresh_token = request.headers.get("x-refresh-token", "")
|
refresh_token = request.headers.get("x-refresh-token", "")
|
||||||
|
access_token = request.headers.get("authorization", "").removeprefix("Bearer ").strip()
|
||||||
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
|
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
|
||||||
user_id = TokenManager.decode_user_id(access_token) if access_token else None
|
|
||||||
|
|
||||||
videos_dir = os.environ.get(
|
|
||||||
"VIDEOS_DIR", os.path.join(os.getcwd(), "data", "videos")
|
|
||||||
)
|
|
||||||
images_dir = os.environ.get(
|
images_dir = os.environ.get(
|
||||||
"IMAGES_DIR", os.path.join(os.getcwd(), "data", "images")
|
"IMAGES_DIR", os.path.join(os.getcwd(), "data", "images")
|
||||||
)
|
)
|
||||||
@@ -386,20 +399,16 @@ async def detect_image(
|
|||||||
content_hash = None
|
content_hash = None
|
||||||
if token_mgr and user_id:
|
if token_mgr and user_id:
|
||||||
content_hash = compute_media_content_hash(image_bytes)
|
content_hash = compute_media_content_hash(image_bytes)
|
||||||
base = videos_dir if kind == "video" else images_dir
|
os.makedirs(images_dir, exist_ok=True)
|
||||||
os.makedirs(base, exist_ok=True)
|
save_ext = ext if ext.startswith(".") else f".{ext}" if ext else ".jpg"
|
||||||
if not ext.startswith("."):
|
storage_path = os.path.abspath(os.path.join(images_dir, f"{content_hash}{save_ext}"))
|
||||||
ext = "." + ext
|
with open(storage_path, "wb") as out:
|
||||||
storage_path = os.path.abspath(os.path.join(base, f"{content_hash}{ext}"))
|
out.write(image_bytes)
|
||||||
if kind == "image":
|
|
||||||
with open(storage_path, "wb") as out:
|
|
||||||
out.write(image_bytes)
|
|
||||||
mt = "Video" if kind == "video" else "Image"
|
|
||||||
payload = {
|
payload = {
|
||||||
"id": content_hash,
|
"id": content_hash,
|
||||||
"name": Path(orig_name).name,
|
"name": Path(orig_name).name,
|
||||||
"path": storage_path,
|
"path": storage_path,
|
||||||
"mediaType": mt,
|
"mediaType": "Image",
|
||||||
"mediaStatus": _MEDIA_STATUS_NEW,
|
"mediaStatus": _MEDIA_STATUS_NEW,
|
||||||
"userId": user_id,
|
"userId": user_id,
|
||||||
}
|
}
|
||||||
@@ -411,29 +420,17 @@ async def detect_image(
|
|||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
inf = get_inference()
|
inf = get_inference()
|
||||||
results = []
|
results = []
|
||||||
tmp_video_path = None
|
|
||||||
|
|
||||||
def on_annotation(annotation, percent):
|
def on_annotation(annotation, percent):
|
||||||
results.extend(annotation.detections)
|
results.extend(annotation.detections)
|
||||||
|
|
||||||
ai_cfg = ai_config_from_dict(config_dict)
|
ai_cfg = ai_config_from_dict(config_dict)
|
||||||
|
|
||||||
def run_upload():
|
def run_detect():
|
||||||
nonlocal tmp_video_path
|
inf.run_detect_image(image_bytes, ai_cfg, media_name, on_annotation)
|
||||||
if kind == "video":
|
|
||||||
if storage_path:
|
|
||||||
save = storage_path
|
|
||||||
else:
|
|
||||||
suf = ext if ext.startswith(".") else ".mp4"
|
|
||||||
fd, tmp_video_path = tempfile.mkstemp(suffix=suf)
|
|
||||||
os.close(fd)
|
|
||||||
save = tmp_video_path
|
|
||||||
inf.run_detect_video(image_bytes, ai_cfg, media_name, save, on_annotation)
|
|
||||||
else:
|
|
||||||
inf.run_detect_image(image_bytes, ai_cfg, media_name, on_annotation)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await loop.run_in_executor(executor, run_upload)
|
await loop.run_in_executor(executor, run_detect)
|
||||||
if token_mgr and user_id and content_hash:
|
if token_mgr and user_id and content_hash:
|
||||||
_put_media_status(
|
_put_media_status(
|
||||||
content_hash, _MEDIA_STATUS_AI_PROCESSED, token_mgr.get_valid_token()
|
content_hash, _MEDIA_STATUS_AI_PROCESSED, token_mgr.get_valid_token()
|
||||||
@@ -459,16 +456,13 @@ async def detect_image(
|
|||||||
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
|
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
finally:
|
|
||||||
if tmp_video_path and os.path.isfile(tmp_video_path):
|
|
||||||
try:
|
|
||||||
os.unlink(tmp_video_path)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/detect/video")
|
@app.post("/detect/video")
|
||||||
async def detect_video_upload(request: Request):
|
async def detect_video_upload(
|
||||||
|
request: Request,
|
||||||
|
user_id: str = Depends(require_auth),
|
||||||
|
):
|
||||||
from media_hash import compute_media_content_hash_from_file
|
from media_hash import compute_media_content_hash_from_file
|
||||||
from inference import ai_config_from_dict
|
from inference import ai_config_from_dict
|
||||||
from streaming_buffer import StreamingBuffer
|
from streaming_buffer import StreamingBuffer
|
||||||
@@ -482,11 +476,9 @@ async def detect_video_upload(request: Request):
|
|||||||
config_dict = json.loads(config_json) if config_json else {}
|
config_dict = json.loads(config_json) if config_json else {}
|
||||||
ai_cfg = ai_config_from_dict(config_dict)
|
ai_cfg = ai_config_from_dict(config_dict)
|
||||||
|
|
||||||
auth_header = request.headers.get("authorization", "")
|
|
||||||
access_token = auth_header.removeprefix("Bearer ").strip() if auth_header else ""
|
|
||||||
refresh_token = request.headers.get("x-refresh-token", "")
|
refresh_token = request.headers.get("x-refresh-token", "")
|
||||||
|
access_token = request.headers.get("authorization", "").removeprefix("Bearer ").strip()
|
||||||
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
|
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
|
||||||
user_id = TokenManager.decode_user_id(access_token) if access_token else None
|
|
||||||
|
|
||||||
videos_dir = os.environ.get(
|
videos_dir = os.environ.get(
|
||||||
"VIDEOS_DIR", os.path.join(os.getcwd(), "data", "videos")
|
"VIDEOS_DIR", os.path.join(os.getcwd(), "data", "videos")
|
||||||
@@ -499,33 +491,29 @@ async def detect_video_upload(request: Request):
|
|||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
inf = get_inference()
|
inf = get_inference()
|
||||||
|
|
||||||
def _enqueue(event):
|
|
||||||
for q in _event_queues:
|
|
||||||
try:
|
|
||||||
q.put_nowait(event)
|
|
||||||
except asyncio.QueueFull:
|
|
||||||
pass
|
|
||||||
|
|
||||||
placeholder_id = f"tmp_{os.path.basename(buffer.path)}"
|
placeholder_id = f"tmp_{os.path.basename(buffer.path)}"
|
||||||
|
current_id = [placeholder_id] # mutable — updated to content_hash after upload
|
||||||
|
|
||||||
def on_annotation(annotation, percent):
|
def on_annotation(annotation, percent):
|
||||||
dtos = [detection_to_dto(d) for d in annotation.detections]
|
dtos = [detection_to_dto(d) for d in annotation.detections]
|
||||||
|
mid = current_id[0]
|
||||||
event = DetectionEvent(
|
event = DetectionEvent(
|
||||||
annotations=dtos,
|
annotations=dtos,
|
||||||
mediaId=placeholder_id,
|
mediaId=mid,
|
||||||
mediaStatus="AIProcessing",
|
mediaStatus="AIProcessing",
|
||||||
mediaPercent=percent,
|
mediaPercent=percent,
|
||||||
)
|
)
|
||||||
loop.call_soon_threadsafe(_enqueue, event)
|
loop.call_soon_threadsafe(_enqueue, mid, event)
|
||||||
|
|
||||||
def on_status(media_name_cb, count):
|
def on_status(media_name_cb, count):
|
||||||
|
mid = current_id[0]
|
||||||
event = DetectionEvent(
|
event = DetectionEvent(
|
||||||
annotations=[],
|
annotations=[],
|
||||||
mediaId=placeholder_id,
|
mediaId=mid,
|
||||||
mediaStatus="AIProcessed",
|
mediaStatus="AIProcessed",
|
||||||
mediaPercent=100,
|
mediaPercent=100,
|
||||||
)
|
)
|
||||||
loop.call_soon_threadsafe(_enqueue, event)
|
loop.call_soon_threadsafe(_enqueue, mid, event)
|
||||||
|
|
||||||
def run_inference():
|
def run_inference():
|
||||||
inf.run_detect_video_stream(buffer, ai_cfg, media_name, on_annotation, on_status)
|
inf.run_detect_video_stream(buffer, ai_cfg, media_name, on_annotation, on_status)
|
||||||
@@ -546,6 +534,14 @@ async def detect_video_upload(request: Request):
|
|||||||
ext = "." + ext
|
ext = "." + ext
|
||||||
storage_path = os.path.abspath(os.path.join(videos_dir, f"{content_hash}{ext}"))
|
storage_path = os.path.abspath(os.path.join(videos_dir, f"{content_hash}{ext}"))
|
||||||
|
|
||||||
|
# Re-key buffered events from placeholder_id to content_hash so clients
|
||||||
|
# can subscribe to GET /detect/{content_hash} after POST returns.
|
||||||
|
if placeholder_id in _job_buffers:
|
||||||
|
_job_buffers[content_hash] = _job_buffers.pop(placeholder_id)
|
||||||
|
if placeholder_id in _job_queues:
|
||||||
|
_job_queues[content_hash] = _job_queues.pop(placeholder_id)
|
||||||
|
current_id[0] = content_hash # future on_annotation/on_status callbacks use content_hash
|
||||||
|
|
||||||
if token_mgr and user_id:
|
if token_mgr and user_id:
|
||||||
os.rename(buffer.path, storage_path)
|
os.rename(buffer.path, storage_path)
|
||||||
payload = {
|
payload = {
|
||||||
@@ -574,7 +570,7 @@ async def detect_video_upload(request: Request):
|
|||||||
mediaStatus="AIProcessed",
|
mediaStatus="AIProcessed",
|
||||||
mediaPercent=100,
|
mediaPercent=100,
|
||||||
)
|
)
|
||||||
_enqueue(done_event)
|
_enqueue(content_hash, done_event)
|
||||||
except Exception:
|
except Exception:
|
||||||
if token_mgr and user_id:
|
if token_mgr and user_id:
|
||||||
_put_media_status(
|
_put_media_status(
|
||||||
@@ -585,9 +581,10 @@ async def detect_video_upload(request: Request):
|
|||||||
annotations=[], mediaId=content_hash,
|
annotations=[], mediaId=content_hash,
|
||||||
mediaStatus="Error", mediaPercent=0,
|
mediaStatus="Error", mediaPercent=0,
|
||||||
)
|
)
|
||||||
_enqueue(err_event)
|
_enqueue(content_hash, err_event)
|
||||||
finally:
|
finally:
|
||||||
_active_detections.pop(content_hash, None)
|
_active_detections.pop(content_hash, None)
|
||||||
|
_schedule_buffer_cleanup(content_hash)
|
||||||
buffer.close()
|
buffer.close()
|
||||||
if not (token_mgr and user_id) and os.path.isfile(buffer.path):
|
if not (token_mgr and user_id) and os.path.isfile(buffer.path):
|
||||||
try:
|
try:
|
||||||
@@ -627,14 +624,14 @@ async def detect_media(
|
|||||||
media_id: str,
|
media_id: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
config: Annotated[Optional[AIConfigDto], Body()] = None,
|
config: Annotated[Optional[AIConfigDto], Body()] = None,
|
||||||
|
user_id: str = Depends(require_auth),
|
||||||
):
|
):
|
||||||
existing = _active_detections.get(media_id)
|
existing = _active_detections.get(media_id)
|
||||||
if existing is not None and not existing.done():
|
if existing is not None and not existing.done():
|
||||||
raise HTTPException(status_code=409, detail="Detection already in progress for this media")
|
raise HTTPException(status_code=409, detail="Detection already in progress for this media")
|
||||||
|
|
||||||
auth_header = request.headers.get("authorization", "")
|
|
||||||
access_token = auth_header.removeprefix("Bearer ").strip() if auth_header else ""
|
|
||||||
refresh_token = request.headers.get("x-refresh-token", "")
|
refresh_token = request.headers.get("x-refresh-token", "")
|
||||||
|
access_token = request.headers.get("authorization", "").removeprefix("Bearer ").strip()
|
||||||
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
|
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
|
||||||
|
|
||||||
config_dict, media_path = _resolve_media_for_detect(media_id, token_mgr, config)
|
config_dict, media_path = _resolve_media_for_detect(media_id, token_mgr, config)
|
||||||
@@ -642,13 +639,6 @@ async def detect_media(
|
|||||||
async def run_detection():
|
async def run_detection():
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
def _enqueue(event):
|
|
||||||
for q in _event_queues:
|
|
||||||
try:
|
|
||||||
q.put_nowait(event)
|
|
||||||
except asyncio.QueueFull:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from inference import ai_config_from_dict
|
from inference import ai_config_from_dict
|
||||||
|
|
||||||
@@ -678,7 +668,7 @@ async def detect_media(
|
|||||||
mediaStatus="AIProcessing",
|
mediaStatus="AIProcessing",
|
||||||
mediaPercent=percent,
|
mediaPercent=percent,
|
||||||
)
|
)
|
||||||
loop.call_soon_threadsafe(_enqueue, event)
|
loop.call_soon_threadsafe(_enqueue, media_id, event)
|
||||||
if token_mgr and dtos:
|
if token_mgr and dtos:
|
||||||
_post_annotation_to_service(token_mgr, media_id, annotation, dtos)
|
_post_annotation_to_service(token_mgr, media_id, annotation, dtos)
|
||||||
|
|
||||||
@@ -689,7 +679,7 @@ async def detect_media(
|
|||||||
mediaStatus="AIProcessed",
|
mediaStatus="AIProcessed",
|
||||||
mediaPercent=100,
|
mediaPercent=100,
|
||||||
)
|
)
|
||||||
loop.call_soon_threadsafe(_enqueue, event)
|
loop.call_soon_threadsafe(_enqueue, media_id, event)
|
||||||
if token_mgr:
|
if token_mgr:
|
||||||
_put_media_status(
|
_put_media_status(
|
||||||
media_id,
|
media_id,
|
||||||
@@ -728,28 +718,33 @@ async def detect_media(
|
|||||||
mediaStatus="Error",
|
mediaStatus="Error",
|
||||||
mediaPercent=0,
|
mediaPercent=0,
|
||||||
)
|
)
|
||||||
_enqueue(error_event)
|
_enqueue(media_id, error_event)
|
||||||
finally:
|
finally:
|
||||||
_active_detections.pop(media_id, None)
|
_active_detections.pop(media_id, None)
|
||||||
|
_schedule_buffer_cleanup(media_id)
|
||||||
|
|
||||||
_active_detections[media_id] = asyncio.create_task(run_detection())
|
_active_detections[media_id] = asyncio.create_task(run_detection())
|
||||||
return {"status": "started", "mediaId": media_id}
|
return {"status": "started", "mediaId": media_id}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/detect/stream")
|
@app.get("/detect/{media_id}", dependencies=[Depends(require_auth)])
|
||||||
async def detect_stream():
|
async def detect_events(media_id: str):
|
||||||
queue: asyncio.Queue = asyncio.Queue(maxsize=100)
|
queue: asyncio.Queue = asyncio.Queue(maxsize=200)
|
||||||
_event_queues.append(queue)
|
_job_queues.setdefault(media_id, []).append(queue)
|
||||||
|
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
try:
|
try:
|
||||||
|
for data in list(_job_buffers.get(media_id, [])):
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
while True:
|
while True:
|
||||||
event = await queue.get()
|
data = await queue.get()
|
||||||
yield f"data: {event.model_dump_json()}\n\n"
|
yield f"data: {data}\n\n"
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
_event_queues.remove(queue)
|
queues = _job_queues.get(media_id, [])
|
||||||
|
if queue in queues:
|
||||||
|
queues.remove(queue)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_generator(),
|
event_generator(),
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
@@ -8,6 +9,14 @@ from fastapi import HTTPException
|
|||||||
|
|
||||||
|
|
||||||
def _access_jwt(sub: str = "u1") -> str:
|
def _access_jwt(sub: str = "u1") -> str:
|
||||||
|
secret = os.environ.get("JWT_SECRET", "")
|
||||||
|
if secret:
|
||||||
|
import jwt as pyjwt
|
||||||
|
return pyjwt.encode(
|
||||||
|
{"exp": int(time.time()) + 3600, "sub": sub},
|
||||||
|
secret,
|
||||||
|
algorithm="HS256",
|
||||||
|
)
|
||||||
raw = json.dumps(
|
raw = json.dumps(
|
||||||
{"exp": int(time.time()) + 3600, "sub": sub}, separators=(",", ":")
|
{"exp": int(time.time()) + 3600, "sub": sub}, separators=(",", ":")
|
||||||
).encode()
|
).encode()
|
||||||
@@ -19,11 +28,7 @@ def test_token_manager_decode_user_id_sub():
|
|||||||
# Arrange
|
# Arrange
|
||||||
from main import TokenManager
|
from main import TokenManager
|
||||||
|
|
||||||
raw = json.dumps(
|
token = _access_jwt("user-abc")
|
||||||
{"sub": "user-abc", "exp": int(time.time()) + 3600}, separators=(",", ":")
|
|
||||||
).encode()
|
|
||||||
payload = base64.urlsafe_b64encode(raw).decode().rstrip("=")
|
|
||||||
token = f"hdr.{payload}.sig"
|
|
||||||
# Act
|
# Act
|
||||||
uid = TokenManager.decode_user_id(token)
|
uid = TokenManager.decode_user_id(token)
|
||||||
# Assert
|
# Assert
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
import base64
|
import builtins
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
import jwt as pyjwt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
@@ -17,9 +16,15 @@ import inference as inference_mod
|
|||||||
|
|
||||||
|
|
||||||
def _access_jwt(sub: str = "u1") -> str:
|
def _access_jwt(sub: str = "u1") -> str:
|
||||||
raw = json.dumps(
|
secret = os.environ.get("JWT_SECRET", "")
|
||||||
{"exp": int(time.time()) + 3600, "sub": sub}, separators=(",", ":")
|
if secret:
|
||||||
).encode()
|
return pyjwt.encode(
|
||||||
|
{"exp": int(time.time()) + 3600, "sub": sub},
|
||||||
|
secret,
|
||||||
|
algorithm="HS256",
|
||||||
|
)
|
||||||
|
import base64, json
|
||||||
|
raw = json.dumps({"exp": int(time.time()) + 3600, "sub": sub}, separators=(",", ":")).encode()
|
||||||
payload = base64.urlsafe_b64encode(raw).decode().rstrip("=")
|
payload = base64.urlsafe_b64encode(raw).decode().rstrip("=")
|
||||||
return f"h.{payload}.s"
|
return f"h.{payload}.s"
|
||||||
|
|
||||||
@@ -30,53 +35,7 @@ def _jpeg_bytes() -> bytes:
|
|||||||
|
|
||||||
|
|
||||||
class _FakeInfVideo:
|
class _FakeInfVideo:
|
||||||
def run_detect_video(
|
def run_detect_image(self, image_bytes, ai_cfg, media_name, on_annotation, *args, **kwargs):
|
||||||
self,
|
|
||||||
video_bytes,
|
|
||||||
ai_cfg,
|
|
||||||
media_name,
|
|
||||||
save_path,
|
|
||||||
on_annotation,
|
|
||||||
status_callback=None,
|
|
||||||
):
|
|
||||||
writer = inference_mod._write_video_bytes_to_path
|
|
||||||
ev = threading.Event()
|
|
||||||
t = threading.Thread(
|
|
||||||
target=writer,
|
|
||||||
args=(save_path, video_bytes, ev),
|
|
||||||
)
|
|
||||||
t.start()
|
|
||||||
ev.wait(timeout=60)
|
|
||||||
t.join(timeout=60)
|
|
||||||
|
|
||||||
def run_detect_image(self, *args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeInfVideoConcurrent:
|
|
||||||
def run_detect_video(
|
|
||||||
self,
|
|
||||||
video_bytes,
|
|
||||||
ai_cfg,
|
|
||||||
media_name,
|
|
||||||
save_path,
|
|
||||||
on_annotation,
|
|
||||||
status_callback=None,
|
|
||||||
):
|
|
||||||
holder = {}
|
|
||||||
writer = inference_mod._write_video_bytes_to_path
|
|
||||||
|
|
||||||
def write():
|
|
||||||
holder["tid"] = threading.get_ident()
|
|
||||||
writer(save_path, video_bytes, threading.Event())
|
|
||||||
|
|
||||||
t = threading.Thread(target=write)
|
|
||||||
t.start()
|
|
||||||
holder["caller_tid"] = threading.get_ident()
|
|
||||||
t.join(timeout=60)
|
|
||||||
assert holder["tid"] != holder["caller_tid"]
|
|
||||||
|
|
||||||
def run_detect_image(self, *args, **kwargs):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -89,92 +48,8 @@ def reset_main_inference():
|
|||||||
main.inference = None
|
main.inference = None
|
||||||
|
|
||||||
|
|
||||||
def test_auth_video_storage_path_opened_wb_once(reset_main_inference):
|
|
||||||
# Arrange
|
|
||||||
import main
|
|
||||||
from media_hash import compute_media_content_hash
|
|
||||||
|
|
||||||
write_paths = []
|
|
||||||
real_write = inference_mod._write_video_bytes_to_path
|
|
||||||
|
|
||||||
def tracking_write(path, data, ev):
|
|
||||||
write_paths.append(os.path.abspath(str(path)))
|
|
||||||
return real_write(path, data, ev)
|
|
||||||
|
|
||||||
video_body = b"vid-bytes-" * 20
|
|
||||||
token = _access_jwt()
|
|
||||||
mock_post = MagicMock()
|
|
||||||
mock_post.return_value.status_code = 201
|
|
||||||
mock_put = MagicMock()
|
|
||||||
mock_put.return_value.status_code = 204
|
|
||||||
with tempfile.TemporaryDirectory() as vd:
|
|
||||||
os.environ["VIDEOS_DIR"] = vd
|
|
||||||
content_hash = compute_media_content_hash(video_body)
|
|
||||||
expected_path = os.path.abspath(os.path.join(vd, f"{content_hash}.mp4"))
|
|
||||||
client = TestClient(main.app)
|
|
||||||
with (
|
|
||||||
patch.object(inference_mod, "_write_video_bytes_to_path", tracking_write),
|
|
||||||
patch.object(main.http_requests, "post", mock_post),
|
|
||||||
patch.object(main.http_requests, "put", mock_put),
|
|
||||||
patch.object(main, "get_inference", return_value=_FakeInfVideo()),
|
|
||||||
):
|
|
||||||
# Act
|
|
||||||
r = client.post(
|
|
||||||
"/detect",
|
|
||||||
files={"file": ("clip.mp4", video_body, "video/mp4")},
|
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
|
||||||
)
|
|
||||||
# Assert
|
|
||||||
assert r.status_code == 200
|
|
||||||
assert write_paths.count(expected_path) == 1
|
|
||||||
with open(expected_path, "rb") as f:
|
|
||||||
assert f.read() == video_body
|
|
||||||
assert mock_post.called
|
|
||||||
assert mock_put.call_count >= 2
|
|
||||||
|
|
||||||
|
|
||||||
def test_non_auth_temp_video_opened_wb_once_and_removed(reset_main_inference):
|
|
||||||
# Arrange
|
|
||||||
import main
|
|
||||||
|
|
||||||
write_paths = []
|
|
||||||
real_write = inference_mod._write_video_bytes_to_path
|
|
||||||
|
|
||||||
def tracking_write(path, data, ev):
|
|
||||||
write_paths.append(os.path.abspath(str(path)))
|
|
||||||
return real_write(path, data, ev)
|
|
||||||
|
|
||||||
video_body = b"tmp-vid-" * 30
|
|
||||||
client = TestClient(main.app)
|
|
||||||
tmp_path_holder = []
|
|
||||||
|
|
||||||
class _CaptureTmp(_FakeInfVideo):
|
|
||||||
def run_detect_video(self, video_bytes, ai_cfg, media_name, save_path, on_annotation, status_callback=None):
|
|
||||||
tmp_path_holder.append(os.path.abspath(str(save_path)))
|
|
||||||
super().run_detect_video(
|
|
||||||
video_bytes, ai_cfg, media_name, save_path, on_annotation, status_callback
|
|
||||||
)
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(inference_mod, "_write_video_bytes_to_path", tracking_write),
|
|
||||||
patch.object(main, "get_inference", return_value=_CaptureTmp()),
|
|
||||||
):
|
|
||||||
# Act
|
|
||||||
r = client.post(
|
|
||||||
"/detect",
|
|
||||||
files={"file": ("n.mp4", video_body, "video/mp4")},
|
|
||||||
)
|
|
||||||
# Assert
|
|
||||||
assert r.status_code == 200
|
|
||||||
assert len(tmp_path_holder) == 1
|
|
||||||
tmp_path = tmp_path_holder[0]
|
|
||||||
assert write_paths.count(tmp_path) == 1
|
|
||||||
assert not os.path.isfile(tmp_path)
|
|
||||||
|
|
||||||
|
|
||||||
def test_auth_image_still_writes_once_before_detect(reset_main_inference):
|
def test_auth_image_still_writes_once_before_detect(reset_main_inference):
|
||||||
# Arrange
|
# Arrange
|
||||||
import builtins
|
|
||||||
import main
|
import main
|
||||||
from media_hash import compute_media_content_hash
|
from media_hash import compute_media_content_hash
|
||||||
|
|
||||||
@@ -205,7 +80,7 @@ def test_auth_image_still_writes_once_before_detect(reset_main_inference):
|
|||||||
):
|
):
|
||||||
# Act
|
# Act
|
||||||
r = client.post(
|
r = client.post(
|
||||||
"/detect",
|
"/detect/image",
|
||||||
files={"file": ("p.jpg", img, "image/jpeg")},
|
files={"file": ("p.jpg", img, "image/jpeg")},
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
)
|
)
|
||||||
@@ -214,31 +89,3 @@ def test_auth_image_still_writes_once_before_detect(reset_main_inference):
|
|||||||
assert wb_hits.count(expected_path) == 1
|
assert wb_hits.count(expected_path) == 1
|
||||||
with real_open(expected_path, "rb") as f:
|
with real_open(expected_path, "rb") as f:
|
||||||
assert f.read() == img
|
assert f.read() == img
|
||||||
|
|
||||||
|
|
||||||
def test_video_writer_runs_in_separate_thread_from_executor(reset_main_inference):
|
|
||||||
# Arrange
|
|
||||||
import main
|
|
||||||
|
|
||||||
token = _access_jwt()
|
|
||||||
mock_post = MagicMock()
|
|
||||||
mock_post.return_value.status_code = 201
|
|
||||||
mock_put = MagicMock()
|
|
||||||
mock_put.return_value.status_code = 204
|
|
||||||
video_body = b"thr-test-" * 15
|
|
||||||
with tempfile.TemporaryDirectory() as vd:
|
|
||||||
os.environ["VIDEOS_DIR"] = vd
|
|
||||||
client = TestClient(main.app)
|
|
||||||
with (
|
|
||||||
patch.object(main.http_requests, "post", mock_post),
|
|
||||||
patch.object(main.http_requests, "put", mock_put),
|
|
||||||
patch.object(main, "get_inference", return_value=_FakeInfVideoConcurrent()),
|
|
||||||
):
|
|
||||||
# Act
|
|
||||||
r = client.post(
|
|
||||||
"/detect",
|
|
||||||
files={"file": ("c.mp4", video_body, "video/mp4")},
|
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
|
||||||
)
|
|
||||||
# Assert
|
|
||||||
assert r.status_code == 200
|
|
||||||
|
|||||||
@@ -276,6 +276,14 @@ class TestMediaContentHashFromFile:
|
|||||||
|
|
||||||
|
|
||||||
def _access_jwt(sub: str = "u1") -> str:
|
def _access_jwt(sub: str = "u1") -> str:
|
||||||
|
import jwt as pyjwt
|
||||||
|
secret = os.environ.get("JWT_SECRET", "")
|
||||||
|
if secret:
|
||||||
|
return pyjwt.encode(
|
||||||
|
{"exp": int(time.time()) + 3600, "sub": sub},
|
||||||
|
secret,
|
||||||
|
algorithm="HS256",
|
||||||
|
)
|
||||||
raw = json.dumps(
|
raw = json.dumps(
|
||||||
{"exp": int(time.time()) + 3600, "sub": sub}, separators=(",", ":")
|
{"exp": int(time.time()) + 3600, "sub": sub}, separators=(",", ":")
|
||||||
).encode()
|
).encode()
|
||||||
@@ -361,7 +369,10 @@ class TestDetectVideoEndpoint:
|
|||||||
os.environ["VIDEOS_DIR"] = vd
|
os.environ["VIDEOS_DIR"] = vd
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
client = TestClient(main.app)
|
client = TestClient(main.app)
|
||||||
with patch.object(main, "get_inference", return_value=_FakeInfStream()):
|
with (
|
||||||
|
patch.object(main, "JWT_SECRET", ""),
|
||||||
|
patch.object(main, "get_inference", return_value=_FakeInfStream()),
|
||||||
|
):
|
||||||
# Act
|
# Act
|
||||||
r = client.post(
|
r = client.post(
|
||||||
"/detect/video",
|
"/detect/video",
|
||||||
@@ -379,12 +390,13 @@ class TestDetectVideoEndpoint:
|
|||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
client = TestClient(main.app)
|
client = TestClient(main.app)
|
||||||
|
|
||||||
# Act
|
# Act — patch JWT_SECRET to "" so auth does not block the extension check
|
||||||
r = client.post(
|
with patch.object(main, "JWT_SECRET", ""):
|
||||||
"/detect/video",
|
r = client.post(
|
||||||
content=b"data",
|
"/detect/video",
|
||||||
headers={"X-Filename": "photo.jpg"},
|
content=b"data",
|
||||||
)
|
headers={"X-Filename": "photo.jpg"},
|
||||||
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert r.status_code == 400
|
assert r.status_code == 400
|
||||||
@@ -411,12 +423,16 @@ class TestDetectVideoEndpoint:
|
|||||||
os.environ["VIDEOS_DIR"] = vd
|
os.environ["VIDEOS_DIR"] = vd
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
client = TestClient(main.app)
|
client = TestClient(main.app)
|
||||||
|
token = _access_jwt()
|
||||||
with patch.object(main, "get_inference", return_value=_CaptureInf()):
|
with patch.object(main, "get_inference", return_value=_CaptureInf()):
|
||||||
# Act
|
# Act
|
||||||
r = client.post(
|
r = client.post(
|
||||||
"/detect/video",
|
"/detect/video",
|
||||||
content=video_body,
|
content=video_body,
|
||||||
headers={"X-Filename": "v.mp4"},
|
headers={
|
||||||
|
"X-Filename": "v.mp4",
|
||||||
|
"Authorization": f"Bearer {token}",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
|||||||
Reference in New Issue
Block a user