mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 08:46:32 +00:00
86b8f076b7
- Modified the health endpoint to return "None" for AI availability when inference is not initialized, improving clarity on system status. - Enhanced the test documentation to include handling of skipped tests, emphasizing the need for investigation before proceeding. - Updated test assertions to ensure proper execution order and prevent premature engine initialization. - Refactored test cases to streamline performance testing and improve readability, removing unnecessary complexity. These changes aim to enhance the robustness of the health check and improve the overall testing framework.
288 lines
8.8 KiB
Python
288 lines
8.8 KiB
Python
import asyncio
|
|
import base64
|
|
import json
|
|
import os
|
|
import time
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from typing import Optional
|
|
|
|
import requests as http_requests
|
|
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel
|
|
|
|
from loader_http_client import LoaderHttpClient, LoadResult
|
|
|
|
app = FastAPI(title="Azaion.Detections")
|
|
executor = ThreadPoolExecutor(max_workers=2)
|
|
|
|
LOADER_URL = os.environ.get("LOADER_URL", "http://loader:8080")
|
|
ANNOTATIONS_URL = os.environ.get("ANNOTATIONS_URL", "http://annotations:8080")
|
|
|
|
loader_client = LoaderHttpClient(LOADER_URL)
|
|
inference = None
|
|
_event_queues: list[asyncio.Queue] = []
|
|
_active_detections: dict[str, asyncio.Task] = {}
|
|
|
|
|
|
class TokenManager:
|
|
def __init__(self, access_token: str, refresh_token: str):
|
|
self.access_token = access_token
|
|
self.refresh_token = refresh_token
|
|
|
|
def get_valid_token(self) -> str:
|
|
exp = self._decode_exp(self.access_token)
|
|
if exp and exp - time.time() < 60:
|
|
self._refresh()
|
|
return self.access_token
|
|
|
|
def _refresh(self):
|
|
try:
|
|
resp = http_requests.post(
|
|
f"{ANNOTATIONS_URL}/auth/refresh",
|
|
json={"refreshToken": self.refresh_token},
|
|
timeout=10,
|
|
)
|
|
if resp.status_code == 200:
|
|
self.access_token = resp.json()["token"]
|
|
except Exception:
|
|
pass
|
|
|
|
@staticmethod
|
|
def _decode_exp(token: str) -> Optional[float]:
|
|
try:
|
|
payload = token.split(".")[1]
|
|
padding = 4 - len(payload) % 4
|
|
if padding != 4:
|
|
payload += "=" * padding
|
|
data = json.loads(base64.urlsafe_b64decode(payload))
|
|
return float(data.get("exp", 0))
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def get_inference():
|
|
global inference
|
|
if inference is None:
|
|
from inference import Inference
|
|
inference = Inference(loader_client)
|
|
return inference
|
|
|
|
|
|
class DetectionDto(BaseModel):
|
|
centerX: float
|
|
centerY: float
|
|
width: float
|
|
height: float
|
|
classNum: int
|
|
label: str
|
|
confidence: float
|
|
|
|
|
|
class DetectionEvent(BaseModel):
|
|
annotations: list[DetectionDto]
|
|
mediaId: str
|
|
mediaStatus: str
|
|
mediaPercent: int
|
|
|
|
|
|
class HealthResponse(BaseModel):
|
|
status: str
|
|
aiAvailability: str
|
|
engineType: Optional[str] = None
|
|
errorMessage: Optional[str] = None
|
|
|
|
|
|
class AIConfigDto(BaseModel):
|
|
frame_period_recognition: int = 4
|
|
frame_recognition_seconds: int = 2
|
|
probability_threshold: float = 0.25
|
|
tracking_distance_confidence: float = 0.0
|
|
tracking_probability_increase: float = 0.0
|
|
tracking_intersection_threshold: float = 0.6
|
|
model_batch_size: int = 1
|
|
big_image_tile_overlap_percent: int = 20
|
|
altitude: float = 400
|
|
focal_length: float = 24
|
|
sensor_width: float = 23.5
|
|
paths: list[str] = []
|
|
|
|
|
|
def detection_to_dto(det) -> DetectionDto:
|
|
import constants_inf
|
|
label = constants_inf.get_annotation_name(det.cls)
|
|
return DetectionDto(
|
|
centerX=det.x,
|
|
centerY=det.y,
|
|
width=det.w,
|
|
height=det.h,
|
|
classNum=det.cls,
|
|
label=label,
|
|
confidence=det.confidence,
|
|
)
|
|
|
|
|
|
@app.get("/health")
|
|
def health() -> HealthResponse:
|
|
if inference is None:
|
|
return HealthResponse(status="healthy", aiAvailability="None")
|
|
try:
|
|
status = inference.ai_availability_status
|
|
status_str = str(status).split()[0] if str(status).strip() else "None"
|
|
error_msg = status.error_message if hasattr(status, 'error_message') else None
|
|
engine_type = inference.engine_name
|
|
return HealthResponse(
|
|
status="healthy",
|
|
aiAvailability=status_str,
|
|
engineType=engine_type,
|
|
errorMessage=error_msg,
|
|
)
|
|
except Exception as e:
|
|
return HealthResponse(
|
|
status="healthy",
|
|
aiAvailability="None",
|
|
errorMessage=str(e),
|
|
)
|
|
|
|
|
|
@app.post("/detect")
|
|
async def detect_image(
|
|
file: UploadFile = File(...),
|
|
config: Optional[str] = Form(None),
|
|
):
|
|
image_bytes = await file.read()
|
|
if not image_bytes:
|
|
raise HTTPException(status_code=400, detail="Image is empty")
|
|
|
|
config_dict = {}
|
|
if config:
|
|
config_dict = json.loads(config)
|
|
|
|
loop = asyncio.get_event_loop()
|
|
try:
|
|
inf = get_inference()
|
|
detections = await loop.run_in_executor(
|
|
executor, inf.detect_single_image, image_bytes, config_dict
|
|
)
|
|
except RuntimeError as e:
|
|
if "not available" in str(e):
|
|
raise HTTPException(status_code=503, detail=str(e))
|
|
raise HTTPException(status_code=422, detail=str(e))
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
return [detection_to_dto(d) for d in detections]
|
|
|
|
|
|
def _post_annotation_to_service(token_mgr: TokenManager, media_id: str,
|
|
annotation, dtos: list[DetectionDto]):
|
|
try:
|
|
token = token_mgr.get_valid_token()
|
|
image_b64 = base64.b64encode(annotation.image).decode() if annotation.image else None
|
|
payload = {
|
|
"mediaId": media_id,
|
|
"source": 0,
|
|
"videoTime": f"00:00:{annotation.time // 1000:02d}" if annotation.time else "00:00:00",
|
|
"detections": [d.model_dump() for d in dtos],
|
|
}
|
|
if image_b64:
|
|
payload["image"] = image_b64
|
|
http_requests.post(
|
|
f"{ANNOTATIONS_URL}/annotations",
|
|
json=payload,
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
timeout=30,
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
@app.post("/detect/{media_id}")
|
|
async def detect_media(media_id: str, request: Request, config: Optional[AIConfigDto] = None):
|
|
existing = _active_detections.get(media_id)
|
|
if existing is not None and not existing.done():
|
|
raise HTTPException(status_code=409, detail="Detection already in progress for this media")
|
|
|
|
auth_header = request.headers.get("authorization", "")
|
|
access_token = auth_header.removeprefix("Bearer ").strip() if auth_header else ""
|
|
refresh_token = request.headers.get("x-refresh-token", "")
|
|
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
|
|
|
|
cfg = config or AIConfigDto()
|
|
config_dict = cfg.model_dump()
|
|
|
|
async def run_detection():
|
|
loop = asyncio.get_event_loop()
|
|
|
|
def _enqueue(event):
|
|
for q in _event_queues:
|
|
try:
|
|
q.put_nowait(event)
|
|
except asyncio.QueueFull:
|
|
pass
|
|
|
|
try:
|
|
inf = get_inference()
|
|
if not inf.is_engine_ready:
|
|
raise RuntimeError("Detection service unavailable")
|
|
|
|
def on_annotation(annotation, percent):
|
|
dtos = [detection_to_dto(d) for d in annotation.detections]
|
|
event = DetectionEvent(
|
|
annotations=dtos,
|
|
mediaId=media_id,
|
|
mediaStatus="AIProcessing",
|
|
mediaPercent=percent,
|
|
)
|
|
loop.call_soon_threadsafe(_enqueue, event)
|
|
if token_mgr and dtos:
|
|
_post_annotation_to_service(token_mgr, media_id, annotation, dtos)
|
|
|
|
def on_status(media_name, count):
|
|
event = DetectionEvent(
|
|
annotations=[],
|
|
mediaId=media_id,
|
|
mediaStatus="AIProcessed",
|
|
mediaPercent=100,
|
|
)
|
|
loop.call_soon_threadsafe(_enqueue, event)
|
|
|
|
await loop.run_in_executor(
|
|
executor, inf.run_detect, config_dict, on_annotation, on_status
|
|
)
|
|
except Exception:
|
|
error_event = DetectionEvent(
|
|
annotations=[],
|
|
mediaId=media_id,
|
|
mediaStatus="Error",
|
|
mediaPercent=0,
|
|
)
|
|
_enqueue(error_event)
|
|
finally:
|
|
_active_detections.pop(media_id, None)
|
|
|
|
_active_detections[media_id] = asyncio.create_task(run_detection())
|
|
return {"status": "started", "mediaId": media_id}
|
|
|
|
|
|
@app.get("/detect/stream")
|
|
async def detect_stream():
|
|
queue: asyncio.Queue = asyncio.Queue(maxsize=100)
|
|
_event_queues.append(queue)
|
|
|
|
async def event_generator():
|
|
try:
|
|
while True:
|
|
event = await queue.get()
|
|
yield f"data: {event.model_dump_json()}\n\n"
|
|
except asyncio.CancelledError:
|
|
pass
|
|
finally:
|
|
_event_queues.remove(queue)
|
|
|
|
return StreamingResponse(
|
|
event_generator(),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
|
)
|