[AZ-180] Refactor detection event handling and improve SSE support

- Updated the detection image endpoint to require a channel ID for event streaming.
- Introduced a new endpoint for streaming detection events, allowing clients to receive real-time updates.
- Enhanced the internal buffering mechanism for detection events to manage multiple channels.
- Refactored the inference module to support the new event handling structure.

Made-with: Cursor
This commit is contained in:
Oleksandr Bezdieniezhnykh
2026-04-03 02:42:05 +03:00
parent 2c35e59a77
commit 8baa96978b
26 changed files with 819 additions and 413 deletions
+29 -8
View File
@@ -1,6 +1,16 @@
import os
import platform
import sys
from loguru import logger
from engines.engine_factory import (
EngineFactory,
OnnxEngineFactory,
CoreMLEngineFactory,
TensorRTEngineFactory,
JetsonTensorRTEngineFactory,
)
def _check_tensor_gpu_index():
try:
@@ -35,18 +45,29 @@ def _is_apple_silicon():
return False
def _is_jetson():
return (
platform.machine() == "aarch64"
and tensor_gpu_index > -1
and os.path.isfile("/etc/nv_tegra_release")
)
tensor_gpu_index = _check_tensor_gpu_index()
def _select_engine_class():
def _create_engine_factory() -> EngineFactory:
if _is_jetson():
logger.info("Engine factory: JetsonTensorRTEngineFactory")
return JetsonTensorRTEngineFactory()
if tensor_gpu_index > -1:
from engines.tensorrt_engine import TensorRTEngine # pyright: ignore[reportMissingImports]
return TensorRTEngine
logger.info("Engine factory: TensorRTEngineFactory")
return TensorRTEngineFactory()
if _is_apple_silicon():
from engines.coreml_engine import CoreMLEngine
return CoreMLEngine
from engines.onnx_engine import OnnxEngine
return OnnxEngine
logger.info("Engine factory: CoreMLEngineFactory")
return CoreMLEngineFactory()
logger.info("Engine factory: OnnxEngineFactory")
return OnnxEngineFactory()
EngineClass = _select_engine_class()
engine_factory = _create_engine_factory()
-4
View File
@@ -30,10 +30,6 @@ cdef class CoreMLEngine(InferenceEngine):
constants_inf.log(<str>f'CoreML model: {self.img_width}x{self.img_height}')
@staticmethod
def get_engine_filename():
return "azaion_coreml.zip"
@staticmethod
def _extract_from_zip(model_bytes):
tmpdir = tempfile.mkdtemp()
+109
View File
@@ -0,0 +1,109 @@
import os
import tempfile
class EngineFactory:
has_build_step = False
def create(self, model_bytes: bytes):
raise NotImplementedError
def load_engine(self, loader_client, models_dir: str):
filename = self._get_ai_engine_filename()
if filename is None:
return None
try:
res = loader_client.load_big_small_resource(filename, models_dir)
if res.err is None:
return self.create(res.data)
except Exception:
pass
return None
def _get_ai_engine_filename(self) -> str | None:
return None
def get_source_filename(self) -> str | None:
return None
def build_from_source(self, onnx_bytes: bytes, loader_client, models_dir: str):
raise NotImplementedError(f"{type(self).__name__} does not support building from source")
class OnnxEngineFactory(EngineFactory):
def create(self, model_bytes: bytes):
from engines.onnx_engine import OnnxEngine
return OnnxEngine(model_bytes)
def get_source_filename(self) -> str:
import constants_inf
return constants_inf.AI_ONNX_MODEL_FILE
class CoreMLEngineFactory(EngineFactory):
def create(self, model_bytes: bytes):
from engines.coreml_engine import CoreMLEngine
return CoreMLEngine(model_bytes)
def _get_ai_engine_filename(self) -> str:
return "azaion_coreml.zip"
class TensorRTEngineFactory(EngineFactory):
has_build_step = True
def create(self, model_bytes: bytes):
from engines.tensorrt_engine import TensorRTEngine
return TensorRTEngine(model_bytes)
def _get_ai_engine_filename(self) -> str | None:
from engines.tensorrt_engine import TensorRTEngine
return TensorRTEngine.get_engine_filename()
def get_source_filename(self) -> str:
import constants_inf
return constants_inf.AI_ONNX_MODEL_FILE
def build_from_source(self, onnx_bytes: bytes, loader_client, models_dir: str):
from engines.tensorrt_engine import TensorRTEngine
engine_bytes = TensorRTEngine.convert_from_source(onnx_bytes, None)
return engine_bytes, TensorRTEngine.get_engine_filename()
class JetsonTensorRTEngineFactory(TensorRTEngineFactory):
def create(self, model_bytes: bytes):
from engines.jetson_tensorrt_engine import JetsonTensorRTEngine
return JetsonTensorRTEngine(model_bytes)
def _get_ai_engine_filename(self) -> str | None:
from engines.tensorrt_engine import TensorRTEngine
return TensorRTEngine.get_engine_filename("int8")
def build_from_source(self, onnx_bytes: bytes, loader_client, models_dir: str):
from engines.tensorrt_engine import TensorRTEngine
calib_cache_path = self._download_calib_cache(loader_client, models_dir)
try:
engine_bytes = TensorRTEngine.convert_from_source(onnx_bytes, calib_cache_path)
return engine_bytes, TensorRTEngine.get_engine_filename("int8")
finally:
if calib_cache_path is not None:
try:
os.unlink(calib_cache_path)
except Exception:
pass
def _download_calib_cache(self, loader_client, models_dir: str) -> str | None:
import constants_inf
try:
res = loader_client.load_big_small_resource(constants_inf.INT8_CALIB_CACHE_FILE, models_dir)
if res.err is not None:
constants_inf.log(f"INT8 calibration cache not available: {res.err}")
return None
fd, path = tempfile.mkstemp(suffix=".cache")
with os.fdopen(fd, "wb") as f:
f.write(res.data)
constants_inf.log("INT8 calibration cache downloaded")
return path
except Exception as e:
constants_inf.log(f"INT8 calibration cache download failed: {str(e)}")
return None
+5
View File
@@ -0,0 +1,5 @@
from engines.tensorrt_engine cimport TensorRTEngine
cdef class JetsonTensorRTEngine(TensorRTEngine):
pass
+5
View File
@@ -0,0 +1,5 @@
from engines.tensorrt_engine cimport TensorRTEngine
cdef class JetsonTensorRTEngine(TensorRTEngine):
pass
+1 -1
View File
@@ -23,7 +23,7 @@ cdef class OnnxEngine(InferenceEngine):
self.model_inputs = self.session.get_inputs()
self.input_name = self.model_inputs[0].name
self.input_shape = self.model_inputs[0].shape
if self.input_shape[0] not in (-1, None, "N"):
if isinstance(self.input_shape[0], int) and self.input_shape[0] > 0:
self.max_batch_size = self.input_shape[0]
constants_inf.log(f'AI detection model input: {self.model_inputs} {self.input_shape}')
model_meta = self.session.get_modelmeta()
-5
View File
@@ -113,11 +113,6 @@ cdef class TensorRTEngine(InferenceEngine):
except Exception:
return None
@staticmethod
def get_source_filename():
import constants_inf
return constants_inf.AI_ONNX_MODEL_FILE
@staticmethod
def convert_from_source(bytes onnx_model, str calib_cache_path=None):
gpu_mem = TensorRTEngine.get_gpu_memory_bytes(0)
+22 -55
View File
@@ -1,6 +1,4 @@
import io
import os
import tempfile
import threading
import av
@@ -14,7 +12,7 @@ from ai_config cimport AIRecognitionConfig
from engines.inference_engine cimport InferenceEngine
from loader_http_client cimport LoaderHttpClient
from threading import Thread
from engines import EngineClass
from engines import engine_factory
def ai_config_from_dict(dict data):
@@ -76,29 +74,23 @@ cdef class Inference:
raise Exception(res.err)
return <bytes>res.data
cdef convert_and_upload_model(self, bytes source_bytes, str engine_filename, str calib_cache_path):
cdef convert_and_upload_model(self, bytes source_bytes, str models_dir):
try:
self.ai_availability_status.set_status(AIAvailabilityEnum.CONVERTING)
models_dir = constants_inf.MODELS_FOLDER
model_bytes = EngineClass.convert_from_source(source_bytes, calib_cache_path)
engine_bytes, engine_filename = engine_factory.build_from_source(source_bytes, self.loader_client, models_dir)
self.ai_availability_status.set_status(AIAvailabilityEnum.UPLOADING)
res = self.loader_client.upload_big_small_resource(model_bytes, engine_filename, models_dir)
res = self.loader_client.upload_big_small_resource(engine_bytes, engine_filename, models_dir)
if res.err is not None:
self.ai_availability_status.set_status(AIAvailabilityEnum.WARNING, <str>f"Failed to upload converted model: {res.err}")
self._converted_model_bytes = model_bytes
self._converted_model_bytes = engine_bytes
self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED)
except Exception as e:
self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str> str(e))
self._converted_model_bytes = <bytes>None
finally:
self.is_building_engine = <bint>False
if calib_cache_path is not None:
try:
os.unlink(calib_cache_path)
except Exception:
pass
cdef init_ai(self):
constants_inf.log(<str> 'init AI...')
@@ -110,7 +102,7 @@ cdef class Inference:
if self._converted_model_bytes is not None:
try:
self.engine = EngineClass(self._converted_model_bytes)
self.engine = engine_factory.create(self._converted_model_bytes)
self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED)
except Exception as e:
self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str> str(e))
@@ -119,58 +111,33 @@ cdef class Inference:
return
models_dir = constants_inf.MODELS_FOLDER
engine_filename_fp16 = EngineClass.get_engine_filename()
if engine_filename_fp16 is not None:
engine_filename_int8 = EngineClass.get_engine_filename(<str>"int8")
for candidate in [engine_filename_int8, engine_filename_fp16]:
try:
self.ai_availability_status.set_status(AIAvailabilityEnum.DOWNLOADING)
res = self.loader_client.load_big_small_resource(candidate, models_dir)
if res.err is not None:
raise Exception(res.err)
self.engine = EngineClass(res.data)
self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED)
return
except Exception:
pass
self.ai_availability_status.set_status(AIAvailabilityEnum.DOWNLOADING)
engine = engine_factory.load_engine(self.loader_client, models_dir)
if engine is not None:
self.engine = engine
self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED)
return
source_filename = EngineClass.get_source_filename()
if source_filename is None:
self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str>"Pre-built engine not found and no source available")
return
source_filename = engine_factory.get_source_filename()
if source_filename is None:
self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str>"No engine available and no source to build from")
return
source_bytes = self.download_model(source_filename)
if engine_factory.has_build_step:
self.ai_availability_status.set_status(AIAvailabilityEnum.WARNING, <str>"Cached engine not found, converting from source")
source_bytes = self.download_model(source_filename)
calib_cache_path = self._try_download_calib_cache(models_dir)
target_engine_filename = EngineClass.get_engine_filename(<str>"int8") if calib_cache_path is not None else engine_filename_fp16
self.is_building_engine = <bint>True
thread = Thread(target=self.convert_and_upload_model, args=(source_bytes, target_engine_filename, calib_cache_path))
thread = Thread(target=self.convert_and_upload_model, args=(source_bytes, models_dir))
thread.daemon = True
thread.start()
return
else:
self.engine = EngineClass(<bytes>self.download_model(constants_inf.AI_ONNX_MODEL_FILE))
self.engine = engine_factory.create(source_bytes)
self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED)
self.is_building_engine = <bint>False
except Exception as e:
self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str>str(e))
self.is_building_engine = <bint>False
cdef str _try_download_calib_cache(self, str models_dir):
try:
res = self.loader_client.load_big_small_resource(constants_inf.INT8_CALIB_CACHE_FILE, models_dir)
if res.err is not None:
constants_inf.log(<str>f"INT8 calibration cache not available: {res.err}")
return <str>None
fd, path = tempfile.mkstemp(suffix='.cache')
with os.fdopen(fd, 'wb') as f:
f.write(res.data)
constants_inf.log(<str>'INT8 calibration cache downloaded')
return <str>path
except Exception as e:
constants_inf.log(<str>f"INT8 calibration cache download failed: {str(e)}")
return <str>None
cpdef run_detect_image(self, bytes image_bytes, AIRecognitionConfig ai_config, str media_name,
object annotation_callback, object status_callback=None):
cdef list all_frame_data = []
+164 -131
View File
@@ -5,6 +5,7 @@ import json
import os
import tempfile
import time
from collections import deque
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Annotated, Optional
@@ -15,7 +16,7 @@ import jwt as pyjwt
import numpy as np
import requests as http_requests
from fastapi import Body, Depends, FastAPI, File, Form, HTTPException, Request, UploadFile
from fastapi.responses import StreamingResponse
from fastapi.responses import Response, StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel
@@ -37,11 +38,14 @@ _MEDIA_STATUS_ERROR = 6
_VIDEO_EXTENSIONS = frozenset({".mp4", ".mov", ".webm", ".mkv", ".avi", ".m4v"})
_IMAGE_EXTENSIONS = frozenset({".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif", ".tiff"})
_BUFFER_TTL_MS = 10_000
_BUFFER_MAX = 200
loader_client = LoaderHttpClient(LOADER_URL)
annotations_client = LoaderHttpClient(ANNOTATIONS_URL)
inference = None
_job_queues: dict[str, list[asyncio.Queue]] = {}
_job_buffers: dict[str, list[str]] = {}
_channel_buffers: dict[str, deque] = {}
_active_detections: dict[str, asyncio.Task] = {}
_bearer = HTTPBearer(auto_error=False)
@@ -323,21 +327,50 @@ def detection_to_dto(det) -> DetectionDto:
)
def _enqueue(media_id: str, event: DetectionEvent):
def _post_annotation_to_service(token_mgr: TokenManager, media_id: str,
annotation, dtos: list[DetectionDto]):
try:
token = token_mgr.get_valid_token()
image_b64 = base64.b64encode(annotation.image).decode() if annotation.image else None
payload = {
"mediaId": media_id,
"source": 0,
"videoTime": f"00:00:{annotation.time // 1000:02d}" if annotation.time else "00:00:00",
"detections": [d.model_dump() for d in dtos],
}
if image_b64:
payload["image"] = image_b64
http_requests.post(
f"{ANNOTATIONS_URL}/annotations",
json=payload,
headers={"Authorization": f"Bearer {token}"},
timeout=30,
)
except Exception:
pass
def _cleanup_channel(channel_id: str):
_channel_buffers.pop(channel_id, None)
def _enqueue(channel_id: str, event: DetectionEvent):
now_ms = int(time.time() * 1000)
data = event.model_dump_json()
_job_buffers.setdefault(media_id, []).append(data)
for q in _job_queues.get(media_id, []):
buf = _channel_buffers.setdefault(channel_id, deque(maxlen=_BUFFER_MAX))
buf.append((now_ms, data))
cutoff = now_ms - _BUFFER_TTL_MS
while buf and buf[0][0] < cutoff:
buf.popleft()
for q in _job_queues.get(channel_id, []):
try:
q.put_nowait(data)
q.put_nowait((now_ms, data))
except asyncio.QueueFull:
pass
def _schedule_buffer_cleanup(media_id: str, delay: float = 300.0):
loop = asyncio.get_event_loop()
loop.call_later(delay, lambda: _job_buffers.pop(media_id, None))
@app.get("/health")
def health() -> HealthResponse:
if inference is None:
@@ -361,6 +394,36 @@ def health() -> HealthResponse:
)
@app.get("/detect/events/{channel_id}", dependencies=[Depends(require_auth)])
async def detect_events(channel_id: str, request: Request, after_ts: Optional[int] = None):
queue: asyncio.Queue = asyncio.Queue(maxsize=200)
_job_queues.setdefault(channel_id, []).append(queue)
async def event_generator():
try:
if after_ts is not None:
for ts_ms, data in list(_channel_buffers.get(channel_id, [])):
if ts_ms > after_ts:
yield f"id: {ts_ms}\ndata: {data}\n\n"
while True:
ts_ms, data = await queue.get()
yield f"id: {ts_ms}\ndata: {data}\n\n"
except asyncio.CancelledError:
pass
finally:
queues = _job_queues.get(channel_id, [])
if queue in queues:
queues.remove(queue)
if not queues:
_job_queues.pop(channel_id, None)
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
@app.post("/detect/image")
async def detect_image(
request: Request,
@@ -384,6 +447,10 @@ async def detect_image(
if cv2.imdecode(arr, cv2.IMREAD_COLOR) is None:
raise HTTPException(status_code=400, detail="Invalid image data")
channel_id = request.headers.get("x-channel-id", "")
if not channel_id:
raise HTTPException(status_code=400, detail="X-Channel-Id header required")
config_dict = {}
if config:
config_dict = json.loads(config)
@@ -395,7 +462,6 @@ async def detect_image(
images_dir = os.environ.get(
"IMAGES_DIR", os.path.join(os.getcwd(), "data", "images")
)
storage_path = None
content_hash = None
if token_mgr and user_id:
content_hash = compute_media_content_hash(image_bytes)
@@ -417,45 +483,65 @@ async def detect_image(
_put_media_status(content_hash, _MEDIA_STATUS_AI_PROCESSING, bearer)
media_name = Path(orig_name).stem.replace(" ", "")
media_id = content_hash or channel_id
loop = asyncio.get_event_loop()
inf = get_inference()
results = []
def on_annotation(annotation, percent):
results.extend(annotation.detections)
async def run_detection():
ai_cfg = ai_config_from_dict(config_dict)
ai_cfg = ai_config_from_dict(config_dict)
def on_annotation(annotation, percent):
dtos = [detection_to_dto(d) for d in annotation.detections]
event = DetectionEvent(
annotations=dtos,
mediaId=media_id,
mediaStatus="AIProcessing",
mediaPercent=percent,
)
loop.call_soon_threadsafe(_enqueue, channel_id, event)
if token_mgr and content_hash and dtos:
_post_annotation_to_service(token_mgr, content_hash, annotation, dtos)
def run_detect():
inf.run_detect_image(image_bytes, ai_cfg, media_name, on_annotation)
def run_sync():
inf.run_detect_image(image_bytes, ai_cfg, media_name, on_annotation)
try:
await loop.run_in_executor(executor, run_detect)
if token_mgr and user_id and content_hash:
_put_media_status(
content_hash, _MEDIA_STATUS_AI_PROCESSED, token_mgr.get_valid_token()
)
return [detection_to_dto(d) for d in results]
except RuntimeError as e:
if token_mgr and user_id and content_hash:
_put_media_status(
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
)
if "not available" in str(e):
raise HTTPException(status_code=503, detail=str(e))
raise HTTPException(status_code=422, detail=str(e))
except ValueError as e:
if token_mgr and user_id and content_hash:
_put_media_status(
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
)
raise HTTPException(status_code=400, detail=str(e))
except Exception:
if token_mgr and user_id and content_hash:
_put_media_status(
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
)
raise
try:
await loop.run_in_executor(executor, run_sync)
_enqueue(channel_id, DetectionEvent(
annotations=[], mediaId=media_id,
mediaStatus="AIProcessed", mediaPercent=100,
))
if token_mgr and content_hash:
_put_media_status(
content_hash, _MEDIA_STATUS_AI_PROCESSED, token_mgr.get_valid_token()
)
except RuntimeError as e:
if token_mgr and content_hash:
_put_media_status(
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
)
_enqueue(channel_id, DetectionEvent(
annotations=[], mediaId=media_id,
mediaStatus="Error", mediaPercent=0,
))
if "not available" in str(e):
return
raise
except Exception:
if token_mgr and content_hash:
_put_media_status(
content_hash, _MEDIA_STATUS_ERROR, token_mgr.get_valid_token()
)
_enqueue(channel_id, DetectionEvent(
annotations=[], mediaId=media_id,
mediaStatus="Error", mediaPercent=0,
))
raise
finally:
loop.call_later(10.0, _cleanup_channel, channel_id)
asyncio.create_task(run_detection())
return Response(status_code=202)
@app.post("/detect/video")
@@ -467,6 +553,10 @@ async def detect_video_upload(
from inference import ai_config_from_dict
from streaming_buffer import StreamingBuffer
channel_id = request.headers.get("x-channel-id", "")
if not channel_id:
raise HTTPException(status_code=400, detail="X-Channel-Id header required")
filename = request.headers.get("x-filename", "upload.mp4")
config_json = request.headers.get("x-config", "")
ext = _normalize_upload_ext(filename)
@@ -491,32 +581,23 @@ async def detect_video_upload(
loop = asyncio.get_event_loop()
inf = get_inference()
placeholder_id = f"tmp_{os.path.basename(buffer.path)}"
current_id = [placeholder_id] # mutable — updated to content_hash after upload
current_media_id = [channel_id]
def on_annotation(annotation, percent):
dtos = [detection_to_dto(d) for d in annotation.detections]
mid = current_id[0]
mid = current_media_id[0]
event = DetectionEvent(
annotations=dtos,
mediaId=mid,
mediaStatus="AIProcessing",
mediaPercent=percent,
)
loop.call_soon_threadsafe(_enqueue, mid, event)
def on_status(media_name_cb, count):
mid = current_id[0]
event = DetectionEvent(
annotations=[],
mediaId=mid,
mediaStatus="AIProcessed",
mediaPercent=100,
)
loop.call_soon_threadsafe(_enqueue, mid, event)
loop.call_soon_threadsafe(_enqueue, channel_id, event)
if token_mgr and mid != channel_id and dtos:
_post_annotation_to_service(token_mgr, mid, annotation, dtos)
def run_inference():
inf.run_detect_video_stream(buffer, ai_cfg, media_name, on_annotation, on_status)
inf.run_detect_video_stream(buffer, ai_cfg, media_name, on_annotation, lambda *_: None)
inference_future = loop.run_in_executor(executor, run_inference)
@@ -533,14 +614,14 @@ async def detect_video_upload(
if not ext.startswith("."):
ext = "." + ext
storage_path = os.path.abspath(os.path.join(videos_dir, f"{content_hash}{ext}"))
current_media_id[0] = content_hash
# Re-key buffered events from placeholder_id to content_hash so clients
# can subscribe to GET /detect/{content_hash} after POST returns.
if placeholder_id in _job_buffers:
_job_buffers[content_hash] = _job_buffers.pop(placeholder_id)
if placeholder_id in _job_queues:
_job_queues[content_hash] = _job_queues.pop(placeholder_id)
current_id[0] = content_hash # future on_annotation/on_status callbacks use content_hash
_enqueue(channel_id, DetectionEvent(
annotations=[],
mediaId=content_hash,
mediaStatus="Started",
mediaPercent=0,
))
if token_mgr and user_id:
os.rename(buffer.path, storage_path)
@@ -564,27 +645,24 @@ async def detect_video_upload(
content_hash, _MEDIA_STATUS_AI_PROCESSED,
token_mgr.get_valid_token(),
)
done_event = DetectionEvent(
_enqueue(channel_id, DetectionEvent(
annotations=[],
mediaId=content_hash,
mediaStatus="AIProcessed",
mediaPercent=100,
)
_enqueue(content_hash, done_event)
))
except Exception:
if token_mgr and user_id:
_put_media_status(
content_hash, _MEDIA_STATUS_ERROR,
token_mgr.get_valid_token(),
)
err_event = DetectionEvent(
_enqueue(channel_id, DetectionEvent(
annotations=[], mediaId=content_hash,
mediaStatus="Error", mediaPercent=0,
)
_enqueue(content_hash, err_event)
))
finally:
_active_detections.pop(content_hash, None)
_schedule_buffer_cleanup(content_hash)
loop.call_later(10.0, _cleanup_channel, channel_id)
buffer.close()
if not (token_mgr and user_id) and os.path.isfile(buffer.path):
try:
@@ -592,31 +670,8 @@ async def detect_video_upload(
except OSError:
pass
_active_detections[content_hash] = asyncio.create_task(_wait_inference())
return {"status": "started", "mediaId": content_hash}
def _post_annotation_to_service(token_mgr: TokenManager, media_id: str,
annotation, dtos: list[DetectionDto]):
try:
token = token_mgr.get_valid_token()
image_b64 = base64.b64encode(annotation.image).decode() if annotation.image else None
payload = {
"mediaId": media_id,
"source": 0,
"videoTime": f"00:00:{annotation.time // 1000:02d}" if annotation.time else "00:00:00",
"detections": [d.model_dump() for d in dtos],
}
if image_b64:
payload["image"] = image_b64
http_requests.post(
f"{ANNOTATIONS_URL}/annotations",
json=payload,
headers={"Authorization": f"Bearer {token}"},
timeout=30,
)
except Exception:
pass
asyncio.create_task(_wait_inference())
return Response(status_code=202)
@app.post("/detect/{media_id}")
@@ -630,6 +685,10 @@ async def detect_media(
if existing is not None and not existing.done():
raise HTTPException(status_code=409, detail="Detection already in progress for this media")
channel_id = request.headers.get("x-channel-id", "")
if not channel_id:
raise HTTPException(status_code=400, detail="X-Channel-Id header required")
refresh_token = request.headers.get("x-refresh-token", "")
access_token = request.headers.get("authorization", "").removeprefix("Bearer ").strip()
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
@@ -668,7 +727,7 @@ async def detect_media(
mediaStatus="AIProcessing",
mediaPercent=percent,
)
loop.call_soon_threadsafe(_enqueue, media_id, event)
loop.call_soon_threadsafe(_enqueue, channel_id, event)
if token_mgr and dtos:
_post_annotation_to_service(token_mgr, media_id, annotation, dtos)
@@ -679,7 +738,7 @@ async def detect_media(
mediaStatus="AIProcessed",
mediaPercent=100,
)
loop.call_soon_threadsafe(_enqueue, media_id, event)
loop.call_soon_threadsafe(_enqueue, channel_id, event)
if token_mgr:
_put_media_status(
media_id,
@@ -718,36 +777,10 @@ async def detect_media(
mediaStatus="Error",
mediaPercent=0,
)
_enqueue(media_id, error_event)
_enqueue(channel_id, error_event)
finally:
_active_detections.pop(media_id, None)
_schedule_buffer_cleanup(media_id)
loop.call_later(10.0, _cleanup_channel, channel_id)
_active_detections[media_id] = asyncio.create_task(run_detection())
return {"status": "started", "mediaId": media_id}
@app.get("/detect/{media_id}", dependencies=[Depends(require_auth)])
async def detect_events(media_id: str):
queue: asyncio.Queue = asyncio.Queue(maxsize=200)
_job_queues.setdefault(media_id, []).append(queue)
async def event_generator():
try:
for data in list(_job_buffers.get(media_id, [])):
yield f"data: {data}\n\n"
while True:
data = await queue.get()
yield f"data: {data}\n\n"
except asyncio.CancelledError:
pass
finally:
queues = _job_queues.get(media_id, [])
if queue in queues:
queues.remove(queue)
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
return Response(status_code=202)