mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 08:46:32 +00:00
27f4aceb52
- Updated the `Inference` class to replace the `get_onnx_engine_bytes` method with `download_model`, allowing for dynamic model loading based on a specified filename. - Modified the `convert_and_upload_model` method to accept `source_bytes` instead of `onnx_engine_bytes`, enhancing flexibility in model conversion. - Introduced a new property `engine_name` to the `Inference` class for better access to engine details. - Adjusted the `AIRecognitionConfig` structure to include a new method pointer `from_dict`, improving configuration handling. - Updated various test cases to reflect changes in model paths and timeout settings, ensuring consistency and reliability in testing.
287 lines
8.7 KiB
Python
287 lines
8.7 KiB
Python
import asyncio
|
|
import base64
|
|
import json
|
|
import os
|
|
import time
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from typing import Optional
|
|
|
|
import requests as http_requests
|
|
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel
|
|
|
|
from loader_http_client import LoaderHttpClient, LoadResult
|
|
|
|
app = FastAPI(title="Azaion.Detections")
|
|
executor = ThreadPoolExecutor(max_workers=2)
|
|
|
|
LOADER_URL = os.environ.get("LOADER_URL", "http://loader:8080")
|
|
ANNOTATIONS_URL = os.environ.get("ANNOTATIONS_URL", "http://annotations:8080")
|
|
|
|
loader_client = LoaderHttpClient(LOADER_URL)
|
|
inference = None
|
|
_event_queues: list[asyncio.Queue] = []
|
|
_active_detections: dict[str, asyncio.Task] = {}
|
|
|
|
|
|
class TokenManager:
|
|
def __init__(self, access_token: str, refresh_token: str):
|
|
self.access_token = access_token
|
|
self.refresh_token = refresh_token
|
|
|
|
def get_valid_token(self) -> str:
|
|
exp = self._decode_exp(self.access_token)
|
|
if exp and exp - time.time() < 60:
|
|
self._refresh()
|
|
return self.access_token
|
|
|
|
def _refresh(self):
|
|
try:
|
|
resp = http_requests.post(
|
|
f"{ANNOTATIONS_URL}/auth/refresh",
|
|
json={"refreshToken": self.refresh_token},
|
|
timeout=10,
|
|
)
|
|
if resp.status_code == 200:
|
|
self.access_token = resp.json()["token"]
|
|
except Exception:
|
|
pass
|
|
|
|
@staticmethod
|
|
def _decode_exp(token: str) -> Optional[float]:
|
|
try:
|
|
payload = token.split(".")[1]
|
|
padding = 4 - len(payload) % 4
|
|
if padding != 4:
|
|
payload += "=" * padding
|
|
data = json.loads(base64.urlsafe_b64decode(payload))
|
|
return float(data.get("exp", 0))
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def get_inference():
|
|
global inference
|
|
if inference is None:
|
|
from inference import Inference
|
|
inference = Inference(loader_client)
|
|
return inference
|
|
|
|
|
|
class DetectionDto(BaseModel):
|
|
centerX: float
|
|
centerY: float
|
|
width: float
|
|
height: float
|
|
classNum: int
|
|
label: str
|
|
confidence: float
|
|
|
|
|
|
class DetectionEvent(BaseModel):
|
|
annotations: list[DetectionDto]
|
|
mediaId: str
|
|
mediaStatus: str
|
|
mediaPercent: int
|
|
|
|
|
|
class HealthResponse(BaseModel):
|
|
status: str
|
|
aiAvailability: str
|
|
engineType: Optional[str] = None
|
|
errorMessage: Optional[str] = None
|
|
|
|
|
|
class AIConfigDto(BaseModel):
|
|
frame_period_recognition: int = 4
|
|
frame_recognition_seconds: int = 2
|
|
probability_threshold: float = 0.25
|
|
tracking_distance_confidence: float = 0.0
|
|
tracking_probability_increase: float = 0.0
|
|
tracking_intersection_threshold: float = 0.6
|
|
model_batch_size: int = 1
|
|
big_image_tile_overlap_percent: int = 20
|
|
altitude: float = 400
|
|
focal_length: float = 24
|
|
sensor_width: float = 23.5
|
|
paths: list[str] = []
|
|
|
|
|
|
def detection_to_dto(det) -> DetectionDto:
|
|
import constants_inf
|
|
label = constants_inf.get_annotation_name(det.cls)
|
|
return DetectionDto(
|
|
centerX=det.x,
|
|
centerY=det.y,
|
|
width=det.w,
|
|
height=det.h,
|
|
classNum=det.cls,
|
|
label=label,
|
|
confidence=det.confidence,
|
|
)
|
|
|
|
|
|
@app.get("/health")
|
|
def health() -> HealthResponse:
|
|
try:
|
|
inf = get_inference()
|
|
status = inf.ai_availability_status
|
|
status_str = str(status).split()[0] if str(status).strip() else "None"
|
|
error_msg = status.error_message if hasattr(status, 'error_message') else None
|
|
engine_type = inf.engine_name
|
|
return HealthResponse(
|
|
status="healthy",
|
|
aiAvailability=status_str,
|
|
engineType=engine_type,
|
|
errorMessage=error_msg,
|
|
)
|
|
except Exception as e:
|
|
return HealthResponse(
|
|
status="healthy",
|
|
aiAvailability="None",
|
|
errorMessage=str(e),
|
|
)
|
|
|
|
|
|
@app.post("/detect")
|
|
async def detect_image(
|
|
file: UploadFile = File(...),
|
|
config: Optional[str] = Form(None),
|
|
):
|
|
image_bytes = await file.read()
|
|
if not image_bytes:
|
|
raise HTTPException(status_code=400, detail="Image is empty")
|
|
|
|
config_dict = {}
|
|
if config:
|
|
config_dict = json.loads(config)
|
|
|
|
loop = asyncio.get_event_loop()
|
|
try:
|
|
inf = get_inference()
|
|
detections = await loop.run_in_executor(
|
|
executor, inf.detect_single_image, image_bytes, config_dict
|
|
)
|
|
except RuntimeError as e:
|
|
if "not available" in str(e):
|
|
raise HTTPException(status_code=503, detail=str(e))
|
|
raise HTTPException(status_code=422, detail=str(e))
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
return [detection_to_dto(d) for d in detections]
|
|
|
|
|
|
def _post_annotation_to_service(token_mgr: TokenManager, media_id: str,
|
|
annotation, dtos: list[DetectionDto]):
|
|
try:
|
|
token = token_mgr.get_valid_token()
|
|
image_b64 = base64.b64encode(annotation.image).decode() if annotation.image else None
|
|
payload = {
|
|
"mediaId": media_id,
|
|
"source": 0,
|
|
"videoTime": f"00:00:{annotation.time // 1000:02d}" if annotation.time else "00:00:00",
|
|
"detections": [d.model_dump() for d in dtos],
|
|
}
|
|
if image_b64:
|
|
payload["image"] = image_b64
|
|
http_requests.post(
|
|
f"{ANNOTATIONS_URL}/annotations",
|
|
json=payload,
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
timeout=30,
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
@app.post("/detect/{media_id}")
|
|
async def detect_media(media_id: str, request: Request, config: Optional[AIConfigDto] = None):
|
|
existing = _active_detections.get(media_id)
|
|
if existing is not None and not existing.done():
|
|
raise HTTPException(status_code=409, detail="Detection already in progress for this media")
|
|
|
|
auth_header = request.headers.get("authorization", "")
|
|
access_token = auth_header.removeprefix("Bearer ").strip() if auth_header else ""
|
|
refresh_token = request.headers.get("x-refresh-token", "")
|
|
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
|
|
|
|
cfg = config or AIConfigDto()
|
|
config_dict = cfg.model_dump()
|
|
|
|
async def run_detection():
|
|
loop = asyncio.get_event_loop()
|
|
|
|
def _enqueue(event):
|
|
for q in _event_queues:
|
|
try:
|
|
q.put_nowait(event)
|
|
except asyncio.QueueFull:
|
|
pass
|
|
|
|
try:
|
|
inf = get_inference()
|
|
if not inf.is_engine_ready:
|
|
raise RuntimeError("Detection service unavailable")
|
|
|
|
def on_annotation(annotation, percent):
|
|
dtos = [detection_to_dto(d) for d in annotation.detections]
|
|
event = DetectionEvent(
|
|
annotations=dtos,
|
|
mediaId=media_id,
|
|
mediaStatus="AIProcessing",
|
|
mediaPercent=percent,
|
|
)
|
|
loop.call_soon_threadsafe(_enqueue, event)
|
|
if token_mgr and dtos:
|
|
_post_annotation_to_service(token_mgr, media_id, annotation, dtos)
|
|
|
|
def on_status(media_name, count):
|
|
event = DetectionEvent(
|
|
annotations=[],
|
|
mediaId=media_id,
|
|
mediaStatus="AIProcessed",
|
|
mediaPercent=100,
|
|
)
|
|
loop.call_soon_threadsafe(_enqueue, event)
|
|
|
|
await loop.run_in_executor(
|
|
executor, inf.run_detect, config_dict, on_annotation, on_status
|
|
)
|
|
except Exception:
|
|
error_event = DetectionEvent(
|
|
annotations=[],
|
|
mediaId=media_id,
|
|
mediaStatus="Error",
|
|
mediaPercent=0,
|
|
)
|
|
_enqueue(error_event)
|
|
finally:
|
|
_active_detections.pop(media_id, None)
|
|
|
|
_active_detections[media_id] = asyncio.create_task(run_detection())
|
|
return {"status": "started", "mediaId": media_id}
|
|
|
|
|
|
@app.get("/detect/stream")
|
|
async def detect_stream():
|
|
queue: asyncio.Queue = asyncio.Queue(maxsize=100)
|
|
_event_queues.append(queue)
|
|
|
|
async def event_generator():
|
|
try:
|
|
while True:
|
|
event = await queue.get()
|
|
yield f"data: {event.model_dump_json()}\n\n"
|
|
except asyncio.CancelledError:
|
|
pass
|
|
finally:
|
|
_event_queues.remove(queue)
|
|
|
|
return StreamingResponse(
|
|
event_generator(),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
|
)
|