mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 06:46:32 +00:00
[AZ-178] Fix Critical/High security findings: auth, CVEs, non-root containers, per-job SSE
- Pin all deps; h11==0.16.0 (CVE-2025-43859), python-multipart>=1.3.1 (CVE-2026-28356), PyJWT==2.12.1
- Add HMAC JWT verification (require_auth FastAPI dependency, JWT_SECRET-gated)
- Fix TokenManager._refresh() to use ADMIN_API_URL instead of ANNOTATIONS_URL
- Rename POST /detect → POST /detect/image (image-only, rejects video files)
- Replace global SSE stream with per-job SSE: GET /detect/{media_id} with event replay buffer
- Apply require_auth to all 4 protected endpoints
- Fix on_annotation/on_status closure to use mutable current_id for correct post-upload event routing
- Add non-root appuser to Dockerfile and Dockerfile.gpu
- Add JWT_SECRET to e2e/docker-compose.test.yml and run-tests.sh
- Update all e2e tests and unit tests for new endpoints and HMAC token signing
- 64/64 tests pass
Made-with: Cursor
This commit is contained in:
+119
-124
@@ -11,10 +11,12 @@ from typing import Annotated, Optional
|
||||
|
||||
import av
|
||||
import cv2
|
||||
import jwt as pyjwt
|
||||
import numpy as np
|
||||
import requests as http_requests
|
||||
from fastapi import Body, FastAPI, UploadFile, File, Form, HTTPException, Request
|
||||
from fastapi import Body, Depends, FastAPI, File, Form, HTTPException, Request, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
from loader_http_client import LoaderHttpClient, LoadResult
|
||||
@@ -24,6 +26,8 @@ executor = ThreadPoolExecutor(max_workers=2)
|
||||
|
||||
LOADER_URL = os.environ.get("LOADER_URL", "http://loader:8080")
|
||||
ANNOTATIONS_URL = os.environ.get("ANNOTATIONS_URL", "http://annotations:8080")
|
||||
JWT_SECRET = os.environ.get("JWT_SECRET", "")
|
||||
ADMIN_API_URL = os.environ.get("ADMIN_API_URL", "")
|
||||
|
||||
_MEDIA_STATUS_NEW = 1
|
||||
_MEDIA_STATUS_AI_PROCESSING = 2
|
||||
@@ -36,9 +40,28 @@ _IMAGE_EXTENSIONS = frozenset({".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif",
|
||||
loader_client = LoaderHttpClient(LOADER_URL)
|
||||
annotations_client = LoaderHttpClient(ANNOTATIONS_URL)
|
||||
inference = None
|
||||
_event_queues: list[asyncio.Queue] = []
|
||||
_job_queues: dict[str, list[asyncio.Queue]] = {}
|
||||
_job_buffers: dict[str, list[str]] = {}
|
||||
_active_detections: dict[str, asyncio.Task] = {}
|
||||
|
||||
_bearer = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def require_auth(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(_bearer),
|
||||
) -> str:
|
||||
if not JWT_SECRET:
|
||||
return ""
|
||||
if not credentials:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
try:
|
||||
payload = pyjwt.decode(credentials.credentials, JWT_SECRET, algorithms=["HS256"])
|
||||
except pyjwt.ExpiredSignatureError:
|
||||
raise HTTPException(status_code=401, detail="Token expired")
|
||||
except pyjwt.InvalidTokenError:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
return str(payload.get("sub") or payload.get("userId") or "")
|
||||
|
||||
|
||||
class TokenManager:
|
||||
def __init__(self, access_token: str, refresh_token: str):
|
||||
@@ -46,15 +69,17 @@ class TokenManager:
|
||||
self.refresh_token = refresh_token
|
||||
|
||||
def get_valid_token(self) -> str:
|
||||
exp = self._decode_exp(self.access_token)
|
||||
if exp and exp - time.time() < 60:
|
||||
exp = self._decode_claims(self.access_token).get("exp")
|
||||
if exp and float(exp) - time.time() < 60:
|
||||
self._refresh()
|
||||
return self.access_token
|
||||
|
||||
def _refresh(self):
|
||||
if not ADMIN_API_URL:
|
||||
return
|
||||
try:
|
||||
resp = http_requests.post(
|
||||
f"{ANNOTATIONS_URL}/auth/refresh",
|
||||
f"{ADMIN_API_URL}/auth/refresh",
|
||||
json={"refreshToken": self.refresh_token},
|
||||
timeout=10,
|
||||
)
|
||||
@@ -64,39 +89,33 @@ class TokenManager:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _decode_exp(token: str) -> Optional[float]:
|
||||
def _decode_claims(token: str) -> dict:
|
||||
try:
|
||||
if JWT_SECRET:
|
||||
return pyjwt.decode(token, JWT_SECRET, algorithms=["HS256"])
|
||||
payload = token.split(".")[1]
|
||||
padding = 4 - len(payload) % 4
|
||||
if padding != 4:
|
||||
payload += "=" * padding
|
||||
data = json.loads(base64.urlsafe_b64decode(payload))
|
||||
return float(data.get("exp", 0))
|
||||
return json.loads(base64.urlsafe_b64decode(payload))
|
||||
except Exception:
|
||||
return None
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def decode_user_id(token: str) -> Optional[str]:
|
||||
try:
|
||||
payload = token.split(".")[1]
|
||||
padding = 4 - len(payload) % 4
|
||||
if padding != 4:
|
||||
payload += "=" * padding
|
||||
data = json.loads(base64.urlsafe_b64decode(payload))
|
||||
uid = (
|
||||
data.get("sub")
|
||||
or data.get("userId")
|
||||
or data.get("user_id")
|
||||
or data.get("nameid")
|
||||
or data.get(
|
||||
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/nameidentifier"
|
||||
)
|
||||
data = TokenManager._decode_claims(token)
|
||||
uid = (
|
||||
data.get("sub")
|
||||
or data.get("userId")
|
||||
or data.get("user_id")
|
||||
or data.get("nameid")
|
||||
or data.get(
|
||||
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/nameidentifier"
|
||||
)
|
||||
if uid is None:
|
||||
return None
|
||||
return str(uid)
|
||||
except Exception:
|
||||
)
|
||||
if uid is None:
|
||||
return None
|
||||
return str(uid)
|
||||
|
||||
|
||||
def get_inference():
|
||||
@@ -233,24 +252,6 @@ def _normalize_upload_ext(filename: str) -> str:
|
||||
return s if s else ""
|
||||
|
||||
|
||||
def _detect_upload_kind(filename: str, data: bytes) -> tuple[str, str]:
|
||||
ext = _normalize_upload_ext(filename)
|
||||
if ext in _VIDEO_EXTENSIONS:
|
||||
return "video", ext
|
||||
if ext in _IMAGE_EXTENSIONS:
|
||||
return "image", ext
|
||||
arr = np.frombuffer(data, dtype=np.uint8)
|
||||
if cv2.imdecode(arr, cv2.IMREAD_COLOR) is not None:
|
||||
return "image", ext if ext else ".jpg"
|
||||
try:
|
||||
bio = io.BytesIO(data)
|
||||
with av.open(bio):
|
||||
pass
|
||||
return "video", ext if ext else ".mp4"
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="Invalid image or video data")
|
||||
|
||||
|
||||
def _is_video_media_path(media_path: str) -> bool:
|
||||
return Path(media_path).suffix.lower() in _VIDEO_EXTENSIONS
|
||||
|
||||
@@ -322,6 +323,21 @@ def detection_to_dto(det) -> DetectionDto:
|
||||
)
|
||||
|
||||
|
||||
def _enqueue(media_id: str, event: DetectionEvent):
|
||||
data = event.model_dump_json()
|
||||
_job_buffers.setdefault(media_id, []).append(data)
|
||||
for q in _job_queues.get(media_id, []):
|
||||
try:
|
||||
q.put_nowait(data)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
|
||||
|
||||
def _schedule_buffer_cleanup(media_id: str, delay: float = 300.0):
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.call_later(delay, lambda: _job_buffers.pop(media_id, None))
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> HealthResponse:
|
||||
if inference is None:
|
||||
@@ -345,11 +361,12 @@ def health() -> HealthResponse:
|
||||
)
|
||||
|
||||
|
||||
@app.post("/detect")
|
||||
@app.post("/detect/image")
|
||||
async def detect_image(
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
config: Optional[str] = Form(None),
|
||||
user_id: str = Depends(require_auth),
|
||||
):
|
||||
from media_hash import compute_media_content_hash
|
||||
from inference import ai_config_from_dict
|
||||
@@ -359,26 +376,22 @@ async def detect_image(
|
||||
raise HTTPException(status_code=400, detail="Image is empty")
|
||||
|
||||
orig_name = file.filename or "upload"
|
||||
kind, ext = _detect_upload_kind(orig_name, image_bytes)
|
||||
ext = _normalize_upload_ext(orig_name)
|
||||
if ext and ext not in _IMAGE_EXTENSIONS:
|
||||
raise HTTPException(status_code=400, detail="Expected an image file")
|
||||
|
||||
if kind == "image":
|
||||
arr = np.frombuffer(image_bytes, dtype=np.uint8)
|
||||
if cv2.imdecode(arr, cv2.IMREAD_COLOR) is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid image data")
|
||||
arr = np.frombuffer(image_bytes, dtype=np.uint8)
|
||||
if cv2.imdecode(arr, cv2.IMREAD_COLOR) is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid image data")
|
||||
|
||||
config_dict = {}
|
||||
if config:
|
||||
config_dict = json.loads(config)
|
||||
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
access_token = auth_header.removeprefix("Bearer ").strip() if auth_header else ""
|
||||
refresh_token = request.headers.get("x-refresh-token", "")
|
||||
access_token = request.headers.get("authorization", "").removeprefix("Bearer ").strip()
|
||||
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
|
||||
user_id = TokenManager.decode_user_id(access_token) if access_token else None
|
||||
|
||||
videos_dir = os.environ.get(
|
||||
"VIDEOS_DIR", os.path.join(os.getcwd(), "data", "videos")
|
||||
)
|
||||
images_dir = os.environ.get(
|
||||
"IMAGES_DIR", os.path.join(os.getcwd(), "data", "images")
|
||||
)
|
||||
@@ -386,20 +399,16 @@ async def detect_image(
|
||||
content_hash = None
|
||||
if token_mgr and user_id:
|
||||
content_hash = compute_media_content_hash(image_bytes)
|
||||
base = videos_dir if kind == "video" else images_dir
|
||||
os.makedirs(base, exist_ok=True)
|
||||
if not ext.startswith("."):
|
||||
ext = "." + ext
|
||||
storage_path = os.path.abspath(os.path.join(base, f"{content_hash}{ext}"))
|
||||
if kind == "image":
|
||||
with open(storage_path, "wb") as out:
|
||||
out.write(image_bytes)
|
||||
mt = "Video" if kind == "video" else "Image"
|
||||
os.makedirs(images_dir, exist_ok=True)
|
||||
save_ext = ext if ext.startswith(".") else f".{ext}" if ext else ".jpg"
|
||||
storage_path = os.path.abspath(os.path.join(images_dir, f"{content_hash}{save_ext}"))
|
||||
with open(storage_path, "wb") as out:
|
||||
out.write(image_bytes)
|
||||
payload = {
|
||||
"id": content_hash,
|
||||
"name": Path(orig_name).name,
|
||||
"path": storage_path,
|
||||
"mediaType": mt,
|
||||
"mediaType": "Image",
|
||||
"mediaStatus": _MEDIA_STATUS_NEW,
|
||||
"userId": user_id,
|
||||
}
|
||||
@@ -411,29 +420,17 @@ async def detect_image(
|
||||
loop = asyncio.get_event_loop()
|
||||
inf = get_inference()
|
||||
results = []
|
||||
tmp_video_path = None
|
||||
|
||||
def on_annotation(annotation, percent):
|
||||
results.extend(annotation.detections)
|
||||
|
||||
ai_cfg = ai_config_from_dict(config_dict)
|
||||
|
||||
def run_upload():
|
||||
nonlocal tmp_video_path
|
||||
if kind == "video":
|
||||
if storage_path:
|
||||
save = storage_path
|
||||
else:
|
||||
suf = ext if ext.startswith(".") else ".mp4"
|
||||
fd, tmp_video_path = tempfile.mkstemp(suffix=suf)
|
||||
os.close(fd)
|
||||
save = tmp_video_path
|
||||
inf.run_detect_video(image_bytes, ai_cfg, media_name, save, on_annotation)
|
||||
else:
|
||||
inf.run_detect_image(image_bytes, ai_cfg, media_name, on_annotation)
|
||||
def run_detect():
|
||||
inf.run_detect_image(image_bytes, ai_cfg, media_name, on_annotation)
|
||||
|
||||
try:
|
||||
await loop.run_in_executor(executor, run_upload)
|
||||
await loop.run_in_executor(executor, run_detect)
|
||||
if token_mgr and user_id and content_hash:
|
||||
_put_media_status(
|
||||
content_hash, _MEDIA_STATUS_AI_PROCESSED, token_mgr.get_valid_token()
|
||||
@@ -459,16 +456,13 @@ async def detect_image(
|
||||
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
if tmp_video_path and os.path.isfile(tmp_video_path):
|
||||
try:
|
||||
os.unlink(tmp_video_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
@app.post("/detect/video")
|
||||
async def detect_video_upload(request: Request):
|
||||
async def detect_video_upload(
|
||||
request: Request,
|
||||
user_id: str = Depends(require_auth),
|
||||
):
|
||||
from media_hash import compute_media_content_hash_from_file
|
||||
from inference import ai_config_from_dict
|
||||
from streaming_buffer import StreamingBuffer
|
||||
@@ -482,11 +476,9 @@ async def detect_video_upload(request: Request):
|
||||
config_dict = json.loads(config_json) if config_json else {}
|
||||
ai_cfg = ai_config_from_dict(config_dict)
|
||||
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
access_token = auth_header.removeprefix("Bearer ").strip() if auth_header else ""
|
||||
refresh_token = request.headers.get("x-refresh-token", "")
|
||||
access_token = request.headers.get("authorization", "").removeprefix("Bearer ").strip()
|
||||
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
|
||||
user_id = TokenManager.decode_user_id(access_token) if access_token else None
|
||||
|
||||
videos_dir = os.environ.get(
|
||||
"VIDEOS_DIR", os.path.join(os.getcwd(), "data", "videos")
|
||||
@@ -499,33 +491,29 @@ async def detect_video_upload(request: Request):
|
||||
loop = asyncio.get_event_loop()
|
||||
inf = get_inference()
|
||||
|
||||
def _enqueue(event):
|
||||
for q in _event_queues:
|
||||
try:
|
||||
q.put_nowait(event)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
|
||||
placeholder_id = f"tmp_{os.path.basename(buffer.path)}"
|
||||
current_id = [placeholder_id] # mutable — updated to content_hash after upload
|
||||
|
||||
def on_annotation(annotation, percent):
|
||||
dtos = [detection_to_dto(d) for d in annotation.detections]
|
||||
mid = current_id[0]
|
||||
event = DetectionEvent(
|
||||
annotations=dtos,
|
||||
mediaId=placeholder_id,
|
||||
mediaId=mid,
|
||||
mediaStatus="AIProcessing",
|
||||
mediaPercent=percent,
|
||||
)
|
||||
loop.call_soon_threadsafe(_enqueue, event)
|
||||
loop.call_soon_threadsafe(_enqueue, mid, event)
|
||||
|
||||
def on_status(media_name_cb, count):
|
||||
mid = current_id[0]
|
||||
event = DetectionEvent(
|
||||
annotations=[],
|
||||
mediaId=placeholder_id,
|
||||
mediaId=mid,
|
||||
mediaStatus="AIProcessed",
|
||||
mediaPercent=100,
|
||||
)
|
||||
loop.call_soon_threadsafe(_enqueue, event)
|
||||
loop.call_soon_threadsafe(_enqueue, mid, event)
|
||||
|
||||
def run_inference():
|
||||
inf.run_detect_video_stream(buffer, ai_cfg, media_name, on_annotation, on_status)
|
||||
@@ -546,6 +534,14 @@ async def detect_video_upload(request: Request):
|
||||
ext = "." + ext
|
||||
storage_path = os.path.abspath(os.path.join(videos_dir, f"{content_hash}{ext}"))
|
||||
|
||||
# Re-key buffered events from placeholder_id to content_hash so clients
|
||||
# can subscribe to GET /detect/{content_hash} after POST returns.
|
||||
if placeholder_id in _job_buffers:
|
||||
_job_buffers[content_hash] = _job_buffers.pop(placeholder_id)
|
||||
if placeholder_id in _job_queues:
|
||||
_job_queues[content_hash] = _job_queues.pop(placeholder_id)
|
||||
current_id[0] = content_hash # future on_annotation/on_status callbacks use content_hash
|
||||
|
||||
if token_mgr and user_id:
|
||||
os.rename(buffer.path, storage_path)
|
||||
payload = {
|
||||
@@ -574,7 +570,7 @@ async def detect_video_upload(request: Request):
|
||||
mediaStatus="AIProcessed",
|
||||
mediaPercent=100,
|
||||
)
|
||||
_enqueue(done_event)
|
||||
_enqueue(content_hash, done_event)
|
||||
except Exception:
|
||||
if token_mgr and user_id:
|
||||
_put_media_status(
|
||||
@@ -585,9 +581,10 @@ async def detect_video_upload(request: Request):
|
||||
annotations=[], mediaId=content_hash,
|
||||
mediaStatus="Error", mediaPercent=0,
|
||||
)
|
||||
_enqueue(err_event)
|
||||
_enqueue(content_hash, err_event)
|
||||
finally:
|
||||
_active_detections.pop(content_hash, None)
|
||||
_schedule_buffer_cleanup(content_hash)
|
||||
buffer.close()
|
||||
if not (token_mgr and user_id) and os.path.isfile(buffer.path):
|
||||
try:
|
||||
@@ -627,14 +624,14 @@ async def detect_media(
|
||||
media_id: str,
|
||||
request: Request,
|
||||
config: Annotated[Optional[AIConfigDto], Body()] = None,
|
||||
user_id: str = Depends(require_auth),
|
||||
):
|
||||
existing = _active_detections.get(media_id)
|
||||
if existing is not None and not existing.done():
|
||||
raise HTTPException(status_code=409, detail="Detection already in progress for this media")
|
||||
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
access_token = auth_header.removeprefix("Bearer ").strip() if auth_header else ""
|
||||
refresh_token = request.headers.get("x-refresh-token", "")
|
||||
access_token = request.headers.get("authorization", "").removeprefix("Bearer ").strip()
|
||||
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
|
||||
|
||||
config_dict, media_path = _resolve_media_for_detect(media_id, token_mgr, config)
|
||||
@@ -642,13 +639,6 @@ async def detect_media(
|
||||
async def run_detection():
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _enqueue(event):
|
||||
for q in _event_queues:
|
||||
try:
|
||||
q.put_nowait(event)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
|
||||
try:
|
||||
from inference import ai_config_from_dict
|
||||
|
||||
@@ -678,7 +668,7 @@ async def detect_media(
|
||||
mediaStatus="AIProcessing",
|
||||
mediaPercent=percent,
|
||||
)
|
||||
loop.call_soon_threadsafe(_enqueue, event)
|
||||
loop.call_soon_threadsafe(_enqueue, media_id, event)
|
||||
if token_mgr and dtos:
|
||||
_post_annotation_to_service(token_mgr, media_id, annotation, dtos)
|
||||
|
||||
@@ -689,7 +679,7 @@ async def detect_media(
|
||||
mediaStatus="AIProcessed",
|
||||
mediaPercent=100,
|
||||
)
|
||||
loop.call_soon_threadsafe(_enqueue, event)
|
||||
loop.call_soon_threadsafe(_enqueue, media_id, event)
|
||||
if token_mgr:
|
||||
_put_media_status(
|
||||
media_id,
|
||||
@@ -728,28 +718,33 @@ async def detect_media(
|
||||
mediaStatus="Error",
|
||||
mediaPercent=0,
|
||||
)
|
||||
_enqueue(error_event)
|
||||
_enqueue(media_id, error_event)
|
||||
finally:
|
||||
_active_detections.pop(media_id, None)
|
||||
_schedule_buffer_cleanup(media_id)
|
||||
|
||||
_active_detections[media_id] = asyncio.create_task(run_detection())
|
||||
return {"status": "started", "mediaId": media_id}
|
||||
|
||||
|
||||
@app.get("/detect/stream")
|
||||
async def detect_stream():
|
||||
queue: asyncio.Queue = asyncio.Queue(maxsize=100)
|
||||
_event_queues.append(queue)
|
||||
@app.get("/detect/{media_id}", dependencies=[Depends(require_auth)])
|
||||
async def detect_events(media_id: str):
|
||||
queue: asyncio.Queue = asyncio.Queue(maxsize=200)
|
||||
_job_queues.setdefault(media_id, []).append(queue)
|
||||
|
||||
async def event_generator():
|
||||
try:
|
||||
for data in list(_job_buffers.get(media_id, [])):
|
||||
yield f"data: {data}\n\n"
|
||||
while True:
|
||||
event = await queue.get()
|
||||
yield f"data: {event.model_dump_json()}\n\n"
|
||||
data = await queue.get()
|
||||
yield f"data: {data}\n\n"
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
_event_queues.remove(queue)
|
||||
queues = _job_queues.get(media_id, [])
|
||||
if queue in queues:
|
||||
queues.remove(queue)
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
|
||||
Reference in New Issue
Block a user