import asyncio import base64 import json import os import time from concurrent.futures import ThreadPoolExecutor from typing import Optional import requests as http_requests from fastapi import FastAPI, UploadFile, File, HTTPException, Request from fastapi.responses import StreamingResponse from pydantic import BaseModel from loader_http_client import LoaderHttpClient, LoadResult app = FastAPI(title="Azaion.Detections") executor = ThreadPoolExecutor(max_workers=2) LOADER_URL = os.environ.get("LOADER_URL", "http://loader:8080") ANNOTATIONS_URL = os.environ.get("ANNOTATIONS_URL", "http://annotations:8080") loader_client = LoaderHttpClient(LOADER_URL) inference = None _event_queues: list[asyncio.Queue] = [] _active_detections: dict[str, bool] = {} class TokenManager: def __init__(self, access_token: str, refresh_token: str): self.access_token = access_token self.refresh_token = refresh_token def get_valid_token(self) -> str: exp = self._decode_exp(self.access_token) if exp and exp - time.time() < 60: self._refresh() return self.access_token def _refresh(self): try: resp = http_requests.post( f"{ANNOTATIONS_URL}/auth/refresh", json={"refreshToken": self.refresh_token}, timeout=10, ) if resp.status_code == 200: self.access_token = resp.json()["token"] except Exception: pass @staticmethod def _decode_exp(token: str) -> Optional[float]: try: payload = token.split(".")[1] padding = 4 - len(payload) % 4 if padding != 4: payload += "=" * padding data = json.loads(base64.urlsafe_b64decode(payload)) return float(data.get("exp", 0)) except Exception: return None def get_inference(): global inference if inference is None: from inference import Inference inference = Inference(loader_client) return inference class DetectionDto(BaseModel): centerX: float centerY: float width: float height: float classNum: int label: str confidence: float class DetectionEvent(BaseModel): annotations: list[DetectionDto] mediaId: str mediaStatus: str mediaPercent: int class HealthResponse(BaseModel): status: str aiAvailability: str errorMessage: Optional[str] = None class AIConfigDto(BaseModel): frame_period_recognition: int = 4 frame_recognition_seconds: int = 2 probability_threshold: float = 0.25 tracking_distance_confidence: float = 0.0 tracking_probability_increase: float = 0.0 tracking_intersection_threshold: float = 0.6 model_batch_size: int = 1 big_image_tile_overlap_percent: int = 20 altitude: float = 400 focal_length: float = 24 sensor_width: float = 23.5 paths: list[str] = [] def detection_to_dto(det) -> DetectionDto: import constants_inf label = "" if det.cls in constants_inf.annotations_dict: label = constants_inf.annotations_dict[det.cls].name return DetectionDto( centerX=det.x, centerY=det.y, width=det.w, height=det.h, classNum=det.cls, label=label, confidence=det.confidence, ) @app.get("/health") def health() -> HealthResponse: try: inf = get_inference() status = inf.ai_availability_status status_str = str(status).split()[0] if str(status).strip() else "None" error_msg = status.error_message if hasattr(status, 'error_message') else None return HealthResponse( status="healthy", aiAvailability=status_str, errorMessage=error_msg, ) except Exception: return HealthResponse( status="healthy", aiAvailability="None", errorMessage=None, ) @app.post("/detect") async def detect_image( file: UploadFile = File(...), config: Optional[str] = None, ): image_bytes = await file.read() if not image_bytes: raise HTTPException(status_code=400, detail="Image is empty") config_dict = {} if config: config_dict = json.loads(config) loop = asyncio.get_event_loop() try: inf = get_inference() detections = await loop.run_in_executor( executor, inf.detect_single_image, image_bytes, config_dict ) except RuntimeError as e: if "not available" in str(e): raise HTTPException(status_code=503, detail=str(e)) raise HTTPException(status_code=422, detail=str(e)) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return [detection_to_dto(d) for d in detections] def _post_annotation_to_service(token_mgr: TokenManager, media_id: str, annotation, dtos: list[DetectionDto]): try: token = token_mgr.get_valid_token() image_b64 = base64.b64encode(annotation.image).decode() if annotation.image else None payload = { "mediaId": media_id, "source": 0, "videoTime": f"00:00:{annotation.time // 1000:02d}" if annotation.time else "00:00:00", "detections": [d.model_dump() for d in dtos], } if image_b64: payload["image"] = image_b64 http_requests.post( f"{ANNOTATIONS_URL}/annotations", json=payload, headers={"Authorization": f"Bearer {token}"}, timeout=30, ) except Exception: pass @app.post("/detect/{media_id}") async def detect_media(media_id: str, request: Request, config: Optional[AIConfigDto] = None): if media_id in _active_detections: raise HTTPException(status_code=409, detail="Detection already in progress for this media") auth_header = request.headers.get("authorization", "") access_token = auth_header.removeprefix("Bearer ").strip() if auth_header else "" refresh_token = request.headers.get("x-refresh-token", "") token_mgr = TokenManager(access_token, refresh_token) if access_token else None cfg = config or AIConfigDto() config_dict = cfg.model_dump() _active_detections[media_id] = True async def run_detection(): loop = asyncio.get_event_loop() try: inf = get_inference() if inf.engine is None: raise RuntimeError("Detection service unavailable") def on_annotation(annotation, percent): dtos = [detection_to_dto(d) for d in annotation.detections] event = DetectionEvent( annotations=dtos, mediaId=media_id, mediaStatus="AIProcessing", mediaPercent=percent, ) for q in _event_queues: try: q.put_nowait(event) except asyncio.QueueFull: pass if token_mgr and dtos: _post_annotation_to_service(token_mgr, media_id, annotation, dtos) def on_status(media_name, count): event = DetectionEvent( annotations=[], mediaId=media_id, mediaStatus="AIProcessed", mediaPercent=100, ) for q in _event_queues: try: q.put_nowait(event) except asyncio.QueueFull: pass await loop.run_in_executor( executor, inf.run_detect, config_dict, on_annotation, on_status ) except Exception: error_event = DetectionEvent( annotations=[], mediaId=media_id, mediaStatus="Error", mediaPercent=0, ) for q in _event_queues: try: q.put_nowait(error_event) except asyncio.QueueFull: pass finally: _active_detections.pop(media_id, None) asyncio.create_task(run_detection()) return {"status": "started", "mediaId": media_id} @app.get("/detect/stream") async def detect_stream(): queue: asyncio.Queue = asyncio.Queue(maxsize=100) _event_queues.append(queue) async def event_generator(): try: while True: event = await queue.get() yield f"data: {event.model_dump_json()}\n\n" except asyncio.CancelledError: pass finally: _event_queues.remove(queue) return StreamingResponse( event_generator(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, )