Files
detections/main.py
T
Oleksandr Bezdieniezhnykh 86b8f076b7 Update health endpoint and refine test documentation
- Modified the health endpoint to return "None" for AI availability when inference is not initialized, improving clarity on system status.
- Enhanced the test documentation to include handling of skipped tests, emphasizing the need for investigation before proceeding.
- Updated test assertions to ensure proper execution order and prevent premature engine initialization.
- Refactored test cases to streamline performance testing and improve readability, removing unnecessary complexity.

These changes aim to enhance the robustness of the health check and improve the overall testing framework.
2026-03-30 01:17:53 +03:00

288 lines
8.8 KiB
Python

import asyncio
import base64
import json
import os
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
import requests as http_requests
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from loader_http_client import LoaderHttpClient, LoadResult
app = FastAPI(title="Azaion.Detections")
executor = ThreadPoolExecutor(max_workers=2)
LOADER_URL = os.environ.get("LOADER_URL", "http://loader:8080")
ANNOTATIONS_URL = os.environ.get("ANNOTATIONS_URL", "http://annotations:8080")
loader_client = LoaderHttpClient(LOADER_URL)
inference = None
_event_queues: list[asyncio.Queue] = []
_active_detections: dict[str, asyncio.Task] = {}
class TokenManager:
def __init__(self, access_token: str, refresh_token: str):
self.access_token = access_token
self.refresh_token = refresh_token
def get_valid_token(self) -> str:
exp = self._decode_exp(self.access_token)
if exp and exp - time.time() < 60:
self._refresh()
return self.access_token
def _refresh(self):
try:
resp = http_requests.post(
f"{ANNOTATIONS_URL}/auth/refresh",
json={"refreshToken": self.refresh_token},
timeout=10,
)
if resp.status_code == 200:
self.access_token = resp.json()["token"]
except Exception:
pass
@staticmethod
def _decode_exp(token: str) -> Optional[float]:
try:
payload = token.split(".")[1]
padding = 4 - len(payload) % 4
if padding != 4:
payload += "=" * padding
data = json.loads(base64.urlsafe_b64decode(payload))
return float(data.get("exp", 0))
except Exception:
return None
def get_inference():
global inference
if inference is None:
from inference import Inference
inference = Inference(loader_client)
return inference
class DetectionDto(BaseModel):
centerX: float
centerY: float
width: float
height: float
classNum: int
label: str
confidence: float
class DetectionEvent(BaseModel):
annotations: list[DetectionDto]
mediaId: str
mediaStatus: str
mediaPercent: int
class HealthResponse(BaseModel):
status: str
aiAvailability: str
engineType: Optional[str] = None
errorMessage: Optional[str] = None
class AIConfigDto(BaseModel):
frame_period_recognition: int = 4
frame_recognition_seconds: int = 2
probability_threshold: float = 0.25
tracking_distance_confidence: float = 0.0
tracking_probability_increase: float = 0.0
tracking_intersection_threshold: float = 0.6
model_batch_size: int = 1
big_image_tile_overlap_percent: int = 20
altitude: float = 400
focal_length: float = 24
sensor_width: float = 23.5
paths: list[str] = []
def detection_to_dto(det) -> DetectionDto:
import constants_inf
label = constants_inf.get_annotation_name(det.cls)
return DetectionDto(
centerX=det.x,
centerY=det.y,
width=det.w,
height=det.h,
classNum=det.cls,
label=label,
confidence=det.confidence,
)
@app.get("/health")
def health() -> HealthResponse:
if inference is None:
return HealthResponse(status="healthy", aiAvailability="None")
try:
status = inference.ai_availability_status
status_str = str(status).split()[0] if str(status).strip() else "None"
error_msg = status.error_message if hasattr(status, 'error_message') else None
engine_type = inference.engine_name
return HealthResponse(
status="healthy",
aiAvailability=status_str,
engineType=engine_type,
errorMessage=error_msg,
)
except Exception as e:
return HealthResponse(
status="healthy",
aiAvailability="None",
errorMessage=str(e),
)
@app.post("/detect")
async def detect_image(
file: UploadFile = File(...),
config: Optional[str] = Form(None),
):
image_bytes = await file.read()
if not image_bytes:
raise HTTPException(status_code=400, detail="Image is empty")
config_dict = {}
if config:
config_dict = json.loads(config)
loop = asyncio.get_event_loop()
try:
inf = get_inference()
detections = await loop.run_in_executor(
executor, inf.detect_single_image, image_bytes, config_dict
)
except RuntimeError as e:
if "not available" in str(e):
raise HTTPException(status_code=503, detail=str(e))
raise HTTPException(status_code=422, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return [detection_to_dto(d) for d in detections]
def _post_annotation_to_service(token_mgr: TokenManager, media_id: str,
annotation, dtos: list[DetectionDto]):
try:
token = token_mgr.get_valid_token()
image_b64 = base64.b64encode(annotation.image).decode() if annotation.image else None
payload = {
"mediaId": media_id,
"source": 0,
"videoTime": f"00:00:{annotation.time // 1000:02d}" if annotation.time else "00:00:00",
"detections": [d.model_dump() for d in dtos],
}
if image_b64:
payload["image"] = image_b64
http_requests.post(
f"{ANNOTATIONS_URL}/annotations",
json=payload,
headers={"Authorization": f"Bearer {token}"},
timeout=30,
)
except Exception:
pass
@app.post("/detect/{media_id}")
async def detect_media(media_id: str, request: Request, config: Optional[AIConfigDto] = None):
existing = _active_detections.get(media_id)
if existing is not None and not existing.done():
raise HTTPException(status_code=409, detail="Detection already in progress for this media")
auth_header = request.headers.get("authorization", "")
access_token = auth_header.removeprefix("Bearer ").strip() if auth_header else ""
refresh_token = request.headers.get("x-refresh-token", "")
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
cfg = config or AIConfigDto()
config_dict = cfg.model_dump()
async def run_detection():
loop = asyncio.get_event_loop()
def _enqueue(event):
for q in _event_queues:
try:
q.put_nowait(event)
except asyncio.QueueFull:
pass
try:
inf = get_inference()
if not inf.is_engine_ready:
raise RuntimeError("Detection service unavailable")
def on_annotation(annotation, percent):
dtos = [detection_to_dto(d) for d in annotation.detections]
event = DetectionEvent(
annotations=dtos,
mediaId=media_id,
mediaStatus="AIProcessing",
mediaPercent=percent,
)
loop.call_soon_threadsafe(_enqueue, event)
if token_mgr and dtos:
_post_annotation_to_service(token_mgr, media_id, annotation, dtos)
def on_status(media_name, count):
event = DetectionEvent(
annotations=[],
mediaId=media_id,
mediaStatus="AIProcessed",
mediaPercent=100,
)
loop.call_soon_threadsafe(_enqueue, event)
await loop.run_in_executor(
executor, inf.run_detect, config_dict, on_annotation, on_status
)
except Exception:
error_event = DetectionEvent(
annotations=[],
mediaId=media_id,
mediaStatus="Error",
mediaPercent=0,
)
_enqueue(error_event)
finally:
_active_detections.pop(media_id, None)
_active_detections[media_id] = asyncio.create_task(run_detection())
return {"status": "started", "mediaId": media_id}
@app.get("/detect/stream")
async def detect_stream():
queue: asyncio.Queue = asyncio.Queue(maxsize=100)
_event_queues.append(queue)
async def event_generator():
try:
while True:
event = await queue.get()
yield f"data: {event.model_dump_json()}\n\n"
except asyncio.CancelledError:
pass
finally:
_event_queues.remove(queue)
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)