diff --git a/_docs/_autopilot_state.md b/_docs/_autopilot_state.md index d3cb6c4..8d8bb3a 100644 --- a/_docs/_autopilot_state.md +++ b/_docs/_autopilot_state.md @@ -4,6 +4,6 @@ flow: existing-code step: 9 name: Implement -status: not_started +status: in_progress sub_step: 0 retry_count: 0 diff --git a/e2e/mocks/annotations/app.py b/e2e/mocks/annotations/app.py index 647cf9b..25778b7 100644 --- a/e2e/mocks/annotations/app.py +++ b/e2e/mocks/annotations/app.py @@ -1,3 +1,5 @@ +import os + from flask import Flask, request app = Flask(__name__) @@ -25,6 +27,35 @@ def auth_refresh(): return {"token": "refreshed-test-token"} +@app.route("/api/users//ai-settings", methods=["GET"]) +def user_ai_settings(user_id): + if _fail(): + return "", 503 + return { + "frame_period_recognition": 4, + "frame_recognition_seconds": 2, + "probability_threshold": 0.25, + "tracking_distance_confidence": 0.1, + "tracking_probability_increase": 0.1, + "tracking_intersection_threshold": 0.6, + "model_batch_size": 8, + "big_image_tile_overlap_percent": 20, + "altitude": 400, + "focal_length": 24, + "sensor_width": 23.5, + } + + +@app.route("/api/media/", methods=["GET"]) +def media_path(media_id): + if _fail(): + return "", 503 + root = os.environ.get("MEDIA_DIR", "/media") + if media_id.startswith("sse-") or media_id.startswith("video-"): + return {"path": f"{root}/video_test01.mp4"} + return {"path": f"{root}/image_small.jpg"} + + @app.route("/mock/config", methods=["POST"]) def mock_config(): global _mode diff --git a/e2e/requirements.txt b/e2e/requirements.txt index 96d5b33..f8d8c84 100644 --- a/e2e/requirements.txt +++ b/e2e/requirements.txt @@ -3,3 +3,5 @@ pytest-csv requests==2.32.4 sseclient-py pytest-timeout +flask +gunicorn diff --git a/e2e/tests/test_async_sse.py b/e2e/tests/test_async_sse.py index 8ba0344..3322090 100644 --- a/e2e/tests/test_async_sse.py +++ b/e2e/tests/test_async_sse.py @@ -1,35 +1,17 @@ import json -import os import threading import time import uuid import pytest -_MEDIA = os.environ.get("MEDIA_DIR", "/media") - def _ai_config_video() -> dict: - return { - "probability_threshold": 0.25, - "tracking_intersection_threshold": 0.6, - "altitude": 400, - "focal_length": 24, - "sensor_width": 23.5, - "paths": [f"{_MEDIA}/video_test01.mp4"], - "frame_period_recognition": 4, - "frame_recognition_seconds": 2, - } + return {} def _ai_config_image() -> dict: - return { - "probability_threshold": 0.25, - "altitude": 400, - "focal_length": 24, - "sensor_width": 23.5, - "paths": [f"{_MEDIA}/image_small.jpg"], - } + return {} def test_ft_p08_immediate_async_response( diff --git a/e2e/tests/test_video.py b/e2e/tests/test_video.py index 5b64ac4..30f627d 100644 --- a/e2e/tests/test_video.py +++ b/e2e/tests/test_video.py @@ -20,20 +20,9 @@ def _make_jwt() -> str: @pytest.fixture(scope="module") -def video_events(warm_engine, http_client, video_short_path): +def video_events(warm_engine, http_client): media_id = f"video-{uuid.uuid4().hex}" - body = { - "probability_threshold": 0.25, - "frame_period_recognition": 4, - "frame_recognition_seconds": 2, - "tracking_distance_confidence": 0.1, - "tracking_probability_increase": 0.1, - "tracking_intersection_threshold": 0.6, - "altitude": 400.0, - "focal_length": 24.0, - "sensor_width": 23.5, - "paths": [video_short_path], - } + body = {} token = _make_jwt() collected: list[tuple[float, dict]] = [] diff --git a/requirements.txt b/requirements.txt index 43f8105..7cf4c35 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ pynvml==12.0.0 requests==2.32.4 loguru==0.7.3 python-multipart +av==14.2.0 diff --git a/run-tests.sh b/run-tests.sh index 193ce87..7662fda 100755 --- a/run-tests.sh +++ b/run-tests.sh @@ -17,8 +17,10 @@ cleanup() { } trap cleanup EXIT +PY="$(command -v python3 2>/dev/null || command -v python 2>/dev/null || echo python)" + echo "Building Cython extensions ..." -python setup.py build_ext --inplace +"$PY" setup.py build_ext --inplace for port in $LOADER_PORT $ANNOTATIONS_PORT $DETECTIONS_PORT; do if lsof -ti :"$port" >/dev/null 2>&1; then @@ -29,13 +31,15 @@ for port in $LOADER_PORT $ANNOTATIONS_PORT $DETECTIONS_PORT; do done echo "Starting mock-loader on :$LOADER_PORT ..." -MODELS_ROOT="$FIXTURES" \ - python -m gunicorn --bind "0.0.0.0:$LOADER_PORT" --workers 1 --timeout 120 \ +cd "$ROOT" +MODELS_ROOT="$FIXTURES" PYTHONPATH="$ROOT" \ + "$PY" -m gunicorn --bind "0.0.0.0:$LOADER_PORT" --workers 1 --timeout 120 \ 'e2e.mocks.loader.app:app' >/dev/null 2>&1 & PIDS+=($!) echo "Starting mock-annotations on :$ANNOTATIONS_PORT ..." -python -m gunicorn --bind "0.0.0.0:$ANNOTATIONS_PORT" --workers 1 --timeout 120 \ +MEDIA_DIR="$FIXTURES" PYTHONPATH="$ROOT" \ + "$PY" -m gunicorn --bind "0.0.0.0:$ANNOTATIONS_PORT" --workers 1 --timeout 120 \ 'e2e.mocks.annotations.app:app' >/dev/null 2>&1 & PIDS+=($!) @@ -43,7 +47,7 @@ echo "Starting detections service on :$DETECTIONS_PORT ..." LOADER_URL="http://localhost:$LOADER_PORT" \ ANNOTATIONS_URL="http://localhost:$ANNOTATIONS_PORT" \ PYTHONPATH="$ROOT/src" \ - python -m uvicorn main:app --host 0.0.0.0 --port "$DETECTIONS_PORT" \ + "$PY" -m uvicorn main:app --host 0.0.0.0 --port "$DETECTIONS_PORT" \ --log-level warning >/dev/null 2>&1 & PIDS+=($!) @@ -66,4 +70,5 @@ BASE_URL="http://localhost:$DETECTIONS_PORT" \ MOCK_LOADER_URL="http://localhost:$LOADER_PORT" \ MOCK_ANNOTATIONS_URL="http://localhost:$ANNOTATIONS_PORT" \ MEDIA_DIR="$FIXTURES" \ - python -m pytest e2e/tests/ -v --tb=short "$@" +PYTHONPATH="$ROOT/src" \ + "$PY" -m pytest e2e/tests/ tests/ -v --tb=short "$@" diff --git a/src/ai_config.pxd b/src/ai_config.pxd index 399444a..d69247e 100644 --- a/src/ai_config.pxd +++ b/src/ai_config.pxd @@ -10,7 +10,6 @@ cdef class AIRecognitionConfig: cdef public int big_image_tile_overlap_percent - cdef public list[str] paths cdef public int model_batch_size cdef public double altitude diff --git a/src/ai_config.pyx b/src/ai_config.pyx index 14f6cc0..0a8c4ef 100644 --- a/src/ai_config.pyx +++ b/src/ai_config.pyx @@ -7,7 +7,6 @@ cdef class AIRecognitionConfig: tracking_distance_confidence, tracking_probability_increase, tracking_intersection_threshold, - paths, model_batch_size, big_image_tile_overlap_percent, altitude, @@ -22,7 +21,6 @@ cdef class AIRecognitionConfig: self.tracking_probability_increase = tracking_probability_increase self.tracking_intersection_threshold = tracking_intersection_threshold - self.paths = paths self.model_batch_size = model_batch_size self.big_image_tile_overlap_percent = big_image_tile_overlap_percent @@ -37,7 +35,6 @@ cdef class AIRecognitionConfig: f'intersection_threshold : {self.tracking_intersection_threshold}, ' f'frame_period_recognition : {self.frame_period_recognition}, ' f'big_image_tile_overlap_percent: {self.big_image_tile_overlap_percent}, ' - f'paths: {self.paths}, ' f'model_batch_size: {self.model_batch_size}, ' f'altitude: {self.altitude}, ' f'focal_length: {self.focal_length}, ' @@ -55,7 +52,6 @@ cdef class AIRecognitionConfig: data.get("tracking_probability_increase", 0.0), data.get("tracking_intersection_threshold", 0.6), - data.get("paths", []), data.get("model_batch_size", 8), data.get("big_image_tile_overlap_percent", 20), diff --git a/src/inference.pyx b/src/inference.pyx index 4bc9996..c528084 100644 --- a/src/inference.pyx +++ b/src/inference.pyx @@ -1,7 +1,11 @@ +import io import mimetypes +import threading from pathlib import Path +import av import cv2 +import numpy as np cimport constants_inf from ai_availability_status cimport AIAvailabilityEnum, AIAvailabilityStatus @@ -13,6 +17,18 @@ from threading import Thread from engines import EngineClass +def ai_config_from_dict(dict data): + return AIRecognitionConfig.from_dict(data) + + +def _write_video_bytes_to_path(str path, bytes data, object done_event): + try: + with open(path, 'wb') as f: + f.write(data) + finally: + done_event.set() + + cdef class Inference: cdef LoaderHttpClient loader_client cdef InferenceEngine engine @@ -135,6 +151,7 @@ cdef class Inference: cpdef run_detect(self, dict config_dict, object annotation_callback, object status_callback=None): cdef list[str] videos = [] cdef list[str] images = [] + cdef object media_paths = config_dict.get("paths", []) cdef AIRecognitionConfig ai_config = AIRecognitionConfig.from_dict(config_dict) if ai_config is None: raise Exception('ai recognition config is empty') @@ -148,7 +165,7 @@ cdef class Inference: return self.detection_counts = {} - for p in ai_config.paths: + for p in media_paths: media_name = Path(p).stem.replace(" ", "") self.detection_counts[media_name] = 0 if self.is_video(p): @@ -163,22 +180,147 @@ cdef class Inference: constants_inf.log(f'run inference on {v}...') self._process_video(ai_config, v) + cpdef run_detect_image(self, bytes image_bytes, AIRecognitionConfig ai_config, str media_name, + object annotation_callback, object status_callback=None): + cdef list all_frame_data = [] + cdef str original_media_name + self._annotation_callback = annotation_callback + self._status_callback = status_callback + self.stop_signal = False + self.init_ai() + if self.engine is None: + constants_inf.log( "AI engine not available. Conversion may be in progress. Skipping inference.") + return + if not image_bytes: + return + frame = cv2.imdecode(np.frombuffer(image_bytes, dtype=np.uint8), cv2.IMREAD_COLOR) + if frame is None: + constants_inf.logerror('Failed to decode image bytes') + return + original_media_name = media_name.replace(" ", "") + self.detection_counts = {} + self.detection_counts[original_media_name] = 0 + self._tile_detections = {} + self._append_image_frame_entries(ai_config, all_frame_data, frame, original_media_name) + self._finalize_image_inference(ai_config, all_frame_data) + + cpdef run_detect_video(self, bytes video_bytes, AIRecognitionConfig ai_config, str media_name, str save_path, + object annotation_callback, object status_callback=None): + cdef str original_media_name + self._annotation_callback = annotation_callback + self._status_callback = status_callback + self.stop_signal = False + self.init_ai() + if self.engine is None: + constants_inf.log( "AI engine not available. Conversion may be in progress. Skipping inference.") + return + if not video_bytes: + return + original_media_name = media_name.replace(" ", "") + self.detection_counts = {} + self.detection_counts[original_media_name] = 0 + writer_done = threading.Event() + wt = threading.Thread( + target=_write_video_bytes_to_path, + args=(save_path, video_bytes, writer_done), + daemon=True, + ) + wt.start() + try: + bio = io.BytesIO(video_bytes) + container = av.open(bio) + try: + self._process_video_pyav(ai_config, original_media_name, container) + finally: + container.close() + finally: + writer_done.wait() + wt.join(timeout=3600) + + cdef _process_video_pyav(self, AIRecognitionConfig ai_config, str original_media_name, object container): + cdef int frame_count = 0 + cdef int batch_count = 0 + cdef list batch_frames = [] + cdef list[long] batch_timestamps = [] + cdef int model_h, model_w + cdef int total_frames + cdef int tf + cdef double duration_sec + cdef double fps + self._previous_annotation = None + model_h, model_w = self.engine.get_input_shape() + streams = container.streams.video + if not streams: + constants_inf.logerror('No video stream in container') + self.send_detection_status() + return + vstream = streams[0] + total_frames = 0 + if vstream.frames is not None and int(vstream.frames) > 0: + total_frames = int(vstream.frames) + else: + duration_sec = 0.0 + if vstream.duration is not None and vstream.time_base is not None: + duration_sec = float(vstream.duration * vstream.time_base) + fps = 25.0 + if vstream.average_rate is not None: + fps = float(vstream.average_rate) + if duration_sec > 0: + total_frames = int(duration_sec * fps) + if total_frames < 1: + total_frames = 1 + tf = total_frames + constants_inf.log(f'Video (PyAV): ~{tf} frames est, {vstream.width}x{vstream.height}') + cdef int effective_batch = min(self.engine.max_batch_size, ai_config.model_batch_size) + if effective_batch < 1: + effective_batch = 1 + for av_frame in container.decode(vstream): + if self.stop_signal: + break + frame_count += 1 + arr = av_frame.to_ndarray(format='bgr24') + if frame_count % ai_config.frame_period_recognition == 0: + ts_ms = 0 + if av_frame.time is not None: + ts_ms = int(av_frame.time * 1000) + elif av_frame.pts is not None and vstream.time_base is not None: + ts_ms = int(float(av_frame.pts) * float(vstream.time_base) * 1000) + batch_frames.append(arr) + batch_timestamps.append(ts_ms) + if len(batch_frames) >= effective_batch: + batch_count += 1 + tf = total_frames if total_frames > 0 else max(frame_count, 1) + constants_inf.log(f'Video batch {batch_count}: frame {frame_count}/{tf} ({frame_count*100//tf}%)') + self._process_video_batch(ai_config, batch_frames, batch_timestamps, original_media_name, frame_count, tf, model_w) + batch_frames = [] + batch_timestamps = [] + if batch_frames: + batch_count += 1 + tf = total_frames if total_frames > 0 else max(frame_count, 1) + constants_inf.log(f'Video batch {batch_count} (flush): {len(batch_frames)} remaining frames') + self._process_video_batch(ai_config, batch_frames, batch_timestamps, original_media_name, frame_count, tf, model_w) + constants_inf.log(f'Video done: {frame_count} frames read, {batch_count} batches processed') + self.send_detection_status() + cdef _process_video(self, AIRecognitionConfig ai_config, str video_name): cdef int frame_count = 0 cdef int batch_count = 0 cdef list batch_frames = [] cdef list[long] batch_timestamps = [] - cdef Annotation annotation cdef int model_h, model_w + cdef str original_media_name self._previous_annotation = None model_h, model_w = self.engine.get_input_shape() + original_media_name = Path(video_name).stem.replace(" ", "") v_input = cv2.VideoCapture(video_name) if not v_input.isOpened(): constants_inf.logerror(f'Failed to open video: {video_name}') return total_frames = int(v_input.get(cv2.CAP_PROP_FRAME_COUNT)) + if total_frames < 1: + total_frames = 1 fps = v_input.get(cv2.CAP_PROP_FPS) width = int(v_input.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(v_input.get(cv2.CAP_PROP_FRAME_HEIGHT)) @@ -201,21 +343,21 @@ cdef class Inference: if len(batch_frames) >= effective_batch: batch_count += 1 constants_inf.log(f'Video batch {batch_count}: frame {frame_count}/{total_frames} ({frame_count*100//total_frames}%)') - self._process_video_batch(ai_config, batch_frames, batch_timestamps, video_name, frame_count, total_frames, model_w) + self._process_video_batch(ai_config, batch_frames, batch_timestamps, original_media_name, frame_count, total_frames, model_w) batch_frames = [] batch_timestamps = [] if batch_frames: batch_count += 1 constants_inf.log(f'Video batch {batch_count} (flush): {len(batch_frames)} remaining frames') - self._process_video_batch(ai_config, batch_frames, batch_timestamps, video_name, frame_count, total_frames, model_w) + self._process_video_batch(ai_config, batch_frames, batch_timestamps, original_media_name, frame_count, total_frames, model_w) v_input.release() constants_inf.log(f'Video done: {frame_count} frames read, {batch_count} batches processed') self.send_detection_status() cdef _process_video_batch(self, AIRecognitionConfig ai_config, list batch_frames, - list batch_timestamps, str video_name, + list batch_timestamps, str original_media_name, int frame_count, int total_frames, int model_w): cdef Annotation annotation list_detections = self.engine.process_frames(batch_frames, ai_config) @@ -225,7 +367,6 @@ cdef class Inference: for i in range(len(list_detections)): detections = list_detections[i] - original_media_name = Path(video_name).stem.replace(" ", "") name = f'{original_media_name}_{constants_inf.format_time(batch_timestamps[i])}' annotation = Annotation(name, original_media_name, batch_timestamps[i], detections) @@ -247,56 +388,54 @@ cdef class Inference: cb = self._annotation_callback cb(annotation, percent) - cdef _process_images(self, AIRecognitionConfig ai_config, list[str] image_paths): - cdef list all_frame_data = [] + cdef _append_image_frame_entries(self, AIRecognitionConfig ai_config, list all_frame_data, frame, str original_media_name): cdef double ground_sampling_distance cdef int model_h, model_w - + cdef int img_h, img_w model_h, model_w = self.engine.get_input_shape() - self._tile_detections = {} - - for path in image_paths: - frame = cv2.imread(path) - if frame is None: - constants_inf.logerror(f'Failed to read image {path}') - continue - img_h, img_w, _ = frame.shape - original_media_name = Path( path).stem.replace(" ", "") - - ground_sampling_distance = ai_config.sensor_width * ai_config.altitude / (ai_config.focal_length * img_w) - constants_inf.log(f'ground sampling distance: {ground_sampling_distance}') - - if img_h <= 1.5 * model_h and img_w <= 1.5 * model_w: - all_frame_data.append((frame, original_media_name, f'{original_media_name}_000000', ground_sampling_distance)) - else: - tile_size = int(constants_inf.METERS_IN_TILE / ground_sampling_distance) - constants_inf.log( f'calc tile size: {tile_size}') - res = self.split_to_tiles(frame, path, tile_size, ai_config.big_image_tile_overlap_percent) - for tile_frame, omn, tile_name in res: - all_frame_data.append((tile_frame, omn, tile_name, ground_sampling_distance)) + img_h, img_w, _ = frame.shape + ground_sampling_distance = ai_config.sensor_width * ai_config.altitude / (ai_config.focal_length * img_w) + constants_inf.log(f'ground sampling distance: {ground_sampling_distance}') + if img_h <= 1.5 * model_h and img_w <= 1.5 * model_w: + all_frame_data.append((frame, original_media_name, f'{original_media_name}_000000', ground_sampling_distance)) + else: + tile_size = int(constants_inf.METERS_IN_TILE / ground_sampling_distance) + constants_inf.log( f'calc tile size: {tile_size}') + res = self.split_to_tiles(frame, original_media_name, tile_size, ai_config.big_image_tile_overlap_percent) + for tile_frame, omn, tile_name in res: + all_frame_data.append((tile_frame, omn, tile_name, ground_sampling_distance)) + cdef _finalize_image_inference(self, AIRecognitionConfig ai_config, list all_frame_data): if not all_frame_data: return - frames = [fd[0] for fd in all_frame_data] all_dets = self.engine.process_frames(frames, ai_config) - for i in range(len(all_dets)): frame_entry = all_frame_data[i] f = frame_entry[0] original_media_name = frame_entry[1] name = frame_entry[2] gsd = frame_entry[3] - annotation = Annotation(name, original_media_name, 0, all_dets[i]) if self.is_valid_image_annotation(annotation, gsd, f.shape): constants_inf.log( f'Detected {annotation}') _, image = cv2.imencode('.jpg', f) annotation.image = image.tobytes() self.on_annotation(annotation) - self.send_detection_status() + cdef _process_images(self, AIRecognitionConfig ai_config, list[str] image_paths): + cdef list all_frame_data = [] + self._tile_detections = {} + for path in image_paths: + frame = cv2.imread(path) + if frame is None: + constants_inf.logerror(f'Failed to read image {path}') + continue + original_media_name = Path( path).stem.replace(" ", "") + self._append_image_frame_entries(ai_config, all_frame_data, frame, original_media_name) + self._finalize_image_inference(ai_config, all_frame_data) + cdef send_detection_status(self): if self._status_callback is not None: cb = self._status_callback @@ -304,14 +443,14 @@ cdef class Inference: cb(media_name, self.detection_counts[media_name]) self.detection_counts.clear() - cdef split_to_tiles(self, frame, path, tile_size, overlap_percent): - constants_inf.log(f'splitting image {path} to tiles...') + cdef split_to_tiles(self, frame, str media_stem, tile_size, overlap_percent): + constants_inf.log(f'splitting image {media_stem} to tiles...') img_h, img_w, _ = frame.shape stride_w = int(tile_size * (1 - overlap_percent / 100)) stride_h = int(tile_size * (1 - overlap_percent / 100)) results = [] - original_media_name = Path( path).stem.replace(" ", "") + original_media_name = media_stem for y in range(0, img_h, stride_h): for x in range(0, img_w, stride_w): x_end = min(x + tile_size, img_w) diff --git a/src/loader_http_client.pxd b/src/loader_http_client.pxd index d0946fc..d60b367 100644 --- a/src/loader_http_client.pxd +++ b/src/loader_http_client.pxd @@ -6,3 +6,5 @@ cdef class LoaderHttpClient: cdef str base_url cdef LoadResult load_big_small_resource(self, str filename, str directory) cdef LoadResult upload_big_small_resource(self, bytes content, str filename, str directory) + cpdef object fetch_user_ai_settings(self, str user_id, str bearer_token) + cpdef object fetch_media_path(self, str media_id, str bearer_token) diff --git a/src/loader_http_client.pyx b/src/loader_http_client.pyx index 2a275da..d65aa17 100644 --- a/src/loader_http_client.pyx +++ b/src/loader_http_client.pyx @@ -41,3 +41,38 @@ cdef class LoaderHttpClient: except Exception as e: logger.error(f"LoaderHttpClient.upload_big_small_resource failed: {e}") return LoadResult(str(e)) + + cpdef object fetch_user_ai_settings(self, str user_id, str bearer_token): + try: + headers = {} + if bearer_token: + headers["Authorization"] = f"Bearer {bearer_token}" + response = requests.get( + f"{self.base_url}/api/users/{user_id}/ai-settings", + headers=headers, + timeout=30, + ) + if response.status_code != 200: + return None + return response.json() + except Exception as e: + logger.error(f"LoaderHttpClient.fetch_user_ai_settings failed: {e}") + return None + + cpdef object fetch_media_path(self, str media_id, str bearer_token): + try: + headers = {} + if bearer_token: + headers["Authorization"] = f"Bearer {bearer_token}" + response = requests.get( + f"{self.base_url}/api/media/{media_id}", + headers=headers, + timeout=30, + ) + if response.status_code != 200: + return None + data = response.json() + return data.get("path") + except Exception as e: + logger.error(f"LoaderHttpClient.fetch_media_path failed: {e}") + return None diff --git a/src/main.py b/src/main.py index 67dda84..c5d5848 100644 --- a/src/main.py +++ b/src/main.py @@ -4,10 +4,10 @@ import json import os import time from concurrent.futures import ThreadPoolExecutor -from typing import Optional +from typing import Annotated, Optional import requests as http_requests -from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request +from fastapi import Body, FastAPI, UploadFile, File, Form, HTTPException, Request from fastapi.responses import StreamingResponse from pydantic import BaseModel @@ -20,6 +20,7 @@ LOADER_URL = os.environ.get("LOADER_URL", "http://loader:8080") ANNOTATIONS_URL = os.environ.get("ANNOTATIONS_URL", "http://annotations:8080") loader_client = LoaderHttpClient(LOADER_URL) +annotations_client = LoaderHttpClient(ANNOTATIONS_URL) inference = None _event_queues: list[asyncio.Queue] = [] _active_detections: dict[str, asyncio.Task] = {} @@ -60,6 +61,29 @@ class TokenManager: except Exception: return None + @staticmethod + def decode_user_id(token: str) -> Optional[str]: + try: + payload = token.split(".")[1] + padding = 4 - len(payload) % 4 + if padding != 4: + payload += "=" * padding + data = json.loads(base64.urlsafe_b64decode(payload)) + uid = ( + data.get("sub") + or data.get("userId") + or data.get("user_id") + or data.get("nameid") + or data.get( + "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/nameidentifier" + ) + ) + if uid is None: + return None + return str(uid) + except Exception: + return None + def get_inference(): global inference @@ -105,7 +129,115 @@ class AIConfigDto(BaseModel): altitude: float = 400 focal_length: float = 24 sensor_width: float = 23.5 - paths: list[str] = [] + + +_AI_SETTINGS_FIELD_KEYS = ( + ( + "frame_period_recognition", + ("frame_period_recognition", "framePeriodRecognition", "FramePeriodRecognition"), + ), + ( + "frame_recognition_seconds", + ("frame_recognition_seconds", "frameRecognitionSeconds", "FrameRecognitionSeconds"), + ), + ( + "probability_threshold", + ("probability_threshold", "probabilityThreshold", "ProbabilityThreshold"), + ), + ( + "tracking_distance_confidence", + ( + "tracking_distance_confidence", + "trackingDistanceConfidence", + "TrackingDistanceConfidence", + ), + ), + ( + "tracking_probability_increase", + ( + "tracking_probability_increase", + "trackingProbabilityIncrease", + "TrackingProbabilityIncrease", + ), + ), + ( + "tracking_intersection_threshold", + ( + "tracking_intersection_threshold", + "trackingIntersectionThreshold", + "TrackingIntersectionThreshold", + ), + ), + ( + "model_batch_size", + ("model_batch_size", "modelBatchSize", "ModelBatchSize"), + ), + ( + "big_image_tile_overlap_percent", + ( + "big_image_tile_overlap_percent", + "bigImageTileOverlapPercent", + "BigImageTileOverlapPercent", + ), + ), + ( + "altitude", + ("altitude", "Altitude"), + ), + ( + "focal_length", + ("focal_length", "focalLength", "FocalLength"), + ), + ( + "sensor_width", + ("sensor_width", "sensorWidth", "SensorWidth"), + ), +) + + +def _merged_annotation_settings_payload(raw: object) -> dict: + if not raw or not isinstance(raw, dict): + return {} + merged = dict(raw) + inner = raw.get("aiRecognitionSettings") + if isinstance(inner, dict): + merged.update(inner) + cam = raw.get("cameraSettings") + if isinstance(cam, dict): + merged.update(cam) + out = {} + for snake, aliases in _AI_SETTINGS_FIELD_KEYS: + for key in aliases: + if key in merged and merged[key] is not None: + out[snake] = merged[key] + break + return out + + +def _build_media_detect_config_dict( + media_id: str, + token_mgr: Optional[TokenManager], + override: Optional[AIConfigDto], +) -> dict: + cfg: dict = {} + bearer = "" + if token_mgr: + bearer = token_mgr.get_valid_token() + uid = TokenManager.decode_user_id(token_mgr.access_token) + if uid: + raw = annotations_client.fetch_user_ai_settings(uid, bearer) + cfg.update(_merged_annotation_settings_payload(raw)) + if override is not None: + for k, v in override.model_dump(exclude_defaults=True).items(): + cfg[k] = v + media_path = annotations_client.fetch_media_path(media_id, bearer) + if not media_path: + raise HTTPException( + status_code=503, + detail="Could not resolve media path from annotations service", + ) + cfg["paths"] = [media_path] + return cfg def detection_to_dto(det) -> DetectionDto: @@ -150,9 +282,11 @@ async def detect_image( file: UploadFile = File(...), config: Optional[str] = Form(None), ): - import tempfile import cv2 import numpy as np + from pathlib import Path + + from inference import ai_config_from_dict image_bytes = await file.read() if not image_bytes: @@ -166,21 +300,21 @@ async def detect_image( if config: config_dict = json.loads(config) - suffix = os.path.splitext(file.filename or "upload.jpg")[1] or ".jpg" - tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) + media_name = Path(file.filename or "upload.jpg").stem.replace(" ", "") + loop = asyncio.get_event_loop() + inf = get_inference() + results = [] + + def on_annotation(annotation, percent): + results.extend(annotation.detections) + + ai_cfg = ai_config_from_dict(config_dict) + + def run_img(): + inf.run_detect_image(image_bytes, ai_cfg, media_name, on_annotation) + try: - tmp.write(image_bytes) - tmp.close() - config_dict["paths"] = [tmp.name] - - loop = asyncio.get_event_loop() - inf = get_inference() - results = [] - - def on_annotation(annotation, percent): - results.extend(annotation.detections) - - await loop.run_in_executor(executor, inf.run_detect, config_dict, on_annotation) + await loop.run_in_executor(executor, run_img) return [detection_to_dto(d) for d in results] except RuntimeError as e: if "not available" in str(e): @@ -188,8 +322,6 @@ async def detect_image( raise HTTPException(status_code=422, detail=str(e)) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) - finally: - os.unlink(tmp.name) def _post_annotation_to_service(token_mgr: TokenManager, media_id: str, @@ -216,7 +348,11 @@ def _post_annotation_to_service(token_mgr: TokenManager, media_id: str, @app.post("/detect/{media_id}") -async def detect_media(media_id: str, request: Request, config: Optional[AIConfigDto] = None): +async def detect_media( + media_id: str, + request: Request, + config: Annotated[Optional[AIConfigDto], Body()] = None, +): existing = _active_detections.get(media_id) if existing is not None and not existing.done(): raise HTTPException(status_code=409, detail="Detection already in progress for this media") @@ -226,8 +362,7 @@ async def detect_media(media_id: str, request: Request, config: Optional[AIConfi refresh_token = request.headers.get("x-refresh-token", "") token_mgr = TokenManager(access_token, refresh_token) if access_token else None - cfg = config or AIConfigDto() - config_dict = cfg.model_dump() + config_dict = _build_media_detect_config_dict(media_id, token_mgr, config) async def run_detection(): loop = asyncio.get_event_loop() diff --git a/tests/test_ai_config_from_dict.py b/tests/test_ai_config_from_dict.py new file mode 100644 index 0000000..be19a60 --- /dev/null +++ b/tests/test_ai_config_from_dict.py @@ -0,0 +1,15 @@ +def test_ai_config_from_dict_defaults(): + from inference import ai_config_from_dict + + cfg = ai_config_from_dict({}) + assert cfg.model_batch_size == 8 + assert cfg.frame_period_recognition == 4 + assert cfg.frame_recognition_seconds == 2 + + +def test_ai_config_from_dict_overrides(): + from inference import ai_config_from_dict + + cfg = ai_config_from_dict({"model_batch_size": 4, "probability_threshold": 0.5}) + assert cfg.model_batch_size == 4 + assert cfg.probability_threshold == 0.5 diff --git a/tests/test_az174_db_driven_config.py b/tests/test_az174_db_driven_config.py new file mode 100644 index 0000000..15319ef --- /dev/null +++ b/tests/test_az174_db_driven_config.py @@ -0,0 +1,126 @@ +import base64 +import json +import time +from unittest.mock import MagicMock, patch + +import pytest +from fastapi import HTTPException + + +def _access_jwt(sub: str = "u1") -> str: + raw = json.dumps( + {"exp": int(time.time()) + 3600, "sub": sub}, separators=(",", ":") + ).encode() + payload = base64.urlsafe_b64encode(raw).decode().rstrip("=") + return f"h.{payload}.s" + + +def test_token_manager_decode_user_id_sub(): + # Arrange + from main import TokenManager + + raw = json.dumps( + {"sub": "user-abc", "exp": int(time.time()) + 3600}, separators=(",", ":") + ).encode() + payload = base64.urlsafe_b64encode(raw).decode().rstrip("=") + token = f"hdr.{payload}.sig" + # Act + uid = TokenManager.decode_user_id(token) + # Assert + assert uid == "user-abc" + + +def test_token_manager_decode_user_id_invalid(): + # Arrange + from main import TokenManager + + # Act + uid = TokenManager.decode_user_id("not-a-jwt") + # Assert + assert uid is None + + +def test_merged_annotation_settings_pascal_case(): + # Arrange + from main import _merged_annotation_settings_payload + + raw = { + "FramePeriodRecognition": 5, + "ProbabilityThreshold": 0.4, + "Altitude": 300, + "FocalLength": 35, + "SensorWidth": 36, + } + # Act + out = _merged_annotation_settings_payload(raw) + # Assert + assert out["frame_period_recognition"] == 5 + assert out["probability_threshold"] == 0.4 + assert out["altitude"] == 300 + + +def test_merged_annotation_nested_sections(): + # Arrange + from main import _merged_annotation_settings_payload + + raw = { + "aiRecognitionSettings": {"modelBatchSize": 4}, + "cameraSettings": {"altitude": 100}, + } + # Act + out = _merged_annotation_settings_payload(raw) + # Assert + assert out["model_batch_size"] == 4 + assert out["altitude"] == 100 + + +def test_build_media_detect_config_uses_api_path_and_defaults_when_api_empty(): + # Arrange + import main + + tm = main.TokenManager(_access_jwt(), "") + mock_ann = MagicMock() + mock_ann.fetch_user_ai_settings.return_value = None + mock_ann.fetch_media_path.return_value = "/m/file.jpg" + with patch("main.annotations_client", mock_ann): + # Act + cfg = main._build_media_detect_config_dict("mid-1", tm, None) + # Assert + assert cfg["paths"] == ["/m/file.jpg"] + assert "probability_threshold" not in cfg + + +def test_build_media_detect_config_override_wins(): + # Arrange + import main + + tm = main.TokenManager(_access_jwt(), "") + override = main.AIConfigDto(probability_threshold=0.99) + mock_ann = MagicMock() + mock_ann.fetch_user_ai_settings.return_value = { + "probabilityThreshold": 0.2, + "altitude": 500, + } + mock_ann.fetch_media_path.return_value = "/m/v.mp4" + with patch("main.annotations_client", mock_ann): + # Act + cfg = main._build_media_detect_config_dict("vid-1", tm, override) + # Assert + assert cfg["probability_threshold"] == 0.99 + assert cfg["altitude"] == 500 + assert cfg["paths"] == ["/m/v.mp4"] + + +def test_build_media_detect_config_raises_when_no_media_path(): + # Arrange + import main + + tm = main.TokenManager(_access_jwt(), "") + mock_ann = MagicMock() + mock_ann.fetch_user_ai_settings.return_value = {} + mock_ann.fetch_media_path.return_value = None + with patch("main.annotations_client", mock_ann): + # Act / Assert + with pytest.raises(HTTPException) as exc: + main._build_media_detect_config_dict("missing", tm, None) + assert exc.value.status_code == 503