Files
detections/src/main.py
T
Roman Meshko 911da5cb1c
ci/woodpecker/push/build-arm Pipeline was successful
ci/woodpecker/manual/build-arm Pipeline was successful
Update file with test results (#2)
* Skip GSD and size filtering without altitude

* Update files

* Skip GSD and size filtering without altitude
2026-04-23 21:01:25 +03:00

795 lines
26 KiB
Python

import asyncio
import base64
import io
import json
import os
import tempfile
import time
from collections import deque
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Annotated, Optional
import av
import cv2
import jwt as pyjwt
import numpy as np
import requests as http_requests
from loguru import logger
from fastapi import Body, Depends, FastAPI, File, Form, HTTPException, Request, UploadFile
from fastapi.responses import Response, StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel
from loader_http_client import LoaderHttpClient, LoadResult
app = FastAPI(title="Azaion.Detections")
executor = ThreadPoolExecutor(max_workers=2)
LOADER_URL = os.environ.get("LOADER_URL", "http://loader:8080")
ANNOTATIONS_URL = os.environ.get("ANNOTATIONS_URL", "http://annotations:8080")
JWT_SECRET = os.environ.get("JWT_SECRET", "")
ADMIN_API_URL = os.environ.get("ADMIN_API_URL", "")
_MEDIA_STATUS_NEW = 1
_MEDIA_STATUS_AI_PROCESSING = 2
_MEDIA_STATUS_AI_PROCESSED = 3
_MEDIA_STATUS_ERROR = 6
_VIDEO_EXTENSIONS = frozenset({".mp4", ".mov", ".webm", ".mkv", ".avi", ".m4v"})
_IMAGE_EXTENSIONS = frozenset({".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif", ".tiff"})
_BUFFER_TTL_MS = 10_000
_BUFFER_MAX = 200
loader_client = LoaderHttpClient(LOADER_URL)
annotations_client = LoaderHttpClient(ANNOTATIONS_URL)
inference = None
_job_queues: dict[str, list[asyncio.Queue]] = {}
_channel_buffers: dict[str, deque] = {}
_active_detections: dict[str, asyncio.Task] = {}
_bearer = HTTPBearer(auto_error=False)
async def require_auth(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(_bearer),
) -> str:
if not JWT_SECRET:
return ""
if not credentials:
raise HTTPException(status_code=401, detail="Authentication required")
try:
payload = pyjwt.decode(credentials.credentials, JWT_SECRET, algorithms=["HS256"])
except pyjwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token expired")
except pyjwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token")
return str(payload.get("sub") or payload.get("userId") or "")
class TokenManager:
def __init__(self, access_token: str, refresh_token: str):
self.access_token = access_token
self.refresh_token = refresh_token
def get_valid_token(self) -> str:
exp = self._decode_claims(self.access_token).get("exp")
if exp and float(exp) - time.time() < 60:
self._refresh()
return self.access_token
def _refresh(self):
if not ADMIN_API_URL:
return
try:
resp = http_requests.post(
f"{ADMIN_API_URL}/auth/refresh",
json={"refreshToken": self.refresh_token},
timeout=10,
)
if resp.status_code == 200:
self.access_token = resp.json()["token"]
except Exception:
pass
@staticmethod
def _decode_claims(token: str) -> dict:
try:
if JWT_SECRET:
return pyjwt.decode(token, JWT_SECRET, algorithms=["HS256"])
payload = token.split(".")[1]
padding = 4 - len(payload) % 4
if padding != 4:
payload += "=" * padding
return json.loads(base64.urlsafe_b64decode(payload))
except Exception:
return {}
@staticmethod
def decode_user_id(token: str) -> Optional[str]:
data = TokenManager._decode_claims(token)
uid = (
data.get("sub")
or data.get("userId")
or data.get("user_id")
or data.get("nameid")
or data.get(
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/nameidentifier"
)
)
if uid is None:
return None
return str(uid)
def get_inference():
global inference
if inference is None:
from inference import Inference
inference = Inference(loader_client)
return inference
class DetectionDto(BaseModel):
centerX: float
centerY: float
width: float
height: float
classNum: int
label: str
confidence: float
class DetectionEvent(BaseModel):
annotations: list[DetectionDto]
mediaId: str
mediaStatus: str
mediaPercent: int
class HealthResponse(BaseModel):
status: str
aiAvailability: str
engineType: Optional[str] = None
errorMessage: Optional[str] = None
class AIConfigDto(BaseModel):
frame_period_recognition: int = 4
frame_recognition_seconds: int = 2
probability_threshold: float = 0.25
tracking_distance_confidence: float = 0.0
tracking_probability_increase: float = 0.0
tracking_intersection_threshold: float = 0.6
model_batch_size: int = 8
big_image_tile_overlap_percent: int = 20
altitude: Optional[float] = None
focal_length: float = 24
sensor_width: float = 23.5
_AI_SETTINGS_FIELD_KEYS = (
(
"frame_period_recognition",
("frame_period_recognition", "framePeriodRecognition", "FramePeriodRecognition"),
),
(
"frame_recognition_seconds",
("frame_recognition_seconds", "frameRecognitionSeconds", "FrameRecognitionSeconds"),
),
(
"probability_threshold",
("probability_threshold", "probabilityThreshold", "ProbabilityThreshold"),
),
(
"tracking_distance_confidence",
(
"tracking_distance_confidence",
"trackingDistanceConfidence",
"TrackingDistanceConfidence",
),
),
(
"tracking_probability_increase",
(
"tracking_probability_increase",
"trackingProbabilityIncrease",
"TrackingProbabilityIncrease",
),
),
(
"tracking_intersection_threshold",
(
"tracking_intersection_threshold",
"trackingIntersectionThreshold",
"TrackingIntersectionThreshold",
),
),
(
"model_batch_size",
("model_batch_size", "modelBatchSize", "ModelBatchSize"),
),
(
"big_image_tile_overlap_percent",
(
"big_image_tile_overlap_percent",
"bigImageTileOverlapPercent",
"BigImageTileOverlapPercent",
),
),
(
"altitude",
("altitude", "Altitude"),
),
(
"focal_length",
("focal_length", "focalLength", "FocalLength"),
),
(
"sensor_width",
("sensor_width", "sensorWidth", "SensorWidth"),
),
)
def _merged_annotation_settings_payload(raw: object) -> dict:
if not raw or not isinstance(raw, dict):
return {}
merged = dict(raw)
inner = raw.get("aiRecognitionSettings")
if isinstance(inner, dict):
merged.update(inner)
cam = raw.get("cameraSettings")
if isinstance(cam, dict):
merged.update(cam)
out = {}
for snake, aliases in _AI_SETTINGS_FIELD_KEYS:
for key in aliases:
if key in merged and merged[key] is not None:
out[snake] = merged[key]
break
return out
def _normalize_upload_ext(filename: str) -> str:
s = Path(filename or "").suffix.lower()
return s if s else ""
def _is_video_media_path(media_path: str) -> bool:
return Path(media_path).suffix.lower() in _VIDEO_EXTENSIONS
def _post_media_record(payload: dict, bearer: str) -> bool:
try:
headers = {"Authorization": f"Bearer {bearer}"}
r = http_requests.post(
f"{ANNOTATIONS_URL}/api/media",
json=payload,
headers=headers,
timeout=30,
)
return r.status_code in (200, 201)
except Exception as exc:
logger.warning(f"Failed to create media record in annotations service: {exc}")
return False
def _put_media_status(media_id: str, media_status: int, bearer: str) -> bool:
try:
headers = {"Authorization": f"Bearer {bearer}"}
r = http_requests.put(
f"{ANNOTATIONS_URL}/api/media/{media_id}/status",
json={"mediaStatus": media_status},
headers=headers,
timeout=30,
)
return r.status_code in (200, 204)
except Exception as exc:
logger.warning(f"Failed to update media status in annotations service for {media_id}: {exc}")
return False
def _resolve_media_for_detect(
media_id: str,
token_mgr: Optional[TokenManager],
override: Optional[AIConfigDto],
) -> tuple[dict, str]:
cfg: dict = {}
bearer = ""
if token_mgr:
bearer = token_mgr.get_valid_token()
uid = TokenManager.decode_user_id(token_mgr.access_token)
if uid:
raw = annotations_client.fetch_user_ai_settings(uid, bearer)
cfg.update(_merged_annotation_settings_payload(raw))
if override is not None:
for k, v in override.model_dump(exclude_defaults=True).items():
cfg[k] = v
media_path = annotations_client.fetch_media_path(media_id, bearer)
if not media_path:
raise HTTPException(
status_code=503,
detail="Could not resolve media path from annotations service",
)
return cfg, media_path
def detection_to_dto(det) -> DetectionDto:
import constants_inf
label = constants_inf.get_annotation_name(det.cls)
return DetectionDto(
centerX=det.x,
centerY=det.y,
width=det.w,
height=det.h,
classNum=det.cls,
label=label,
confidence=det.confidence,
)
def _post_annotation_to_service(token_mgr: TokenManager, media_id: str,
annotation, dtos: list[DetectionDto]):
try:
token = token_mgr.get_valid_token()
image_b64 = base64.b64encode(annotation.image).decode() if annotation.image else None
total_seconds = int(annotation.time // 1000) if annotation.time else 0
hours, remainder = divmod(total_seconds, 3600)
minutes, seconds = divmod(remainder, 60)
payload = {
"mediaId": media_id,
"source": 0,
"videoTime": f"{hours:02d}:{minutes:02d}:{seconds:02d}",
"detections": [d.model_dump() for d in dtos],
}
if image_b64:
payload["image"] = image_b64
http_requests.post(
f"{ANNOTATIONS_URL}/annotations",
json=payload,
headers={"Authorization": f"Bearer {token}"},
timeout=30,
)
except Exception as exc:
logger.warning(
f"Failed to post annotation to annotations service for media {media_id}: {exc}"
)
def _cleanup_channel(channel_id: str):
_channel_buffers.pop(channel_id, None)
def _enqueue(channel_id: str, event: DetectionEvent):
now_ms = int(time.time() * 1000)
data = event.model_dump_json()
buf = _channel_buffers.setdefault(channel_id, deque(maxlen=_BUFFER_MAX))
buf.append((now_ms, data))
cutoff = now_ms - _BUFFER_TTL_MS
while buf and buf[0][0] < cutoff:
buf.popleft()
for q in _job_queues.get(channel_id, []):
try:
q.put_nowait((now_ms, data))
except asyncio.QueueFull:
pass
@app.get("/health")
def health() -> HealthResponse:
if inference is None:
return HealthResponse(status="healthy", aiAvailability="None")
try:
status = inference.ai_availability_status
status_str = str(status).split()[0] if str(status).strip() else "None"
error_msg = status.error_message if hasattr(status, 'error_message') else None
engine_type = inference.engine_name
return HealthResponse(
status="healthy",
aiAvailability=status_str,
engineType=engine_type,
errorMessage=error_msg,
)
except Exception as e:
return HealthResponse(
status="healthy",
aiAvailability="None",
errorMessage=str(e),
)
@app.get("/detect/events/{channel_id}", dependencies=[Depends(require_auth)])
async def detect_events(channel_id: str, request: Request, after_ts: Optional[int] = None):
queue: asyncio.Queue = asyncio.Queue(maxsize=200)
_job_queues.setdefault(channel_id, []).append(queue)
async def event_generator():
try:
if after_ts is not None:
for ts_ms, data in list(_channel_buffers.get(channel_id, [])):
if ts_ms > after_ts:
yield f"id: {ts_ms}\ndata: {data}\n\n"
while True:
ts_ms, data = await queue.get()
yield f"id: {ts_ms}\ndata: {data}\n\n"
except asyncio.CancelledError:
pass
finally:
queues = _job_queues.get(channel_id, [])
if queue in queues:
queues.remove(queue)
if not queues:
_job_queues.pop(channel_id, None)
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
@app.post("/detect/image")
async def detect_image(
request: Request,
file: UploadFile = File(...),
config: Optional[str] = Form(None),
user_id: str = Depends(require_auth),
):
from media_hash import compute_media_content_hash
from inference import ai_config_from_dict
image_bytes = await file.read()
if not image_bytes:
raise HTTPException(status_code=400, detail="Image is empty")
orig_name = file.filename or "upload"
ext = _normalize_upload_ext(orig_name)
if ext and ext not in _IMAGE_EXTENSIONS:
raise HTTPException(status_code=400, detail="Expected an image file")
arr = np.frombuffer(image_bytes, dtype=np.uint8)
if cv2.imdecode(arr, cv2.IMREAD_COLOR) is None:
raise HTTPException(status_code=400, detail="Invalid image data")
channel_id = request.headers.get("x-channel-id", "")
if not channel_id:
raise HTTPException(status_code=400, detail="X-Channel-Id header required")
config_dict = {}
if config:
config_dict = json.loads(config)
refresh_token = request.headers.get("x-refresh-token", "")
access_token = request.headers.get("authorization", "").removeprefix("Bearer ").strip()
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
images_dir = os.environ.get(
"IMAGES_DIR", os.path.join(os.getcwd(), "data", "images")
)
content_hash = None
if token_mgr and user_id:
content_hash = compute_media_content_hash(image_bytes)
os.makedirs(images_dir, exist_ok=True)
save_ext = ext if ext.startswith(".") else f".{ext}" if ext else ".jpg"
storage_path = os.path.abspath(os.path.join(images_dir, f"{content_hash}{save_ext}"))
with open(storage_path, "wb") as out:
out.write(image_bytes)
payload = {
"id": content_hash,
"name": Path(orig_name).name,
"path": storage_path,
"mediaType": "Image",
"mediaStatus": _MEDIA_STATUS_NEW,
"userId": user_id,
}
bearer = token_mgr.get_valid_token()
_post_media_record(payload, bearer)
_put_media_status(content_hash, _MEDIA_STATUS_AI_PROCESSING, bearer)
media_name = Path(orig_name).stem.replace(" ", "")
media_id = content_hash or channel_id
loop = asyncio.get_event_loop()
inf = get_inference()
async def run_detection():
ai_cfg = ai_config_from_dict(config_dict)
def on_annotation(annotation, percent):
dtos = [detection_to_dto(d) for d in annotation.detections]
event = DetectionEvent(
annotations=dtos,
mediaId=media_id,
mediaStatus="AIProcessing",
mediaPercent=percent,
)
loop.call_soon_threadsafe(_enqueue, channel_id, event)
if token_mgr and content_hash and dtos:
_post_annotation_to_service(token_mgr, content_hash, annotation, dtos)
def run_sync():
inf.run_detect_image(image_bytes, ai_cfg, media_name, on_annotation)
try:
await loop.run_in_executor(executor, run_sync)
_enqueue(channel_id, DetectionEvent(
annotations=[], mediaId=media_id,
mediaStatus="AIProcessed", mediaPercent=100,
))
if token_mgr and content_hash:
_put_media_status(
content_hash, _MEDIA_STATUS_AI_PROCESSED, token_mgr.get_valid_token()
)
except RuntimeError as e:
if token_mgr and content_hash:
_put_media_status(
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
)
_enqueue(channel_id, DetectionEvent(
annotations=[], mediaId=media_id,
mediaStatus="Error", mediaPercent=0,
))
if "not available" in str(e):
return
raise
except Exception:
if token_mgr and content_hash:
_put_media_status(
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
)
_enqueue(channel_id, DetectionEvent(
annotations=[], mediaId=media_id,
mediaStatus="Error", mediaPercent=0,
))
raise
finally:
loop.call_later(10.0, _cleanup_channel, channel_id)
asyncio.create_task(run_detection())
return Response(status_code=202)
@app.post("/detect/video")
async def detect_video_upload(
request: Request,
user_id: str = Depends(require_auth),
):
from media_hash import compute_media_content_hash_from_file
from inference import ai_config_from_dict
from streaming_buffer import StreamingBuffer
channel_id = request.headers.get("x-channel-id", "")
if not channel_id:
raise HTTPException(status_code=400, detail="X-Channel-Id header required")
filename = request.headers.get("x-filename", "upload.mp4")
config_json = request.headers.get("x-config", "")
ext = _normalize_upload_ext(filename)
if ext not in _VIDEO_EXTENSIONS:
raise HTTPException(status_code=400, detail="Expected a video file extension")
config_dict = json.loads(config_json) if config_json else {}
ai_cfg = ai_config_from_dict(config_dict)
refresh_token = request.headers.get("x-refresh-token", "")
access_token = request.headers.get("authorization", "").removeprefix("Bearer ").strip()
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
videos_dir = os.environ.get(
"VIDEOS_DIR", os.path.join(os.getcwd(), "data", "videos")
)
os.makedirs(videos_dir, exist_ok=True)
content_length = request.headers.get("content-length")
total_size = int(content_length) if content_length else None
buffer = StreamingBuffer(videos_dir, total_size=total_size)
media_name = Path(filename).stem.replace(" ", "")
loop = asyncio.get_event_loop()
inf = get_inference()
current_media_id = [channel_id]
def on_annotation(annotation, percent):
dtos = [detection_to_dto(d) for d in annotation.detections]
mid = current_media_id[0]
event = DetectionEvent(
annotations=dtos,
mediaId=mid,
mediaStatus="AIProcessing",
mediaPercent=percent,
)
loop.call_soon_threadsafe(_enqueue, channel_id, event)
if token_mgr and mid != channel_id and dtos:
_post_annotation_to_service(token_mgr, mid, annotation, dtos)
def run_inference():
inf.run_detect_video_stream(buffer, ai_cfg, media_name, on_annotation, lambda *_: None)
inference_future = loop.run_in_executor(executor, run_inference)
try:
async for chunk in request.stream():
await loop.run_in_executor(None, buffer.append, chunk)
except Exception:
buffer.close_writer()
buffer.close()
raise
buffer.close_writer()
content_hash = compute_media_content_hash_from_file(buffer.path)
if not ext.startswith("."):
ext = "." + ext
storage_path = os.path.abspath(os.path.join(videos_dir, f"{content_hash}{ext}"))
current_media_id[0] = content_hash
_enqueue(channel_id, DetectionEvent(
annotations=[],
mediaId=content_hash,
mediaStatus="Started",
mediaPercent=0,
))
if token_mgr and user_id:
os.rename(buffer.path, storage_path)
payload = {
"id": content_hash,
"name": Path(filename).name,
"path": storage_path,
"mediaType": "Video",
"mediaStatus": _MEDIA_STATUS_NEW,
"userId": user_id,
}
bearer = token_mgr.get_valid_token()
_post_media_record(payload, bearer)
_put_media_status(content_hash, _MEDIA_STATUS_AI_PROCESSING, bearer)
async def _wait_inference():
try:
await inference_future
if token_mgr and user_id:
_put_media_status(
content_hash, _MEDIA_STATUS_AI_PROCESSED,
token_mgr.get_valid_token(),
)
await asyncio.sleep(0.01)
_enqueue(channel_id, DetectionEvent(
annotations=[],
mediaId=content_hash,
mediaStatus="AIProcessed",
mediaPercent=100,
))
except Exception:
if token_mgr and user_id:
_put_media_status(
content_hash, _MEDIA_STATUS_ERROR,
token_mgr.get_valid_token(),
)
_enqueue(channel_id, DetectionEvent(
annotations=[], mediaId=content_hash,
mediaStatus="Error", mediaPercent=0,
))
finally:
loop.call_later(10.0, _cleanup_channel, channel_id)
buffer.close()
if not (token_mgr and user_id) and os.path.isfile(buffer.path):
try:
os.unlink(buffer.path)
except OSError:
pass
asyncio.create_task(_wait_inference())
return Response(status_code=202)
@app.post("/detect/{media_id}")
async def detect_media(
media_id: str,
request: Request,
config: Annotated[Optional[AIConfigDto], Body()] = None,
user_id: str = Depends(require_auth),
):
if media_id in _active_detections:
raise HTTPException(status_code=409, detail="Detection already in progress for this media")
channel_id = request.headers.get("x-channel-id", "")
if not channel_id:
raise HTTPException(status_code=400, detail="X-Channel-Id header required")
refresh_token = request.headers.get("x-refresh-token", "")
access_token = request.headers.get("authorization", "").removeprefix("Bearer ").strip()
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
config_dict, media_path = _resolve_media_for_detect(media_id, token_mgr, config)
async def run_detection():
loop = asyncio.get_event_loop()
try:
from inference import ai_config_from_dict
if token_mgr:
_put_media_status(
media_id,
_MEDIA_STATUS_AI_PROCESSING,
token_mgr.get_valid_token(),
)
with open(media_path, "rb") as mf:
file_bytes = mf.read()
video = _is_video_media_path(media_path)
stem_name = Path(media_path).stem.replace(" ", "")
ai_cfg = ai_config_from_dict(config_dict)
inf = get_inference()
if not inf.is_engine_ready:
raise RuntimeError("Detection service unavailable")
def on_annotation(annotation, percent):
dtos = [detection_to_dto(d) for d in annotation.detections]
event = DetectionEvent(
annotations=dtos,
mediaId=media_id,
mediaStatus="AIProcessing",
mediaPercent=percent,
)
loop.call_soon_threadsafe(_enqueue, channel_id, event)
if token_mgr and dtos:
_post_annotation_to_service(token_mgr, media_id, annotation, dtos)
def on_status(media_name_cb, count):
event = DetectionEvent(
annotations=[],
mediaId=media_id,
mediaStatus="AIProcessed",
mediaPercent=100,
)
loop.call_soon_threadsafe(_enqueue, channel_id, event)
if token_mgr:
_put_media_status(
media_id,
_MEDIA_STATUS_AI_PROCESSED,
token_mgr.get_valid_token(),
)
def run_sync():
if video:
inf.run_detect_video(
file_bytes,
ai_cfg,
stem_name,
media_path,
on_annotation,
on_status,
)
else:
inf.run_detect_image(
file_bytes,
ai_cfg,
stem_name,
on_annotation,
on_status,
)
await loop.run_in_executor(executor, run_sync)
except Exception:
if token_mgr:
_put_media_status(
media_id, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
)
error_event = DetectionEvent(
annotations=[],
mediaId=media_id,
mediaStatus="Error",
mediaPercent=0,
)
_enqueue(channel_id, error_event)
finally:
loop.call_later(5.0, lambda: _active_detections.pop(media_id, None))
loop.call_later(10.0, _cleanup_channel, channel_id)
_active_detections[media_id] = asyncio.create_task(run_detection())
return Response(status_code=202)