mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 08:46:32 +00:00
[AZ-173] [AZ-174] Stream-based detection API and DB-driven AI config
Made-with: Cursor
This commit is contained in:
+158
-23
@@ -4,10 +4,10 @@ import json
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import requests as http_requests
|
||||
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
|
||||
from fastapi import Body, FastAPI, UploadFile, File, Form, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -20,6 +20,7 @@ LOADER_URL = os.environ.get("LOADER_URL", "http://loader:8080")
|
||||
ANNOTATIONS_URL = os.environ.get("ANNOTATIONS_URL", "http://annotations:8080")
|
||||
|
||||
loader_client = LoaderHttpClient(LOADER_URL)
|
||||
annotations_client = LoaderHttpClient(ANNOTATIONS_URL)
|
||||
inference = None
|
||||
_event_queues: list[asyncio.Queue] = []
|
||||
_active_detections: dict[str, asyncio.Task] = {}
|
||||
@@ -60,6 +61,29 @@ class TokenManager:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def decode_user_id(token: str) -> Optional[str]:
|
||||
try:
|
||||
payload = token.split(".")[1]
|
||||
padding = 4 - len(payload) % 4
|
||||
if padding != 4:
|
||||
payload += "=" * padding
|
||||
data = json.loads(base64.urlsafe_b64decode(payload))
|
||||
uid = (
|
||||
data.get("sub")
|
||||
or data.get("userId")
|
||||
or data.get("user_id")
|
||||
or data.get("nameid")
|
||||
or data.get(
|
||||
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/nameidentifier"
|
||||
)
|
||||
)
|
||||
if uid is None:
|
||||
return None
|
||||
return str(uid)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def get_inference():
|
||||
global inference
|
||||
@@ -105,7 +129,115 @@ class AIConfigDto(BaseModel):
|
||||
altitude: float = 400
|
||||
focal_length: float = 24
|
||||
sensor_width: float = 23.5
|
||||
paths: list[str] = []
|
||||
|
||||
|
||||
_AI_SETTINGS_FIELD_KEYS = (
|
||||
(
|
||||
"frame_period_recognition",
|
||||
("frame_period_recognition", "framePeriodRecognition", "FramePeriodRecognition"),
|
||||
),
|
||||
(
|
||||
"frame_recognition_seconds",
|
||||
("frame_recognition_seconds", "frameRecognitionSeconds", "FrameRecognitionSeconds"),
|
||||
),
|
||||
(
|
||||
"probability_threshold",
|
||||
("probability_threshold", "probabilityThreshold", "ProbabilityThreshold"),
|
||||
),
|
||||
(
|
||||
"tracking_distance_confidence",
|
||||
(
|
||||
"tracking_distance_confidence",
|
||||
"trackingDistanceConfidence",
|
||||
"TrackingDistanceConfidence",
|
||||
),
|
||||
),
|
||||
(
|
||||
"tracking_probability_increase",
|
||||
(
|
||||
"tracking_probability_increase",
|
||||
"trackingProbabilityIncrease",
|
||||
"TrackingProbabilityIncrease",
|
||||
),
|
||||
),
|
||||
(
|
||||
"tracking_intersection_threshold",
|
||||
(
|
||||
"tracking_intersection_threshold",
|
||||
"trackingIntersectionThreshold",
|
||||
"TrackingIntersectionThreshold",
|
||||
),
|
||||
),
|
||||
(
|
||||
"model_batch_size",
|
||||
("model_batch_size", "modelBatchSize", "ModelBatchSize"),
|
||||
),
|
||||
(
|
||||
"big_image_tile_overlap_percent",
|
||||
(
|
||||
"big_image_tile_overlap_percent",
|
||||
"bigImageTileOverlapPercent",
|
||||
"BigImageTileOverlapPercent",
|
||||
),
|
||||
),
|
||||
(
|
||||
"altitude",
|
||||
("altitude", "Altitude"),
|
||||
),
|
||||
(
|
||||
"focal_length",
|
||||
("focal_length", "focalLength", "FocalLength"),
|
||||
),
|
||||
(
|
||||
"sensor_width",
|
||||
("sensor_width", "sensorWidth", "SensorWidth"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _merged_annotation_settings_payload(raw: object) -> dict:
|
||||
if not raw or not isinstance(raw, dict):
|
||||
return {}
|
||||
merged = dict(raw)
|
||||
inner = raw.get("aiRecognitionSettings")
|
||||
if isinstance(inner, dict):
|
||||
merged.update(inner)
|
||||
cam = raw.get("cameraSettings")
|
||||
if isinstance(cam, dict):
|
||||
merged.update(cam)
|
||||
out = {}
|
||||
for snake, aliases in _AI_SETTINGS_FIELD_KEYS:
|
||||
for key in aliases:
|
||||
if key in merged and merged[key] is not None:
|
||||
out[snake] = merged[key]
|
||||
break
|
||||
return out
|
||||
|
||||
|
||||
def _build_media_detect_config_dict(
|
||||
media_id: str,
|
||||
token_mgr: Optional[TokenManager],
|
||||
override: Optional[AIConfigDto],
|
||||
) -> dict:
|
||||
cfg: dict = {}
|
||||
bearer = ""
|
||||
if token_mgr:
|
||||
bearer = token_mgr.get_valid_token()
|
||||
uid = TokenManager.decode_user_id(token_mgr.access_token)
|
||||
if uid:
|
||||
raw = annotations_client.fetch_user_ai_settings(uid, bearer)
|
||||
cfg.update(_merged_annotation_settings_payload(raw))
|
||||
if override is not None:
|
||||
for k, v in override.model_dump(exclude_defaults=True).items():
|
||||
cfg[k] = v
|
||||
media_path = annotations_client.fetch_media_path(media_id, bearer)
|
||||
if not media_path:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Could not resolve media path from annotations service",
|
||||
)
|
||||
cfg["paths"] = [media_path]
|
||||
return cfg
|
||||
|
||||
|
||||
def detection_to_dto(det) -> DetectionDto:
|
||||
@@ -150,9 +282,11 @@ async def detect_image(
|
||||
file: UploadFile = File(...),
|
||||
config: Optional[str] = Form(None),
|
||||
):
|
||||
import tempfile
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
from inference import ai_config_from_dict
|
||||
|
||||
image_bytes = await file.read()
|
||||
if not image_bytes:
|
||||
@@ -166,21 +300,21 @@ async def detect_image(
|
||||
if config:
|
||||
config_dict = json.loads(config)
|
||||
|
||||
suffix = os.path.splitext(file.filename or "upload.jpg")[1] or ".jpg"
|
||||
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
|
||||
media_name = Path(file.filename or "upload.jpg").stem.replace(" ", "")
|
||||
loop = asyncio.get_event_loop()
|
||||
inf = get_inference()
|
||||
results = []
|
||||
|
||||
def on_annotation(annotation, percent):
|
||||
results.extend(annotation.detections)
|
||||
|
||||
ai_cfg = ai_config_from_dict(config_dict)
|
||||
|
||||
def run_img():
|
||||
inf.run_detect_image(image_bytes, ai_cfg, media_name, on_annotation)
|
||||
|
||||
try:
|
||||
tmp.write(image_bytes)
|
||||
tmp.close()
|
||||
config_dict["paths"] = [tmp.name]
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
inf = get_inference()
|
||||
results = []
|
||||
|
||||
def on_annotation(annotation, percent):
|
||||
results.extend(annotation.detections)
|
||||
|
||||
await loop.run_in_executor(executor, inf.run_detect, config_dict, on_annotation)
|
||||
await loop.run_in_executor(executor, run_img)
|
||||
return [detection_to_dto(d) for d in results]
|
||||
except RuntimeError as e:
|
||||
if "not available" in str(e):
|
||||
@@ -188,8 +322,6 @@ async def detect_image(
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
finally:
|
||||
os.unlink(tmp.name)
|
||||
|
||||
|
||||
def _post_annotation_to_service(token_mgr: TokenManager, media_id: str,
|
||||
@@ -216,7 +348,11 @@ def _post_annotation_to_service(token_mgr: TokenManager, media_id: str,
|
||||
|
||||
|
||||
@app.post("/detect/{media_id}")
|
||||
async def detect_media(media_id: str, request: Request, config: Optional[AIConfigDto] = None):
|
||||
async def detect_media(
|
||||
media_id: str,
|
||||
request: Request,
|
||||
config: Annotated[Optional[AIConfigDto], Body()] = None,
|
||||
):
|
||||
existing = _active_detections.get(media_id)
|
||||
if existing is not None and not existing.done():
|
||||
raise HTTPException(status_code=409, detail="Detection already in progress for this media")
|
||||
@@ -226,8 +362,7 @@ async def detect_media(media_id: str, request: Request, config: Optional[AIConfi
|
||||
refresh_token = request.headers.get("x-refresh-token", "")
|
||||
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
|
||||
|
||||
cfg = config or AIConfigDto()
|
||||
config_dict = cfg.model_dump()
|
||||
config_dict = _build_media_detect_config_dict(media_id, token_mgr, config)
|
||||
|
||||
async def run_detection():
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
Reference in New Issue
Block a user