mirror of
https://github.com/azaion/detections-semantic.git
synced 2026-04-22 08:56:38 +00:00
Initial commit
Made-with: Cursor
This commit is contained in:
@@ -0,0 +1,291 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
|
||||
import requests as http_requests
|
||||
from fastapi import FastAPI, UploadFile, File, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from loader_http_client import LoaderHttpClient, LoadResult
|
||||
|
||||
app = FastAPI(title="Azaion.Detections")
|
||||
executor = ThreadPoolExecutor(max_workers=2)
|
||||
|
||||
LOADER_URL = os.environ.get("LOADER_URL", "http://loader:8080")
|
||||
ANNOTATIONS_URL = os.environ.get("ANNOTATIONS_URL", "http://annotations:8080")
|
||||
|
||||
loader_client = LoaderHttpClient(LOADER_URL)
|
||||
inference = None
|
||||
_event_queues: list[asyncio.Queue] = []
|
||||
_active_detections: dict[str, bool] = {}
|
||||
|
||||
|
||||
class TokenManager:
|
||||
def __init__(self, access_token: str, refresh_token: str):
|
||||
self.access_token = access_token
|
||||
self.refresh_token = refresh_token
|
||||
|
||||
def get_valid_token(self) -> str:
|
||||
exp = self._decode_exp(self.access_token)
|
||||
if exp and exp - time.time() < 60:
|
||||
self._refresh()
|
||||
return self.access_token
|
||||
|
||||
def _refresh(self):
|
||||
try:
|
||||
resp = http_requests.post(
|
||||
f"{ANNOTATIONS_URL}/auth/refresh",
|
||||
json={"refreshToken": self.refresh_token},
|
||||
timeout=10,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
self.access_token = resp.json()["token"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _decode_exp(token: str) -> Optional[float]:
|
||||
try:
|
||||
payload = token.split(".")[1]
|
||||
padding = 4 - len(payload) % 4
|
||||
if padding != 4:
|
||||
payload += "=" * padding
|
||||
data = json.loads(base64.urlsafe_b64decode(payload))
|
||||
return float(data.get("exp", 0))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def get_inference():
|
||||
global inference
|
||||
if inference is None:
|
||||
from inference import Inference
|
||||
inference = Inference(loader_client)
|
||||
return inference
|
||||
|
||||
|
||||
class DetectionDto(BaseModel):
|
||||
centerX: float
|
||||
centerY: float
|
||||
width: float
|
||||
height: float
|
||||
classNum: int
|
||||
label: str
|
||||
confidence: float
|
||||
|
||||
|
||||
class DetectionEvent(BaseModel):
|
||||
annotations: list[DetectionDto]
|
||||
mediaId: str
|
||||
mediaStatus: str
|
||||
mediaPercent: int
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
aiAvailability: str
|
||||
errorMessage: Optional[str] = None
|
||||
|
||||
|
||||
class AIConfigDto(BaseModel):
|
||||
frame_period_recognition: int = 4
|
||||
frame_recognition_seconds: int = 2
|
||||
probability_threshold: float = 0.25
|
||||
tracking_distance_confidence: float = 0.0
|
||||
tracking_probability_increase: float = 0.0
|
||||
tracking_intersection_threshold: float = 0.6
|
||||
model_batch_size: int = 1
|
||||
big_image_tile_overlap_percent: int = 20
|
||||
altitude: float = 400
|
||||
focal_length: float = 24
|
||||
sensor_width: float = 23.5
|
||||
paths: list[str] = []
|
||||
|
||||
|
||||
def detection_to_dto(det) -> DetectionDto:
|
||||
import constants_inf
|
||||
label = ""
|
||||
if det.cls in constants_inf.annotations_dict:
|
||||
label = constants_inf.annotations_dict[det.cls].name
|
||||
return DetectionDto(
|
||||
centerX=det.x,
|
||||
centerY=det.y,
|
||||
width=det.w,
|
||||
height=det.h,
|
||||
classNum=det.cls,
|
||||
label=label,
|
||||
confidence=det.confidence,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> HealthResponse:
|
||||
try:
|
||||
inf = get_inference()
|
||||
status = inf.ai_availability_status
|
||||
status_str = str(status).split()[0] if str(status).strip() else "None"
|
||||
error_msg = status.error_message if hasattr(status, 'error_message') else None
|
||||
return HealthResponse(
|
||||
status="healthy",
|
||||
aiAvailability=status_str,
|
||||
errorMessage=error_msg,
|
||||
)
|
||||
except Exception:
|
||||
return HealthResponse(
|
||||
status="healthy",
|
||||
aiAvailability="None",
|
||||
errorMessage=None,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/detect")
|
||||
async def detect_image(
|
||||
file: UploadFile = File(...),
|
||||
config: Optional[str] = None,
|
||||
):
|
||||
image_bytes = await file.read()
|
||||
if not image_bytes:
|
||||
raise HTTPException(status_code=400, detail="Image is empty")
|
||||
|
||||
config_dict = {}
|
||||
if config:
|
||||
config_dict = json.loads(config)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
inf = get_inference()
|
||||
detections = await loop.run_in_executor(
|
||||
executor, inf.detect_single_image, image_bytes, config_dict
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if "not available" in str(e):
|
||||
raise HTTPException(status_code=503, detail=str(e))
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return [detection_to_dto(d) for d in detections]
|
||||
|
||||
|
||||
def _post_annotation_to_service(token_mgr: TokenManager, media_id: str,
|
||||
annotation, dtos: list[DetectionDto]):
|
||||
try:
|
||||
token = token_mgr.get_valid_token()
|
||||
image_b64 = base64.b64encode(annotation.image).decode() if annotation.image else None
|
||||
payload = {
|
||||
"mediaId": media_id,
|
||||
"source": 0,
|
||||
"videoTime": f"00:00:{annotation.time // 1000:02d}" if annotation.time else "00:00:00",
|
||||
"detections": [d.model_dump() for d in dtos],
|
||||
}
|
||||
if image_b64:
|
||||
payload["image"] = image_b64
|
||||
http_requests.post(
|
||||
f"{ANNOTATIONS_URL}/annotations",
|
||||
json=payload,
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=30,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@app.post("/detect/{media_id}")
|
||||
async def detect_media(media_id: str, request: Request, config: Optional[AIConfigDto] = None):
|
||||
if media_id in _active_detections:
|
||||
raise HTTPException(status_code=409, detail="Detection already in progress for this media")
|
||||
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
access_token = auth_header.removeprefix("Bearer ").strip() if auth_header else ""
|
||||
refresh_token = request.headers.get("x-refresh-token", "")
|
||||
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
|
||||
|
||||
cfg = config or AIConfigDto()
|
||||
config_dict = cfg.model_dump()
|
||||
|
||||
_active_detections[media_id] = True
|
||||
|
||||
async def run_detection():
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
inf = get_inference()
|
||||
if inf.engine is None:
|
||||
raise RuntimeError("Detection service unavailable")
|
||||
|
||||
def on_annotation(annotation, percent):
|
||||
dtos = [detection_to_dto(d) for d in annotation.detections]
|
||||
event = DetectionEvent(
|
||||
annotations=dtos,
|
||||
mediaId=media_id,
|
||||
mediaStatus="AIProcessing",
|
||||
mediaPercent=percent,
|
||||
)
|
||||
for q in _event_queues:
|
||||
try:
|
||||
q.put_nowait(event)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
|
||||
if token_mgr and dtos:
|
||||
_post_annotation_to_service(token_mgr, media_id, annotation, dtos)
|
||||
|
||||
def on_status(media_name, count):
|
||||
event = DetectionEvent(
|
||||
annotations=[],
|
||||
mediaId=media_id,
|
||||
mediaStatus="AIProcessed",
|
||||
mediaPercent=100,
|
||||
)
|
||||
for q in _event_queues:
|
||||
try:
|
||||
q.put_nowait(event)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
|
||||
await loop.run_in_executor(
|
||||
executor, inf.run_detect, config_dict, on_annotation, on_status
|
||||
)
|
||||
except Exception:
|
||||
error_event = DetectionEvent(
|
||||
annotations=[],
|
||||
mediaId=media_id,
|
||||
mediaStatus="Error",
|
||||
mediaPercent=0,
|
||||
)
|
||||
for q in _event_queues:
|
||||
try:
|
||||
q.put_nowait(error_event)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
finally:
|
||||
_active_detections.pop(media_id, None)
|
||||
|
||||
asyncio.create_task(run_detection())
|
||||
return {"status": "started", "mediaId": media_id}
|
||||
|
||||
|
||||
@app.get("/detect/stream")
|
||||
async def detect_stream():
|
||||
queue: asyncio.Queue = asyncio.Queue(maxsize=100)
|
||||
_event_queues.append(queue)
|
||||
|
||||
async def event_generator():
|
||||
try:
|
||||
while True:
|
||||
event = await queue.get()
|
||||
yield f"data: {event.model_dump_json()}\n\n"
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
_event_queues.remove(queue)
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
Reference in New Issue
Block a user