mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 12:36:32 +00:00
Add AIAvailabilityStatus and AIRecognitionConfig classes for AI model management
- Introduced `AIAvailabilityStatus` class to manage the availability status of AI models, including methods for setting status and logging messages. - Added `AIRecognitionConfig` class to encapsulate configuration parameters for AI recognition, with a static method for creating instances from dictionaries. - Implemented enums for AI availability states to enhance clarity and maintainability. - Updated related Cython files to support the new classes and ensure proper type handling. These changes aim to improve the structure and functionality of the AI model management system, facilitating better status tracking and configuration handling.
This commit is contained in:
+305
@@ -0,0 +1,305 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
|
||||
import requests as http_requests
|
||||
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from loader_http_client import LoaderHttpClient, LoadResult
|
||||
|
||||
app = FastAPI(title="Azaion.Detections")
|
||||
executor = ThreadPoolExecutor(max_workers=2)
|
||||
|
||||
LOADER_URL = os.environ.get("LOADER_URL", "http://loader:8080")
|
||||
ANNOTATIONS_URL = os.environ.get("ANNOTATIONS_URL", "http://annotations:8080")
|
||||
|
||||
loader_client = LoaderHttpClient(LOADER_URL)
|
||||
inference = None
|
||||
_event_queues: list[asyncio.Queue] = []
|
||||
_active_detections: dict[str, asyncio.Task] = {}
|
||||
|
||||
|
||||
class TokenManager:
|
||||
def __init__(self, access_token: str, refresh_token: str):
|
||||
self.access_token = access_token
|
||||
self.refresh_token = refresh_token
|
||||
|
||||
def get_valid_token(self) -> str:
|
||||
exp = self._decode_exp(self.access_token)
|
||||
if exp and exp - time.time() < 60:
|
||||
self._refresh()
|
||||
return self.access_token
|
||||
|
||||
def _refresh(self):
|
||||
try:
|
||||
resp = http_requests.post(
|
||||
f"{ANNOTATIONS_URL}/auth/refresh",
|
||||
json={"refreshToken": self.refresh_token},
|
||||
timeout=10,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
self.access_token = resp.json()["token"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _decode_exp(token: str) -> Optional[float]:
|
||||
try:
|
||||
payload = token.split(".")[1]
|
||||
padding = 4 - len(payload) % 4
|
||||
if padding != 4:
|
||||
payload += "=" * padding
|
||||
data = json.loads(base64.urlsafe_b64decode(payload))
|
||||
return float(data.get("exp", 0))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def get_inference():
|
||||
global inference
|
||||
if inference is None:
|
||||
from inference import Inference
|
||||
inference = Inference(loader_client)
|
||||
return inference
|
||||
|
||||
|
||||
class DetectionDto(BaseModel):
|
||||
centerX: float
|
||||
centerY: float
|
||||
width: float
|
||||
height: float
|
||||
classNum: int
|
||||
label: str
|
||||
confidence: float
|
||||
|
||||
|
||||
class DetectionEvent(BaseModel):
|
||||
annotations: list[DetectionDto]
|
||||
mediaId: str
|
||||
mediaStatus: str
|
||||
mediaPercent: int
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
aiAvailability: str
|
||||
engineType: Optional[str] = None
|
||||
errorMessage: Optional[str] = None
|
||||
|
||||
|
||||
class AIConfigDto(BaseModel):
|
||||
frame_period_recognition: int = 4
|
||||
frame_recognition_seconds: int = 2
|
||||
probability_threshold: float = 0.25
|
||||
tracking_distance_confidence: float = 0.0
|
||||
tracking_probability_increase: float = 0.0
|
||||
tracking_intersection_threshold: float = 0.6
|
||||
model_batch_size: int = 8
|
||||
big_image_tile_overlap_percent: int = 20
|
||||
altitude: float = 400
|
||||
focal_length: float = 24
|
||||
sensor_width: float = 23.5
|
||||
paths: list[str] = []
|
||||
|
||||
|
||||
def detection_to_dto(det) -> DetectionDto:
|
||||
import constants_inf
|
||||
label = constants_inf.get_annotation_name(det.cls)
|
||||
return DetectionDto(
|
||||
centerX=det.x,
|
||||
centerY=det.y,
|
||||
width=det.w,
|
||||
height=det.h,
|
||||
classNum=det.cls,
|
||||
label=label,
|
||||
confidence=det.confidence,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> HealthResponse:
|
||||
if inference is None:
|
||||
return HealthResponse(status="healthy", aiAvailability="None")
|
||||
try:
|
||||
status = inference.ai_availability_status
|
||||
status_str = str(status).split()[0] if str(status).strip() else "None"
|
||||
error_msg = status.error_message if hasattr(status, 'error_message') else None
|
||||
engine_type = inference.engine_name
|
||||
return HealthResponse(
|
||||
status="healthy",
|
||||
aiAvailability=status_str,
|
||||
engineType=engine_type,
|
||||
errorMessage=error_msg,
|
||||
)
|
||||
except Exception as e:
|
||||
return HealthResponse(
|
||||
status="healthy",
|
||||
aiAvailability="None",
|
||||
errorMessage=str(e),
|
||||
)
|
||||
|
||||
|
||||
@app.post("/detect")
|
||||
async def detect_image(
|
||||
file: UploadFile = File(...),
|
||||
config: Optional[str] = Form(None),
|
||||
):
|
||||
import tempfile
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
image_bytes = await file.read()
|
||||
if not image_bytes:
|
||||
raise HTTPException(status_code=400, detail="Image is empty")
|
||||
|
||||
arr = np.frombuffer(image_bytes, dtype=np.uint8)
|
||||
if cv2.imdecode(arr, cv2.IMREAD_COLOR) is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid image data")
|
||||
|
||||
config_dict = {}
|
||||
if config:
|
||||
config_dict = json.loads(config)
|
||||
|
||||
suffix = os.path.splitext(file.filename or "upload.jpg")[1] or ".jpg"
|
||||
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
|
||||
try:
|
||||
tmp.write(image_bytes)
|
||||
tmp.close()
|
||||
config_dict["paths"] = [tmp.name]
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
inf = get_inference()
|
||||
results = []
|
||||
|
||||
def on_annotation(annotation, percent):
|
||||
results.extend(annotation.detections)
|
||||
|
||||
await loop.run_in_executor(executor, inf.run_detect, config_dict, on_annotation)
|
||||
return [detection_to_dto(d) for d in results]
|
||||
except RuntimeError as e:
|
||||
if "not available" in str(e):
|
||||
raise HTTPException(status_code=503, detail=str(e))
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
finally:
|
||||
os.unlink(tmp.name)
|
||||
|
||||
|
||||
def _post_annotation_to_service(token_mgr: TokenManager, media_id: str,
|
||||
annotation, dtos: list[DetectionDto]):
|
||||
try:
|
||||
token = token_mgr.get_valid_token()
|
||||
image_b64 = base64.b64encode(annotation.image).decode() if annotation.image else None
|
||||
payload = {
|
||||
"mediaId": media_id,
|
||||
"source": 0,
|
||||
"videoTime": f"00:00:{annotation.time // 1000:02d}" if annotation.time else "00:00:00",
|
||||
"detections": [d.model_dump() for d in dtos],
|
||||
}
|
||||
if image_b64:
|
||||
payload["image"] = image_b64
|
||||
http_requests.post(
|
||||
f"{ANNOTATIONS_URL}/annotations",
|
||||
json=payload,
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=30,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@app.post("/detect/{media_id}")
|
||||
async def detect_media(media_id: str, request: Request, config: Optional[AIConfigDto] = None):
|
||||
existing = _active_detections.get(media_id)
|
||||
if existing is not None and not existing.done():
|
||||
raise HTTPException(status_code=409, detail="Detection already in progress for this media")
|
||||
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
access_token = auth_header.removeprefix("Bearer ").strip() if auth_header else ""
|
||||
refresh_token = request.headers.get("x-refresh-token", "")
|
||||
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
|
||||
|
||||
cfg = config or AIConfigDto()
|
||||
config_dict = cfg.model_dump()
|
||||
|
||||
async def run_detection():
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _enqueue(event):
|
||||
for q in _event_queues:
|
||||
try:
|
||||
q.put_nowait(event)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
|
||||
try:
|
||||
inf = get_inference()
|
||||
if not inf.is_engine_ready:
|
||||
raise RuntimeError("Detection service unavailable")
|
||||
|
||||
def on_annotation(annotation, percent):
|
||||
dtos = [detection_to_dto(d) for d in annotation.detections]
|
||||
event = DetectionEvent(
|
||||
annotations=dtos,
|
||||
mediaId=media_id,
|
||||
mediaStatus="AIProcessing",
|
||||
mediaPercent=percent,
|
||||
)
|
||||
loop.call_soon_threadsafe(_enqueue, event)
|
||||
if token_mgr and dtos:
|
||||
_post_annotation_to_service(token_mgr, media_id, annotation, dtos)
|
||||
|
||||
def on_status(media_name, count):
|
||||
event = DetectionEvent(
|
||||
annotations=[],
|
||||
mediaId=media_id,
|
||||
mediaStatus="AIProcessed",
|
||||
mediaPercent=100,
|
||||
)
|
||||
loop.call_soon_threadsafe(_enqueue, event)
|
||||
|
||||
await loop.run_in_executor(
|
||||
executor, inf.run_detect, config_dict, on_annotation, on_status
|
||||
)
|
||||
except Exception:
|
||||
error_event = DetectionEvent(
|
||||
annotations=[],
|
||||
mediaId=media_id,
|
||||
mediaStatus="Error",
|
||||
mediaPercent=0,
|
||||
)
|
||||
_enqueue(error_event)
|
||||
finally:
|
||||
_active_detections.pop(media_id, None)
|
||||
|
||||
_active_detections[media_id] = asyncio.create_task(run_detection())
|
||||
return {"status": "started", "mediaId": media_id}
|
||||
|
||||
|
||||
@app.get("/detect/stream")
|
||||
async def detect_stream():
|
||||
queue: asyncio.Queue = asyncio.Queue(maxsize=100)
|
||||
_event_queues.append(queue)
|
||||
|
||||
async def event_generator():
|
||||
try:
|
||||
while True:
|
||||
event = await queue.get()
|
||||
yield f"data: {event.model_dump_json()}\n\n"
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
_event_queues.remove(queue)
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
Reference in New Issue
Block a user