import asyncio import base64 import io import json import os import tempfile import time from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Annotated, Optional import av import cv2 import jwt as pyjwt import numpy as np import requests as http_requests from fastapi import Body, Depends, FastAPI, File, Form, HTTPException, Request, UploadFile from fastapi.responses import StreamingResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from pydantic import BaseModel from loader_http_client import LoaderHttpClient, LoadResult app = FastAPI(title="Azaion.Detections") executor = ThreadPoolExecutor(max_workers=2) LOADER_URL = os.environ.get("LOADER_URL", "http://loader:8080") ANNOTATIONS_URL = os.environ.get("ANNOTATIONS_URL", "http://annotations:8080") JWT_SECRET = os.environ.get("JWT_SECRET", "") ADMIN_API_URL = os.environ.get("ADMIN_API_URL", "") _MEDIA_STATUS_NEW = 1 _MEDIA_STATUS_AI_PROCESSING = 2 _MEDIA_STATUS_AI_PROCESSED = 3 _MEDIA_STATUS_ERROR = 6 _VIDEO_EXTENSIONS = frozenset({".mp4", ".mov", ".webm", ".mkv", ".avi", ".m4v"}) _IMAGE_EXTENSIONS = frozenset({".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif", ".tiff"}) loader_client = LoaderHttpClient(LOADER_URL) annotations_client = LoaderHttpClient(ANNOTATIONS_URL) inference = None _job_queues: dict[str, list[asyncio.Queue]] = {} _job_buffers: dict[str, list[str]] = {} _active_detections: dict[str, asyncio.Task] = {} _bearer = HTTPBearer(auto_error=False) async def require_auth( credentials: Optional[HTTPAuthorizationCredentials] = Depends(_bearer), ) -> str: if not JWT_SECRET: return "" if not credentials: raise HTTPException(status_code=401, detail="Authentication required") try: payload = pyjwt.decode(credentials.credentials, JWT_SECRET, algorithms=["HS256"]) except pyjwt.ExpiredSignatureError: raise HTTPException(status_code=401, detail="Token expired") except pyjwt.InvalidTokenError: raise HTTPException(status_code=401, detail="Invalid token") return str(payload.get("sub") or payload.get("userId") or "") class TokenManager: def __init__(self, access_token: str, refresh_token: str): self.access_token = access_token self.refresh_token = refresh_token def get_valid_token(self) -> str: exp = self._decode_claims(self.access_token).get("exp") if exp and float(exp) - time.time() < 60: self._refresh() return self.access_token def _refresh(self): if not ADMIN_API_URL: return try: resp = http_requests.post( f"{ADMIN_API_URL}/auth/refresh", json={"refreshToken": self.refresh_token}, timeout=10, ) if resp.status_code == 200: self.access_token = resp.json()["token"] except Exception: pass @staticmethod def _decode_claims(token: str) -> dict: try: if JWT_SECRET: return pyjwt.decode(token, JWT_SECRET, algorithms=["HS256"]) payload = token.split(".")[1] padding = 4 - len(payload) % 4 if padding != 4: payload += "=" * padding return json.loads(base64.urlsafe_b64decode(payload)) except Exception: return {} @staticmethod def decode_user_id(token: str) -> Optional[str]: data = TokenManager._decode_claims(token) uid = ( data.get("sub") or data.get("userId") or data.get("user_id") or data.get("nameid") or data.get( "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/nameidentifier" ) ) if uid is None: return None return str(uid) def get_inference(): global inference if inference is None: from inference import Inference inference = Inference(loader_client) return inference class DetectionDto(BaseModel): centerX: float centerY: float width: float height: float classNum: int label: str confidence: float class DetectionEvent(BaseModel): annotations: list[DetectionDto] mediaId: str mediaStatus: str mediaPercent: int class HealthResponse(BaseModel): status: str aiAvailability: str engineType: Optional[str] = None errorMessage: Optional[str] = None class AIConfigDto(BaseModel): frame_period_recognition: int = 4 frame_recognition_seconds: int = 2 probability_threshold: float = 0.25 tracking_distance_confidence: float = 0.0 tracking_probability_increase: float = 0.0 tracking_intersection_threshold: float = 0.6 model_batch_size: int = 8 big_image_tile_overlap_percent: int = 20 altitude: float = 400 focal_length: float = 24 sensor_width: float = 23.5 _AI_SETTINGS_FIELD_KEYS = ( ( "frame_period_recognition", ("frame_period_recognition", "framePeriodRecognition", "FramePeriodRecognition"), ), ( "frame_recognition_seconds", ("frame_recognition_seconds", "frameRecognitionSeconds", "FrameRecognitionSeconds"), ), ( "probability_threshold", ("probability_threshold", "probabilityThreshold", "ProbabilityThreshold"), ), ( "tracking_distance_confidence", ( "tracking_distance_confidence", "trackingDistanceConfidence", "TrackingDistanceConfidence", ), ), ( "tracking_probability_increase", ( "tracking_probability_increase", "trackingProbabilityIncrease", "TrackingProbabilityIncrease", ), ), ( "tracking_intersection_threshold", ( "tracking_intersection_threshold", "trackingIntersectionThreshold", "TrackingIntersectionThreshold", ), ), ( "model_batch_size", ("model_batch_size", "modelBatchSize", "ModelBatchSize"), ), ( "big_image_tile_overlap_percent", ( "big_image_tile_overlap_percent", "bigImageTileOverlapPercent", "BigImageTileOverlapPercent", ), ), ( "altitude", ("altitude", "Altitude"), ), ( "focal_length", ("focal_length", "focalLength", "FocalLength"), ), ( "sensor_width", ("sensor_width", "sensorWidth", "SensorWidth"), ), ) def _merged_annotation_settings_payload(raw: object) -> dict: if not raw or not isinstance(raw, dict): return {} merged = dict(raw) inner = raw.get("aiRecognitionSettings") if isinstance(inner, dict): merged.update(inner) cam = raw.get("cameraSettings") if isinstance(cam, dict): merged.update(cam) out = {} for snake, aliases in _AI_SETTINGS_FIELD_KEYS: for key in aliases: if key in merged and merged[key] is not None: out[snake] = merged[key] break return out def _normalize_upload_ext(filename: str) -> str: s = Path(filename or "").suffix.lower() return s if s else "" def _is_video_media_path(media_path: str) -> bool: return Path(media_path).suffix.lower() in _VIDEO_EXTENSIONS def _post_media_record(payload: dict, bearer: str) -> bool: try: headers = {"Authorization": f"Bearer {bearer}"} r = http_requests.post( f"{ANNOTATIONS_URL}/api/media", json=payload, headers=headers, timeout=30, ) return r.status_code in (200, 201) except Exception: return False def _put_media_status(media_id: str, media_status: int, bearer: str) -> bool: try: headers = {"Authorization": f"Bearer {bearer}"} r = http_requests.put( f"{ANNOTATIONS_URL}/api/media/{media_id}/status", json={"mediaStatus": media_status}, headers=headers, timeout=30, ) return r.status_code in (200, 204) except Exception: return False def _resolve_media_for_detect( media_id: str, token_mgr: Optional[TokenManager], override: Optional[AIConfigDto], ) -> tuple[dict, str]: cfg: dict = {} bearer = "" if token_mgr: bearer = token_mgr.get_valid_token() uid = TokenManager.decode_user_id(token_mgr.access_token) if uid: raw = annotations_client.fetch_user_ai_settings(uid, bearer) cfg.update(_merged_annotation_settings_payload(raw)) if override is not None: for k, v in override.model_dump(exclude_defaults=True).items(): cfg[k] = v media_path = annotations_client.fetch_media_path(media_id, bearer) if not media_path: raise HTTPException( status_code=503, detail="Could not resolve media path from annotations service", ) return cfg, media_path def detection_to_dto(det) -> DetectionDto: import constants_inf label = constants_inf.get_annotation_name(det.cls) return DetectionDto( centerX=det.x, centerY=det.y, width=det.w, height=det.h, classNum=det.cls, label=label, confidence=det.confidence, ) def _enqueue(media_id: str, event: DetectionEvent): data = event.model_dump_json() _job_buffers.setdefault(media_id, []).append(data) for q in _job_queues.get(media_id, []): try: q.put_nowait(data) except asyncio.QueueFull: pass def _schedule_buffer_cleanup(media_id: str, delay: float = 300.0): loop = asyncio.get_event_loop() loop.call_later(delay, lambda: _job_buffers.pop(media_id, None)) @app.get("/health") def health() -> HealthResponse: if inference is None: return HealthResponse(status="healthy", aiAvailability="None") try: status = inference.ai_availability_status status_str = str(status).split()[0] if str(status).strip() else "None" error_msg = status.error_message if hasattr(status, 'error_message') else None engine_type = inference.engine_name return HealthResponse( status="healthy", aiAvailability=status_str, engineType=engine_type, errorMessage=error_msg, ) except Exception as e: return HealthResponse( status="healthy", aiAvailability="None", errorMessage=str(e), ) @app.post("/detect/image") async def detect_image( request: Request, file: UploadFile = File(...), config: Optional[str] = Form(None), user_id: str = Depends(require_auth), ): from media_hash import compute_media_content_hash from inference import ai_config_from_dict image_bytes = await file.read() if not image_bytes: raise HTTPException(status_code=400, detail="Image is empty") orig_name = file.filename or "upload" ext = _normalize_upload_ext(orig_name) if ext and ext not in _IMAGE_EXTENSIONS: raise HTTPException(status_code=400, detail="Expected an image file") arr = np.frombuffer(image_bytes, dtype=np.uint8) if cv2.imdecode(arr, cv2.IMREAD_COLOR) is None: raise HTTPException(status_code=400, detail="Invalid image data") config_dict = {} if config: config_dict = json.loads(config) refresh_token = request.headers.get("x-refresh-token", "") access_token = request.headers.get("authorization", "").removeprefix("Bearer ").strip() token_mgr = TokenManager(access_token, refresh_token) if access_token else None images_dir = os.environ.get( "IMAGES_DIR", os.path.join(os.getcwd(), "data", "images") ) storage_path = None content_hash = None if token_mgr and user_id: content_hash = compute_media_content_hash(image_bytes) os.makedirs(images_dir, exist_ok=True) save_ext = ext if ext.startswith(".") else f".{ext}" if ext else ".jpg" storage_path = os.path.abspath(os.path.join(images_dir, f"{content_hash}{save_ext}")) with open(storage_path, "wb") as out: out.write(image_bytes) payload = { "id": content_hash, "name": Path(orig_name).name, "path": storage_path, "mediaType": "Image", "mediaStatus": _MEDIA_STATUS_NEW, "userId": user_id, } bearer = token_mgr.get_valid_token() _post_media_record(payload, bearer) _put_media_status(content_hash, _MEDIA_STATUS_AI_PROCESSING, bearer) media_name = Path(orig_name).stem.replace(" ", "") loop = asyncio.get_event_loop() inf = get_inference() results = [] def on_annotation(annotation, percent): results.extend(annotation.detections) ai_cfg = ai_config_from_dict(config_dict) def run_detect(): inf.run_detect_image(image_bytes, ai_cfg, media_name, on_annotation) try: await loop.run_in_executor(executor, run_detect) if token_mgr and user_id and content_hash: _put_media_status( content_hash, _MEDIA_STATUS_AI_PROCESSED, token_mgr.get_valid_token() ) return [detection_to_dto(d) for d in results] except RuntimeError as e: if token_mgr and user_id and content_hash: _put_media_status( content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token() ) if "not available" in str(e): raise HTTPException(status_code=503, detail=str(e)) raise HTTPException(status_code=422, detail=str(e)) except ValueError as e: if token_mgr and user_id and content_hash: _put_media_status( content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token() ) raise HTTPException(status_code=400, detail=str(e)) except Exception: if token_mgr and user_id and content_hash: _put_media_status( content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token() ) raise @app.post("/detect/video") async def detect_video_upload( request: Request, user_id: str = Depends(require_auth), ): from media_hash import compute_media_content_hash_from_file from inference import ai_config_from_dict from streaming_buffer import StreamingBuffer filename = request.headers.get("x-filename", "upload.mp4") config_json = request.headers.get("x-config", "") ext = _normalize_upload_ext(filename) if ext not in _VIDEO_EXTENSIONS: raise HTTPException(status_code=400, detail="Expected a video file extension") config_dict = json.loads(config_json) if config_json else {} ai_cfg = ai_config_from_dict(config_dict) refresh_token = request.headers.get("x-refresh-token", "") access_token = request.headers.get("authorization", "").removeprefix("Bearer ").strip() token_mgr = TokenManager(access_token, refresh_token) if access_token else None videos_dir = os.environ.get( "VIDEOS_DIR", os.path.join(os.getcwd(), "data", "videos") ) os.makedirs(videos_dir, exist_ok=True) content_length = request.headers.get("content-length") total_size = int(content_length) if content_length else None buffer = StreamingBuffer(videos_dir, total_size=total_size) media_name = Path(filename).stem.replace(" ", "") loop = asyncio.get_event_loop() inf = get_inference() placeholder_id = f"tmp_{os.path.basename(buffer.path)}" current_id = [placeholder_id] # mutable — updated to content_hash after upload def on_annotation(annotation, percent): dtos = [detection_to_dto(d) for d in annotation.detections] mid = current_id[0] event = DetectionEvent( annotations=dtos, mediaId=mid, mediaStatus="AIProcessing", mediaPercent=percent, ) loop.call_soon_threadsafe(_enqueue, mid, event) def on_status(media_name_cb, count): mid = current_id[0] event = DetectionEvent( annotations=[], mediaId=mid, mediaStatus="AIProcessed", mediaPercent=100, ) loop.call_soon_threadsafe(_enqueue, mid, event) def run_inference(): inf.run_detect_video_stream(buffer, ai_cfg, media_name, on_annotation, on_status) inference_future = loop.run_in_executor(executor, run_inference) try: async for chunk in request.stream(): await loop.run_in_executor(None, buffer.append, chunk) except Exception: buffer.close_writer() buffer.close() raise buffer.close_writer() content_hash = compute_media_content_hash_from_file(buffer.path) if not ext.startswith("."): ext = "." + ext storage_path = os.path.abspath(os.path.join(videos_dir, f"{content_hash}{ext}")) # Re-key buffered events from placeholder_id to content_hash so clients # can subscribe to GET /detect/{content_hash} after POST returns. if placeholder_id in _job_buffers: _job_buffers[content_hash] = _job_buffers.pop(placeholder_id) if placeholder_id in _job_queues: _job_queues[content_hash] = _job_queues.pop(placeholder_id) current_id[0] = content_hash # future on_annotation/on_status callbacks use content_hash if token_mgr and user_id: os.rename(buffer.path, storage_path) payload = { "id": content_hash, "name": Path(filename).name, "path": storage_path, "mediaType": "Video", "mediaStatus": _MEDIA_STATUS_NEW, "userId": user_id, } bearer = token_mgr.get_valid_token() _post_media_record(payload, bearer) _put_media_status(content_hash, _MEDIA_STATUS_AI_PROCESSING, bearer) async def _wait_inference(): try: await inference_future if token_mgr and user_id: _put_media_status( content_hash, _MEDIA_STATUS_AI_PROCESSED, token_mgr.get_valid_token(), ) done_event = DetectionEvent( annotations=[], mediaId=content_hash, mediaStatus="AIProcessed", mediaPercent=100, ) _enqueue(content_hash, done_event) except Exception: if token_mgr and user_id: _put_media_status( content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token(), ) err_event = DetectionEvent( annotations=[], mediaId=content_hash, mediaStatus="Error", mediaPercent=0, ) _enqueue(content_hash, err_event) finally: _active_detections.pop(content_hash, None) _schedule_buffer_cleanup(content_hash) buffer.close() if not (token_mgr and user_id) and os.path.isfile(buffer.path): try: os.unlink(buffer.path) except OSError: pass _active_detections[content_hash] = asyncio.create_task(_wait_inference()) return {"status": "started", "mediaId": content_hash} def _post_annotation_to_service(token_mgr: TokenManager, media_id: str, annotation, dtos: list[DetectionDto]): try: token = token_mgr.get_valid_token() image_b64 = base64.b64encode(annotation.image).decode() if annotation.image else None payload = { "mediaId": media_id, "source": 0, "videoTime": f"00:00:{annotation.time // 1000:02d}" if annotation.time else "00:00:00", "detections": [d.model_dump() for d in dtos], } if image_b64: payload["image"] = image_b64 http_requests.post( f"{ANNOTATIONS_URL}/annotations", json=payload, headers={"Authorization": f"Bearer {token}"}, timeout=30, ) except Exception: pass @app.post("/detect/{media_id}") async def detect_media( media_id: str, request: Request, config: Annotated[Optional[AIConfigDto], Body()] = None, user_id: str = Depends(require_auth), ): existing = _active_detections.get(media_id) if existing is not None and not existing.done(): raise HTTPException(status_code=409, detail="Detection already in progress for this media") refresh_token = request.headers.get("x-refresh-token", "") access_token = request.headers.get("authorization", "").removeprefix("Bearer ").strip() token_mgr = TokenManager(access_token, refresh_token) if access_token else None config_dict, media_path = _resolve_media_for_detect(media_id, token_mgr, config) async def run_detection(): loop = asyncio.get_event_loop() try: from inference import ai_config_from_dict if token_mgr: _put_media_status( media_id, _MEDIA_STATUS_AI_PROCESSING, token_mgr.get_valid_token(), ) with open(media_path, "rb") as mf: file_bytes = mf.read() video = _is_video_media_path(media_path) stem_name = Path(media_path).stem.replace(" ", "") ai_cfg = ai_config_from_dict(config_dict) inf = get_inference() if not inf.is_engine_ready: raise RuntimeError("Detection service unavailable") def on_annotation(annotation, percent): dtos = [detection_to_dto(d) for d in annotation.detections] event = DetectionEvent( annotations=dtos, mediaId=media_id, mediaStatus="AIProcessing", mediaPercent=percent, ) loop.call_soon_threadsafe(_enqueue, media_id, event) if token_mgr and dtos: _post_annotation_to_service(token_mgr, media_id, annotation, dtos) def on_status(media_name_cb, count): event = DetectionEvent( annotations=[], mediaId=media_id, mediaStatus="AIProcessed", mediaPercent=100, ) loop.call_soon_threadsafe(_enqueue, media_id, event) if token_mgr: _put_media_status( media_id, _MEDIA_STATUS_AI_PROCESSED, token_mgr.get_valid_token(), ) def run_sync(): if video: inf.run_detect_video( file_bytes, ai_cfg, stem_name, media_path, on_annotation, on_status, ) else: inf.run_detect_image( file_bytes, ai_cfg, stem_name, on_annotation, on_status, ) await loop.run_in_executor(executor, run_sync) except Exception: if token_mgr: _put_media_status( media_id, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token() ) error_event = DetectionEvent( annotations=[], mediaId=media_id, mediaStatus="Error", mediaPercent=0, ) _enqueue(media_id, error_event) finally: _active_detections.pop(media_id, None) _schedule_buffer_cleanup(media_id) _active_detections[media_id] = asyncio.create_task(run_detection()) return {"status": "started", "mediaId": media_id} @app.get("/detect/{media_id}", dependencies=[Depends(require_auth)]) async def detect_events(media_id: str): queue: asyncio.Queue = asyncio.Queue(maxsize=200) _job_queues.setdefault(media_id, []).append(queue) async def event_generator(): try: for data in list(_job_buffers.get(media_id, [])): yield f"data: {data}\n\n" while True: data = await queue.get() yield f"data: {data}\n\n" except asyncio.CancelledError: pass finally: queues = _job_queues.get(media_id, []) if queue in queues: queues.remove(queue) return StreamingResponse( event_generator(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, )