mirror of
https://github.com/azaion/detections.git
synced 2026-04-26 00:46:31 +00:00
[AZ-180] Refactor detection event handling and improve SSE support
- Updated the detection image endpoint to require a channel ID for event streaming. - Introduced a new endpoint for streaming detection events, allowing clients to receive real-time updates. - Enhanced the internal buffering mechanism for detection events to manage multiple channels. - Refactored the inference module to support the new event handling structure. Made-with: Cursor
This commit is contained in:
+164
-131
@@ -5,6 +5,7 @@ import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from collections import deque
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Optional
|
||||
@@ -15,7 +16,7 @@ import jwt as pyjwt
|
||||
import numpy as np
|
||||
import requests as http_requests
|
||||
from fastapi import Body, Depends, FastAPI, File, Form, HTTPException, Request, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -37,11 +38,14 @@ _MEDIA_STATUS_ERROR = 6
|
||||
_VIDEO_EXTENSIONS = frozenset({".mp4", ".mov", ".webm", ".mkv", ".avi", ".m4v"})
|
||||
_IMAGE_EXTENSIONS = frozenset({".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif", ".tiff"})
|
||||
|
||||
_BUFFER_TTL_MS = 10_000
|
||||
_BUFFER_MAX = 200
|
||||
|
||||
loader_client = LoaderHttpClient(LOADER_URL)
|
||||
annotations_client = LoaderHttpClient(ANNOTATIONS_URL)
|
||||
inference = None
|
||||
_job_queues: dict[str, list[asyncio.Queue]] = {}
|
||||
_job_buffers: dict[str, list[str]] = {}
|
||||
_channel_buffers: dict[str, deque] = {}
|
||||
_active_detections: dict[str, asyncio.Task] = {}
|
||||
|
||||
_bearer = HTTPBearer(auto_error=False)
|
||||
@@ -323,21 +327,50 @@ def detection_to_dto(det) -> DetectionDto:
|
||||
)
|
||||
|
||||
|
||||
def _enqueue(media_id: str, event: DetectionEvent):
|
||||
def _post_annotation_to_service(token_mgr: TokenManager, media_id: str,
|
||||
annotation, dtos: list[DetectionDto]):
|
||||
try:
|
||||
token = token_mgr.get_valid_token()
|
||||
image_b64 = base64.b64encode(annotation.image).decode() if annotation.image else None
|
||||
payload = {
|
||||
"mediaId": media_id,
|
||||
"source": 0,
|
||||
"videoTime": f"00:00:{annotation.time // 1000:02d}" if annotation.time else "00:00:00",
|
||||
"detections": [d.model_dump() for d in dtos],
|
||||
}
|
||||
if image_b64:
|
||||
payload["image"] = image_b64
|
||||
http_requests.post(
|
||||
f"{ANNOTATIONS_URL}/annotations",
|
||||
json=payload,
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=30,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _cleanup_channel(channel_id: str):
|
||||
_channel_buffers.pop(channel_id, None)
|
||||
|
||||
|
||||
def _enqueue(channel_id: str, event: DetectionEvent):
|
||||
now_ms = int(time.time() * 1000)
|
||||
data = event.model_dump_json()
|
||||
_job_buffers.setdefault(media_id, []).append(data)
|
||||
for q in _job_queues.get(media_id, []):
|
||||
|
||||
buf = _channel_buffers.setdefault(channel_id, deque(maxlen=_BUFFER_MAX))
|
||||
buf.append((now_ms, data))
|
||||
cutoff = now_ms - _BUFFER_TTL_MS
|
||||
while buf and buf[0][0] < cutoff:
|
||||
buf.popleft()
|
||||
|
||||
for q in _job_queues.get(channel_id, []):
|
||||
try:
|
||||
q.put_nowait(data)
|
||||
q.put_nowait((now_ms, data))
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
|
||||
|
||||
def _schedule_buffer_cleanup(media_id: str, delay: float = 300.0):
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.call_later(delay, lambda: _job_buffers.pop(media_id, None))
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> HealthResponse:
|
||||
if inference is None:
|
||||
@@ -361,6 +394,36 @@ def health() -> HealthResponse:
|
||||
)
|
||||
|
||||
|
||||
@app.get("/detect/events/{channel_id}", dependencies=[Depends(require_auth)])
|
||||
async def detect_events(channel_id: str, request: Request, after_ts: Optional[int] = None):
|
||||
queue: asyncio.Queue = asyncio.Queue(maxsize=200)
|
||||
_job_queues.setdefault(channel_id, []).append(queue)
|
||||
|
||||
async def event_generator():
|
||||
try:
|
||||
if after_ts is not None:
|
||||
for ts_ms, data in list(_channel_buffers.get(channel_id, [])):
|
||||
if ts_ms > after_ts:
|
||||
yield f"id: {ts_ms}\ndata: {data}\n\n"
|
||||
while True:
|
||||
ts_ms, data = await queue.get()
|
||||
yield f"id: {ts_ms}\ndata: {data}\n\n"
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
queues = _job_queues.get(channel_id, [])
|
||||
if queue in queues:
|
||||
queues.remove(queue)
|
||||
if not queues:
|
||||
_job_queues.pop(channel_id, None)
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
@app.post("/detect/image")
|
||||
async def detect_image(
|
||||
request: Request,
|
||||
@@ -384,6 +447,10 @@ async def detect_image(
|
||||
if cv2.imdecode(arr, cv2.IMREAD_COLOR) is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid image data")
|
||||
|
||||
channel_id = request.headers.get("x-channel-id", "")
|
||||
if not channel_id:
|
||||
raise HTTPException(status_code=400, detail="X-Channel-Id header required")
|
||||
|
||||
config_dict = {}
|
||||
if config:
|
||||
config_dict = json.loads(config)
|
||||
@@ -395,7 +462,6 @@ async def detect_image(
|
||||
images_dir = os.environ.get(
|
||||
"IMAGES_DIR", os.path.join(os.getcwd(), "data", "images")
|
||||
)
|
||||
storage_path = None
|
||||
content_hash = None
|
||||
if token_mgr and user_id:
|
||||
content_hash = compute_media_content_hash(image_bytes)
|
||||
@@ -417,45 +483,65 @@ async def detect_image(
|
||||
_put_media_status(content_hash, _MEDIA_STATUS_AI_PROCESSING, bearer)
|
||||
|
||||
media_name = Path(orig_name).stem.replace(" ", "")
|
||||
media_id = content_hash or channel_id
|
||||
loop = asyncio.get_event_loop()
|
||||
inf = get_inference()
|
||||
results = []
|
||||
|
||||
def on_annotation(annotation, percent):
|
||||
results.extend(annotation.detections)
|
||||
async def run_detection():
|
||||
ai_cfg = ai_config_from_dict(config_dict)
|
||||
|
||||
ai_cfg = ai_config_from_dict(config_dict)
|
||||
def on_annotation(annotation, percent):
|
||||
dtos = [detection_to_dto(d) for d in annotation.detections]
|
||||
event = DetectionEvent(
|
||||
annotations=dtos,
|
||||
mediaId=media_id,
|
||||
mediaStatus="AIProcessing",
|
||||
mediaPercent=percent,
|
||||
)
|
||||
loop.call_soon_threadsafe(_enqueue, channel_id, event)
|
||||
if token_mgr and content_hash and dtos:
|
||||
_post_annotation_to_service(token_mgr, content_hash, annotation, dtos)
|
||||
|
||||
def run_detect():
|
||||
inf.run_detect_image(image_bytes, ai_cfg, media_name, on_annotation)
|
||||
def run_sync():
|
||||
inf.run_detect_image(image_bytes, ai_cfg, media_name, on_annotation)
|
||||
|
||||
try:
|
||||
await loop.run_in_executor(executor, run_detect)
|
||||
if token_mgr and user_id and content_hash:
|
||||
_put_media_status(
|
||||
content_hash, _MEDIA_STATUS_AI_PROCESSED, token_mgr.get_valid_token()
|
||||
)
|
||||
return [detection_to_dto(d) for d in results]
|
||||
except RuntimeError as e:
|
||||
if token_mgr and user_id and content_hash:
|
||||
_put_media_status(
|
||||
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
|
||||
)
|
||||
if "not available" in str(e):
|
||||
raise HTTPException(status_code=503, detail=str(e))
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except ValueError as e:
|
||||
if token_mgr and user_id and content_hash:
|
||||
_put_media_status(
|
||||
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception:
|
||||
if token_mgr and user_id and content_hash:
|
||||
_put_media_status(
|
||||
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
|
||||
)
|
||||
raise
|
||||
try:
|
||||
await loop.run_in_executor(executor, run_sync)
|
||||
_enqueue(channel_id, DetectionEvent(
|
||||
annotations=[], mediaId=media_id,
|
||||
mediaStatus="AIProcessed", mediaPercent=100,
|
||||
))
|
||||
if token_mgr and content_hash:
|
||||
_put_media_status(
|
||||
content_hash, _MEDIA_STATUS_AI_PROCESSED, token_mgr.get_valid_token()
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if token_mgr and content_hash:
|
||||
_put_media_status(
|
||||
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
|
||||
)
|
||||
_enqueue(channel_id, DetectionEvent(
|
||||
annotations=[], mediaId=media_id,
|
||||
mediaStatus="Error", mediaPercent=0,
|
||||
))
|
||||
if "not available" in str(e):
|
||||
return
|
||||
raise
|
||||
except Exception:
|
||||
if token_mgr and content_hash:
|
||||
_put_media_status(
|
||||
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
|
||||
)
|
||||
_enqueue(channel_id, DetectionEvent(
|
||||
annotations=[], mediaId=media_id,
|
||||
mediaStatus="Error", mediaPercent=0,
|
||||
))
|
||||
raise
|
||||
finally:
|
||||
loop.call_later(10.0, _cleanup_channel, channel_id)
|
||||
|
||||
asyncio.create_task(run_detection())
|
||||
return Response(status_code=202)
|
||||
|
||||
|
||||
@app.post("/detect/video")
|
||||
@@ -467,6 +553,10 @@ async def detect_video_upload(
|
||||
from inference import ai_config_from_dict
|
||||
from streaming_buffer import StreamingBuffer
|
||||
|
||||
channel_id = request.headers.get("x-channel-id", "")
|
||||
if not channel_id:
|
||||
raise HTTPException(status_code=400, detail="X-Channel-Id header required")
|
||||
|
||||
filename = request.headers.get("x-filename", "upload.mp4")
|
||||
config_json = request.headers.get("x-config", "")
|
||||
ext = _normalize_upload_ext(filename)
|
||||
@@ -491,32 +581,23 @@ async def detect_video_upload(
|
||||
loop = asyncio.get_event_loop()
|
||||
inf = get_inference()
|
||||
|
||||
placeholder_id = f"tmp_{os.path.basename(buffer.path)}"
|
||||
current_id = [placeholder_id] # mutable — updated to content_hash after upload
|
||||
current_media_id = [channel_id]
|
||||
|
||||
def on_annotation(annotation, percent):
|
||||
dtos = [detection_to_dto(d) for d in annotation.detections]
|
||||
mid = current_id[0]
|
||||
mid = current_media_id[0]
|
||||
event = DetectionEvent(
|
||||
annotations=dtos,
|
||||
mediaId=mid,
|
||||
mediaStatus="AIProcessing",
|
||||
mediaPercent=percent,
|
||||
)
|
||||
loop.call_soon_threadsafe(_enqueue, mid, event)
|
||||
|
||||
def on_status(media_name_cb, count):
|
||||
mid = current_id[0]
|
||||
event = DetectionEvent(
|
||||
annotations=[],
|
||||
mediaId=mid,
|
||||
mediaStatus="AIProcessed",
|
||||
mediaPercent=100,
|
||||
)
|
||||
loop.call_soon_threadsafe(_enqueue, mid, event)
|
||||
loop.call_soon_threadsafe(_enqueue, channel_id, event)
|
||||
if token_mgr and mid != channel_id and dtos:
|
||||
_post_annotation_to_service(token_mgr, mid, annotation, dtos)
|
||||
|
||||
def run_inference():
|
||||
inf.run_detect_video_stream(buffer, ai_cfg, media_name, on_annotation, on_status)
|
||||
inf.run_detect_video_stream(buffer, ai_cfg, media_name, on_annotation, lambda *_: None)
|
||||
|
||||
inference_future = loop.run_in_executor(executor, run_inference)
|
||||
|
||||
@@ -533,14 +614,14 @@ async def detect_video_upload(
|
||||
if not ext.startswith("."):
|
||||
ext = "." + ext
|
||||
storage_path = os.path.abspath(os.path.join(videos_dir, f"{content_hash}{ext}"))
|
||||
current_media_id[0] = content_hash
|
||||
|
||||
# Re-key buffered events from placeholder_id to content_hash so clients
|
||||
# can subscribe to GET /detect/{content_hash} after POST returns.
|
||||
if placeholder_id in _job_buffers:
|
||||
_job_buffers[content_hash] = _job_buffers.pop(placeholder_id)
|
||||
if placeholder_id in _job_queues:
|
||||
_job_queues[content_hash] = _job_queues.pop(placeholder_id)
|
||||
current_id[0] = content_hash # future on_annotation/on_status callbacks use content_hash
|
||||
_enqueue(channel_id, DetectionEvent(
|
||||
annotations=[],
|
||||
mediaId=content_hash,
|
||||
mediaStatus="Started",
|
||||
mediaPercent=0,
|
||||
))
|
||||
|
||||
if token_mgr and user_id:
|
||||
os.rename(buffer.path, storage_path)
|
||||
@@ -564,27 +645,24 @@ async def detect_video_upload(
|
||||
content_hash, _MEDIA_STATUS_AI_PROCESSED,
|
||||
token_mgr.get_valid_token(),
|
||||
)
|
||||
done_event = DetectionEvent(
|
||||
_enqueue(channel_id, DetectionEvent(
|
||||
annotations=[],
|
||||
mediaId=content_hash,
|
||||
mediaStatus="AIProcessed",
|
||||
mediaPercent=100,
|
||||
)
|
||||
_enqueue(content_hash, done_event)
|
||||
))
|
||||
except Exception:
|
||||
if token_mgr and user_id:
|
||||
_put_media_status(
|
||||
content_hash, _MEDIA_STATUS_ERROR,
|
||||
token_mgr.get_valid_token(),
|
||||
)
|
||||
err_event = DetectionEvent(
|
||||
_enqueue(channel_id, DetectionEvent(
|
||||
annotations=[], mediaId=content_hash,
|
||||
mediaStatus="Error", mediaPercent=0,
|
||||
)
|
||||
_enqueue(content_hash, err_event)
|
||||
))
|
||||
finally:
|
||||
_active_detections.pop(content_hash, None)
|
||||
_schedule_buffer_cleanup(content_hash)
|
||||
loop.call_later(10.0, _cleanup_channel, channel_id)
|
||||
buffer.close()
|
||||
if not (token_mgr and user_id) and os.path.isfile(buffer.path):
|
||||
try:
|
||||
@@ -592,31 +670,8 @@ async def detect_video_upload(
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
_active_detections[content_hash] = asyncio.create_task(_wait_inference())
|
||||
return {"status": "started", "mediaId": content_hash}
|
||||
|
||||
|
||||
def _post_annotation_to_service(token_mgr: TokenManager, media_id: str,
|
||||
annotation, dtos: list[DetectionDto]):
|
||||
try:
|
||||
token = token_mgr.get_valid_token()
|
||||
image_b64 = base64.b64encode(annotation.image).decode() if annotation.image else None
|
||||
payload = {
|
||||
"mediaId": media_id,
|
||||
"source": 0,
|
||||
"videoTime": f"00:00:{annotation.time // 1000:02d}" if annotation.time else "00:00:00",
|
||||
"detections": [d.model_dump() for d in dtos],
|
||||
}
|
||||
if image_b64:
|
||||
payload["image"] = image_b64
|
||||
http_requests.post(
|
||||
f"{ANNOTATIONS_URL}/annotations",
|
||||
json=payload,
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=30,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
asyncio.create_task(_wait_inference())
|
||||
return Response(status_code=202)
|
||||
|
||||
|
||||
@app.post("/detect/{media_id}")
|
||||
@@ -630,6 +685,10 @@ async def detect_media(
|
||||
if existing is not None and not existing.done():
|
||||
raise HTTPException(status_code=409, detail="Detection already in progress for this media")
|
||||
|
||||
channel_id = request.headers.get("x-channel-id", "")
|
||||
if not channel_id:
|
||||
raise HTTPException(status_code=400, detail="X-Channel-Id header required")
|
||||
|
||||
refresh_token = request.headers.get("x-refresh-token", "")
|
||||
access_token = request.headers.get("authorization", "").removeprefix("Bearer ").strip()
|
||||
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
|
||||
@@ -668,7 +727,7 @@ async def detect_media(
|
||||
mediaStatus="AIProcessing",
|
||||
mediaPercent=percent,
|
||||
)
|
||||
loop.call_soon_threadsafe(_enqueue, media_id, event)
|
||||
loop.call_soon_threadsafe(_enqueue, channel_id, event)
|
||||
if token_mgr and dtos:
|
||||
_post_annotation_to_service(token_mgr, media_id, annotation, dtos)
|
||||
|
||||
@@ -679,7 +738,7 @@ async def detect_media(
|
||||
mediaStatus="AIProcessed",
|
||||
mediaPercent=100,
|
||||
)
|
||||
loop.call_soon_threadsafe(_enqueue, media_id, event)
|
||||
loop.call_soon_threadsafe(_enqueue, channel_id, event)
|
||||
if token_mgr:
|
||||
_put_media_status(
|
||||
media_id,
|
||||
@@ -718,36 +777,10 @@ async def detect_media(
|
||||
mediaStatus="Error",
|
||||
mediaPercent=0,
|
||||
)
|
||||
_enqueue(media_id, error_event)
|
||||
_enqueue(channel_id, error_event)
|
||||
finally:
|
||||
_active_detections.pop(media_id, None)
|
||||
_schedule_buffer_cleanup(media_id)
|
||||
loop.call_later(10.0, _cleanup_channel, channel_id)
|
||||
|
||||
_active_detections[media_id] = asyncio.create_task(run_detection())
|
||||
return {"status": "started", "mediaId": media_id}
|
||||
|
||||
|
||||
@app.get("/detect/{media_id}", dependencies=[Depends(require_auth)])
|
||||
async def detect_events(media_id: str):
|
||||
queue: asyncio.Queue = asyncio.Queue(maxsize=200)
|
||||
_job_queues.setdefault(media_id, []).append(queue)
|
||||
|
||||
async def event_generator():
|
||||
try:
|
||||
for data in list(_job_buffers.get(media_id, [])):
|
||||
yield f"data: {data}\n\n"
|
||||
while True:
|
||||
data = await queue.get()
|
||||
yield f"data: {data}\n\n"
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
queues = _job_queues.get(media_id, [])
|
||||
if queue in queues:
|
||||
queues.remove(queue)
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
return Response(status_code=202)
|
||||
|
||||
Reference in New Issue
Block a user