Files
detections/src/main.py
T
2026-04-01 01:12:05 +03:00

627 lines
20 KiB
Python

import asyncio
import base64
import io
import json
import os
import tempfile
import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Annotated, Optional
import av
import cv2
import numpy as np
import requests as http_requests
from fastapi import Body, FastAPI, UploadFile, File, Form, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from loader_http_client import LoaderHttpClient, LoadResult
app = FastAPI(title="Azaion.Detections")
executor = ThreadPoolExecutor(max_workers=2)
LOADER_URL = os.environ.get("LOADER_URL", "http://loader:8080")
ANNOTATIONS_URL = os.environ.get("ANNOTATIONS_URL", "http://annotations:8080")
_MEDIA_STATUS_NEW = 1
_MEDIA_STATUS_AI_PROCESSING = 2
_MEDIA_STATUS_AI_PROCESSED = 3
_MEDIA_STATUS_ERROR = 6
_VIDEO_EXTENSIONS = frozenset({".mp4", ".mov", ".webm", ".mkv", ".avi", ".m4v"})
_IMAGE_EXTENSIONS = frozenset({".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif", ".tiff"})
loader_client = LoaderHttpClient(LOADER_URL)
annotations_client = LoaderHttpClient(ANNOTATIONS_URL)
inference = None
_event_queues: list[asyncio.Queue] = []
_active_detections: dict[str, asyncio.Task] = {}
class TokenManager:
def __init__(self, access_token: str, refresh_token: str):
self.access_token = access_token
self.refresh_token = refresh_token
def get_valid_token(self) -> str:
exp = self._decode_exp(self.access_token)
if exp and exp - time.time() < 60:
self._refresh()
return self.access_token
def _refresh(self):
try:
resp = http_requests.post(
f"{ANNOTATIONS_URL}/auth/refresh",
json={"refreshToken": self.refresh_token},
timeout=10,
)
if resp.status_code == 200:
self.access_token = resp.json()["token"]
except Exception:
pass
@staticmethod
def _decode_exp(token: str) -> Optional[float]:
try:
payload = token.split(".")[1]
padding = 4 - len(payload) % 4
if padding != 4:
payload += "=" * padding
data = json.loads(base64.urlsafe_b64decode(payload))
return float(data.get("exp", 0))
except Exception:
return None
@staticmethod
def decode_user_id(token: str) -> Optional[str]:
try:
payload = token.split(".")[1]
padding = 4 - len(payload) % 4
if padding != 4:
payload += "=" * padding
data = json.loads(base64.urlsafe_b64decode(payload))
uid = (
data.get("sub")
or data.get("userId")
or data.get("user_id")
or data.get("nameid")
or data.get(
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/nameidentifier"
)
)
if uid is None:
return None
return str(uid)
except Exception:
return None
def get_inference():
global inference
if inference is None:
from inference import Inference
inference = Inference(loader_client)
return inference
class DetectionDto(BaseModel):
centerX: float
centerY: float
width: float
height: float
classNum: int
label: str
confidence: float
class DetectionEvent(BaseModel):
annotations: list[DetectionDto]
mediaId: str
mediaStatus: str
mediaPercent: int
class HealthResponse(BaseModel):
status: str
aiAvailability: str
engineType: Optional[str] = None
errorMessage: Optional[str] = None
class AIConfigDto(BaseModel):
frame_period_recognition: int = 4
frame_recognition_seconds: int = 2
probability_threshold: float = 0.25
tracking_distance_confidence: float = 0.0
tracking_probability_increase: float = 0.0
tracking_intersection_threshold: float = 0.6
model_batch_size: int = 8
big_image_tile_overlap_percent: int = 20
altitude: float = 400
focal_length: float = 24
sensor_width: float = 23.5
_AI_SETTINGS_FIELD_KEYS = (
(
"frame_period_recognition",
("frame_period_recognition", "framePeriodRecognition", "FramePeriodRecognition"),
),
(
"frame_recognition_seconds",
("frame_recognition_seconds", "frameRecognitionSeconds", "FrameRecognitionSeconds"),
),
(
"probability_threshold",
("probability_threshold", "probabilityThreshold", "ProbabilityThreshold"),
),
(
"tracking_distance_confidence",
(
"tracking_distance_confidence",
"trackingDistanceConfidence",
"TrackingDistanceConfidence",
),
),
(
"tracking_probability_increase",
(
"tracking_probability_increase",
"trackingProbabilityIncrease",
"TrackingProbabilityIncrease",
),
),
(
"tracking_intersection_threshold",
(
"tracking_intersection_threshold",
"trackingIntersectionThreshold",
"TrackingIntersectionThreshold",
),
),
(
"model_batch_size",
("model_batch_size", "modelBatchSize", "ModelBatchSize"),
),
(
"big_image_tile_overlap_percent",
(
"big_image_tile_overlap_percent",
"bigImageTileOverlapPercent",
"BigImageTileOverlapPercent",
),
),
(
"altitude",
("altitude", "Altitude"),
),
(
"focal_length",
("focal_length", "focalLength", "FocalLength"),
),
(
"sensor_width",
("sensor_width", "sensorWidth", "SensorWidth"),
),
)
def _merged_annotation_settings_payload(raw: object) -> dict:
if not raw or not isinstance(raw, dict):
return {}
merged = dict(raw)
inner = raw.get("aiRecognitionSettings")
if isinstance(inner, dict):
merged.update(inner)
cam = raw.get("cameraSettings")
if isinstance(cam, dict):
merged.update(cam)
out = {}
for snake, aliases in _AI_SETTINGS_FIELD_KEYS:
for key in aliases:
if key in merged and merged[key] is not None:
out[snake] = merged[key]
break
return out
def _normalize_upload_ext(filename: str) -> str:
s = Path(filename or "").suffix.lower()
return s if s else ""
def _detect_upload_kind(filename: str, data: bytes) -> tuple[str, str]:
ext = _normalize_upload_ext(filename)
if ext in _VIDEO_EXTENSIONS:
return "video", ext
if ext in _IMAGE_EXTENSIONS:
return "image", ext
arr = np.frombuffer(data, dtype=np.uint8)
if cv2.imdecode(arr, cv2.IMREAD_COLOR) is not None:
return "image", ext if ext else ".jpg"
try:
bio = io.BytesIO(data)
with av.open(bio):
pass
return "video", ext if ext else ".mp4"
except Exception:
raise HTTPException(status_code=400, detail="Invalid image or video data")
def _is_video_media_path(media_path: str) -> bool:
return Path(media_path).suffix.lower() in _VIDEO_EXTENSIONS
def _post_media_record(payload: dict, bearer: str) -> bool:
try:
headers = {"Authorization": f"Bearer {bearer}"}
r = http_requests.post(
f"{ANNOTATIONS_URL}/api/media",
json=payload,
headers=headers,
timeout=30,
)
return r.status_code in (200, 201)
except Exception:
return False
def _put_media_status(media_id: str, media_status: int, bearer: str) -> bool:
try:
headers = {"Authorization": f"Bearer {bearer}"}
r = http_requests.put(
f"{ANNOTATIONS_URL}/api/media/{media_id}/status",
json={"mediaStatus": media_status},
headers=headers,
timeout=30,
)
return r.status_code in (200, 204)
except Exception:
return False
def _resolve_media_for_detect(
media_id: str,
token_mgr: Optional[TokenManager],
override: Optional[AIConfigDto],
) -> tuple[dict, str]:
cfg: dict = {}
bearer = ""
if token_mgr:
bearer = token_mgr.get_valid_token()
uid = TokenManager.decode_user_id(token_mgr.access_token)
if uid:
raw = annotations_client.fetch_user_ai_settings(uid, bearer)
cfg.update(_merged_annotation_settings_payload(raw))
if override is not None:
for k, v in override.model_dump(exclude_defaults=True).items():
cfg[k] = v
media_path = annotations_client.fetch_media_path(media_id, bearer)
if not media_path:
raise HTTPException(
status_code=503,
detail="Could not resolve media path from annotations service",
)
return cfg, media_path
def detection_to_dto(det) -> DetectionDto:
import constants_inf
label = constants_inf.get_annotation_name(det.cls)
return DetectionDto(
centerX=det.x,
centerY=det.y,
width=det.w,
height=det.h,
classNum=det.cls,
label=label,
confidence=det.confidence,
)
@app.get("/health")
def health() -> HealthResponse:
if inference is None:
return HealthResponse(status="healthy", aiAvailability="None")
try:
status = inference.ai_availability_status
status_str = str(status).split()[0] if str(status).strip() else "None"
error_msg = status.error_message if hasattr(status, 'error_message') else None
engine_type = inference.engine_name
return HealthResponse(
status="healthy",
aiAvailability=status_str,
engineType=engine_type,
errorMessage=error_msg,
)
except Exception as e:
return HealthResponse(
status="healthy",
aiAvailability="None",
errorMessage=str(e),
)
@app.post("/detect")
async def detect_image(
request: Request,
file: UploadFile = File(...),
config: Optional[str] = Form(None),
):
from media_hash import compute_media_content_hash
from inference import ai_config_from_dict
image_bytes = await file.read()
if not image_bytes:
raise HTTPException(status_code=400, detail="Image is empty")
orig_name = file.filename or "upload"
kind, ext = _detect_upload_kind(orig_name, image_bytes)
if kind == "image":
arr = np.frombuffer(image_bytes, dtype=np.uint8)
if cv2.imdecode(arr, cv2.IMREAD_COLOR) is None:
raise HTTPException(status_code=400, detail="Invalid image data")
config_dict = {}
if config:
config_dict = json.loads(config)
auth_header = request.headers.get("authorization", "")
access_token = auth_header.removeprefix("Bearer ").strip() if auth_header else ""
refresh_token = request.headers.get("x-refresh-token", "")
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
user_id = TokenManager.decode_user_id(access_token) if access_token else None
videos_dir = os.environ.get(
"VIDEOS_DIR", os.path.join(os.getcwd(), "data", "videos")
)
images_dir = os.environ.get(
"IMAGES_DIR", os.path.join(os.getcwd(), "data", "images")
)
storage_path = None
content_hash = None
if token_mgr and user_id:
content_hash = compute_media_content_hash(image_bytes)
base = videos_dir if kind == "video" else images_dir
os.makedirs(base, exist_ok=True)
if not ext.startswith("."):
ext = "." + ext
storage_path = os.path.abspath(os.path.join(base, f"{content_hash}{ext}"))
if kind == "image":
with open(storage_path, "wb") as out:
out.write(image_bytes)
mt = "Video" if kind == "video" else "Image"
payload = {
"id": content_hash,
"name": Path(orig_name).name,
"path": storage_path,
"mediaType": mt,
"mediaStatus": _MEDIA_STATUS_NEW,
"userId": user_id,
}
bearer = token_mgr.get_valid_token()
_post_media_record(payload, bearer)
_put_media_status(content_hash, _MEDIA_STATUS_AI_PROCESSING, bearer)
media_name = Path(orig_name).stem.replace(" ", "")
loop = asyncio.get_event_loop()
inf = get_inference()
results = []
tmp_video_path = None
def on_annotation(annotation, percent):
results.extend(annotation.detections)
ai_cfg = ai_config_from_dict(config_dict)
def run_upload():
nonlocal tmp_video_path
if kind == "video":
if storage_path:
save = storage_path
else:
suf = ext if ext.startswith(".") else ".mp4"
fd, tmp_video_path = tempfile.mkstemp(suffix=suf)
os.close(fd)
save = tmp_video_path
inf.run_detect_video(image_bytes, ai_cfg, media_name, save, on_annotation)
else:
inf.run_detect_image(image_bytes, ai_cfg, media_name, on_annotation)
try:
await loop.run_in_executor(executor, run_upload)
if token_mgr and user_id and content_hash:
_put_media_status(
content_hash, _MEDIA_STATUS_AI_PROCESSED, token_mgr.get_valid_token()
)
return [detection_to_dto(d) for d in results]
except RuntimeError as e:
if token_mgr and user_id and content_hash:
_put_media_status(
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
)
if "not available" in str(e):
raise HTTPException(status_code=503, detail=str(e))
raise HTTPException(status_code=422, detail=str(e))
except ValueError as e:
if token_mgr and user_id and content_hash:
_put_media_status(
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
)
raise HTTPException(status_code=400, detail=str(e))
except Exception:
if token_mgr and user_id and content_hash:
_put_media_status(
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
)
raise
finally:
if tmp_video_path and os.path.isfile(tmp_video_path):
try:
os.unlink(tmp_video_path)
except OSError:
pass
def _post_annotation_to_service(token_mgr: TokenManager, media_id: str,
annotation, dtos: list[DetectionDto]):
try:
token = token_mgr.get_valid_token()
image_b64 = base64.b64encode(annotation.image).decode() if annotation.image else None
payload = {
"mediaId": media_id,
"source": 0,
"videoTime": f"00:00:{annotation.time // 1000:02d}" if annotation.time else "00:00:00",
"detections": [d.model_dump() for d in dtos],
}
if image_b64:
payload["image"] = image_b64
http_requests.post(
f"{ANNOTATIONS_URL}/annotations",
json=payload,
headers={"Authorization": f"Bearer {token}"},
timeout=30,
)
except Exception:
pass
@app.post("/detect/{media_id}")
async def detect_media(
media_id: str,
request: Request,
config: Annotated[Optional[AIConfigDto], Body()] = None,
):
existing = _active_detections.get(media_id)
if existing is not None and not existing.done():
raise HTTPException(status_code=409, detail="Detection already in progress for this media")
auth_header = request.headers.get("authorization", "")
access_token = auth_header.removeprefix("Bearer ").strip() if auth_header else ""
refresh_token = request.headers.get("x-refresh-token", "")
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
config_dict, media_path = _resolve_media_for_detect(media_id, token_mgr, config)
async def run_detection():
loop = asyncio.get_event_loop()
def _enqueue(event):
for q in _event_queues:
try:
q.put_nowait(event)
except asyncio.QueueFull:
pass
try:
from inference import ai_config_from_dict
if token_mgr:
_put_media_status(
media_id,
_MEDIA_STATUS_AI_PROCESSING,
token_mgr.get_valid_token(),
)
with open(media_path, "rb") as mf:
file_bytes = mf.read()
video = _is_video_media_path(media_path)
stem_name = Path(media_path).stem.replace(" ", "")
ai_cfg = ai_config_from_dict(config_dict)
inf = get_inference()
if not inf.is_engine_ready:
raise RuntimeError("Detection service unavailable")
def on_annotation(annotation, percent):
dtos = [detection_to_dto(d) for d in annotation.detections]
event = DetectionEvent(
annotations=dtos,
mediaId=media_id,
mediaStatus="AIProcessing",
mediaPercent=percent,
)
loop.call_soon_threadsafe(_enqueue, event)
if token_mgr and dtos:
_post_annotation_to_service(token_mgr, media_id, annotation, dtos)
def on_status(media_name_cb, count):
event = DetectionEvent(
annotations=[],
mediaId=media_id,
mediaStatus="AIProcessed",
mediaPercent=100,
)
loop.call_soon_threadsafe(_enqueue, event)
if token_mgr:
_put_media_status(
media_id,
_MEDIA_STATUS_AI_PROCESSED,
token_mgr.get_valid_token(),
)
def run_sync():
if video:
inf.run_detect_video(
file_bytes,
ai_cfg,
stem_name,
media_path,
on_annotation,
on_status,
)
else:
inf.run_detect_image(
file_bytes,
ai_cfg,
stem_name,
on_annotation,
on_status,
)
await loop.run_in_executor(executor, run_sync)
except Exception:
if token_mgr:
_put_media_status(
media_id, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
)
error_event = DetectionEvent(
annotations=[],
mediaId=media_id,
mediaStatus="Error",
mediaPercent=0,
)
_enqueue(error_event)
finally:
_active_detections.pop(media_id, None)
_active_detections[media_id] = asyncio.create_task(run_detection())
return {"status": "started", "mediaId": media_id}
@app.get("/detect/stream")
async def detect_stream():
queue: asyncio.Queue = asyncio.Queue(maxsize=100)
_event_queues.append(queue)
async def event_generator():
try:
while True:
event = await queue.get()
yield f"data: {event.model_dump_json()}\n\n"
except asyncio.CancelledError:
pass
finally:
_event_queues.remove(queue)
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)