mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 11:16:31 +00:00
[AZ-173] [AZ-174] Stream-based detection API and DB-driven AI config
Made-with: Cursor
This commit is contained in:
@@ -4,6 +4,6 @@
|
|||||||
flow: existing-code
|
flow: existing-code
|
||||||
step: 9
|
step: 9
|
||||||
name: Implement
|
name: Implement
|
||||||
status: not_started
|
status: in_progress
|
||||||
sub_step: 0
|
sub_step: 0
|
||||||
retry_count: 0
|
retry_count: 0
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
from flask import Flask, request
|
from flask import Flask, request
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
@@ -25,6 +27,35 @@ def auth_refresh():
|
|||||||
return {"token": "refreshed-test-token"}
|
return {"token": "refreshed-test-token"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/api/users/<user_id>/ai-settings", methods=["GET"])
|
||||||
|
def user_ai_settings(user_id):
|
||||||
|
if _fail():
|
||||||
|
return "", 503
|
||||||
|
return {
|
||||||
|
"frame_period_recognition": 4,
|
||||||
|
"frame_recognition_seconds": 2,
|
||||||
|
"probability_threshold": 0.25,
|
||||||
|
"tracking_distance_confidence": 0.1,
|
||||||
|
"tracking_probability_increase": 0.1,
|
||||||
|
"tracking_intersection_threshold": 0.6,
|
||||||
|
"model_batch_size": 8,
|
||||||
|
"big_image_tile_overlap_percent": 20,
|
||||||
|
"altitude": 400,
|
||||||
|
"focal_length": 24,
|
||||||
|
"sensor_width": 23.5,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/api/media/<media_id>", methods=["GET"])
|
||||||
|
def media_path(media_id):
|
||||||
|
if _fail():
|
||||||
|
return "", 503
|
||||||
|
root = os.environ.get("MEDIA_DIR", "/media")
|
||||||
|
if media_id.startswith("sse-") or media_id.startswith("video-"):
|
||||||
|
return {"path": f"{root}/video_test01.mp4"}
|
||||||
|
return {"path": f"{root}/image_small.jpg"}
|
||||||
|
|
||||||
|
|
||||||
@app.route("/mock/config", methods=["POST"])
|
@app.route("/mock/config", methods=["POST"])
|
||||||
def mock_config():
|
def mock_config():
|
||||||
global _mode
|
global _mode
|
||||||
|
|||||||
@@ -3,3 +3,5 @@ pytest-csv
|
|||||||
requests==2.32.4
|
requests==2.32.4
|
||||||
sseclient-py
|
sseclient-py
|
||||||
pytest-timeout
|
pytest-timeout
|
||||||
|
flask
|
||||||
|
gunicorn
|
||||||
|
|||||||
@@ -1,35 +1,17 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
_MEDIA = os.environ.get("MEDIA_DIR", "/media")
|
|
||||||
|
|
||||||
|
|
||||||
def _ai_config_video() -> dict:
|
def _ai_config_video() -> dict:
|
||||||
return {
|
return {}
|
||||||
"probability_threshold": 0.25,
|
|
||||||
"tracking_intersection_threshold": 0.6,
|
|
||||||
"altitude": 400,
|
|
||||||
"focal_length": 24,
|
|
||||||
"sensor_width": 23.5,
|
|
||||||
"paths": [f"{_MEDIA}/video_test01.mp4"],
|
|
||||||
"frame_period_recognition": 4,
|
|
||||||
"frame_recognition_seconds": 2,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _ai_config_image() -> dict:
|
def _ai_config_image() -> dict:
|
||||||
return {
|
return {}
|
||||||
"probability_threshold": 0.25,
|
|
||||||
"altitude": 400,
|
|
||||||
"focal_length": 24,
|
|
||||||
"sensor_width": 23.5,
|
|
||||||
"paths": [f"{_MEDIA}/image_small.jpg"],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_ft_p08_immediate_async_response(
|
def test_ft_p08_immediate_async_response(
|
||||||
|
|||||||
+2
-13
@@ -20,20 +20,9 @@ def _make_jwt() -> str:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def video_events(warm_engine, http_client, video_short_path):
|
def video_events(warm_engine, http_client):
|
||||||
media_id = f"video-{uuid.uuid4().hex}"
|
media_id = f"video-{uuid.uuid4().hex}"
|
||||||
body = {
|
body = {}
|
||||||
"probability_threshold": 0.25,
|
|
||||||
"frame_period_recognition": 4,
|
|
||||||
"frame_recognition_seconds": 2,
|
|
||||||
"tracking_distance_confidence": 0.1,
|
|
||||||
"tracking_probability_increase": 0.1,
|
|
||||||
"tracking_intersection_threshold": 0.6,
|
|
||||||
"altitude": 400.0,
|
|
||||||
"focal_length": 24.0,
|
|
||||||
"sensor_width": 23.5,
|
|
||||||
"paths": [video_short_path],
|
|
||||||
}
|
|
||||||
token = _make_jwt()
|
token = _make_jwt()
|
||||||
|
|
||||||
collected: list[tuple[float, dict]] = []
|
collected: list[tuple[float, dict]] = []
|
||||||
|
|||||||
@@ -8,3 +8,4 @@ pynvml==12.0.0
|
|||||||
requests==2.32.4
|
requests==2.32.4
|
||||||
loguru==0.7.3
|
loguru==0.7.3
|
||||||
python-multipart
|
python-multipart
|
||||||
|
av==14.2.0
|
||||||
|
|||||||
+11
-6
@@ -17,8 +17,10 @@ cleanup() {
|
|||||||
}
|
}
|
||||||
trap cleanup EXIT
|
trap cleanup EXIT
|
||||||
|
|
||||||
|
PY="$(command -v python3 2>/dev/null || command -v python 2>/dev/null || echo python)"
|
||||||
|
|
||||||
echo "Building Cython extensions ..."
|
echo "Building Cython extensions ..."
|
||||||
python setup.py build_ext --inplace
|
"$PY" setup.py build_ext --inplace
|
||||||
|
|
||||||
for port in $LOADER_PORT $ANNOTATIONS_PORT $DETECTIONS_PORT; do
|
for port in $LOADER_PORT $ANNOTATIONS_PORT $DETECTIONS_PORT; do
|
||||||
if lsof -ti :"$port" >/dev/null 2>&1; then
|
if lsof -ti :"$port" >/dev/null 2>&1; then
|
||||||
@@ -29,13 +31,15 @@ for port in $LOADER_PORT $ANNOTATIONS_PORT $DETECTIONS_PORT; do
|
|||||||
done
|
done
|
||||||
|
|
||||||
echo "Starting mock-loader on :$LOADER_PORT ..."
|
echo "Starting mock-loader on :$LOADER_PORT ..."
|
||||||
MODELS_ROOT="$FIXTURES" \
|
cd "$ROOT"
|
||||||
python -m gunicorn --bind "0.0.0.0:$LOADER_PORT" --workers 1 --timeout 120 \
|
MODELS_ROOT="$FIXTURES" PYTHONPATH="$ROOT" \
|
||||||
|
"$PY" -m gunicorn --bind "0.0.0.0:$LOADER_PORT" --workers 1 --timeout 120 \
|
||||||
'e2e.mocks.loader.app:app' >/dev/null 2>&1 &
|
'e2e.mocks.loader.app:app' >/dev/null 2>&1 &
|
||||||
PIDS+=($!)
|
PIDS+=($!)
|
||||||
|
|
||||||
echo "Starting mock-annotations on :$ANNOTATIONS_PORT ..."
|
echo "Starting mock-annotations on :$ANNOTATIONS_PORT ..."
|
||||||
python -m gunicorn --bind "0.0.0.0:$ANNOTATIONS_PORT" --workers 1 --timeout 120 \
|
MEDIA_DIR="$FIXTURES" PYTHONPATH="$ROOT" \
|
||||||
|
"$PY" -m gunicorn --bind "0.0.0.0:$ANNOTATIONS_PORT" --workers 1 --timeout 120 \
|
||||||
'e2e.mocks.annotations.app:app' >/dev/null 2>&1 &
|
'e2e.mocks.annotations.app:app' >/dev/null 2>&1 &
|
||||||
PIDS+=($!)
|
PIDS+=($!)
|
||||||
|
|
||||||
@@ -43,7 +47,7 @@ echo "Starting detections service on :$DETECTIONS_PORT ..."
|
|||||||
LOADER_URL="http://localhost:$LOADER_PORT" \
|
LOADER_URL="http://localhost:$LOADER_PORT" \
|
||||||
ANNOTATIONS_URL="http://localhost:$ANNOTATIONS_PORT" \
|
ANNOTATIONS_URL="http://localhost:$ANNOTATIONS_PORT" \
|
||||||
PYTHONPATH="$ROOT/src" \
|
PYTHONPATH="$ROOT/src" \
|
||||||
python -m uvicorn main:app --host 0.0.0.0 --port "$DETECTIONS_PORT" \
|
"$PY" -m uvicorn main:app --host 0.0.0.0 --port "$DETECTIONS_PORT" \
|
||||||
--log-level warning >/dev/null 2>&1 &
|
--log-level warning >/dev/null 2>&1 &
|
||||||
PIDS+=($!)
|
PIDS+=($!)
|
||||||
|
|
||||||
@@ -66,4 +70,5 @@ BASE_URL="http://localhost:$DETECTIONS_PORT" \
|
|||||||
MOCK_LOADER_URL="http://localhost:$LOADER_PORT" \
|
MOCK_LOADER_URL="http://localhost:$LOADER_PORT" \
|
||||||
MOCK_ANNOTATIONS_URL="http://localhost:$ANNOTATIONS_PORT" \
|
MOCK_ANNOTATIONS_URL="http://localhost:$ANNOTATIONS_PORT" \
|
||||||
MEDIA_DIR="$FIXTURES" \
|
MEDIA_DIR="$FIXTURES" \
|
||||||
python -m pytest e2e/tests/ -v --tb=short "$@"
|
PYTHONPATH="$ROOT/src" \
|
||||||
|
"$PY" -m pytest e2e/tests/ tests/ -v --tb=short "$@"
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ cdef class AIRecognitionConfig:
|
|||||||
|
|
||||||
cdef public int big_image_tile_overlap_percent
|
cdef public int big_image_tile_overlap_percent
|
||||||
|
|
||||||
cdef public list[str] paths
|
|
||||||
cdef public int model_batch_size
|
cdef public int model_batch_size
|
||||||
|
|
||||||
cdef public double altitude
|
cdef public double altitude
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ cdef class AIRecognitionConfig:
|
|||||||
tracking_distance_confidence,
|
tracking_distance_confidence,
|
||||||
tracking_probability_increase,
|
tracking_probability_increase,
|
||||||
tracking_intersection_threshold,
|
tracking_intersection_threshold,
|
||||||
paths,
|
|
||||||
model_batch_size,
|
model_batch_size,
|
||||||
big_image_tile_overlap_percent,
|
big_image_tile_overlap_percent,
|
||||||
altitude,
|
altitude,
|
||||||
@@ -22,7 +21,6 @@ cdef class AIRecognitionConfig:
|
|||||||
self.tracking_probability_increase = tracking_probability_increase
|
self.tracking_probability_increase = tracking_probability_increase
|
||||||
self.tracking_intersection_threshold = tracking_intersection_threshold
|
self.tracking_intersection_threshold = tracking_intersection_threshold
|
||||||
|
|
||||||
self.paths = paths
|
|
||||||
self.model_batch_size = model_batch_size
|
self.model_batch_size = model_batch_size
|
||||||
|
|
||||||
self.big_image_tile_overlap_percent = big_image_tile_overlap_percent
|
self.big_image_tile_overlap_percent = big_image_tile_overlap_percent
|
||||||
@@ -37,7 +35,6 @@ cdef class AIRecognitionConfig:
|
|||||||
f'intersection_threshold : {self.tracking_intersection_threshold}, '
|
f'intersection_threshold : {self.tracking_intersection_threshold}, '
|
||||||
f'frame_period_recognition : {self.frame_period_recognition}, '
|
f'frame_period_recognition : {self.frame_period_recognition}, '
|
||||||
f'big_image_tile_overlap_percent: {self.big_image_tile_overlap_percent}, '
|
f'big_image_tile_overlap_percent: {self.big_image_tile_overlap_percent}, '
|
||||||
f'paths: {self.paths}, '
|
|
||||||
f'model_batch_size: {self.model_batch_size}, '
|
f'model_batch_size: {self.model_batch_size}, '
|
||||||
f'altitude: {self.altitude}, '
|
f'altitude: {self.altitude}, '
|
||||||
f'focal_length: {self.focal_length}, '
|
f'focal_length: {self.focal_length}, '
|
||||||
@@ -55,7 +52,6 @@ cdef class AIRecognitionConfig:
|
|||||||
data.get("tracking_probability_increase", 0.0),
|
data.get("tracking_probability_increase", 0.0),
|
||||||
data.get("tracking_intersection_threshold", 0.6),
|
data.get("tracking_intersection_threshold", 0.6),
|
||||||
|
|
||||||
data.get("paths", []),
|
|
||||||
data.get("model_batch_size", 8),
|
data.get("model_batch_size", 8),
|
||||||
|
|
||||||
data.get("big_image_tile_overlap_percent", 20),
|
data.get("big_image_tile_overlap_percent", 20),
|
||||||
|
|||||||
+166
-27
@@ -1,7 +1,11 @@
|
|||||||
|
import io
|
||||||
import mimetypes
|
import mimetypes
|
||||||
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import av
|
||||||
import cv2
|
import cv2
|
||||||
|
import numpy as np
|
||||||
cimport constants_inf
|
cimport constants_inf
|
||||||
|
|
||||||
from ai_availability_status cimport AIAvailabilityEnum, AIAvailabilityStatus
|
from ai_availability_status cimport AIAvailabilityEnum, AIAvailabilityStatus
|
||||||
@@ -13,6 +17,18 @@ from threading import Thread
|
|||||||
from engines import EngineClass
|
from engines import EngineClass
|
||||||
|
|
||||||
|
|
||||||
|
def ai_config_from_dict(dict data):
|
||||||
|
return AIRecognitionConfig.from_dict(data)
|
||||||
|
|
||||||
|
|
||||||
|
def _write_video_bytes_to_path(str path, bytes data, object done_event):
|
||||||
|
try:
|
||||||
|
with open(path, 'wb') as f:
|
||||||
|
f.write(data)
|
||||||
|
finally:
|
||||||
|
done_event.set()
|
||||||
|
|
||||||
|
|
||||||
cdef class Inference:
|
cdef class Inference:
|
||||||
cdef LoaderHttpClient loader_client
|
cdef LoaderHttpClient loader_client
|
||||||
cdef InferenceEngine engine
|
cdef InferenceEngine engine
|
||||||
@@ -135,6 +151,7 @@ cdef class Inference:
|
|||||||
cpdef run_detect(self, dict config_dict, object annotation_callback, object status_callback=None):
|
cpdef run_detect(self, dict config_dict, object annotation_callback, object status_callback=None):
|
||||||
cdef list[str] videos = []
|
cdef list[str] videos = []
|
||||||
cdef list[str] images = []
|
cdef list[str] images = []
|
||||||
|
cdef object media_paths = config_dict.get("paths", [])
|
||||||
cdef AIRecognitionConfig ai_config = AIRecognitionConfig.from_dict(config_dict)
|
cdef AIRecognitionConfig ai_config = AIRecognitionConfig.from_dict(config_dict)
|
||||||
if ai_config is None:
|
if ai_config is None:
|
||||||
raise Exception('ai recognition config is empty')
|
raise Exception('ai recognition config is empty')
|
||||||
@@ -148,7 +165,7 @@ cdef class Inference:
|
|||||||
return
|
return
|
||||||
|
|
||||||
self.detection_counts = {}
|
self.detection_counts = {}
|
||||||
for p in ai_config.paths:
|
for p in media_paths:
|
||||||
media_name = Path(<str>p).stem.replace(" ", "")
|
media_name = Path(<str>p).stem.replace(" ", "")
|
||||||
self.detection_counts[media_name] = 0
|
self.detection_counts[media_name] = 0
|
||||||
if self.is_video(p):
|
if self.is_video(p):
|
||||||
@@ -163,22 +180,147 @@ cdef class Inference:
|
|||||||
constants_inf.log(<str>f'run inference on {v}...')
|
constants_inf.log(<str>f'run inference on {v}...')
|
||||||
self._process_video(ai_config, v)
|
self._process_video(ai_config, v)
|
||||||
|
|
||||||
|
cpdef run_detect_image(self, bytes image_bytes, AIRecognitionConfig ai_config, str media_name,
|
||||||
|
object annotation_callback, object status_callback=None):
|
||||||
|
cdef list all_frame_data = []
|
||||||
|
cdef str original_media_name
|
||||||
|
self._annotation_callback = annotation_callback
|
||||||
|
self._status_callback = status_callback
|
||||||
|
self.stop_signal = <bint>False
|
||||||
|
self.init_ai()
|
||||||
|
if self.engine is None:
|
||||||
|
constants_inf.log(<str> "AI engine not available. Conversion may be in progress. Skipping inference.")
|
||||||
|
return
|
||||||
|
if not image_bytes:
|
||||||
|
return
|
||||||
|
frame = cv2.imdecode(np.frombuffer(image_bytes, dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||||
|
if frame is None:
|
||||||
|
constants_inf.logerror(<str>'Failed to decode image bytes')
|
||||||
|
return
|
||||||
|
original_media_name = media_name.replace(" ", "")
|
||||||
|
self.detection_counts = {}
|
||||||
|
self.detection_counts[original_media_name] = 0
|
||||||
|
self._tile_detections = {}
|
||||||
|
self._append_image_frame_entries(ai_config, all_frame_data, frame, original_media_name)
|
||||||
|
self._finalize_image_inference(ai_config, all_frame_data)
|
||||||
|
|
||||||
|
cpdef run_detect_video(self, bytes video_bytes, AIRecognitionConfig ai_config, str media_name, str save_path,
|
||||||
|
object annotation_callback, object status_callback=None):
|
||||||
|
cdef str original_media_name
|
||||||
|
self._annotation_callback = annotation_callback
|
||||||
|
self._status_callback = status_callback
|
||||||
|
self.stop_signal = <bint>False
|
||||||
|
self.init_ai()
|
||||||
|
if self.engine is None:
|
||||||
|
constants_inf.log(<str> "AI engine not available. Conversion may be in progress. Skipping inference.")
|
||||||
|
return
|
||||||
|
if not video_bytes:
|
||||||
|
return
|
||||||
|
original_media_name = media_name.replace(" ", "")
|
||||||
|
self.detection_counts = {}
|
||||||
|
self.detection_counts[original_media_name] = 0
|
||||||
|
writer_done = threading.Event()
|
||||||
|
wt = threading.Thread(
|
||||||
|
target=_write_video_bytes_to_path,
|
||||||
|
args=(save_path, video_bytes, writer_done),
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
wt.start()
|
||||||
|
try:
|
||||||
|
bio = io.BytesIO(video_bytes)
|
||||||
|
container = av.open(bio)
|
||||||
|
try:
|
||||||
|
self._process_video_pyav(ai_config, original_media_name, container)
|
||||||
|
finally:
|
||||||
|
container.close()
|
||||||
|
finally:
|
||||||
|
writer_done.wait()
|
||||||
|
wt.join(timeout=3600)
|
||||||
|
|
||||||
|
cdef _process_video_pyav(self, AIRecognitionConfig ai_config, str original_media_name, object container):
|
||||||
|
cdef int frame_count = 0
|
||||||
|
cdef int batch_count = 0
|
||||||
|
cdef list batch_frames = []
|
||||||
|
cdef list[long] batch_timestamps = []
|
||||||
|
cdef int model_h, model_w
|
||||||
|
cdef int total_frames
|
||||||
|
cdef int tf
|
||||||
|
cdef double duration_sec
|
||||||
|
cdef double fps
|
||||||
|
self._previous_annotation = <Annotation>None
|
||||||
|
model_h, model_w = self.engine.get_input_shape()
|
||||||
|
streams = container.streams.video
|
||||||
|
if not streams:
|
||||||
|
constants_inf.logerror(<str>'No video stream in container')
|
||||||
|
self.send_detection_status()
|
||||||
|
return
|
||||||
|
vstream = streams[0]
|
||||||
|
total_frames = 0
|
||||||
|
if vstream.frames is not None and int(vstream.frames) > 0:
|
||||||
|
total_frames = int(vstream.frames)
|
||||||
|
else:
|
||||||
|
duration_sec = 0.0
|
||||||
|
if vstream.duration is not None and vstream.time_base is not None:
|
||||||
|
duration_sec = float(vstream.duration * vstream.time_base)
|
||||||
|
fps = 25.0
|
||||||
|
if vstream.average_rate is not None:
|
||||||
|
fps = float(vstream.average_rate)
|
||||||
|
if duration_sec > 0:
|
||||||
|
total_frames = int(duration_sec * fps)
|
||||||
|
if total_frames < 1:
|
||||||
|
total_frames = 1
|
||||||
|
tf = total_frames
|
||||||
|
constants_inf.log(<str>f'Video (PyAV): ~{tf} frames est, {vstream.width}x{vstream.height}')
|
||||||
|
cdef int effective_batch = min(self.engine.max_batch_size, ai_config.model_batch_size)
|
||||||
|
if effective_batch < 1:
|
||||||
|
effective_batch = 1
|
||||||
|
for av_frame in container.decode(vstream):
|
||||||
|
if self.stop_signal:
|
||||||
|
break
|
||||||
|
frame_count += 1
|
||||||
|
arr = av_frame.to_ndarray(format='bgr24')
|
||||||
|
if frame_count % ai_config.frame_period_recognition == 0:
|
||||||
|
ts_ms = 0
|
||||||
|
if av_frame.time is not None:
|
||||||
|
ts_ms = int(av_frame.time * 1000)
|
||||||
|
elif av_frame.pts is not None and vstream.time_base is not None:
|
||||||
|
ts_ms = int(float(av_frame.pts) * float(vstream.time_base) * 1000)
|
||||||
|
batch_frames.append(arr)
|
||||||
|
batch_timestamps.append(<long>ts_ms)
|
||||||
|
if len(batch_frames) >= effective_batch:
|
||||||
|
batch_count += 1
|
||||||
|
tf = total_frames if total_frames > 0 else max(frame_count, 1)
|
||||||
|
constants_inf.log(<str>f'Video batch {batch_count}: frame {frame_count}/{tf} ({frame_count*100//tf}%)')
|
||||||
|
self._process_video_batch(ai_config, batch_frames, batch_timestamps, original_media_name, frame_count, tf, model_w)
|
||||||
|
batch_frames = []
|
||||||
|
batch_timestamps = []
|
||||||
|
if batch_frames:
|
||||||
|
batch_count += 1
|
||||||
|
tf = total_frames if total_frames > 0 else max(frame_count, 1)
|
||||||
|
constants_inf.log(<str>f'Video batch {batch_count} (flush): {len(batch_frames)} remaining frames')
|
||||||
|
self._process_video_batch(ai_config, batch_frames, batch_timestamps, original_media_name, frame_count, tf, model_w)
|
||||||
|
constants_inf.log(<str>f'Video done: {frame_count} frames read, {batch_count} batches processed')
|
||||||
|
self.send_detection_status()
|
||||||
|
|
||||||
cdef _process_video(self, AIRecognitionConfig ai_config, str video_name):
|
cdef _process_video(self, AIRecognitionConfig ai_config, str video_name):
|
||||||
cdef int frame_count = 0
|
cdef int frame_count = 0
|
||||||
cdef int batch_count = 0
|
cdef int batch_count = 0
|
||||||
cdef list batch_frames = []
|
cdef list batch_frames = []
|
||||||
cdef list[long] batch_timestamps = []
|
cdef list[long] batch_timestamps = []
|
||||||
cdef Annotation annotation
|
|
||||||
cdef int model_h, model_w
|
cdef int model_h, model_w
|
||||||
|
cdef str original_media_name
|
||||||
self._previous_annotation = <Annotation>None
|
self._previous_annotation = <Annotation>None
|
||||||
|
|
||||||
model_h, model_w = self.engine.get_input_shape()
|
model_h, model_w = self.engine.get_input_shape()
|
||||||
|
original_media_name = Path(<str>video_name).stem.replace(" ", "")
|
||||||
|
|
||||||
v_input = cv2.VideoCapture(<str>video_name)
|
v_input = cv2.VideoCapture(<str>video_name)
|
||||||
if not v_input.isOpened():
|
if not v_input.isOpened():
|
||||||
constants_inf.logerror(<str>f'Failed to open video: {video_name}')
|
constants_inf.logerror(<str>f'Failed to open video: {video_name}')
|
||||||
return
|
return
|
||||||
total_frames = int(v_input.get(cv2.CAP_PROP_FRAME_COUNT))
|
total_frames = int(v_input.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
if total_frames < 1:
|
||||||
|
total_frames = 1
|
||||||
fps = v_input.get(cv2.CAP_PROP_FPS)
|
fps = v_input.get(cv2.CAP_PROP_FPS)
|
||||||
width = int(v_input.get(cv2.CAP_PROP_FRAME_WIDTH))
|
width = int(v_input.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
height = int(v_input.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
height = int(v_input.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
@@ -201,21 +343,21 @@ cdef class Inference:
|
|||||||
if len(batch_frames) >= effective_batch:
|
if len(batch_frames) >= effective_batch:
|
||||||
batch_count += 1
|
batch_count += 1
|
||||||
constants_inf.log(<str>f'Video batch {batch_count}: frame {frame_count}/{total_frames} ({frame_count*100//total_frames}%)')
|
constants_inf.log(<str>f'Video batch {batch_count}: frame {frame_count}/{total_frames} ({frame_count*100//total_frames}%)')
|
||||||
self._process_video_batch(ai_config, batch_frames, batch_timestamps, video_name, frame_count, total_frames, model_w)
|
self._process_video_batch(ai_config, batch_frames, batch_timestamps, original_media_name, frame_count, total_frames, model_w)
|
||||||
batch_frames = []
|
batch_frames = []
|
||||||
batch_timestamps = []
|
batch_timestamps = []
|
||||||
|
|
||||||
if batch_frames:
|
if batch_frames:
|
||||||
batch_count += 1
|
batch_count += 1
|
||||||
constants_inf.log(<str>f'Video batch {batch_count} (flush): {len(batch_frames)} remaining frames')
|
constants_inf.log(<str>f'Video batch {batch_count} (flush): {len(batch_frames)} remaining frames')
|
||||||
self._process_video_batch(ai_config, batch_frames, batch_timestamps, video_name, frame_count, total_frames, model_w)
|
self._process_video_batch(ai_config, batch_frames, batch_timestamps, original_media_name, frame_count, total_frames, model_w)
|
||||||
|
|
||||||
v_input.release()
|
v_input.release()
|
||||||
constants_inf.log(<str>f'Video done: {frame_count} frames read, {batch_count} batches processed')
|
constants_inf.log(<str>f'Video done: {frame_count} frames read, {batch_count} batches processed')
|
||||||
self.send_detection_status()
|
self.send_detection_status()
|
||||||
|
|
||||||
cdef _process_video_batch(self, AIRecognitionConfig ai_config, list batch_frames,
|
cdef _process_video_batch(self, AIRecognitionConfig ai_config, list batch_frames,
|
||||||
list batch_timestamps, str video_name,
|
list batch_timestamps, str original_media_name,
|
||||||
int frame_count, int total_frames, int model_w):
|
int frame_count, int total_frames, int model_w):
|
||||||
cdef Annotation annotation
|
cdef Annotation annotation
|
||||||
list_detections = self.engine.process_frames(batch_frames, ai_config)
|
list_detections = self.engine.process_frames(batch_frames, ai_config)
|
||||||
@@ -225,7 +367,6 @@ cdef class Inference:
|
|||||||
|
|
||||||
for i in range(len(list_detections)):
|
for i in range(len(list_detections)):
|
||||||
detections = list_detections[i]
|
detections = list_detections[i]
|
||||||
original_media_name = Path(<str>video_name).stem.replace(" ", "")
|
|
||||||
name = f'{original_media_name}_{constants_inf.format_time(batch_timestamps[i])}'
|
name = f'{original_media_name}_{constants_inf.format_time(batch_timestamps[i])}'
|
||||||
annotation = Annotation(name, original_media_name, batch_timestamps[i], detections)
|
annotation = Annotation(name, original_media_name, batch_timestamps[i], detections)
|
||||||
|
|
||||||
@@ -247,56 +388,54 @@ cdef class Inference:
|
|||||||
cb = self._annotation_callback
|
cb = self._annotation_callback
|
||||||
cb(annotation, percent)
|
cb(annotation, percent)
|
||||||
|
|
||||||
cdef _process_images(self, AIRecognitionConfig ai_config, list[str] image_paths):
|
cdef _append_image_frame_entries(self, AIRecognitionConfig ai_config, list all_frame_data, frame, str original_media_name):
|
||||||
cdef list all_frame_data = []
|
|
||||||
cdef double ground_sampling_distance
|
cdef double ground_sampling_distance
|
||||||
cdef int model_h, model_w
|
cdef int model_h, model_w
|
||||||
|
cdef int img_h, img_w
|
||||||
model_h, model_w = self.engine.get_input_shape()
|
model_h, model_w = self.engine.get_input_shape()
|
||||||
self._tile_detections = {}
|
|
||||||
|
|
||||||
for path in image_paths:
|
|
||||||
frame = cv2.imread(<str>path)
|
|
||||||
if frame is None:
|
|
||||||
constants_inf.logerror(<str>f'Failed to read image {path}')
|
|
||||||
continue
|
|
||||||
img_h, img_w, _ = frame.shape
|
img_h, img_w, _ = frame.shape
|
||||||
original_media_name = Path(<str> path).stem.replace(" ", "")
|
|
||||||
|
|
||||||
ground_sampling_distance = ai_config.sensor_width * ai_config.altitude / (ai_config.focal_length * img_w)
|
ground_sampling_distance = ai_config.sensor_width * ai_config.altitude / (ai_config.focal_length * img_w)
|
||||||
constants_inf.log(<str>f'ground sampling distance: {ground_sampling_distance}')
|
constants_inf.log(<str>f'ground sampling distance: {ground_sampling_distance}')
|
||||||
|
|
||||||
if img_h <= 1.5 * model_h and img_w <= 1.5 * model_w:
|
if img_h <= 1.5 * model_h and img_w <= 1.5 * model_w:
|
||||||
all_frame_data.append((frame, original_media_name, f'{original_media_name}_000000', ground_sampling_distance))
|
all_frame_data.append((frame, original_media_name, f'{original_media_name}_000000', ground_sampling_distance))
|
||||||
else:
|
else:
|
||||||
tile_size = int(constants_inf.METERS_IN_TILE / ground_sampling_distance)
|
tile_size = int(constants_inf.METERS_IN_TILE / ground_sampling_distance)
|
||||||
constants_inf.log(<str> f'calc tile size: {tile_size}')
|
constants_inf.log(<str> f'calc tile size: {tile_size}')
|
||||||
res = self.split_to_tiles(frame, path, tile_size, ai_config.big_image_tile_overlap_percent)
|
res = self.split_to_tiles(frame, original_media_name, tile_size, ai_config.big_image_tile_overlap_percent)
|
||||||
for tile_frame, omn, tile_name in res:
|
for tile_frame, omn, tile_name in res:
|
||||||
all_frame_data.append((tile_frame, omn, tile_name, ground_sampling_distance))
|
all_frame_data.append((tile_frame, omn, tile_name, ground_sampling_distance))
|
||||||
|
|
||||||
|
cdef _finalize_image_inference(self, AIRecognitionConfig ai_config, list all_frame_data):
|
||||||
if not all_frame_data:
|
if not all_frame_data:
|
||||||
return
|
return
|
||||||
|
|
||||||
frames = [fd[0] for fd in all_frame_data]
|
frames = [fd[0] for fd in all_frame_data]
|
||||||
all_dets = self.engine.process_frames(frames, ai_config)
|
all_dets = self.engine.process_frames(frames, ai_config)
|
||||||
|
|
||||||
for i in range(len(all_dets)):
|
for i in range(len(all_dets)):
|
||||||
frame_entry = all_frame_data[i]
|
frame_entry = all_frame_data[i]
|
||||||
f = frame_entry[0]
|
f = frame_entry[0]
|
||||||
original_media_name = frame_entry[1]
|
original_media_name = frame_entry[1]
|
||||||
name = frame_entry[2]
|
name = frame_entry[2]
|
||||||
gsd = frame_entry[3]
|
gsd = frame_entry[3]
|
||||||
|
|
||||||
annotation = Annotation(name, original_media_name, 0, all_dets[i])
|
annotation = Annotation(name, original_media_name, 0, all_dets[i])
|
||||||
if self.is_valid_image_annotation(annotation, gsd, f.shape):
|
if self.is_valid_image_annotation(annotation, gsd, f.shape):
|
||||||
constants_inf.log(<str> f'Detected {annotation}')
|
constants_inf.log(<str> f'Detected {annotation}')
|
||||||
_, image = cv2.imencode('.jpg', f)
|
_, image = cv2.imencode('.jpg', f)
|
||||||
annotation.image = image.tobytes()
|
annotation.image = image.tobytes()
|
||||||
self.on_annotation(annotation)
|
self.on_annotation(annotation)
|
||||||
|
|
||||||
self.send_detection_status()
|
self.send_detection_status()
|
||||||
|
|
||||||
|
cdef _process_images(self, AIRecognitionConfig ai_config, list[str] image_paths):
|
||||||
|
cdef list all_frame_data = []
|
||||||
|
self._tile_detections = {}
|
||||||
|
for path in image_paths:
|
||||||
|
frame = cv2.imread(<str>path)
|
||||||
|
if frame is None:
|
||||||
|
constants_inf.logerror(<str>f'Failed to read image {path}')
|
||||||
|
continue
|
||||||
|
original_media_name = Path(<str> path).stem.replace(" ", "")
|
||||||
|
self._append_image_frame_entries(ai_config, all_frame_data, frame, original_media_name)
|
||||||
|
self._finalize_image_inference(ai_config, all_frame_data)
|
||||||
|
|
||||||
cdef send_detection_status(self):
|
cdef send_detection_status(self):
|
||||||
if self._status_callback is not None:
|
if self._status_callback is not None:
|
||||||
cb = self._status_callback
|
cb = self._status_callback
|
||||||
@@ -304,14 +443,14 @@ cdef class Inference:
|
|||||||
cb(media_name, self.detection_counts[media_name])
|
cb(media_name, self.detection_counts[media_name])
|
||||||
self.detection_counts.clear()
|
self.detection_counts.clear()
|
||||||
|
|
||||||
cdef split_to_tiles(self, frame, path, tile_size, overlap_percent):
|
cdef split_to_tiles(self, frame, str media_stem, tile_size, overlap_percent):
|
||||||
constants_inf.log(<str>f'splitting image {path} to tiles...')
|
constants_inf.log(<str>f'splitting image {media_stem} to tiles...')
|
||||||
img_h, img_w, _ = frame.shape
|
img_h, img_w, _ = frame.shape
|
||||||
stride_w = int(tile_size * (1 - overlap_percent / 100))
|
stride_w = int(tile_size * (1 - overlap_percent / 100))
|
||||||
stride_h = int(tile_size * (1 - overlap_percent / 100))
|
stride_h = int(tile_size * (1 - overlap_percent / 100))
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
original_media_name = Path(<str> path).stem.replace(" ", "")
|
original_media_name = media_stem
|
||||||
for y in range(0, img_h, stride_h):
|
for y in range(0, img_h, stride_h):
|
||||||
for x in range(0, img_w, stride_w):
|
for x in range(0, img_w, stride_w):
|
||||||
x_end = min(x + tile_size, img_w)
|
x_end = min(x + tile_size, img_w)
|
||||||
|
|||||||
@@ -6,3 +6,5 @@ cdef class LoaderHttpClient:
|
|||||||
cdef str base_url
|
cdef str base_url
|
||||||
cdef LoadResult load_big_small_resource(self, str filename, str directory)
|
cdef LoadResult load_big_small_resource(self, str filename, str directory)
|
||||||
cdef LoadResult upload_big_small_resource(self, bytes content, str filename, str directory)
|
cdef LoadResult upload_big_small_resource(self, bytes content, str filename, str directory)
|
||||||
|
cpdef object fetch_user_ai_settings(self, str user_id, str bearer_token)
|
||||||
|
cpdef object fetch_media_path(self, str media_id, str bearer_token)
|
||||||
|
|||||||
@@ -41,3 +41,38 @@ cdef class LoaderHttpClient:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"LoaderHttpClient.upload_big_small_resource failed: {e}")
|
logger.error(f"LoaderHttpClient.upload_big_small_resource failed: {e}")
|
||||||
return LoadResult(str(e))
|
return LoadResult(str(e))
|
||||||
|
|
||||||
|
cpdef object fetch_user_ai_settings(self, str user_id, str bearer_token):
|
||||||
|
try:
|
||||||
|
headers = {}
|
||||||
|
if bearer_token:
|
||||||
|
headers["Authorization"] = f"Bearer {bearer_token}"
|
||||||
|
response = requests.get(
|
||||||
|
f"{self.base_url}/api/users/{user_id}/ai-settings",
|
||||||
|
headers=headers,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
return None
|
||||||
|
return response.json()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LoaderHttpClient.fetch_user_ai_settings failed: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
cpdef object fetch_media_path(self, str media_id, str bearer_token):
|
||||||
|
try:
|
||||||
|
headers = {}
|
||||||
|
if bearer_token:
|
||||||
|
headers["Authorization"] = f"Bearer {bearer_token}"
|
||||||
|
response = requests.get(
|
||||||
|
f"{self.base_url}/api/media/{media_id}",
|
||||||
|
headers=headers,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
return None
|
||||||
|
data = response.json()
|
||||||
|
return data.get("path")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LoaderHttpClient.fetch_media_path failed: {e}")
|
||||||
|
return None
|
||||||
|
|||||||
+152
-17
@@ -4,10 +4,10 @@ import json
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Optional
|
from typing import Annotated, Optional
|
||||||
|
|
||||||
import requests as http_requests
|
import requests as http_requests
|
||||||
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
|
from fastapi import Body, FastAPI, UploadFile, File, Form, HTTPException, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -20,6 +20,7 @@ LOADER_URL = os.environ.get("LOADER_URL", "http://loader:8080")
|
|||||||
ANNOTATIONS_URL = os.environ.get("ANNOTATIONS_URL", "http://annotations:8080")
|
ANNOTATIONS_URL = os.environ.get("ANNOTATIONS_URL", "http://annotations:8080")
|
||||||
|
|
||||||
loader_client = LoaderHttpClient(LOADER_URL)
|
loader_client = LoaderHttpClient(LOADER_URL)
|
||||||
|
annotations_client = LoaderHttpClient(ANNOTATIONS_URL)
|
||||||
inference = None
|
inference = None
|
||||||
_event_queues: list[asyncio.Queue] = []
|
_event_queues: list[asyncio.Queue] = []
|
||||||
_active_detections: dict[str, asyncio.Task] = {}
|
_active_detections: dict[str, asyncio.Task] = {}
|
||||||
@@ -60,6 +61,29 @@ class TokenManager:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def decode_user_id(token: str) -> Optional[str]:
|
||||||
|
try:
|
||||||
|
payload = token.split(".")[1]
|
||||||
|
padding = 4 - len(payload) % 4
|
||||||
|
if padding != 4:
|
||||||
|
payload += "=" * padding
|
||||||
|
data = json.loads(base64.urlsafe_b64decode(payload))
|
||||||
|
uid = (
|
||||||
|
data.get("sub")
|
||||||
|
or data.get("userId")
|
||||||
|
or data.get("user_id")
|
||||||
|
or data.get("nameid")
|
||||||
|
or data.get(
|
||||||
|
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/nameidentifier"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if uid is None:
|
||||||
|
return None
|
||||||
|
return str(uid)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_inference():
|
def get_inference():
|
||||||
global inference
|
global inference
|
||||||
@@ -105,7 +129,115 @@ class AIConfigDto(BaseModel):
|
|||||||
altitude: float = 400
|
altitude: float = 400
|
||||||
focal_length: float = 24
|
focal_length: float = 24
|
||||||
sensor_width: float = 23.5
|
sensor_width: float = 23.5
|
||||||
paths: list[str] = []
|
|
||||||
|
|
||||||
|
_AI_SETTINGS_FIELD_KEYS = (
|
||||||
|
(
|
||||||
|
"frame_period_recognition",
|
||||||
|
("frame_period_recognition", "framePeriodRecognition", "FramePeriodRecognition"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"frame_recognition_seconds",
|
||||||
|
("frame_recognition_seconds", "frameRecognitionSeconds", "FrameRecognitionSeconds"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"probability_threshold",
|
||||||
|
("probability_threshold", "probabilityThreshold", "ProbabilityThreshold"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"tracking_distance_confidence",
|
||||||
|
(
|
||||||
|
"tracking_distance_confidence",
|
||||||
|
"trackingDistanceConfidence",
|
||||||
|
"TrackingDistanceConfidence",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"tracking_probability_increase",
|
||||||
|
(
|
||||||
|
"tracking_probability_increase",
|
||||||
|
"trackingProbabilityIncrease",
|
||||||
|
"TrackingProbabilityIncrease",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"tracking_intersection_threshold",
|
||||||
|
(
|
||||||
|
"tracking_intersection_threshold",
|
||||||
|
"trackingIntersectionThreshold",
|
||||||
|
"TrackingIntersectionThreshold",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"model_batch_size",
|
||||||
|
("model_batch_size", "modelBatchSize", "ModelBatchSize"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"big_image_tile_overlap_percent",
|
||||||
|
(
|
||||||
|
"big_image_tile_overlap_percent",
|
||||||
|
"bigImageTileOverlapPercent",
|
||||||
|
"BigImageTileOverlapPercent",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"altitude",
|
||||||
|
("altitude", "Altitude"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"focal_length",
|
||||||
|
("focal_length", "focalLength", "FocalLength"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"sensor_width",
|
||||||
|
("sensor_width", "sensorWidth", "SensorWidth"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _merged_annotation_settings_payload(raw: object) -> dict:
|
||||||
|
if not raw or not isinstance(raw, dict):
|
||||||
|
return {}
|
||||||
|
merged = dict(raw)
|
||||||
|
inner = raw.get("aiRecognitionSettings")
|
||||||
|
if isinstance(inner, dict):
|
||||||
|
merged.update(inner)
|
||||||
|
cam = raw.get("cameraSettings")
|
||||||
|
if isinstance(cam, dict):
|
||||||
|
merged.update(cam)
|
||||||
|
out = {}
|
||||||
|
for snake, aliases in _AI_SETTINGS_FIELD_KEYS:
|
||||||
|
for key in aliases:
|
||||||
|
if key in merged and merged[key] is not None:
|
||||||
|
out[snake] = merged[key]
|
||||||
|
break
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _build_media_detect_config_dict(
|
||||||
|
media_id: str,
|
||||||
|
token_mgr: Optional[TokenManager],
|
||||||
|
override: Optional[AIConfigDto],
|
||||||
|
) -> dict:
|
||||||
|
cfg: dict = {}
|
||||||
|
bearer = ""
|
||||||
|
if token_mgr:
|
||||||
|
bearer = token_mgr.get_valid_token()
|
||||||
|
uid = TokenManager.decode_user_id(token_mgr.access_token)
|
||||||
|
if uid:
|
||||||
|
raw = annotations_client.fetch_user_ai_settings(uid, bearer)
|
||||||
|
cfg.update(_merged_annotation_settings_payload(raw))
|
||||||
|
if override is not None:
|
||||||
|
for k, v in override.model_dump(exclude_defaults=True).items():
|
||||||
|
cfg[k] = v
|
||||||
|
media_path = annotations_client.fetch_media_path(media_id, bearer)
|
||||||
|
if not media_path:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503,
|
||||||
|
detail="Could not resolve media path from annotations service",
|
||||||
|
)
|
||||||
|
cfg["paths"] = [media_path]
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
def detection_to_dto(det) -> DetectionDto:
|
def detection_to_dto(det) -> DetectionDto:
|
||||||
@@ -150,9 +282,11 @@ async def detect_image(
|
|||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
config: Optional[str] = Form(None),
|
config: Optional[str] = Form(None),
|
||||||
):
|
):
|
||||||
import tempfile
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from inference import ai_config_from_dict
|
||||||
|
|
||||||
image_bytes = await file.read()
|
image_bytes = await file.read()
|
||||||
if not image_bytes:
|
if not image_bytes:
|
||||||
@@ -166,13 +300,7 @@ async def detect_image(
|
|||||||
if config:
|
if config:
|
||||||
config_dict = json.loads(config)
|
config_dict = json.loads(config)
|
||||||
|
|
||||||
suffix = os.path.splitext(file.filename or "upload.jpg")[1] or ".jpg"
|
media_name = Path(file.filename or "upload.jpg").stem.replace(" ", "")
|
||||||
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
|
|
||||||
try:
|
|
||||||
tmp.write(image_bytes)
|
|
||||||
tmp.close()
|
|
||||||
config_dict["paths"] = [tmp.name]
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
inf = get_inference()
|
inf = get_inference()
|
||||||
results = []
|
results = []
|
||||||
@@ -180,7 +308,13 @@ async def detect_image(
|
|||||||
def on_annotation(annotation, percent):
|
def on_annotation(annotation, percent):
|
||||||
results.extend(annotation.detections)
|
results.extend(annotation.detections)
|
||||||
|
|
||||||
await loop.run_in_executor(executor, inf.run_detect, config_dict, on_annotation)
|
ai_cfg = ai_config_from_dict(config_dict)
|
||||||
|
|
||||||
|
def run_img():
|
||||||
|
inf.run_detect_image(image_bytes, ai_cfg, media_name, on_annotation)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await loop.run_in_executor(executor, run_img)
|
||||||
return [detection_to_dto(d) for d in results]
|
return [detection_to_dto(d) for d in results]
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
if "not available" in str(e):
|
if "not available" in str(e):
|
||||||
@@ -188,8 +322,6 @@ async def detect_image(
|
|||||||
raise HTTPException(status_code=422, detail=str(e))
|
raise HTTPException(status_code=422, detail=str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
finally:
|
|
||||||
os.unlink(tmp.name)
|
|
||||||
|
|
||||||
|
|
||||||
def _post_annotation_to_service(token_mgr: TokenManager, media_id: str,
|
def _post_annotation_to_service(token_mgr: TokenManager, media_id: str,
|
||||||
@@ -216,7 +348,11 @@ def _post_annotation_to_service(token_mgr: TokenManager, media_id: str,
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/detect/{media_id}")
|
@app.post("/detect/{media_id}")
|
||||||
async def detect_media(media_id: str, request: Request, config: Optional[AIConfigDto] = None):
|
async def detect_media(
|
||||||
|
media_id: str,
|
||||||
|
request: Request,
|
||||||
|
config: Annotated[Optional[AIConfigDto], Body()] = None,
|
||||||
|
):
|
||||||
existing = _active_detections.get(media_id)
|
existing = _active_detections.get(media_id)
|
||||||
if existing is not None and not existing.done():
|
if existing is not None and not existing.done():
|
||||||
raise HTTPException(status_code=409, detail="Detection already in progress for this media")
|
raise HTTPException(status_code=409, detail="Detection already in progress for this media")
|
||||||
@@ -226,8 +362,7 @@ async def detect_media(media_id: str, request: Request, config: Optional[AIConfi
|
|||||||
refresh_token = request.headers.get("x-refresh-token", "")
|
refresh_token = request.headers.get("x-refresh-token", "")
|
||||||
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
|
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
|
||||||
|
|
||||||
cfg = config or AIConfigDto()
|
config_dict = _build_media_detect_config_dict(media_id, token_mgr, config)
|
||||||
config_dict = cfg.model_dump()
|
|
||||||
|
|
||||||
async def run_detection():
|
async def run_detection():
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
def test_ai_config_from_dict_defaults():
|
||||||
|
from inference import ai_config_from_dict
|
||||||
|
|
||||||
|
cfg = ai_config_from_dict({})
|
||||||
|
assert cfg.model_batch_size == 8
|
||||||
|
assert cfg.frame_period_recognition == 4
|
||||||
|
assert cfg.frame_recognition_seconds == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_ai_config_from_dict_overrides():
|
||||||
|
from inference import ai_config_from_dict
|
||||||
|
|
||||||
|
cfg = ai_config_from_dict({"model_batch_size": 4, "probability_threshold": 0.5})
|
||||||
|
assert cfg.model_batch_size == 4
|
||||||
|
assert cfg.probability_threshold == 0.5
|
||||||
@@ -0,0 +1,126 @@
|
|||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
|
||||||
|
def _access_jwt(sub: str = "u1") -> str:
|
||||||
|
raw = json.dumps(
|
||||||
|
{"exp": int(time.time()) + 3600, "sub": sub}, separators=(",", ":")
|
||||||
|
).encode()
|
||||||
|
payload = base64.urlsafe_b64encode(raw).decode().rstrip("=")
|
||||||
|
return f"h.{payload}.s"
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_manager_decode_user_id_sub():
|
||||||
|
# Arrange
|
||||||
|
from main import TokenManager
|
||||||
|
|
||||||
|
raw = json.dumps(
|
||||||
|
{"sub": "user-abc", "exp": int(time.time()) + 3600}, separators=(",", ":")
|
||||||
|
).encode()
|
||||||
|
payload = base64.urlsafe_b64encode(raw).decode().rstrip("=")
|
||||||
|
token = f"hdr.{payload}.sig"
|
||||||
|
# Act
|
||||||
|
uid = TokenManager.decode_user_id(token)
|
||||||
|
# Assert
|
||||||
|
assert uid == "user-abc"
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_manager_decode_user_id_invalid():
|
||||||
|
# Arrange
|
||||||
|
from main import TokenManager
|
||||||
|
|
||||||
|
# Act
|
||||||
|
uid = TokenManager.decode_user_id("not-a-jwt")
|
||||||
|
# Assert
|
||||||
|
assert uid is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_merged_annotation_settings_pascal_case():
|
||||||
|
# Arrange
|
||||||
|
from main import _merged_annotation_settings_payload
|
||||||
|
|
||||||
|
raw = {
|
||||||
|
"FramePeriodRecognition": 5,
|
||||||
|
"ProbabilityThreshold": 0.4,
|
||||||
|
"Altitude": 300,
|
||||||
|
"FocalLength": 35,
|
||||||
|
"SensorWidth": 36,
|
||||||
|
}
|
||||||
|
# Act
|
||||||
|
out = _merged_annotation_settings_payload(raw)
|
||||||
|
# Assert
|
||||||
|
assert out["frame_period_recognition"] == 5
|
||||||
|
assert out["probability_threshold"] == 0.4
|
||||||
|
assert out["altitude"] == 300
|
||||||
|
|
||||||
|
|
||||||
|
def test_merged_annotation_nested_sections():
|
||||||
|
# Arrange
|
||||||
|
from main import _merged_annotation_settings_payload
|
||||||
|
|
||||||
|
raw = {
|
||||||
|
"aiRecognitionSettings": {"modelBatchSize": 4},
|
||||||
|
"cameraSettings": {"altitude": 100},
|
||||||
|
}
|
||||||
|
# Act
|
||||||
|
out = _merged_annotation_settings_payload(raw)
|
||||||
|
# Assert
|
||||||
|
assert out["model_batch_size"] == 4
|
||||||
|
assert out["altitude"] == 100
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_media_detect_config_uses_api_path_and_defaults_when_api_empty():
|
||||||
|
# Arrange
|
||||||
|
import main
|
||||||
|
|
||||||
|
tm = main.TokenManager(_access_jwt(), "")
|
||||||
|
mock_ann = MagicMock()
|
||||||
|
mock_ann.fetch_user_ai_settings.return_value = None
|
||||||
|
mock_ann.fetch_media_path.return_value = "/m/file.jpg"
|
||||||
|
with patch("main.annotations_client", mock_ann):
|
||||||
|
# Act
|
||||||
|
cfg = main._build_media_detect_config_dict("mid-1", tm, None)
|
||||||
|
# Assert
|
||||||
|
assert cfg["paths"] == ["/m/file.jpg"]
|
||||||
|
assert "probability_threshold" not in cfg
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_media_detect_config_override_wins():
|
||||||
|
# Arrange
|
||||||
|
import main
|
||||||
|
|
||||||
|
tm = main.TokenManager(_access_jwt(), "")
|
||||||
|
override = main.AIConfigDto(probability_threshold=0.99)
|
||||||
|
mock_ann = MagicMock()
|
||||||
|
mock_ann.fetch_user_ai_settings.return_value = {
|
||||||
|
"probabilityThreshold": 0.2,
|
||||||
|
"altitude": 500,
|
||||||
|
}
|
||||||
|
mock_ann.fetch_media_path.return_value = "/m/v.mp4"
|
||||||
|
with patch("main.annotations_client", mock_ann):
|
||||||
|
# Act
|
||||||
|
cfg = main._build_media_detect_config_dict("vid-1", tm, override)
|
||||||
|
# Assert
|
||||||
|
assert cfg["probability_threshold"] == 0.99
|
||||||
|
assert cfg["altitude"] == 500
|
||||||
|
assert cfg["paths"] == ["/m/v.mp4"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_media_detect_config_raises_when_no_media_path():
|
||||||
|
# Arrange
|
||||||
|
import main
|
||||||
|
|
||||||
|
tm = main.TokenManager(_access_jwt(), "")
|
||||||
|
mock_ann = MagicMock()
|
||||||
|
mock_ann.fetch_user_ai_settings.return_value = {}
|
||||||
|
mock_ann.fetch_media_path.return_value = None
|
||||||
|
with patch("main.annotations_client", mock_ann):
|
||||||
|
# Act / Assert
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
main._build_media_detect_config_dict("missing", tm, None)
|
||||||
|
assert exc.value.status_code == 503
|
||||||
Reference in New Issue
Block a user