Files
detections/src/main.py
T
Oleksandr Bezdieniezhnykh 097811a67b [AZ-178] Fix Critical/High security findings: auth, CVEs, non-root containers, per-job SSE
- Pin all deps; h11==0.16.0 (CVE-2025-43859), python-multipart>=1.3.1 (CVE-2026-28356), PyJWT==2.12.1
- Add HMAC JWT verification (require_auth FastAPI dependency, JWT_SECRET-gated)
- Fix TokenManager._refresh() to use ADMIN_API_URL instead of ANNOTATIONS_URL
- Rename POST /detect → POST /detect/image (image-only, rejects video files)
- Replace global SSE stream with per-job SSE: GET /detect/{media_id} with event replay buffer
- Apply require_auth to all 4 protected endpoints
- Fix on_annotation/on_status closure to use mutable current_id for correct post-upload event routing
- Add non-root appuser to Dockerfile and Dockerfile.gpu
- Add JWT_SECRET to e2e/docker-compose.test.yml and run-tests.sh
- Update all e2e tests and unit tests for new endpoints and HMAC token signing
- 64/64 tests pass

Made-with: Cursor
2026-04-02 06:32:12 +03:00

754 lines
24 KiB
Python

import asyncio
import base64
import io
import json
import os
import tempfile
import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Annotated, Optional
import av
import cv2
import jwt as pyjwt
import numpy as np
import requests as http_requests
from fastapi import Body, Depends, FastAPI, File, Form, HTTPException, Request, UploadFile
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel
from loader_http_client import LoaderHttpClient, LoadResult
app = FastAPI(title="Azaion.Detections")
executor = ThreadPoolExecutor(max_workers=2)
LOADER_URL = os.environ.get("LOADER_URL", "http://loader:8080")
ANNOTATIONS_URL = os.environ.get("ANNOTATIONS_URL", "http://annotations:8080")
JWT_SECRET = os.environ.get("JWT_SECRET", "")
ADMIN_API_URL = os.environ.get("ADMIN_API_URL", "")
_MEDIA_STATUS_NEW = 1
_MEDIA_STATUS_AI_PROCESSING = 2
_MEDIA_STATUS_AI_PROCESSED = 3
_MEDIA_STATUS_ERROR = 6
_VIDEO_EXTENSIONS = frozenset({".mp4", ".mov", ".webm", ".mkv", ".avi", ".m4v"})
_IMAGE_EXTENSIONS = frozenset({".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif", ".tiff"})
loader_client = LoaderHttpClient(LOADER_URL)
annotations_client = LoaderHttpClient(ANNOTATIONS_URL)
inference = None
_job_queues: dict[str, list[asyncio.Queue]] = {}
_job_buffers: dict[str, list[str]] = {}
_active_detections: dict[str, asyncio.Task] = {}
_bearer = HTTPBearer(auto_error=False)
async def require_auth(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(_bearer),
) -> str:
if not JWT_SECRET:
return ""
if not credentials:
raise HTTPException(status_code=401, detail="Authentication required")
try:
payload = pyjwt.decode(credentials.credentials, JWT_SECRET, algorithms=["HS256"])
except pyjwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token expired")
except pyjwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token")
return str(payload.get("sub") or payload.get("userId") or "")
class TokenManager:
def __init__(self, access_token: str, refresh_token: str):
self.access_token = access_token
self.refresh_token = refresh_token
def get_valid_token(self) -> str:
exp = self._decode_claims(self.access_token).get("exp")
if exp and float(exp) - time.time() < 60:
self._refresh()
return self.access_token
def _refresh(self):
if not ADMIN_API_URL:
return
try:
resp = http_requests.post(
f"{ADMIN_API_URL}/auth/refresh",
json={"refreshToken": self.refresh_token},
timeout=10,
)
if resp.status_code == 200:
self.access_token = resp.json()["token"]
except Exception:
pass
@staticmethod
def _decode_claims(token: str) -> dict:
try:
if JWT_SECRET:
return pyjwt.decode(token, JWT_SECRET, algorithms=["HS256"])
payload = token.split(".")[1]
padding = 4 - len(payload) % 4
if padding != 4:
payload += "=" * padding
return json.loads(base64.urlsafe_b64decode(payload))
except Exception:
return {}
@staticmethod
def decode_user_id(token: str) -> Optional[str]:
data = TokenManager._decode_claims(token)
uid = (
data.get("sub")
or data.get("userId")
or data.get("user_id")
or data.get("nameid")
or data.get(
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/nameidentifier"
)
)
if uid is None:
return None
return str(uid)
def get_inference():
global inference
if inference is None:
from inference import Inference
inference = Inference(loader_client)
return inference
class DetectionDto(BaseModel):
centerX: float
centerY: float
width: float
height: float
classNum: int
label: str
confidence: float
class DetectionEvent(BaseModel):
annotations: list[DetectionDto]
mediaId: str
mediaStatus: str
mediaPercent: int
class HealthResponse(BaseModel):
status: str
aiAvailability: str
engineType: Optional[str] = None
errorMessage: Optional[str] = None
class AIConfigDto(BaseModel):
frame_period_recognition: int = 4
frame_recognition_seconds: int = 2
probability_threshold: float = 0.25
tracking_distance_confidence: float = 0.0
tracking_probability_increase: float = 0.0
tracking_intersection_threshold: float = 0.6
model_batch_size: int = 8
big_image_tile_overlap_percent: int = 20
altitude: float = 400
focal_length: float = 24
sensor_width: float = 23.5
_AI_SETTINGS_FIELD_KEYS = (
(
"frame_period_recognition",
("frame_period_recognition", "framePeriodRecognition", "FramePeriodRecognition"),
),
(
"frame_recognition_seconds",
("frame_recognition_seconds", "frameRecognitionSeconds", "FrameRecognitionSeconds"),
),
(
"probability_threshold",
("probability_threshold", "probabilityThreshold", "ProbabilityThreshold"),
),
(
"tracking_distance_confidence",
(
"tracking_distance_confidence",
"trackingDistanceConfidence",
"TrackingDistanceConfidence",
),
),
(
"tracking_probability_increase",
(
"tracking_probability_increase",
"trackingProbabilityIncrease",
"TrackingProbabilityIncrease",
),
),
(
"tracking_intersection_threshold",
(
"tracking_intersection_threshold",
"trackingIntersectionThreshold",
"TrackingIntersectionThreshold",
),
),
(
"model_batch_size",
("model_batch_size", "modelBatchSize", "ModelBatchSize"),
),
(
"big_image_tile_overlap_percent",
(
"big_image_tile_overlap_percent",
"bigImageTileOverlapPercent",
"BigImageTileOverlapPercent",
),
),
(
"altitude",
("altitude", "Altitude"),
),
(
"focal_length",
("focal_length", "focalLength", "FocalLength"),
),
(
"sensor_width",
("sensor_width", "sensorWidth", "SensorWidth"),
),
)
def _merged_annotation_settings_payload(raw: object) -> dict:
if not raw or not isinstance(raw, dict):
return {}
merged = dict(raw)
inner = raw.get("aiRecognitionSettings")
if isinstance(inner, dict):
merged.update(inner)
cam = raw.get("cameraSettings")
if isinstance(cam, dict):
merged.update(cam)
out = {}
for snake, aliases in _AI_SETTINGS_FIELD_KEYS:
for key in aliases:
if key in merged and merged[key] is not None:
out[snake] = merged[key]
break
return out
def _normalize_upload_ext(filename: str) -> str:
s = Path(filename or "").suffix.lower()
return s if s else ""
def _is_video_media_path(media_path: str) -> bool:
return Path(media_path).suffix.lower() in _VIDEO_EXTENSIONS
def _post_media_record(payload: dict, bearer: str) -> bool:
try:
headers = {"Authorization": f"Bearer {bearer}"}
r = http_requests.post(
f"{ANNOTATIONS_URL}/api/media",
json=payload,
headers=headers,
timeout=30,
)
return r.status_code in (200, 201)
except Exception:
return False
def _put_media_status(media_id: str, media_status: int, bearer: str) -> bool:
try:
headers = {"Authorization": f"Bearer {bearer}"}
r = http_requests.put(
f"{ANNOTATIONS_URL}/api/media/{media_id}/status",
json={"mediaStatus": media_status},
headers=headers,
timeout=30,
)
return r.status_code in (200, 204)
except Exception:
return False
def _resolve_media_for_detect(
media_id: str,
token_mgr: Optional[TokenManager],
override: Optional[AIConfigDto],
) -> tuple[dict, str]:
cfg: dict = {}
bearer = ""
if token_mgr:
bearer = token_mgr.get_valid_token()
uid = TokenManager.decode_user_id(token_mgr.access_token)
if uid:
raw = annotations_client.fetch_user_ai_settings(uid, bearer)
cfg.update(_merged_annotation_settings_payload(raw))
if override is not None:
for k, v in override.model_dump(exclude_defaults=True).items():
cfg[k] = v
media_path = annotations_client.fetch_media_path(media_id, bearer)
if not media_path:
raise HTTPException(
status_code=503,
detail="Could not resolve media path from annotations service",
)
return cfg, media_path
def detection_to_dto(det) -> DetectionDto:
import constants_inf
label = constants_inf.get_annotation_name(det.cls)
return DetectionDto(
centerX=det.x,
centerY=det.y,
width=det.w,
height=det.h,
classNum=det.cls,
label=label,
confidence=det.confidence,
)
def _enqueue(media_id: str, event: DetectionEvent):
data = event.model_dump_json()
_job_buffers.setdefault(media_id, []).append(data)
for q in _job_queues.get(media_id, []):
try:
q.put_nowait(data)
except asyncio.QueueFull:
pass
def _schedule_buffer_cleanup(media_id: str, delay: float = 300.0):
loop = asyncio.get_event_loop()
loop.call_later(delay, lambda: _job_buffers.pop(media_id, None))
@app.get("/health")
def health() -> HealthResponse:
if inference is None:
return HealthResponse(status="healthy", aiAvailability="None")
try:
status = inference.ai_availability_status
status_str = str(status).split()[0] if str(status).strip() else "None"
error_msg = status.error_message if hasattr(status, 'error_message') else None
engine_type = inference.engine_name
return HealthResponse(
status="healthy",
aiAvailability=status_str,
engineType=engine_type,
errorMessage=error_msg,
)
except Exception as e:
return HealthResponse(
status="healthy",
aiAvailability="None",
errorMessage=str(e),
)
@app.post("/detect/image")
async def detect_image(
request: Request,
file: UploadFile = File(...),
config: Optional[str] = Form(None),
user_id: str = Depends(require_auth),
):
from media_hash import compute_media_content_hash
from inference import ai_config_from_dict
image_bytes = await file.read()
if not image_bytes:
raise HTTPException(status_code=400, detail="Image is empty")
orig_name = file.filename or "upload"
ext = _normalize_upload_ext(orig_name)
if ext and ext not in _IMAGE_EXTENSIONS:
raise HTTPException(status_code=400, detail="Expected an image file")
arr = np.frombuffer(image_bytes, dtype=np.uint8)
if cv2.imdecode(arr, cv2.IMREAD_COLOR) is None:
raise HTTPException(status_code=400, detail="Invalid image data")
config_dict = {}
if config:
config_dict = json.loads(config)
refresh_token = request.headers.get("x-refresh-token", "")
access_token = request.headers.get("authorization", "").removeprefix("Bearer ").strip()
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
images_dir = os.environ.get(
"IMAGES_DIR", os.path.join(os.getcwd(), "data", "images")
)
storage_path = None
content_hash = None
if token_mgr and user_id:
content_hash = compute_media_content_hash(image_bytes)
os.makedirs(images_dir, exist_ok=True)
save_ext = ext if ext.startswith(".") else f".{ext}" if ext else ".jpg"
storage_path = os.path.abspath(os.path.join(images_dir, f"{content_hash}{save_ext}"))
with open(storage_path, "wb") as out:
out.write(image_bytes)
payload = {
"id": content_hash,
"name": Path(orig_name).name,
"path": storage_path,
"mediaType": "Image",
"mediaStatus": _MEDIA_STATUS_NEW,
"userId": user_id,
}
bearer = token_mgr.get_valid_token()
_post_media_record(payload, bearer)
_put_media_status(content_hash, _MEDIA_STATUS_AI_PROCESSING, bearer)
media_name = Path(orig_name).stem.replace(" ", "")
loop = asyncio.get_event_loop()
inf = get_inference()
results = []
def on_annotation(annotation, percent):
results.extend(annotation.detections)
ai_cfg = ai_config_from_dict(config_dict)
def run_detect():
inf.run_detect_image(image_bytes, ai_cfg, media_name, on_annotation)
try:
await loop.run_in_executor(executor, run_detect)
if token_mgr and user_id and content_hash:
_put_media_status(
content_hash, _MEDIA_STATUS_AI_PROCESSED, token_mgr.get_valid_token()
)
return [detection_to_dto(d) for d in results]
except RuntimeError as e:
if token_mgr and user_id and content_hash:
_put_media_status(
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
)
if "not available" in str(e):
raise HTTPException(status_code=503, detail=str(e))
raise HTTPException(status_code=422, detail=str(e))
except ValueError as e:
if token_mgr and user_id and content_hash:
_put_media_status(
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
)
raise HTTPException(status_code=400, detail=str(e))
except Exception:
if token_mgr and user_id and content_hash:
_put_media_status(
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
)
raise
@app.post("/detect/video")
async def detect_video_upload(
request: Request,
user_id: str = Depends(require_auth),
):
from media_hash import compute_media_content_hash_from_file
from inference import ai_config_from_dict
from streaming_buffer import StreamingBuffer
filename = request.headers.get("x-filename", "upload.mp4")
config_json = request.headers.get("x-config", "")
ext = _normalize_upload_ext(filename)
if ext not in _VIDEO_EXTENSIONS:
raise HTTPException(status_code=400, detail="Expected a video file extension")
config_dict = json.loads(config_json) if config_json else {}
ai_cfg = ai_config_from_dict(config_dict)
refresh_token = request.headers.get("x-refresh-token", "")
access_token = request.headers.get("authorization", "").removeprefix("Bearer ").strip()
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
videos_dir = os.environ.get(
"VIDEOS_DIR", os.path.join(os.getcwd(), "data", "videos")
)
os.makedirs(videos_dir, exist_ok=True)
content_length = request.headers.get("content-length")
total_size = int(content_length) if content_length else None
buffer = StreamingBuffer(videos_dir, total_size=total_size)
media_name = Path(filename).stem.replace(" ", "")
loop = asyncio.get_event_loop()
inf = get_inference()
placeholder_id = f"tmp_{os.path.basename(buffer.path)}"
current_id = [placeholder_id] # mutable — updated to content_hash after upload
def on_annotation(annotation, percent):
dtos = [detection_to_dto(d) for d in annotation.detections]
mid = current_id[0]
event = DetectionEvent(
annotations=dtos,
mediaId=mid,
mediaStatus="AIProcessing",
mediaPercent=percent,
)
loop.call_soon_threadsafe(_enqueue, mid, event)
def on_status(media_name_cb, count):
mid = current_id[0]
event = DetectionEvent(
annotations=[],
mediaId=mid,
mediaStatus="AIProcessed",
mediaPercent=100,
)
loop.call_soon_threadsafe(_enqueue, mid, event)
def run_inference():
inf.run_detect_video_stream(buffer, ai_cfg, media_name, on_annotation, on_status)
inference_future = loop.run_in_executor(executor, run_inference)
try:
async for chunk in request.stream():
await loop.run_in_executor(None, buffer.append, chunk)
except Exception:
buffer.close_writer()
buffer.close()
raise
buffer.close_writer()
content_hash = compute_media_content_hash_from_file(buffer.path)
if not ext.startswith("."):
ext = "." + ext
storage_path = os.path.abspath(os.path.join(videos_dir, f"{content_hash}{ext}"))
# Re-key buffered events from placeholder_id to content_hash so clients
# can subscribe to GET /detect/{content_hash} after POST returns.
if placeholder_id in _job_buffers:
_job_buffers[content_hash] = _job_buffers.pop(placeholder_id)
if placeholder_id in _job_queues:
_job_queues[content_hash] = _job_queues.pop(placeholder_id)
current_id[0] = content_hash # future on_annotation/on_status callbacks use content_hash
if token_mgr and user_id:
os.rename(buffer.path, storage_path)
payload = {
"id": content_hash,
"name": Path(filename).name,
"path": storage_path,
"mediaType": "Video",
"mediaStatus": _MEDIA_STATUS_NEW,
"userId": user_id,
}
bearer = token_mgr.get_valid_token()
_post_media_record(payload, bearer)
_put_media_status(content_hash, _MEDIA_STATUS_AI_PROCESSING, bearer)
async def _wait_inference():
try:
await inference_future
if token_mgr and user_id:
_put_media_status(
content_hash, _MEDIA_STATUS_AI_PROCESSED,
token_mgr.get_valid_token(),
)
done_event = DetectionEvent(
annotations=[],
mediaId=content_hash,
mediaStatus="AIProcessed",
mediaPercent=100,
)
_enqueue(content_hash, done_event)
except Exception:
if token_mgr and user_id:
_put_media_status(
content_hash, _MEDIA_STATUS_ERROR,
token_mgr.get_valid_token(),
)
err_event = DetectionEvent(
annotations=[], mediaId=content_hash,
mediaStatus="Error", mediaPercent=0,
)
_enqueue(content_hash, err_event)
finally:
_active_detections.pop(content_hash, None)
_schedule_buffer_cleanup(content_hash)
buffer.close()
if not (token_mgr and user_id) and os.path.isfile(buffer.path):
try:
os.unlink(buffer.path)
except OSError:
pass
_active_detections[content_hash] = asyncio.create_task(_wait_inference())
return {"status": "started", "mediaId": content_hash}
def _post_annotation_to_service(token_mgr: TokenManager, media_id: str,
annotation, dtos: list[DetectionDto]):
try:
token = token_mgr.get_valid_token()
image_b64 = base64.b64encode(annotation.image).decode() if annotation.image else None
payload = {
"mediaId": media_id,
"source": 0,
"videoTime": f"00:00:{annotation.time // 1000:02d}" if annotation.time else "00:00:00",
"detections": [d.model_dump() for d in dtos],
}
if image_b64:
payload["image"] = image_b64
http_requests.post(
f"{ANNOTATIONS_URL}/annotations",
json=payload,
headers={"Authorization": f"Bearer {token}"},
timeout=30,
)
except Exception:
pass
@app.post("/detect/{media_id}")
async def detect_media(
media_id: str,
request: Request,
config: Annotated[Optional[AIConfigDto], Body()] = None,
user_id: str = Depends(require_auth),
):
existing = _active_detections.get(media_id)
if existing is not None and not existing.done():
raise HTTPException(status_code=409, detail="Detection already in progress for this media")
refresh_token = request.headers.get("x-refresh-token", "")
access_token = request.headers.get("authorization", "").removeprefix("Bearer ").strip()
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
config_dict, media_path = _resolve_media_for_detect(media_id, token_mgr, config)
async def run_detection():
loop = asyncio.get_event_loop()
try:
from inference import ai_config_from_dict
if token_mgr:
_put_media_status(
media_id,
_MEDIA_STATUS_AI_PROCESSING,
token_mgr.get_valid_token(),
)
with open(media_path, "rb") as mf:
file_bytes = mf.read()
video = _is_video_media_path(media_path)
stem_name = Path(media_path).stem.replace(" ", "")
ai_cfg = ai_config_from_dict(config_dict)
inf = get_inference()
if not inf.is_engine_ready:
raise RuntimeError("Detection service unavailable")
def on_annotation(annotation, percent):
dtos = [detection_to_dto(d) for d in annotation.detections]
event = DetectionEvent(
annotations=dtos,
mediaId=media_id,
mediaStatus="AIProcessing",
mediaPercent=percent,
)
loop.call_soon_threadsafe(_enqueue, media_id, event)
if token_mgr and dtos:
_post_annotation_to_service(token_mgr, media_id, annotation, dtos)
def on_status(media_name_cb, count):
event = DetectionEvent(
annotations=[],
mediaId=media_id,
mediaStatus="AIProcessed",
mediaPercent=100,
)
loop.call_soon_threadsafe(_enqueue, media_id, event)
if token_mgr:
_put_media_status(
media_id,
_MEDIA_STATUS_AI_PROCESSED,
token_mgr.get_valid_token(),
)
def run_sync():
if video:
inf.run_detect_video(
file_bytes,
ai_cfg,
stem_name,
media_path,
on_annotation,
on_status,
)
else:
inf.run_detect_image(
file_bytes,
ai_cfg,
stem_name,
on_annotation,
on_status,
)
await loop.run_in_executor(executor, run_sync)
except Exception:
if token_mgr:
_put_media_status(
media_id, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
)
error_event = DetectionEvent(
annotations=[],
mediaId=media_id,
mediaStatus="Error",
mediaPercent=0,
)
_enqueue(media_id, error_event)
finally:
_active_detections.pop(media_id, None)
_schedule_buffer_cleanup(media_id)
_active_detections[media_id] = asyncio.create_task(run_detection())
return {"status": "started", "mediaId": media_id}
@app.get("/detect/{media_id}", dependencies=[Depends(require_auth)])
async def detect_events(media_id: str):
queue: asyncio.Queue = asyncio.Queue(maxsize=200)
_job_queues.setdefault(media_id, []).append(queue)
async def event_generator():
try:
for data in list(_job_buffers.get(media_id, [])):
yield f"data: {data}\n\n"
while True:
data = await queue.get()
yield f"data: {data}\n\n"
except asyncio.CancelledError:
pass
finally:
queues = _job_queues.get(media_id, [])
if queue in queues:
queues.remove(queue)
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)