import asyncio import base64 import io import json import os import tempfile import time from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Annotated, Optional import av import cv2 import numpy as np import requests as http_requests from fastapi import Body, FastAPI, UploadFile, File, Form, HTTPException, Request from fastapi.responses import StreamingResponse from pydantic import BaseModel from loader_http_client import LoaderHttpClient, LoadResult app = FastAPI(title="Azaion.Detections") executor = ThreadPoolExecutor(max_workers=2) LOADER_URL = os.environ.get("LOADER_URL", "http://loader:8080") ANNOTATIONS_URL = os.environ.get("ANNOTATIONS_URL", "http://annotations:8080") _MEDIA_STATUS_NEW = 1 _MEDIA_STATUS_AI_PROCESSING = 2 _MEDIA_STATUS_AI_PROCESSED = 3 _MEDIA_STATUS_ERROR = 6 _VIDEO_EXTENSIONS = frozenset({".mp4", ".mov", ".webm", ".mkv", ".avi", ".m4v"}) _IMAGE_EXTENSIONS = frozenset({".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif", ".tiff"}) loader_client = LoaderHttpClient(LOADER_URL) annotations_client = LoaderHttpClient(ANNOTATIONS_URL) inference = None _event_queues: list[asyncio.Queue] = [] _active_detections: dict[str, asyncio.Task] = {} class TokenManager: def __init__(self, access_token: str, refresh_token: str): self.access_token = access_token self.refresh_token = refresh_token def get_valid_token(self) -> str: exp = self._decode_exp(self.access_token) if exp and exp - time.time() < 60: self._refresh() return self.access_token def _refresh(self): try: resp = http_requests.post( f"{ANNOTATIONS_URL}/auth/refresh", json={"refreshToken": self.refresh_token}, timeout=10, ) if resp.status_code == 200: self.access_token = resp.json()["token"] except Exception: pass @staticmethod def _decode_exp(token: str) -> Optional[float]: try: payload = token.split(".")[1] padding = 4 - len(payload) % 4 if padding != 4: payload += "=" * padding data = json.loads(base64.urlsafe_b64decode(payload)) return float(data.get("exp", 0)) except Exception: return None @staticmethod def decode_user_id(token: str) -> Optional[str]: try: payload = token.split(".")[1] padding = 4 - len(payload) % 4 if padding != 4: payload += "=" * padding data = json.loads(base64.urlsafe_b64decode(payload)) uid = ( data.get("sub") or data.get("userId") or data.get("user_id") or data.get("nameid") or data.get( "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/nameidentifier" ) ) if uid is None: return None return str(uid) except Exception: return None def get_inference(): global inference if inference is None: from inference import Inference inference = Inference(loader_client) return inference class DetectionDto(BaseModel): centerX: float centerY: float width: float height: float classNum: int label: str confidence: float class DetectionEvent(BaseModel): annotations: list[DetectionDto] mediaId: str mediaStatus: str mediaPercent: int class HealthResponse(BaseModel): status: str aiAvailability: str engineType: Optional[str] = None errorMessage: Optional[str] = None class AIConfigDto(BaseModel): frame_period_recognition: int = 4 frame_recognition_seconds: int = 2 probability_threshold: float = 0.25 tracking_distance_confidence: float = 0.0 tracking_probability_increase: float = 0.0 tracking_intersection_threshold: float = 0.6 model_batch_size: int = 8 big_image_tile_overlap_percent: int = 20 altitude: float = 400 focal_length: float = 24 sensor_width: float = 23.5 _AI_SETTINGS_FIELD_KEYS = ( ( "frame_period_recognition", ("frame_period_recognition", "framePeriodRecognition", "FramePeriodRecognition"), ), ( "frame_recognition_seconds", ("frame_recognition_seconds", "frameRecognitionSeconds", "FrameRecognitionSeconds"), ), ( "probability_threshold", ("probability_threshold", "probabilityThreshold", "ProbabilityThreshold"), ), ( "tracking_distance_confidence", ( "tracking_distance_confidence", "trackingDistanceConfidence", "TrackingDistanceConfidence", ), ), ( "tracking_probability_increase", ( "tracking_probability_increase", "trackingProbabilityIncrease", "TrackingProbabilityIncrease", ), ), ( "tracking_intersection_threshold", ( "tracking_intersection_threshold", "trackingIntersectionThreshold", "TrackingIntersectionThreshold", ), ), ( "model_batch_size", ("model_batch_size", "modelBatchSize", "ModelBatchSize"), ), ( "big_image_tile_overlap_percent", ( "big_image_tile_overlap_percent", "bigImageTileOverlapPercent", "BigImageTileOverlapPercent", ), ), ( "altitude", ("altitude", "Altitude"), ), ( "focal_length", ("focal_length", "focalLength", "FocalLength"), ), ( "sensor_width", ("sensor_width", "sensorWidth", "SensorWidth"), ), ) def _merged_annotation_settings_payload(raw: object) -> dict: if not raw or not isinstance(raw, dict): return {} merged = dict(raw) inner = raw.get("aiRecognitionSettings") if isinstance(inner, dict): merged.update(inner) cam = raw.get("cameraSettings") if isinstance(cam, dict): merged.update(cam) out = {} for snake, aliases in _AI_SETTINGS_FIELD_KEYS: for key in aliases: if key in merged and merged[key] is not None: out[snake] = merged[key] break return out def _normalize_upload_ext(filename: str) -> str: s = Path(filename or "").suffix.lower() return s if s else "" def _detect_upload_kind(filename: str, data: bytes) -> tuple[str, str]: ext = _normalize_upload_ext(filename) if ext in _VIDEO_EXTENSIONS: return "video", ext if ext in _IMAGE_EXTENSIONS: return "image", ext arr = np.frombuffer(data, dtype=np.uint8) if cv2.imdecode(arr, cv2.IMREAD_COLOR) is not None: return "image", ext if ext else ".jpg" try: bio = io.BytesIO(data) with av.open(bio): pass return "video", ext if ext else ".mp4" except Exception: raise HTTPException(status_code=400, detail="Invalid image or video data") def _is_video_media_path(media_path: str) -> bool: return Path(media_path).suffix.lower() in _VIDEO_EXTENSIONS def _post_media_record(payload: dict, bearer: str) -> bool: try: headers = {"Authorization": f"Bearer {bearer}"} r = http_requests.post( f"{ANNOTATIONS_URL}/api/media", json=payload, headers=headers, timeout=30, ) return r.status_code in (200, 201) except Exception: return False def _put_media_status(media_id: str, media_status: int, bearer: str) -> bool: try: headers = {"Authorization": f"Bearer {bearer}"} r = http_requests.put( f"{ANNOTATIONS_URL}/api/media/{media_id}/status", json={"mediaStatus": media_status}, headers=headers, timeout=30, ) return r.status_code in (200, 204) except Exception: return False def _resolve_media_for_detect( media_id: str, token_mgr: Optional[TokenManager], override: Optional[AIConfigDto], ) -> tuple[dict, str]: cfg: dict = {} bearer = "" if token_mgr: bearer = token_mgr.get_valid_token() uid = TokenManager.decode_user_id(token_mgr.access_token) if uid: raw = annotations_client.fetch_user_ai_settings(uid, bearer) cfg.update(_merged_annotation_settings_payload(raw)) if override is not None: for k, v in override.model_dump(exclude_defaults=True).items(): cfg[k] = v media_path = annotations_client.fetch_media_path(media_id, bearer) if not media_path: raise HTTPException( status_code=503, detail="Could not resolve media path from annotations service", ) return cfg, media_path def detection_to_dto(det) -> DetectionDto: import constants_inf label = constants_inf.get_annotation_name(det.cls) return DetectionDto( centerX=det.x, centerY=det.y, width=det.w, height=det.h, classNum=det.cls, label=label, confidence=det.confidence, ) @app.get("/health") def health() -> HealthResponse: if inference is None: return HealthResponse(status="healthy", aiAvailability="None") try: status = inference.ai_availability_status status_str = str(status).split()[0] if str(status).strip() else "None" error_msg = status.error_message if hasattr(status, 'error_message') else None engine_type = inference.engine_name return HealthResponse( status="healthy", aiAvailability=status_str, engineType=engine_type, errorMessage=error_msg, ) except Exception as e: return HealthResponse( status="healthy", aiAvailability="None", errorMessage=str(e), ) @app.post("/detect") async def detect_image( request: Request, file: UploadFile = File(...), config: Optional[str] = Form(None), ): from media_hash import compute_media_content_hash from inference import ai_config_from_dict image_bytes = await file.read() if not image_bytes: raise HTTPException(status_code=400, detail="Image is empty") orig_name = file.filename or "upload" kind, ext = _detect_upload_kind(orig_name, image_bytes) if kind == "image": arr = np.frombuffer(image_bytes, dtype=np.uint8) if cv2.imdecode(arr, cv2.IMREAD_COLOR) is None: raise HTTPException(status_code=400, detail="Invalid image data") config_dict = {} if config: config_dict = json.loads(config) auth_header = request.headers.get("authorization", "") access_token = auth_header.removeprefix("Bearer ").strip() if auth_header else "" refresh_token = request.headers.get("x-refresh-token", "") token_mgr = TokenManager(access_token, refresh_token) if access_token else None user_id = TokenManager.decode_user_id(access_token) if access_token else None videos_dir = os.environ.get( "VIDEOS_DIR", os.path.join(os.getcwd(), "data", "videos") ) images_dir = os.environ.get( "IMAGES_DIR", os.path.join(os.getcwd(), "data", "images") ) storage_path = None content_hash = None if token_mgr and user_id: content_hash = compute_media_content_hash(image_bytes) base = videos_dir if kind == "video" else images_dir os.makedirs(base, exist_ok=True) if not ext.startswith("."): ext = "." + ext storage_path = os.path.abspath(os.path.join(base, f"{content_hash}{ext}")) if kind == "image": with open(storage_path, "wb") as out: out.write(image_bytes) mt = "Video" if kind == "video" else "Image" payload = { "id": content_hash, "name": Path(orig_name).name, "path": storage_path, "mediaType": mt, "mediaStatus": _MEDIA_STATUS_NEW, "userId": user_id, } bearer = token_mgr.get_valid_token() _post_media_record(payload, bearer) _put_media_status(content_hash, _MEDIA_STATUS_AI_PROCESSING, bearer) media_name = Path(orig_name).stem.replace(" ", "") loop = asyncio.get_event_loop() inf = get_inference() results = [] tmp_video_path = None def on_annotation(annotation, percent): results.extend(annotation.detections) ai_cfg = ai_config_from_dict(config_dict) def run_upload(): nonlocal tmp_video_path if kind == "video": if storage_path: save = storage_path else: suf = ext if ext.startswith(".") else ".mp4" fd, tmp_video_path = tempfile.mkstemp(suffix=suf) os.close(fd) save = tmp_video_path inf.run_detect_video(image_bytes, ai_cfg, media_name, save, on_annotation) else: inf.run_detect_image(image_bytes, ai_cfg, media_name, on_annotation) try: await loop.run_in_executor(executor, run_upload) if token_mgr and user_id and content_hash: _put_media_status( content_hash, _MEDIA_STATUS_AI_PROCESSED, token_mgr.get_valid_token() ) return [detection_to_dto(d) for d in results] except RuntimeError as e: if token_mgr and user_id and content_hash: _put_media_status( content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token() ) if "not available" in str(e): raise HTTPException(status_code=503, detail=str(e)) raise HTTPException(status_code=422, detail=str(e)) except ValueError as e: if token_mgr and user_id and content_hash: _put_media_status( content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token() ) raise HTTPException(status_code=400, detail=str(e)) except Exception: if token_mgr and user_id and content_hash: _put_media_status( content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token() ) raise finally: if tmp_video_path and os.path.isfile(tmp_video_path): try: os.unlink(tmp_video_path) except OSError: pass def _post_annotation_to_service(token_mgr: TokenManager, media_id: str, annotation, dtos: list[DetectionDto]): try: token = token_mgr.get_valid_token() image_b64 = base64.b64encode(annotation.image).decode() if annotation.image else None payload = { "mediaId": media_id, "source": 0, "videoTime": f"00:00:{annotation.time // 1000:02d}" if annotation.time else "00:00:00", "detections": [d.model_dump() for d in dtos], } if image_b64: payload["image"] = image_b64 http_requests.post( f"{ANNOTATIONS_URL}/annotations", json=payload, headers={"Authorization": f"Bearer {token}"}, timeout=30, ) except Exception: pass @app.post("/detect/{media_id}") async def detect_media( media_id: str, request: Request, config: Annotated[Optional[AIConfigDto], Body()] = None, ): existing = _active_detections.get(media_id) if existing is not None and not existing.done(): raise HTTPException(status_code=409, detail="Detection already in progress for this media") auth_header = request.headers.get("authorization", "") access_token = auth_header.removeprefix("Bearer ").strip() if auth_header else "" refresh_token = request.headers.get("x-refresh-token", "") token_mgr = TokenManager(access_token, refresh_token) if access_token else None config_dict, media_path = _resolve_media_for_detect(media_id, token_mgr, config) async def run_detection(): loop = asyncio.get_event_loop() def _enqueue(event): for q in _event_queues: try: q.put_nowait(event) except asyncio.QueueFull: pass try: from inference import ai_config_from_dict if token_mgr: _put_media_status( media_id, _MEDIA_STATUS_AI_PROCESSING, token_mgr.get_valid_token(), ) with open(media_path, "rb") as mf: file_bytes = mf.read() video = _is_video_media_path(media_path) stem_name = Path(media_path).stem.replace(" ", "") ai_cfg = ai_config_from_dict(config_dict) inf = get_inference() if not inf.is_engine_ready: raise RuntimeError("Detection service unavailable") def on_annotation(annotation, percent): dtos = [detection_to_dto(d) for d in annotation.detections] event = DetectionEvent( annotations=dtos, mediaId=media_id, mediaStatus="AIProcessing", mediaPercent=percent, ) loop.call_soon_threadsafe(_enqueue, event) if token_mgr and dtos: _post_annotation_to_service(token_mgr, media_id, annotation, dtos) def on_status(media_name_cb, count): event = DetectionEvent( annotations=[], mediaId=media_id, mediaStatus="AIProcessed", mediaPercent=100, ) loop.call_soon_threadsafe(_enqueue, event) if token_mgr: _put_media_status( media_id, _MEDIA_STATUS_AI_PROCESSED, token_mgr.get_valid_token(), ) def run_sync(): if video: inf.run_detect_video( file_bytes, ai_cfg, stem_name, media_path, on_annotation, on_status, ) else: inf.run_detect_image( file_bytes, ai_cfg, stem_name, on_annotation, on_status, ) await loop.run_in_executor(executor, run_sync) except Exception: if token_mgr: _put_media_status( media_id, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token() ) error_event = DetectionEvent( annotations=[], mediaId=media_id, mediaStatus="Error", mediaPercent=0, ) _enqueue(error_event) finally: _active_detections.pop(media_id, None) _active_detections[media_id] = asyncio.create_task(run_detection()) return {"status": "started", "mediaId": media_id} @app.get("/detect/stream") async def detect_stream(): queue: asyncio.Queue = asyncio.Queue(maxsize=100) _event_queues.append(queue) async def event_generator(): try: while True: event = await queue.get() yield f"data: {event.model_dump_json()}\n\n" except asyncio.CancelledError: pass finally: _event_queues.remove(queue) return StreamingResponse( event_generator(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, )