mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 06:36:32 +00:00
6c24d09eab
Made-with: Cursor
441 lines
13 KiB
Python
441 lines
13 KiB
Python
import asyncio
|
|
import base64
|
|
import json
|
|
import os
|
|
import time
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from typing import Annotated, Optional
|
|
|
|
import requests as http_requests
|
|
from fastapi import Body, FastAPI, UploadFile, File, Form, HTTPException, Request
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel
|
|
|
|
from loader_http_client import LoaderHttpClient, LoadResult
|
|
|
|
app = FastAPI(title="Azaion.Detections")
|
|
executor = ThreadPoolExecutor(max_workers=2)
|
|
|
|
LOADER_URL = os.environ.get("LOADER_URL", "http://loader:8080")
|
|
ANNOTATIONS_URL = os.environ.get("ANNOTATIONS_URL", "http://annotations:8080")
|
|
|
|
loader_client = LoaderHttpClient(LOADER_URL)
|
|
annotations_client = LoaderHttpClient(ANNOTATIONS_URL)
|
|
inference = None
|
|
_event_queues: list[asyncio.Queue] = []
|
|
_active_detections: dict[str, asyncio.Task] = {}
|
|
|
|
|
|
class TokenManager:
|
|
def __init__(self, access_token: str, refresh_token: str):
|
|
self.access_token = access_token
|
|
self.refresh_token = refresh_token
|
|
|
|
def get_valid_token(self) -> str:
|
|
exp = self._decode_exp(self.access_token)
|
|
if exp and exp - time.time() < 60:
|
|
self._refresh()
|
|
return self.access_token
|
|
|
|
def _refresh(self):
|
|
try:
|
|
resp = http_requests.post(
|
|
f"{ANNOTATIONS_URL}/auth/refresh",
|
|
json={"refreshToken": self.refresh_token},
|
|
timeout=10,
|
|
)
|
|
if resp.status_code == 200:
|
|
self.access_token = resp.json()["token"]
|
|
except Exception:
|
|
pass
|
|
|
|
@staticmethod
|
|
def _decode_exp(token: str) -> Optional[float]:
|
|
try:
|
|
payload = token.split(".")[1]
|
|
padding = 4 - len(payload) % 4
|
|
if padding != 4:
|
|
payload += "=" * padding
|
|
data = json.loads(base64.urlsafe_b64decode(payload))
|
|
return float(data.get("exp", 0))
|
|
except Exception:
|
|
return None
|
|
|
|
@staticmethod
|
|
def decode_user_id(token: str) -> Optional[str]:
|
|
try:
|
|
payload = token.split(".")[1]
|
|
padding = 4 - len(payload) % 4
|
|
if padding != 4:
|
|
payload += "=" * padding
|
|
data = json.loads(base64.urlsafe_b64decode(payload))
|
|
uid = (
|
|
data.get("sub")
|
|
or data.get("userId")
|
|
or data.get("user_id")
|
|
or data.get("nameid")
|
|
or data.get(
|
|
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/nameidentifier"
|
|
)
|
|
)
|
|
if uid is None:
|
|
return None
|
|
return str(uid)
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def get_inference():
|
|
global inference
|
|
if inference is None:
|
|
from inference import Inference
|
|
inference = Inference(loader_client)
|
|
return inference
|
|
|
|
|
|
class DetectionDto(BaseModel):
|
|
centerX: float
|
|
centerY: float
|
|
width: float
|
|
height: float
|
|
classNum: int
|
|
label: str
|
|
confidence: float
|
|
|
|
|
|
class DetectionEvent(BaseModel):
|
|
annotations: list[DetectionDto]
|
|
mediaId: str
|
|
mediaStatus: str
|
|
mediaPercent: int
|
|
|
|
|
|
class HealthResponse(BaseModel):
|
|
status: str
|
|
aiAvailability: str
|
|
engineType: Optional[str] = None
|
|
errorMessage: Optional[str] = None
|
|
|
|
|
|
class AIConfigDto(BaseModel):
|
|
frame_period_recognition: int = 4
|
|
frame_recognition_seconds: int = 2
|
|
probability_threshold: float = 0.25
|
|
tracking_distance_confidence: float = 0.0
|
|
tracking_probability_increase: float = 0.0
|
|
tracking_intersection_threshold: float = 0.6
|
|
model_batch_size: int = 8
|
|
big_image_tile_overlap_percent: int = 20
|
|
altitude: float = 400
|
|
focal_length: float = 24
|
|
sensor_width: float = 23.5
|
|
|
|
|
|
_AI_SETTINGS_FIELD_KEYS = (
|
|
(
|
|
"frame_period_recognition",
|
|
("frame_period_recognition", "framePeriodRecognition", "FramePeriodRecognition"),
|
|
),
|
|
(
|
|
"frame_recognition_seconds",
|
|
("frame_recognition_seconds", "frameRecognitionSeconds", "FrameRecognitionSeconds"),
|
|
),
|
|
(
|
|
"probability_threshold",
|
|
("probability_threshold", "probabilityThreshold", "ProbabilityThreshold"),
|
|
),
|
|
(
|
|
"tracking_distance_confidence",
|
|
(
|
|
"tracking_distance_confidence",
|
|
"trackingDistanceConfidence",
|
|
"TrackingDistanceConfidence",
|
|
),
|
|
),
|
|
(
|
|
"tracking_probability_increase",
|
|
(
|
|
"tracking_probability_increase",
|
|
"trackingProbabilityIncrease",
|
|
"TrackingProbabilityIncrease",
|
|
),
|
|
),
|
|
(
|
|
"tracking_intersection_threshold",
|
|
(
|
|
"tracking_intersection_threshold",
|
|
"trackingIntersectionThreshold",
|
|
"TrackingIntersectionThreshold",
|
|
),
|
|
),
|
|
(
|
|
"model_batch_size",
|
|
("model_batch_size", "modelBatchSize", "ModelBatchSize"),
|
|
),
|
|
(
|
|
"big_image_tile_overlap_percent",
|
|
(
|
|
"big_image_tile_overlap_percent",
|
|
"bigImageTileOverlapPercent",
|
|
"BigImageTileOverlapPercent",
|
|
),
|
|
),
|
|
(
|
|
"altitude",
|
|
("altitude", "Altitude"),
|
|
),
|
|
(
|
|
"focal_length",
|
|
("focal_length", "focalLength", "FocalLength"),
|
|
),
|
|
(
|
|
"sensor_width",
|
|
("sensor_width", "sensorWidth", "SensorWidth"),
|
|
),
|
|
)
|
|
|
|
|
|
def _merged_annotation_settings_payload(raw: object) -> dict:
|
|
if not raw or not isinstance(raw, dict):
|
|
return {}
|
|
merged = dict(raw)
|
|
inner = raw.get("aiRecognitionSettings")
|
|
if isinstance(inner, dict):
|
|
merged.update(inner)
|
|
cam = raw.get("cameraSettings")
|
|
if isinstance(cam, dict):
|
|
merged.update(cam)
|
|
out = {}
|
|
for snake, aliases in _AI_SETTINGS_FIELD_KEYS:
|
|
for key in aliases:
|
|
if key in merged and merged[key] is not None:
|
|
out[snake] = merged[key]
|
|
break
|
|
return out
|
|
|
|
|
|
def _build_media_detect_config_dict(
|
|
media_id: str,
|
|
token_mgr: Optional[TokenManager],
|
|
override: Optional[AIConfigDto],
|
|
) -> dict:
|
|
cfg: dict = {}
|
|
bearer = ""
|
|
if token_mgr:
|
|
bearer = token_mgr.get_valid_token()
|
|
uid = TokenManager.decode_user_id(token_mgr.access_token)
|
|
if uid:
|
|
raw = annotations_client.fetch_user_ai_settings(uid, bearer)
|
|
cfg.update(_merged_annotation_settings_payload(raw))
|
|
if override is not None:
|
|
for k, v in override.model_dump(exclude_defaults=True).items():
|
|
cfg[k] = v
|
|
media_path = annotations_client.fetch_media_path(media_id, bearer)
|
|
if not media_path:
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="Could not resolve media path from annotations service",
|
|
)
|
|
cfg["paths"] = [media_path]
|
|
return cfg
|
|
|
|
|
|
def detection_to_dto(det) -> DetectionDto:
|
|
import constants_inf
|
|
label = constants_inf.get_annotation_name(det.cls)
|
|
return DetectionDto(
|
|
centerX=det.x,
|
|
centerY=det.y,
|
|
width=det.w,
|
|
height=det.h,
|
|
classNum=det.cls,
|
|
label=label,
|
|
confidence=det.confidence,
|
|
)
|
|
|
|
|
|
@app.get("/health")
|
|
def health() -> HealthResponse:
|
|
if inference is None:
|
|
return HealthResponse(status="healthy", aiAvailability="None")
|
|
try:
|
|
status = inference.ai_availability_status
|
|
status_str = str(status).split()[0] if str(status).strip() else "None"
|
|
error_msg = status.error_message if hasattr(status, 'error_message') else None
|
|
engine_type = inference.engine_name
|
|
return HealthResponse(
|
|
status="healthy",
|
|
aiAvailability=status_str,
|
|
engineType=engine_type,
|
|
errorMessage=error_msg,
|
|
)
|
|
except Exception as e:
|
|
return HealthResponse(
|
|
status="healthy",
|
|
aiAvailability="None",
|
|
errorMessage=str(e),
|
|
)
|
|
|
|
|
|
@app.post("/detect")
|
|
async def detect_image(
|
|
file: UploadFile = File(...),
|
|
config: Optional[str] = Form(None),
|
|
):
|
|
import cv2
|
|
import numpy as np
|
|
from pathlib import Path
|
|
|
|
from inference import ai_config_from_dict
|
|
|
|
image_bytes = await file.read()
|
|
if not image_bytes:
|
|
raise HTTPException(status_code=400, detail="Image is empty")
|
|
|
|
arr = np.frombuffer(image_bytes, dtype=np.uint8)
|
|
if cv2.imdecode(arr, cv2.IMREAD_COLOR) is None:
|
|
raise HTTPException(status_code=400, detail="Invalid image data")
|
|
|
|
config_dict = {}
|
|
if config:
|
|
config_dict = json.loads(config)
|
|
|
|
media_name = Path(file.filename or "upload.jpg").stem.replace(" ", "")
|
|
loop = asyncio.get_event_loop()
|
|
inf = get_inference()
|
|
results = []
|
|
|
|
def on_annotation(annotation, percent):
|
|
results.extend(annotation.detections)
|
|
|
|
ai_cfg = ai_config_from_dict(config_dict)
|
|
|
|
def run_img():
|
|
inf.run_detect_image(image_bytes, ai_cfg, media_name, on_annotation)
|
|
|
|
try:
|
|
await loop.run_in_executor(executor, run_img)
|
|
return [detection_to_dto(d) for d in results]
|
|
except RuntimeError as e:
|
|
if "not available" in str(e):
|
|
raise HTTPException(status_code=503, detail=str(e))
|
|
raise HTTPException(status_code=422, detail=str(e))
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
def _post_annotation_to_service(token_mgr: TokenManager, media_id: str,
|
|
annotation, dtos: list[DetectionDto]):
|
|
try:
|
|
token = token_mgr.get_valid_token()
|
|
image_b64 = base64.b64encode(annotation.image).decode() if annotation.image else None
|
|
payload = {
|
|
"mediaId": media_id,
|
|
"source": 0,
|
|
"videoTime": f"00:00:{annotation.time // 1000:02d}" if annotation.time else "00:00:00",
|
|
"detections": [d.model_dump() for d in dtos],
|
|
}
|
|
if image_b64:
|
|
payload["image"] = image_b64
|
|
http_requests.post(
|
|
f"{ANNOTATIONS_URL}/annotations",
|
|
json=payload,
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
timeout=30,
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
@app.post("/detect/{media_id}")
|
|
async def detect_media(
|
|
media_id: str,
|
|
request: Request,
|
|
config: Annotated[Optional[AIConfigDto], Body()] = None,
|
|
):
|
|
existing = _active_detections.get(media_id)
|
|
if existing is not None and not existing.done():
|
|
raise HTTPException(status_code=409, detail="Detection already in progress for this media")
|
|
|
|
auth_header = request.headers.get("authorization", "")
|
|
access_token = auth_header.removeprefix("Bearer ").strip() if auth_header else ""
|
|
refresh_token = request.headers.get("x-refresh-token", "")
|
|
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
|
|
|
|
config_dict = _build_media_detect_config_dict(media_id, token_mgr, config)
|
|
|
|
async def run_detection():
|
|
loop = asyncio.get_event_loop()
|
|
|
|
def _enqueue(event):
|
|
for q in _event_queues:
|
|
try:
|
|
q.put_nowait(event)
|
|
except asyncio.QueueFull:
|
|
pass
|
|
|
|
try:
|
|
inf = get_inference()
|
|
if not inf.is_engine_ready:
|
|
raise RuntimeError("Detection service unavailable")
|
|
|
|
def on_annotation(annotation, percent):
|
|
dtos = [detection_to_dto(d) for d in annotation.detections]
|
|
event = DetectionEvent(
|
|
annotations=dtos,
|
|
mediaId=media_id,
|
|
mediaStatus="AIProcessing",
|
|
mediaPercent=percent,
|
|
)
|
|
loop.call_soon_threadsafe(_enqueue, event)
|
|
if token_mgr and dtos:
|
|
_post_annotation_to_service(token_mgr, media_id, annotation, dtos)
|
|
|
|
def on_status(media_name, count):
|
|
event = DetectionEvent(
|
|
annotations=[],
|
|
mediaId=media_id,
|
|
mediaStatus="AIProcessed",
|
|
mediaPercent=100,
|
|
)
|
|
loop.call_soon_threadsafe(_enqueue, event)
|
|
|
|
await loop.run_in_executor(
|
|
executor, inf.run_detect, config_dict, on_annotation, on_status
|
|
)
|
|
except Exception:
|
|
error_event = DetectionEvent(
|
|
annotations=[],
|
|
mediaId=media_id,
|
|
mediaStatus="Error",
|
|
mediaPercent=0,
|
|
)
|
|
_enqueue(error_event)
|
|
finally:
|
|
_active_detections.pop(media_id, None)
|
|
|
|
_active_detections[media_id] = asyncio.create_task(run_detection())
|
|
return {"status": "started", "mediaId": media_id}
|
|
|
|
|
|
@app.get("/detect/stream")
|
|
async def detect_stream():
|
|
queue: asyncio.Queue = asyncio.Queue(maxsize=100)
|
|
_event_queues.append(queue)
|
|
|
|
async def event_generator():
|
|
try:
|
|
while True:
|
|
event = await queue.get()
|
|
yield f"data: {event.model_dump_json()}\n\n"
|
|
except asyncio.CancelledError:
|
|
pass
|
|
finally:
|
|
_event_queues.remove(queue)
|
|
|
|
return StreamingResponse(
|
|
event_generator(),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
|
)
|