mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 05:26:32 +00:00
Initial commit
Made-with: Cursor
This commit is contained in:
+11
@@ -0,0 +1,11 @@
|
||||
*.pyc
|
||||
__pycache__/
|
||||
*.so
|
||||
*.o
|
||||
build/
|
||||
dist/
|
||||
*.egg-info/
|
||||
.env
|
||||
*.onnx
|
||||
*.trt
|
||||
*.engine
|
||||
@@ -0,0 +1,9 @@
|
||||
FROM python:3.11-slim
|
||||
RUN apt-get update && apt-get install -y python3-dev gcc libgl1 libglib2.0-0 && rm -rf /var/lib/apt/lists/*
|
||||
WORKDIR /app
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
COPY . .
|
||||
RUN python setup.py build_ext --inplace
|
||||
EXPOSE 8080
|
||||
CMD ["python", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8080"]
|
||||
@@ -0,0 +1,9 @@
|
||||
FROM nvidia/cuda:12.2.0-runtime-ubuntu22.04
|
||||
RUN apt-get update && apt-get install -y python3 python3-pip python3-dev gcc libgl1 libglib2.0-0 && rm -rf /var/lib/apt/lists/*
|
||||
WORKDIR /app
|
||||
COPY requirements.txt requirements-gpu.txt ./
|
||||
RUN pip3 install --no-cache-dir -r requirements-gpu.txt
|
||||
COPY . .
|
||||
RUN python3 setup.py build_ext --inplace
|
||||
EXPOSE 8080
|
||||
CMD ["python3", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8080"]
|
||||
@@ -0,0 +1,3 @@
|
||||
# Azaion.Detections
|
||||
|
||||
Cython/Python service for YOLO inference (TensorRT / ONNX Runtime). GPU-enabled container.
|
||||
@@ -0,0 +1,16 @@
|
||||
cdef enum AIAvailabilityEnum:
|
||||
NONE = 0
|
||||
DOWNLOADING = 10
|
||||
CONVERTING = 20
|
||||
UPLOADING = 30
|
||||
ENABLED = 200
|
||||
WARNING = 300
|
||||
ERROR = 500
|
||||
|
||||
cdef class AIAvailabilityStatus:
|
||||
cdef AIAvailabilityEnum status
|
||||
cdef str error_message
|
||||
cdef object _lock
|
||||
|
||||
cdef bytes serialize(self)
|
||||
cdef set_status(self, AIAvailabilityEnum status, str error_message=*)
|
||||
@@ -0,0 +1,55 @@
|
||||
cimport constants_inf
|
||||
import msgpack
|
||||
from threading import Lock
|
||||
|
||||
AIStatus2Text = {
|
||||
AIAvailabilityEnum.NONE: "None",
|
||||
AIAvailabilityEnum.DOWNLOADING: "Downloading",
|
||||
AIAvailabilityEnum.CONVERTING: "Converting",
|
||||
AIAvailabilityEnum.UPLOADING: "Uploading",
|
||||
AIAvailabilityEnum.ENABLED: "Enabled",
|
||||
AIAvailabilityEnum.WARNING: "Warning",
|
||||
AIAvailabilityEnum.ERROR: "Error",
|
||||
}
|
||||
|
||||
cdef class AIAvailabilityStatus:
|
||||
def __init__(self):
|
||||
self.status = AIAvailabilityEnum.NONE
|
||||
self.error_message = None
|
||||
self._lock = Lock()
|
||||
|
||||
def __str__(self):
|
||||
self._lock.acquire()
|
||||
try:
|
||||
status_text = AIStatus2Text.get(self.status, "Unknown")
|
||||
error_text = self.error_message if self.error_message else ""
|
||||
return f"{status_text} {error_text}"
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
cdef bytes serialize(self):
|
||||
self._lock.acquire()
|
||||
try:
|
||||
return msgpack.packb({
|
||||
"s": self.status,
|
||||
"m": self.error_message
|
||||
})
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
cdef set_status(self, AIAvailabilityEnum status, str error_message=None):
|
||||
log_message = ""
|
||||
self._lock.acquire()
|
||||
try:
|
||||
self.status = status
|
||||
self.error_message = error_message
|
||||
status_text = AIStatus2Text.get(self.status, "Unknown")
|
||||
error_text = self.error_message if self.error_message else ""
|
||||
log_message = f"{status_text} {error_text}"
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if error_message is not None:
|
||||
constants_inf.logerror(<str>error_message)
|
||||
else:
|
||||
constants_inf.log(<str>log_message)
|
||||
@@ -0,0 +1,22 @@
|
||||
cdef class AIRecognitionConfig:
|
||||
|
||||
cdef public double frame_recognition_seconds
|
||||
cdef public int frame_period_recognition
|
||||
cdef public double probability_threshold
|
||||
|
||||
cdef public double tracking_distance_confidence
|
||||
cdef public double tracking_probability_increase
|
||||
cdef public double tracking_intersection_threshold
|
||||
|
||||
cdef public int big_image_tile_overlap_percent
|
||||
|
||||
cdef public bytes file_data
|
||||
cdef public list[str] paths
|
||||
cdef public int model_batch_size
|
||||
|
||||
cdef public double altitude
|
||||
cdef public double focal_length
|
||||
cdef public double sensor_width
|
||||
|
||||
@staticmethod
|
||||
cdef from_msgpack(bytes data)
|
||||
@@ -0,0 +1,97 @@
|
||||
from msgpack import unpackb
|
||||
|
||||
cdef class AIRecognitionConfig:
|
||||
def __init__(self,
|
||||
frame_period_recognition,
|
||||
frame_recognition_seconds,
|
||||
probability_threshold,
|
||||
|
||||
tracking_distance_confidence,
|
||||
tracking_probability_increase,
|
||||
tracking_intersection_threshold,
|
||||
|
||||
file_data,
|
||||
paths,
|
||||
model_batch_size,
|
||||
|
||||
big_image_tile_overlap_percent,
|
||||
|
||||
altitude,
|
||||
focal_length,
|
||||
sensor_width
|
||||
):
|
||||
self.frame_period_recognition = frame_period_recognition
|
||||
self.frame_recognition_seconds = frame_recognition_seconds
|
||||
self.probability_threshold = probability_threshold
|
||||
|
||||
self.tracking_distance_confidence = tracking_distance_confidence
|
||||
self.tracking_probability_increase = tracking_probability_increase
|
||||
self.tracking_intersection_threshold = tracking_intersection_threshold
|
||||
|
||||
self.file_data = file_data
|
||||
self.paths = paths
|
||||
self.model_batch_size = model_batch_size
|
||||
|
||||
self.big_image_tile_overlap_percent = big_image_tile_overlap_percent
|
||||
|
||||
self.altitude = altitude
|
||||
self.focal_length = focal_length
|
||||
self.sensor_width = sensor_width
|
||||
|
||||
def __str__(self):
|
||||
return (f'frame_seconds : {self.frame_recognition_seconds}, distance_confidence : {self.tracking_distance_confidence}, '
|
||||
f'probability_increase : {self.tracking_probability_increase}, '
|
||||
f'intersection_threshold : {self.tracking_intersection_threshold}, '
|
||||
f'frame_period_recognition : {self.frame_period_recognition}, '
|
||||
f'big_image_tile_overlap_percent: {self.big_image_tile_overlap_percent}, '
|
||||
f'paths: {self.paths}, '
|
||||
f'model_batch_size: {self.model_batch_size}, '
|
||||
f'altitude: {self.altitude}, '
|
||||
f'focal_length: {self.focal_length}, '
|
||||
f'sensor_width: {self.sensor_width}'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
cdef from_msgpack(bytes data):
|
||||
unpacked = unpackb(data, strict_map_key=False)
|
||||
return AIRecognitionConfig(
|
||||
unpacked.get("f_pr", 0),
|
||||
unpacked.get("f_rs", 0.0),
|
||||
unpacked.get("pt", 0.0),
|
||||
|
||||
unpacked.get("t_dc", 0.0),
|
||||
unpacked.get("t_pi", 0.0),
|
||||
unpacked.get("t_it", 0.0),
|
||||
|
||||
unpacked.get("d", b''),
|
||||
unpacked.get("p", []),
|
||||
unpacked.get("m_bs"),
|
||||
|
||||
unpacked.get("ov_p", 20),
|
||||
|
||||
unpacked.get("cam_a", 400),
|
||||
unpacked.get("cam_fl", 24),
|
||||
unpacked.get("cam_sw", 23.5)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_dict(dict data):
|
||||
return AIRecognitionConfig(
|
||||
data.get("frame_period_recognition", 4),
|
||||
data.get("frame_recognition_seconds", 2),
|
||||
data.get("probability_threshold", 0.25),
|
||||
|
||||
data.get("tracking_distance_confidence", 0.0),
|
||||
data.get("tracking_probability_increase", 0.0),
|
||||
data.get("tracking_intersection_threshold", 0.6),
|
||||
|
||||
data.get("file_data", b''),
|
||||
data.get("paths", []),
|
||||
data.get("model_batch_size", 1),
|
||||
|
||||
data.get("big_image_tile_overlap_percent", 20),
|
||||
|
||||
data.get("altitude", 400),
|
||||
data.get("focal_length", 24),
|
||||
data.get("sensor_width", 23.5)
|
||||
)
|
||||
@@ -0,0 +1,15 @@
|
||||
cdef class Detection:
|
||||
cdef public double x, y, w, h, confidence
|
||||
cdef public str annotation_name
|
||||
cdef public int cls
|
||||
|
||||
cdef public overlaps(self, Detection det2, float confidence_threshold)
|
||||
|
||||
cdef class Annotation:
|
||||
cdef public str name
|
||||
cdef public str original_media_name
|
||||
cdef long time
|
||||
cdef public list[Detection] detections
|
||||
cdef public bytes image
|
||||
|
||||
cdef bytes serialize(self)
|
||||
@@ -0,0 +1,73 @@
|
||||
import msgpack
|
||||
cimport constants_inf
|
||||
|
||||
cdef class Detection:
|
||||
def __init__(self, double x, double y, double w, double h, int cls, double confidence):
|
||||
self.annotation_name = None
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.w = w
|
||||
self.h = h
|
||||
self.cls = cls
|
||||
self.confidence = confidence
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.cls}: {self.x:.2f} {self.y:.2f} {self.w:.2f} {self.h:.2f}, prob: {(self.confidence*100):.1f}%'
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Detection):
|
||||
return False
|
||||
|
||||
if max(abs(self.x - other.x),
|
||||
abs(self.y - other.y),
|
||||
abs(self.w - other.w),
|
||||
abs(self.h - other.h)) > constants_inf.TILE_DUPLICATE_CONFIDENCE_THRESHOLD:
|
||||
return False
|
||||
return True
|
||||
|
||||
cdef overlaps(self, Detection det2, float confidence_threshold):
|
||||
cdef double overlap_x = 0.5 * (self.w + det2.w) - abs(self.x - det2.x)
|
||||
cdef double overlap_y = 0.5 * (self.h + det2.h) - abs(self.y - det2.y)
|
||||
cdef double overlap_area = max(0.0, overlap_x) * max(0.0, overlap_y)
|
||||
cdef double min_area = min(self.w * self.h, det2.w * det2.h)
|
||||
|
||||
return overlap_area / min_area > confidence_threshold
|
||||
|
||||
cdef class Annotation:
|
||||
def __init__(self, str name, str original_media_name, long ms, list[Detection] detections):
|
||||
self.name = name
|
||||
self.original_media_name = original_media_name
|
||||
self.time = ms
|
||||
self.detections = detections if detections is not None else []
|
||||
for d in self.detections:
|
||||
d.annotation_name = self.name
|
||||
self.image = b''
|
||||
|
||||
def __str__(self):
|
||||
if not self.detections:
|
||||
return f"{self.name}: No detections"
|
||||
|
||||
detections_str = ", ".join(
|
||||
f"class: {d.cls} {d.confidence * 100:.1f}% ({d.x:.2f}, {d.y:.2f}) ({d.w:.2f}, {d.h:.2f})"
|
||||
for d in self.detections
|
||||
)
|
||||
return f"{self.name}: {detections_str}"
|
||||
|
||||
cdef bytes serialize(self):
|
||||
return msgpack.packb({
|
||||
"n": self.name,
|
||||
"mn": self.original_media_name,
|
||||
"i": self.image, # "i" = image
|
||||
"t": self.time, # "t" = time
|
||||
"d": [ # "d" = detections
|
||||
{
|
||||
"an": det.annotation_name,
|
||||
"x": det.x,
|
||||
"y": det.y,
|
||||
"w": det.w,
|
||||
"h": det.h,
|
||||
"c": det.cls,
|
||||
"p": det.confidence
|
||||
} for det in self.detections
|
||||
]
|
||||
})
|
||||
@@ -0,0 +1,21 @@
|
||||
[
|
||||
{ "Id": 0, "Name": "ArmorVehicle", "ShortName": "Броня", "Color": "#ff0000", "MaxSizeM": 8 },
|
||||
{ "Id": 1, "Name": "Truck", "ShortName": "Вантаж.", "Color": "#00ff00", "MaxSizeM": 8 },
|
||||
{ "Id": 2, "Name": "Vehicle", "ShortName": "Машина", "Color": "#0000ff", "MaxSizeM": 7 },
|
||||
{ "Id": 3, "Name": "Atillery", "ShortName": "Арта", "Color": "#ffff00", "MaxSizeM": 14 },
|
||||
{ "Id": 4, "Name": "Shadow", "ShortName": "Тінь", "Color": "#ff00ff", "MaxSizeM": 9 },
|
||||
{ "Id": 5, "Name": "Trenches", "ShortName": "Окопи", "Color": "#00ffff", "MaxSizeM": 10 },
|
||||
{ "Id": 6, "Name": "MilitaryMan", "ShortName": "Військов", "Color": "#188021", "MaxSizeM": 2 },
|
||||
{ "Id": 7, "Name": "TyreTracks", "ShortName": "Накати", "Color": "#800000", "MaxSizeM": 5 },
|
||||
{ "Id": 8, "Name": "AdditArmoredTank", "ShortName": "Танк.захист", "Color": "#008000", "MaxSizeM": 7 },
|
||||
{ "Id": 9, "Name": "Smoke", "ShortName": "Дим", "Color": "#000080", "MaxSizeM": 8 },
|
||||
{ "Id": 10, "Name": "Plane", "ShortName": "Літак", "Color": "#a52a2a", "MaxSizeM": 12 },
|
||||
{ "Id": 11, "Name": "Moto", "ShortName": "Мото", "Color": "#808000", "MaxSizeM": 3 },
|
||||
{ "Id": 12, "Name": "CamouflageNet", "ShortName": "Сітка", "Color": "#87ceeb", "MaxSizeM": 14 },
|
||||
{ "Id": 13, "Name": "CamouflageBranches", "ShortName": "Гілки", "Color": "#2f4f4f", "MaxSizeM": 8 },
|
||||
{ "Id": 14, "Name": "Roof", "ShortName": "Дах", "Color": "#1e90ff", "MaxSizeM": 15 },
|
||||
{ "Id": 15, "Name": "Building", "ShortName": "Будівля", "Color": "#ffb6c1", "MaxSizeM": 20 },
|
||||
{ "Id": 16, "Name": "Caponier", "ShortName": "Капонір", "Color": "#ffa500", "MaxSizeM": 10 },
|
||||
{ "Id": 17, "Name": "Ammo", "ShortName": "БК", "Color": "#33658a", "MaxSizeM": 2 },
|
||||
{ "Id": 18, "Name": "Protect.Struct", "ShortName": "Зуби.драк", "Color": "#969647", "MaxSizeM": 2 }
|
||||
]
|
||||
@@ -0,0 +1,55 @@
|
||||
/* Generated by Cython 3.1.2 */
|
||||
|
||||
#ifndef __PYX_HAVE__constants_inf
|
||||
#define __PYX_HAVE__constants_inf
|
||||
|
||||
#include "Python.h"
|
||||
|
||||
#ifndef __PYX_HAVE_API__constants_inf
|
||||
|
||||
#ifdef CYTHON_EXTERN_C
|
||||
#undef __PYX_EXTERN_C
|
||||
#define __PYX_EXTERN_C CYTHON_EXTERN_C
|
||||
#elif defined(__PYX_EXTERN_C)
|
||||
#ifdef _MSC_VER
|
||||
#pragma message ("Please do not define the '__PYX_EXTERN_C' macro externally. Use 'CYTHON_EXTERN_C' instead.")
|
||||
#else
|
||||
#warning Please do not define the '__PYX_EXTERN_C' macro externally. Use 'CYTHON_EXTERN_C' instead.
|
||||
#endif
|
||||
#else
|
||||
#ifdef __cplusplus
|
||||
#define __PYX_EXTERN_C extern "C"
|
||||
#else
|
||||
#define __PYX_EXTERN_C extern
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef DL_IMPORT
|
||||
#define DL_IMPORT(_T) _T
|
||||
#endif
|
||||
|
||||
__PYX_EXTERN_C int TILE_DUPLICATE_CONFIDENCE_THRESHOLD;
|
||||
|
||||
#endif /* !__PYX_HAVE_API__constants_inf */
|
||||
|
||||
/* WARNING: the interface of the module init function changed in CPython 3.5. */
|
||||
/* It now returns a PyModuleDef instance instead of a PyModule instance. */
|
||||
|
||||
/* WARNING: Use PyImport_AppendInittab("constants_inf", PyInit_constants_inf) instead of calling PyInit_constants_inf directly from Python 3.5 */
|
||||
PyMODINIT_FUNC PyInit_constants_inf(void);
|
||||
|
||||
#if PY_VERSION_HEX >= 0x03050000 && (defined(__GNUC__) || defined(__clang__) || defined(_MSC_VER) || (defined(__cplusplus) && __cplusplus >= 201402L))
|
||||
#if defined(__cplusplus) && __cplusplus >= 201402L
|
||||
[[deprecated("Use PyImport_AppendInittab(\"constants_inf\", PyInit_constants_inf) instead of calling PyInit_constants_inf directly.")]] inline
|
||||
#elif defined(__GNUC__) || defined(__clang__)
|
||||
__attribute__ ((__deprecated__("Use PyImport_AppendInittab(\"constants_inf\", PyInit_constants_inf) instead of calling PyInit_constants_inf directly."), __unused__)) __inline__
|
||||
#elif defined(_MSC_VER)
|
||||
__declspec(deprecated("Use PyImport_AppendInittab(\"constants_inf\", PyInit_constants_inf) instead of calling PyInit_constants_inf directly.")) __inline
|
||||
#endif
|
||||
static PyObject* __PYX_WARN_IF_PyInit_constants_inf_INIT_CALLED(PyObject* res) {
|
||||
return res;
|
||||
}
|
||||
#define PyInit_constants_inf() __PYX_WARN_IF_PyInit_constants_inf_INIT_CALLED(PyInit_constants_inf())
|
||||
#endif
|
||||
|
||||
#endif /* !__PYX_HAVE__constants_inf */
|
||||
@@ -0,0 +1,36 @@
|
||||
cdef str CONFIG_FILE # Port for the zmq
|
||||
|
||||
cdef int QUEUE_MAXSIZE # Maximum size of the command queue
|
||||
cdef str COMMANDS_QUEUE # Name of the commands queue in rabbit
|
||||
cdef str ANNOTATIONS_QUEUE # Name of the annotations queue in rabbit
|
||||
|
||||
cdef str QUEUE_CONFIG_FILENAME # queue config filename to load from api
|
||||
|
||||
cdef str AI_ONNX_MODEL_FILE
|
||||
|
||||
cdef str CDN_CONFIG
|
||||
cdef str MODELS_FOLDER
|
||||
|
||||
cdef int SMALL_SIZE_KB
|
||||
|
||||
cdef str SPLIT_SUFFIX
|
||||
cdef double TILE_DUPLICATE_CONFIDENCE_THRESHOLD
|
||||
cdef int METERS_IN_TILE
|
||||
|
||||
cdef log(str log_message)
|
||||
cdef logerror(str error)
|
||||
cdef format_time(int ms)
|
||||
|
||||
cdef dict[int, AnnotationClass] annotations_dict
|
||||
|
||||
cdef class AnnotationClass:
|
||||
cdef public int id
|
||||
cdef public str name
|
||||
cdef public str color
|
||||
cdef public int max_object_size_meters
|
||||
|
||||
cdef enum WeatherMode:
|
||||
Norm = 0
|
||||
Wint = 20
|
||||
Night = 40
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
import json
|
||||
import sys
|
||||
|
||||
from loguru import logger
|
||||
|
||||
cdef str CONFIG_FILE = "config.yaml" # Port for the zmq
|
||||
|
||||
cdef str QUEUE_CONFIG_FILENAME = "secured-config.json"
|
||||
cdef str AI_ONNX_MODEL_FILE = "azaion.onnx"
|
||||
|
||||
cdef str CDN_CONFIG = "cdn.yaml"
|
||||
cdef str MODELS_FOLDER = "models"
|
||||
|
||||
cdef int SMALL_SIZE_KB = 3
|
||||
|
||||
cdef str SPLIT_SUFFIX = "!split!"
|
||||
cdef double TILE_DUPLICATE_CONFIDENCE_THRESHOLD = 0.01
|
||||
cdef int METERS_IN_TILE = 25
|
||||
|
||||
cdef class AnnotationClass:
|
||||
def __init__(self, id, name, color, max_object_size_meters):
|
||||
self.id = id
|
||||
self.name = name
|
||||
self.color = color
|
||||
self.max_object_size_meters = max_object_size_meters
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.id} {self.name} {self.color} {self.max_object_size_meters}'
|
||||
|
||||
cdef int weather_switcher_increase = 20
|
||||
|
||||
WEATHER_MODE_NAMES = {
|
||||
Norm: "Norm",
|
||||
Wint: "Wint",
|
||||
Night: "Night"
|
||||
}
|
||||
|
||||
with open('classes.json', 'r', encoding='utf-8') as f:
|
||||
j = json.loads(f.read())
|
||||
annotations_dict = {}
|
||||
|
||||
for i in range(0, weather_switcher_increase * 3, weather_switcher_increase):
|
||||
for cl in j:
|
||||
id = i + cl['Id']
|
||||
mode_name = WEATHER_MODE_NAMES.get(i, "Unknown")
|
||||
name = cl['Name'] if i == 0 else f'{cl["Name"]}({mode_name})'
|
||||
annotations_dict[id] = AnnotationClass(id, name, cl['Color'], cl['MaxSizeM'])
|
||||
|
||||
logger.remove()
|
||||
log_format = "[{time:HH:mm:ss} {level}] {message}"
|
||||
logger.add(
|
||||
sink="Logs/log_inference_{time:YYYYMMDD}.txt",
|
||||
level="INFO",
|
||||
format=log_format,
|
||||
enqueue=True,
|
||||
rotation="1 day",
|
||||
retention="30 days",
|
||||
)
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
level="DEBUG",
|
||||
format=log_format,
|
||||
filter=lambda record: record["level"].name in ("INFO", "DEBUG", "SUCCESS"),
|
||||
colorize=True
|
||||
)
|
||||
logger.add(
|
||||
sys.stderr,
|
||||
level="WARNING",
|
||||
format=log_format,
|
||||
colorize=True
|
||||
)
|
||||
|
||||
cdef log(str log_message):
|
||||
logger.info(log_message)
|
||||
|
||||
cdef logerror(str error):
|
||||
logger.error(error)
|
||||
|
||||
cdef format_time(int ms):
|
||||
# Calculate hours, minutes, seconds, and hundreds of milliseconds.
|
||||
h = ms // 3600000 # Total full hours.
|
||||
ms_remaining = ms % 3600000
|
||||
m = ms_remaining // 60000 # Full minutes.
|
||||
ms_remaining %= 60000
|
||||
s = ms_remaining // 1000 # Full seconds.
|
||||
f = (ms_remaining % 1000) // 100 # Hundreds of milliseconds.
|
||||
h = h % 10
|
||||
return f"{h}{m:02}{s:02}{f}"
|
||||
@@ -0,0 +1,46 @@
|
||||
from ai_availability_status cimport AIAvailabilityStatus
|
||||
from annotation cimport Annotation, Detection
|
||||
from ai_config cimport AIRecognitionConfig
|
||||
from inference_engine cimport InferenceEngine
|
||||
|
||||
cdef class Inference:
|
||||
cdef object loader_client
|
||||
cdef InferenceEngine engine
|
||||
cdef object _annotation_callback
|
||||
cdef object _status_callback
|
||||
cdef Annotation _previous_annotation
|
||||
cdef dict[str, list(Detection)] _tile_detections
|
||||
cdef dict[str, int] detection_counts
|
||||
cdef AIRecognitionConfig ai_config
|
||||
cdef bint stop_signal
|
||||
cdef public AIAvailabilityStatus ai_availability_status
|
||||
|
||||
cdef str model_input
|
||||
cdef int model_width
|
||||
cdef int model_height
|
||||
|
||||
cdef bytes _converted_model_bytes
|
||||
cdef bytes get_onnx_engine_bytes(self)
|
||||
cdef convert_and_upload_model(self, bytes onnx_engine_bytes, str engine_filename)
|
||||
cdef init_ai(self)
|
||||
cdef bint is_building_engine
|
||||
cdef bint is_video(self, str filepath)
|
||||
|
||||
cpdef run_detect(self, dict config_dict, object annotation_callback, object status_callback=*)
|
||||
cpdef list detect_single_image(self, bytes image_bytes, dict config_dict)
|
||||
cdef _process_video(self, AIRecognitionConfig ai_config, str video_name)
|
||||
cdef _process_images(self, AIRecognitionConfig ai_config, list[str] image_paths)
|
||||
cdef _process_images_inner(self, AIRecognitionConfig ai_config, list frame_data, double ground_sampling_distance)
|
||||
cdef on_annotation(self, Annotation annotation, int frame_count=*, int total_frames=*)
|
||||
cdef split_to_tiles(self, frame, path, tile_size, overlap_percent)
|
||||
cpdef stop(self)
|
||||
|
||||
cdef preprocess(self, frames)
|
||||
cdef send_detection_status(self)
|
||||
cdef remove_overlapping_detections(self, list[Detection] detections, float confidence_threshold=?)
|
||||
cdef postprocess(self, output, ai_config)
|
||||
cdef split_list_extend(self, lst, chunk_size)
|
||||
|
||||
cdef bint is_valid_video_annotation(self, Annotation annotation, AIRecognitionConfig ai_config)
|
||||
cdef bint is_valid_image_annotation(self, Annotation annotation, double ground_sampling_distance, frame_shape)
|
||||
cdef remove_tiled_duplicates(self, Annotation annotation)
|
||||
+499
@@ -0,0 +1,499 @@
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
cimport constants_inf
|
||||
|
||||
from ai_availability_status cimport AIAvailabilityEnum, AIAvailabilityStatus
|
||||
from annotation cimport Detection, Annotation
|
||||
from ai_config cimport AIRecognitionConfig
|
||||
import pynvml
|
||||
from threading import Thread
|
||||
|
||||
cdef int tensor_gpu_index
|
||||
|
||||
cdef int check_tensor_gpu_index():
|
||||
try:
|
||||
pynvml.nvmlInit()
|
||||
deviceCount = pynvml.nvmlDeviceGetCount()
|
||||
|
||||
if deviceCount == 0:
|
||||
constants_inf.logerror(<str>'No NVIDIA GPUs found.')
|
||||
return -1
|
||||
|
||||
for i in range(deviceCount):
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
|
||||
major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
||||
|
||||
if major > 6 or (major == 6 and minor >= 1):
|
||||
constants_inf.log(<str>'found NVIDIA GPU!')
|
||||
return i
|
||||
|
||||
constants_inf.logerror(<str>'NVIDIA GPU doesnt support TensorRT!')
|
||||
return -1
|
||||
|
||||
except pynvml.NVMLError:
|
||||
return -1
|
||||
finally:
|
||||
try:
|
||||
pynvml.nvmlShutdown()
|
||||
except:
|
||||
constants_inf.logerror(<str>'Failed to shutdown pynvml cause probably no NVIDIA GPU')
|
||||
pass
|
||||
|
||||
tensor_gpu_index = check_tensor_gpu_index()
|
||||
if tensor_gpu_index > -1:
|
||||
from tensorrt_engine import TensorRTEngine
|
||||
else:
|
||||
from onnx_engine import OnnxEngine
|
||||
|
||||
|
||||
|
||||
|
||||
cdef class Inference:
|
||||
def __init__(self, loader_client):
|
||||
self.loader_client = loader_client
|
||||
self._annotation_callback = None
|
||||
self._status_callback = None
|
||||
self.stop_signal = False
|
||||
self.model_input = None
|
||||
self.model_width = 0
|
||||
self.model_height = 0
|
||||
self.detection_counts = {}
|
||||
self.engine = None
|
||||
self.is_building_engine = False
|
||||
self.ai_availability_status = AIAvailabilityStatus()
|
||||
self._converted_model_bytes = None
|
||||
self.init_ai()
|
||||
|
||||
|
||||
cdef bytes get_onnx_engine_bytes(self):
|
||||
models_dir = constants_inf.MODELS_FOLDER
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.DOWNLOADING)
|
||||
res = self.loader_client.load_big_small_resource(constants_inf.AI_ONNX_MODEL_FILE, models_dir)
|
||||
if res.err is not None:
|
||||
raise Exception(res.err)
|
||||
return res.data
|
||||
|
||||
cdef convert_and_upload_model(self, bytes onnx_engine_bytes, str engine_filename):
|
||||
try:
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.CONVERTING)
|
||||
models_dir = constants_inf.MODELS_FOLDER
|
||||
model_bytes = TensorRTEngine.convert_from_onnx(onnx_engine_bytes)
|
||||
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.UPLOADING)
|
||||
res = self.loader_client.upload_big_small_resource(model_bytes, engine_filename, models_dir)
|
||||
if res.err is not None:
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.WARNING, <str>f"Failed to upload converted model: {res.err}")
|
||||
|
||||
self._converted_model_bytes = model_bytes
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED)
|
||||
except Exception as e:
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str> str(e))
|
||||
self._converted_model_bytes = None
|
||||
finally:
|
||||
self.is_building_engine = False
|
||||
|
||||
cdef init_ai(self):
|
||||
constants_inf.log(<str> 'init AI...')
|
||||
try:
|
||||
if self.engine is not None:
|
||||
return
|
||||
if self.is_building_engine:
|
||||
return
|
||||
|
||||
if self._converted_model_bytes is not None:
|
||||
try:
|
||||
self.engine = TensorRTEngine(self._converted_model_bytes)
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED)
|
||||
self.model_height, self.model_width = self.engine.get_input_shape()
|
||||
except Exception as e:
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str> str(e))
|
||||
finally:
|
||||
self._converted_model_bytes = None # Consume the bytes
|
||||
return
|
||||
|
||||
models_dir = constants_inf.MODELS_FOLDER
|
||||
if tensor_gpu_index > -1:
|
||||
try:
|
||||
engine_filename = TensorRTEngine.get_engine_filename(0)
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.DOWNLOADING)
|
||||
res = self.loader_client.load_big_small_resource(engine_filename, models_dir)
|
||||
if res.err is not None:
|
||||
raise Exception(res.err)
|
||||
self.engine = TensorRTEngine(res.data)
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED)
|
||||
except Exception as e:
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.WARNING, <str>str(e))
|
||||
onnx_engine_bytes = self.get_onnx_engine_bytes()
|
||||
self.is_building_engine = True
|
||||
|
||||
thread = Thread(target=self.convert_and_upload_model, args=(onnx_engine_bytes, engine_filename))
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
return
|
||||
else:
|
||||
self.engine = OnnxEngine(<bytes>self.get_onnx_engine_bytes())
|
||||
self.is_building_engine = False
|
||||
|
||||
self.model_height, self.model_width = self.engine.get_input_shape()
|
||||
except Exception as e:
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str>str(e))
|
||||
self.is_building_engine = False
|
||||
|
||||
|
||||
cdef preprocess(self, frames):
|
||||
blobs = [cv2.dnn.blobFromImage(frame,
|
||||
scalefactor=1.0 / 255.0,
|
||||
size=(self.model_width, self.model_height),
|
||||
mean=(0, 0, 0),
|
||||
swapRB=True,
|
||||
crop=False)
|
||||
for frame in frames]
|
||||
return np.vstack(blobs)
|
||||
|
||||
cdef postprocess(self, output, ai_config):
|
||||
cdef list[Detection] detections = []
|
||||
cdef int ann_index
|
||||
cdef float x1, y1, x2, y2, conf, cx, cy, w, h
|
||||
cdef int class_id
|
||||
cdef list[list[Detection]] results = []
|
||||
try:
|
||||
for ann_index in range(len(output[0])):
|
||||
detections.clear()
|
||||
for det in output[0][ann_index]:
|
||||
if det[4] == 0: # if confidence is 0 then valid points are over.
|
||||
break
|
||||
x1 = det[0] / self.model_width
|
||||
y1 = det[1] / self.model_height
|
||||
x2 = det[2] / self.model_width
|
||||
y2 = det[3] / self.model_height
|
||||
conf = round(det[4], 2)
|
||||
class_id = int(det[5])
|
||||
|
||||
x = (x1 + x2) / 2
|
||||
y = (y1 + y2) / 2
|
||||
w = x2 - x1
|
||||
h = y2 - y1
|
||||
if conf >= ai_config.probability_threshold:
|
||||
detections.append(Detection(x, y, w, h, class_id, conf))
|
||||
filtered_detections = self.remove_overlapping_detections(detections, ai_config.tracking_intersection_threshold)
|
||||
results.append(filtered_detections)
|
||||
return results
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to postprocess: {str(e)}")
|
||||
|
||||
cdef remove_overlapping_detections(self, list[Detection] detections, float confidence_threshold=0.6):
|
||||
cdef Detection det1, det2
|
||||
filtered_output = []
|
||||
filtered_out_indexes = []
|
||||
|
||||
for det1_index in range(len(detections)):
|
||||
if det1_index in filtered_out_indexes:
|
||||
continue
|
||||
det1 = detections[det1_index]
|
||||
res = det1_index
|
||||
for det2_index in range(det1_index + 1, len(detections)):
|
||||
det2 = detections[det2_index]
|
||||
if det1.overlaps(det2, confidence_threshold):
|
||||
if det1.confidence > det2.confidence or (
|
||||
det1.confidence == det2.confidence and det1.cls < det2.cls): # det1 has higher confidence or lower class_id
|
||||
filtered_out_indexes.append(det2_index)
|
||||
else:
|
||||
filtered_out_indexes.append(res)
|
||||
res = det2_index
|
||||
filtered_output.append(detections[res])
|
||||
filtered_out_indexes.append(res)
|
||||
return filtered_output
|
||||
|
||||
cdef bint is_video(self, str filepath):
|
||||
mime_type, _ = mimetypes.guess_type(<str>filepath)
|
||||
return mime_type and mime_type.startswith("video")
|
||||
|
||||
cdef split_list_extend(self, lst, chunk_size):
|
||||
chunks = [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
||||
|
||||
# If the last chunk is smaller than the desired chunk_size, extend it by duplicating its last element.
|
||||
last_chunk = chunks[len(chunks) - 1]
|
||||
if len(last_chunk) < chunk_size:
|
||||
last_elem = last_chunk[len(last_chunk)-1]
|
||||
while len(last_chunk) < chunk_size:
|
||||
last_chunk.append(last_elem)
|
||||
return chunks
|
||||
|
||||
cpdef run_detect(self, dict config_dict, object annotation_callback, object status_callback=None):
|
||||
cdef list[str] videos = []
|
||||
cdef list[str] images = []
|
||||
cdef AIRecognitionConfig ai_config = AIRecognitionConfig.from_dict(config_dict)
|
||||
if ai_config is None:
|
||||
raise Exception('ai recognition config is empty')
|
||||
|
||||
self._annotation_callback = annotation_callback
|
||||
self._status_callback = status_callback
|
||||
self.stop_signal = False
|
||||
self.init_ai()
|
||||
if self.engine is None:
|
||||
constants_inf.log(<str> "AI engine not available. Conversion may be in progress. Skipping inference.")
|
||||
return
|
||||
|
||||
self.detection_counts = {}
|
||||
for p in ai_config.paths:
|
||||
media_name = Path(<str>p).stem.replace(" ", "")
|
||||
self.detection_counts[media_name] = 0
|
||||
if self.is_video(p):
|
||||
videos.append(p)
|
||||
else:
|
||||
images.append(p)
|
||||
if len(images) > 0:
|
||||
constants_inf.log(<str>f'run inference on {" ".join(images)}...')
|
||||
self._process_images(ai_config, images)
|
||||
if len(videos) > 0:
|
||||
for v in videos:
|
||||
constants_inf.log(<str>f'run inference on {v}...')
|
||||
self._process_video(ai_config, v)
|
||||
|
||||
cpdef list detect_single_image(self, bytes image_bytes, dict config_dict):
|
||||
cdef AIRecognitionConfig ai_config = AIRecognitionConfig.from_dict(config_dict)
|
||||
self.init_ai()
|
||||
if self.engine is None:
|
||||
raise RuntimeError("AI engine not available")
|
||||
|
||||
img_array = np.frombuffer(image_bytes, dtype=np.uint8)
|
||||
frame = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
|
||||
if frame is None:
|
||||
raise ValueError("Invalid image data")
|
||||
|
||||
input_blob = self.preprocess([frame])
|
||||
outputs = self.engine.run(input_blob)
|
||||
list_detections = self.postprocess(outputs, ai_config)
|
||||
if list_detections:
|
||||
return list_detections[0]
|
||||
return []
|
||||
|
||||
cdef _process_video(self, AIRecognitionConfig ai_config, str video_name):
|
||||
cdef int frame_count = 0
|
||||
cdef list batch_frames = []
|
||||
cdef list[int] batch_timestamps = []
|
||||
cdef Annotation annotation
|
||||
self._previous_annotation = None
|
||||
|
||||
|
||||
v_input = cv2.VideoCapture(<str>video_name)
|
||||
total_frames = int(v_input.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
while v_input.isOpened() and not self.stop_signal:
|
||||
ret, frame = v_input.read()
|
||||
if not ret or frame is None:
|
||||
break
|
||||
|
||||
frame_count += 1
|
||||
if frame_count % ai_config.frame_period_recognition == 0:
|
||||
batch_frames.append(frame)
|
||||
batch_timestamps.append(int(v_input.get(cv2.CAP_PROP_POS_MSEC)))
|
||||
|
||||
if len(batch_frames) == self.engine.get_batch_size():
|
||||
input_blob = self.preprocess(batch_frames)
|
||||
|
||||
outputs = self.engine.run(input_blob)
|
||||
|
||||
list_detections = self.postprocess(outputs, ai_config)
|
||||
for i in range(len(list_detections)):
|
||||
detections = list_detections[i]
|
||||
|
||||
original_media_name = Path(<str>video_name).stem.replace(" ", "")
|
||||
name = f'{original_media_name}_{constants_inf.format_time(batch_timestamps[i])}'
|
||||
annotation = Annotation(name, original_media_name, batch_timestamps[i], detections)
|
||||
|
||||
if self.is_valid_video_annotation(annotation, ai_config):
|
||||
_, image = cv2.imencode('.jpg', batch_frames[i])
|
||||
annotation.image = image.tobytes()
|
||||
self._previous_annotation = annotation
|
||||
self.on_annotation(annotation, frame_count, total_frames)
|
||||
|
||||
batch_frames.clear()
|
||||
batch_timestamps.clear()
|
||||
v_input.release()
|
||||
self.send_detection_status()
|
||||
|
||||
cdef on_annotation(self, Annotation annotation, int frame_count=0, int total_frames=0):
|
||||
self.detection_counts[annotation.original_media_name] = self.detection_counts.get(annotation.original_media_name, 0) + 1
|
||||
if self._annotation_callback is not None:
|
||||
percent = int(frame_count * 100 / total_frames) if total_frames > 0 else 0
|
||||
self._annotation_callback(annotation, percent)
|
||||
|
||||
cdef _process_images(self, AIRecognitionConfig ai_config, list[str] image_paths):
|
||||
cdef list frame_data
|
||||
self._tile_detections = {}
|
||||
for path in image_paths:
|
||||
frame_data = []
|
||||
frame = cv2.imread(<str>path)
|
||||
img_h, img_w, _ = frame.shape
|
||||
if frame is None:
|
||||
constants_inf.logerror(<str>f'Failed to read image {path}')
|
||||
continue
|
||||
original_media_name = Path(<str> path).stem.replace(" ", "")
|
||||
|
||||
ground_sampling_distance = ai_config.sensor_width * ai_config.altitude / (ai_config.focal_length * img_w)
|
||||
constants_inf.log(<str>f'ground sampling distance: {ground_sampling_distance}')
|
||||
|
||||
if img_h <= 1.5 * self.model_height and img_w <= 1.5 * self.model_width:
|
||||
frame_data.append((frame, original_media_name, f'{original_media_name}_000000'))
|
||||
else:
|
||||
tile_size = int(constants_inf.METERS_IN_TILE / ground_sampling_distance)
|
||||
constants_inf.log(<str> f'calc tile size: {tile_size}')
|
||||
res = self.split_to_tiles(frame, path, tile_size, ai_config.big_image_tile_overlap_percent)
|
||||
frame_data.extend(res)
|
||||
if len(frame_data) > self.engine.get_batch_size():
|
||||
for chunk in self.split_list_extend(frame_data, self.engine.get_batch_size()):
|
||||
self._process_images_inner(ai_config, chunk, ground_sampling_distance)
|
||||
self.send_detection_status()
|
||||
|
||||
for chunk in self.split_list_extend(frame_data, self.engine.get_batch_size()):
|
||||
self._process_images_inner(ai_config, chunk, ground_sampling_distance)
|
||||
self.send_detection_status()
|
||||
|
||||
cdef send_detection_status(self):
|
||||
if self._status_callback is not None:
|
||||
for media_name in self.detection_counts.keys():
|
||||
self._status_callback(media_name, self.detection_counts[media_name])
|
||||
self.detection_counts.clear()
|
||||
|
||||
cdef split_to_tiles(self, frame, path, tile_size, overlap_percent):
|
||||
constants_inf.log(<str>f'splitting image {path} to tiles...')
|
||||
img_h, img_w, _ = frame.shape
|
||||
stride_w = int(tile_size * (1 - overlap_percent / 100))
|
||||
stride_h = int(tile_size * (1 - overlap_percent / 100))
|
||||
|
||||
results = []
|
||||
original_media_name = Path(<str> path).stem.replace(" ", "")
|
||||
for y in range(0, img_h, stride_h):
|
||||
for x in range(0, img_w, stride_w):
|
||||
x_end = min(x + tile_size, img_w)
|
||||
y_end = min(y + tile_size, img_h)
|
||||
|
||||
# correct x,y for the close-to-border tiles
|
||||
if x_end - x < tile_size:
|
||||
if img_w - (x - stride_w) <= tile_size:
|
||||
continue # the previous tile already covered the last gap
|
||||
x = img_w - tile_size
|
||||
if y_end - y < tile_size:
|
||||
if img_h - (y - stride_h) <= tile_size:
|
||||
continue # the previous tile already covered the last gap
|
||||
y = img_h - tile_size
|
||||
|
||||
tile = frame[y:y_end, x:x_end]
|
||||
name = f'{original_media_name}{constants_inf.SPLIT_SUFFIX}{tile_size:04d}_{x:04d}_{y:04d}!_000000'
|
||||
results.append((tile, original_media_name, name))
|
||||
return results
|
||||
|
||||
cdef _process_images_inner(self, AIRecognitionConfig ai_config, list frame_data, double ground_sampling_distance):
|
||||
cdef list frames, original_media_names, names
|
||||
cdef Annotation annotation
|
||||
cdef int i
|
||||
frames, original_media_names, names = map(list, zip(*frame_data))
|
||||
|
||||
input_blob = self.preprocess(frames)
|
||||
outputs = self.engine.run(input_blob)
|
||||
|
||||
list_detections = self.postprocess(outputs, ai_config)
|
||||
for i in range(len(list_detections)):
|
||||
annotation = Annotation(names[i], original_media_names[i], 0, list_detections[i])
|
||||
if self.is_valid_image_annotation(annotation, ground_sampling_distance, frames[i].shape):
|
||||
constants_inf.log(<str> f'Detected {annotation}')
|
||||
_, image = cv2.imencode('.jpg', frames[i])
|
||||
annotation.image = image.tobytes()
|
||||
self.on_annotation(annotation)
|
||||
|
||||
cpdef stop(self):
|
||||
self.stop_signal = True
|
||||
|
||||
cdef remove_tiled_duplicates(self, Annotation annotation):
|
||||
right = annotation.name.rindex('!')
|
||||
left = annotation.name.index(constants_inf.SPLIT_SUFFIX) + len(constants_inf.SPLIT_SUFFIX)
|
||||
tile_size_str, x_str, y_str = annotation.name[left:right].split('_')
|
||||
tile_size = int(tile_size_str)
|
||||
x = int(x_str)
|
||||
y = int(y_str)
|
||||
|
||||
cdef list[Detection] unique_detections = []
|
||||
|
||||
existing_abs_detections = self._tile_detections.setdefault(annotation.original_media_name, [])
|
||||
|
||||
for det in annotation.detections:
|
||||
x1 = det.x * tile_size
|
||||
y1 = det.y * tile_size
|
||||
det_abs = Detection(x + x1, y + y1, det.w * tile_size, det.h * tile_size, det.cls, det.confidence)
|
||||
|
||||
if det_abs not in existing_abs_detections:
|
||||
unique_detections.append(det)
|
||||
existing_abs_detections.append(det_abs)
|
||||
|
||||
annotation.detections = unique_detections
|
||||
|
||||
cdef bint is_valid_image_annotation(self, Annotation annotation, double ground_sampling_distance, frame_shape):
|
||||
if constants_inf.SPLIT_SUFFIX in annotation.name:
|
||||
self.remove_tiled_duplicates(annotation)
|
||||
img_h, img_w, _ = frame_shape
|
||||
if annotation.detections:
|
||||
constants_inf.log(<str> f'Initial ann: {annotation}')
|
||||
|
||||
cdef list[Detection] valid_detections = []
|
||||
for det in annotation.detections:
|
||||
m_w = det.w * img_w * ground_sampling_distance
|
||||
m_h = det.h * img_h * ground_sampling_distance
|
||||
max_size = constants_inf.annotations_dict[det.cls].max_object_size_meters
|
||||
|
||||
if m_w <= max_size and m_h <= max_size:
|
||||
valid_detections.append(det)
|
||||
constants_inf.log(<str> f'Kept ({m_w} {m_h}) <= {max_size}. class: {constants_inf.annotations_dict[det.cls].name}')
|
||||
else:
|
||||
constants_inf.log(<str> f'Removed ({m_w} {m_h}) > {max_size}. class: {constants_inf.annotations_dict[det.cls].name}')
|
||||
|
||||
annotation.detections = valid_detections
|
||||
|
||||
if not annotation.detections:
|
||||
return False
|
||||
return True
|
||||
|
||||
cdef bint is_valid_video_annotation(self, Annotation annotation, AIRecognitionConfig ai_config):
|
||||
if constants_inf.SPLIT_SUFFIX in annotation.name:
|
||||
self.remove_tiled_duplicates(annotation)
|
||||
if not annotation.detections:
|
||||
return False
|
||||
|
||||
if self._previous_annotation is None:
|
||||
return True
|
||||
|
||||
if annotation.time >= self._previous_annotation.time + <long>(ai_config.frame_recognition_seconds * 1000):
|
||||
return True
|
||||
|
||||
if len(annotation.detections) > len(self._previous_annotation.detections):
|
||||
return True
|
||||
|
||||
cdef:
|
||||
Detection current_det, prev_det
|
||||
double dx, dy, distance_sq, min_distance_sq
|
||||
Detection closest_det
|
||||
|
||||
for current_det in annotation.detections:
|
||||
min_distance_sq = 1e18
|
||||
closest_det = None
|
||||
|
||||
for prev_det in self._previous_annotation.detections:
|
||||
dx = current_det.x - prev_det.x
|
||||
dy = current_det.y - prev_det.y
|
||||
distance_sq = dx * dx + dy * dy
|
||||
|
||||
if distance_sq < min_distance_sq:
|
||||
min_distance_sq = distance_sq
|
||||
closest_det = prev_det
|
||||
|
||||
dist_px = ai_config.tracking_distance_confidence * self.model_width
|
||||
dist_px_sq = dist_px * dist_px
|
||||
if min_distance_sq > dist_px_sq:
|
||||
return True
|
||||
|
||||
if current_det.confidence >= closest_det.confidence + ai_config.tracking_probability_increase:
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -0,0 +1,9 @@
|
||||
from typing import List, Tuple
|
||||
import numpy as np
|
||||
|
||||
|
||||
cdef class InferenceEngine:
|
||||
cdef public int batch_size
|
||||
cdef tuple get_input_shape(self)
|
||||
cdef int get_batch_size(self)
|
||||
cdef run(self, input_data)
|
||||
@@ -0,0 +1,12 @@
|
||||
cdef class InferenceEngine:
|
||||
def __init__(self, model_bytes: bytes, batch_size: int = 1, **kwargs):
|
||||
self.batch_size = batch_size
|
||||
|
||||
cdef tuple get_input_shape(self):
|
||||
raise NotImplementedError("Subclass must implement get_input_shape")
|
||||
|
||||
cdef int get_batch_size(self):
|
||||
return self.batch_size
|
||||
|
||||
cdef run(self, input_data):
|
||||
raise NotImplementedError("Subclass must implement run")
|
||||
@@ -0,0 +1,42 @@
|
||||
import requests
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class LoadResult:
|
||||
def __init__(self, err, data=None):
|
||||
self.err = err
|
||||
self.data = data
|
||||
|
||||
|
||||
class LoaderHttpClient:
|
||||
def __init__(self, base_url: str):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
|
||||
def load_big_small_resource(self, filename: str, directory: str) -> LoadResult:
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/load/{filename}",
|
||||
json={"filename": filename, "folder": directory},
|
||||
stream=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return LoadResult(None, response.content)
|
||||
except Exception as e:
|
||||
logger.error(f"LoaderHttpClient.load_big_small_resource failed: {e}")
|
||||
return LoadResult(str(e))
|
||||
|
||||
def upload_big_small_resource(self, content: bytes, filename: str, directory: str) -> LoadResult:
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/upload/{filename}",
|
||||
files={"data": (filename, content)},
|
||||
data={"folder": directory},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return LoadResult(None)
|
||||
except Exception as e:
|
||||
logger.error(f"LoaderHttpClient.upload_big_small_resource failed: {e}")
|
||||
return LoadResult(str(e))
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
@@ -0,0 +1,291 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
|
||||
import requests as http_requests
|
||||
from fastapi import FastAPI, UploadFile, File, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from loader_http_client import LoaderHttpClient, LoadResult
|
||||
|
||||
app = FastAPI(title="Azaion.Detections")
|
||||
executor = ThreadPoolExecutor(max_workers=2)
|
||||
|
||||
LOADER_URL = os.environ.get("LOADER_URL", "http://loader:8080")
|
||||
ANNOTATIONS_URL = os.environ.get("ANNOTATIONS_URL", "http://annotations:8080")
|
||||
|
||||
loader_client = LoaderHttpClient(LOADER_URL)
|
||||
inference = None
|
||||
_event_queues: list[asyncio.Queue] = []
|
||||
_active_detections: dict[str, bool] = {}
|
||||
|
||||
|
||||
class TokenManager:
|
||||
def __init__(self, access_token: str, refresh_token: str):
|
||||
self.access_token = access_token
|
||||
self.refresh_token = refresh_token
|
||||
|
||||
def get_valid_token(self) -> str:
|
||||
exp = self._decode_exp(self.access_token)
|
||||
if exp and exp - time.time() < 60:
|
||||
self._refresh()
|
||||
return self.access_token
|
||||
|
||||
def _refresh(self):
|
||||
try:
|
||||
resp = http_requests.post(
|
||||
f"{ANNOTATIONS_URL}/auth/refresh",
|
||||
json={"refreshToken": self.refresh_token},
|
||||
timeout=10,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
self.access_token = resp.json()["token"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _decode_exp(token: str) -> Optional[float]:
|
||||
try:
|
||||
payload = token.split(".")[1]
|
||||
padding = 4 - len(payload) % 4
|
||||
if padding != 4:
|
||||
payload += "=" * padding
|
||||
data = json.loads(base64.urlsafe_b64decode(payload))
|
||||
return float(data.get("exp", 0))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def get_inference():
|
||||
global inference
|
||||
if inference is None:
|
||||
from inference import Inference
|
||||
inference = Inference(loader_client)
|
||||
return inference
|
||||
|
||||
|
||||
class DetectionDto(BaseModel):
|
||||
centerX: float
|
||||
centerY: float
|
||||
width: float
|
||||
height: float
|
||||
classNum: int
|
||||
label: str
|
||||
confidence: float
|
||||
|
||||
|
||||
class DetectionEvent(BaseModel):
|
||||
annotations: list[DetectionDto]
|
||||
mediaId: str
|
||||
mediaStatus: str
|
||||
mediaPercent: int
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
aiAvailability: str
|
||||
errorMessage: Optional[str] = None
|
||||
|
||||
|
||||
class AIConfigDto(BaseModel):
|
||||
frame_period_recognition: int = 4
|
||||
frame_recognition_seconds: int = 2
|
||||
probability_threshold: float = 0.25
|
||||
tracking_distance_confidence: float = 0.0
|
||||
tracking_probability_increase: float = 0.0
|
||||
tracking_intersection_threshold: float = 0.6
|
||||
model_batch_size: int = 1
|
||||
big_image_tile_overlap_percent: int = 20
|
||||
altitude: float = 400
|
||||
focal_length: float = 24
|
||||
sensor_width: float = 23.5
|
||||
paths: list[str] = []
|
||||
|
||||
|
||||
def detection_to_dto(det) -> DetectionDto:
|
||||
import constants_inf
|
||||
label = ""
|
||||
if det.cls in constants_inf.annotations_dict:
|
||||
label = constants_inf.annotations_dict[det.cls].name
|
||||
return DetectionDto(
|
||||
centerX=det.x,
|
||||
centerY=det.y,
|
||||
width=det.w,
|
||||
height=det.h,
|
||||
classNum=det.cls,
|
||||
label=label,
|
||||
confidence=det.confidence,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> HealthResponse:
|
||||
try:
|
||||
inf = get_inference()
|
||||
status = inf.ai_availability_status
|
||||
status_str = str(status).split()[0] if str(status).strip() else "None"
|
||||
error_msg = status.error_message if hasattr(status, 'error_message') else None
|
||||
return HealthResponse(
|
||||
status="healthy",
|
||||
aiAvailability=status_str,
|
||||
errorMessage=error_msg,
|
||||
)
|
||||
except Exception:
|
||||
return HealthResponse(
|
||||
status="healthy",
|
||||
aiAvailability="None",
|
||||
errorMessage=None,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/detect")
|
||||
async def detect_image(
|
||||
file: UploadFile = File(...),
|
||||
config: Optional[str] = None,
|
||||
):
|
||||
image_bytes = await file.read()
|
||||
if not image_bytes:
|
||||
raise HTTPException(status_code=400, detail="Image is empty")
|
||||
|
||||
config_dict = {}
|
||||
if config:
|
||||
config_dict = json.loads(config)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
inf = get_inference()
|
||||
detections = await loop.run_in_executor(
|
||||
executor, inf.detect_single_image, image_bytes, config_dict
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if "not available" in str(e):
|
||||
raise HTTPException(status_code=503, detail=str(e))
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return [detection_to_dto(d) for d in detections]
|
||||
|
||||
|
||||
def _post_annotation_to_service(token_mgr: TokenManager, media_id: str,
|
||||
annotation, dtos: list[DetectionDto]):
|
||||
try:
|
||||
token = token_mgr.get_valid_token()
|
||||
image_b64 = base64.b64encode(annotation.image).decode() if annotation.image else None
|
||||
payload = {
|
||||
"mediaId": media_id,
|
||||
"source": 0,
|
||||
"videoTime": f"00:00:{annotation.time // 1000:02d}" if annotation.time else "00:00:00",
|
||||
"detections": [d.model_dump() for d in dtos],
|
||||
}
|
||||
if image_b64:
|
||||
payload["image"] = image_b64
|
||||
http_requests.post(
|
||||
f"{ANNOTATIONS_URL}/annotations",
|
||||
json=payload,
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=30,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@app.post("/detect/{media_id}")
|
||||
async def detect_media(media_id: str, request: Request, config: Optional[AIConfigDto] = None):
|
||||
if media_id in _active_detections:
|
||||
raise HTTPException(status_code=409, detail="Detection already in progress for this media")
|
||||
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
access_token = auth_header.removeprefix("Bearer ").strip() if auth_header else ""
|
||||
refresh_token = request.headers.get("x-refresh-token", "")
|
||||
token_mgr = TokenManager(access_token, refresh_token) if access_token else None
|
||||
|
||||
cfg = config or AIConfigDto()
|
||||
config_dict = cfg.model_dump()
|
||||
|
||||
_active_detections[media_id] = True
|
||||
|
||||
async def run_detection():
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
inf = get_inference()
|
||||
if inf.engine is None:
|
||||
raise RuntimeError("Detection service unavailable")
|
||||
|
||||
def on_annotation(annotation, percent):
|
||||
dtos = [detection_to_dto(d) for d in annotation.detections]
|
||||
event = DetectionEvent(
|
||||
annotations=dtos,
|
||||
mediaId=media_id,
|
||||
mediaStatus="AIProcessing",
|
||||
mediaPercent=percent,
|
||||
)
|
||||
for q in _event_queues:
|
||||
try:
|
||||
q.put_nowait(event)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
|
||||
if token_mgr and dtos:
|
||||
_post_annotation_to_service(token_mgr, media_id, annotation, dtos)
|
||||
|
||||
def on_status(media_name, count):
|
||||
event = DetectionEvent(
|
||||
annotations=[],
|
||||
mediaId=media_id,
|
||||
mediaStatus="AIProcessed",
|
||||
mediaPercent=100,
|
||||
)
|
||||
for q in _event_queues:
|
||||
try:
|
||||
q.put_nowait(event)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
|
||||
await loop.run_in_executor(
|
||||
executor, inf.run_detect, config_dict, on_annotation, on_status
|
||||
)
|
||||
except Exception:
|
||||
error_event = DetectionEvent(
|
||||
annotations=[],
|
||||
mediaId=media_id,
|
||||
mediaStatus="Error",
|
||||
mediaPercent=0,
|
||||
)
|
||||
for q in _event_queues:
|
||||
try:
|
||||
q.put_nowait(error_event)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
finally:
|
||||
_active_detections.pop(media_id, None)
|
||||
|
||||
asyncio.create_task(run_detection())
|
||||
return {"status": "started", "mediaId": media_id}
|
||||
|
||||
|
||||
@app.get("/detect/stream")
|
||||
async def detect_stream():
|
||||
queue: asyncio.Queue = asyncio.Queue(maxsize=100)
|
||||
_event_queues.append(queue)
|
||||
|
||||
async def event_generator():
|
||||
try:
|
||||
while True:
|
||||
event = await queue.get()
|
||||
yield f"data: {event.model_dump_json()}\n\n"
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
_event_queues.remove(queue)
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
@@ -0,0 +1,26 @@
|
||||
from inference_engine cimport InferenceEngine
|
||||
import onnxruntime as onnx
|
||||
cimport constants_inf
|
||||
|
||||
cdef class OnnxEngine(InferenceEngine):
|
||||
def __init__(self, model_bytes: bytes, batch_size: int = 1, **kwargs):
|
||||
super().__init__(model_bytes, batch_size)
|
||||
|
||||
self.session = onnx.InferenceSession(model_bytes, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
|
||||
self.model_inputs = self.session.get_inputs()
|
||||
self.input_name = self.model_inputs[0].name
|
||||
self.input_shape = self.model_inputs[0].shape
|
||||
self.batch_size = self.input_shape[0] if self.input_shape[0] != -1 else batch_size
|
||||
constants_inf.log(f'AI detection model input: {self.model_inputs} {self.input_shape}')
|
||||
model_meta = self.session.get_modelmeta()
|
||||
constants_inf.log(f"Metadata: {model_meta.custom_metadata_map}")
|
||||
|
||||
cdef tuple get_input_shape(self):
|
||||
shape = self.input_shape
|
||||
return shape[2], shape[3]
|
||||
|
||||
cdef int get_batch_size(self):
|
||||
return self.batch_size
|
||||
|
||||
cdef run(self, input_data):
|
||||
return self.session.run(None, {self.input_name: input_data})
|
||||
@@ -0,0 +1,4 @@
|
||||
-r requirements.txt
|
||||
onnxruntime-gpu==1.22.0
|
||||
pycuda==2025.1.1
|
||||
tensorrt==10.11.0.33
|
||||
@@ -0,0 +1,11 @@
|
||||
fastapi
|
||||
uvicorn[standard]
|
||||
Cython==3.1.3
|
||||
opencv-python==4.10.0.84
|
||||
numpy==2.3.0
|
||||
onnxruntime==1.22.0
|
||||
pynvml==12.0.0
|
||||
requests==2.32.4
|
||||
loguru==0.7.3
|
||||
python-multipart
|
||||
msgpack==1.1.1
|
||||
@@ -0,0 +1,36 @@
|
||||
from setuptools import setup, Extension
|
||||
from Cython.Build import cythonize
|
||||
import numpy as np
|
||||
|
||||
extensions = [
|
||||
Extension('constants_inf', ['constants_inf.pyx']),
|
||||
Extension('ai_availability_status', ['ai_availability_status.pyx']),
|
||||
Extension('annotation', ['annotation.pyx']),
|
||||
Extension('ai_config', ['ai_config.pyx']),
|
||||
Extension('onnx_engine', ['onnx_engine.pyx'], include_dirs=[np.get_include()]),
|
||||
Extension('inference_engine', ['inference_engine.pyx'], include_dirs=[np.get_include()]),
|
||||
Extension('inference', ['inference.pyx'], include_dirs=[np.get_include()]),
|
||||
]
|
||||
|
||||
try:
|
||||
import tensorrt
|
||||
extensions.append(
|
||||
Extension('tensorrt_engine', ['tensorrt_engine.pyx'], include_dirs=[np.get_include()])
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
setup(
|
||||
name="azaion.detections",
|
||||
ext_modules=cythonize(
|
||||
extensions,
|
||||
compiler_directives={
|
||||
"language_level": 3,
|
||||
"emit_code_comments": False,
|
||||
"binding": True,
|
||||
'boundscheck': False,
|
||||
'wraparound': False,
|
||||
}
|
||||
),
|
||||
zip_safe=False
|
||||
)
|
||||
@@ -0,0 +1,24 @@
|
||||
from inference_engine cimport InferenceEngine
|
||||
|
||||
|
||||
cdef class TensorRTEngine(InferenceEngine):
|
||||
|
||||
cdef public object context
|
||||
|
||||
cdef public object d_input
|
||||
cdef public object d_output
|
||||
cdef str input_name
|
||||
cdef object input_shape
|
||||
|
||||
cdef object h_output
|
||||
cdef str output_name
|
||||
cdef object output_shape
|
||||
|
||||
cdef object stream
|
||||
|
||||
|
||||
cdef tuple get_input_shape(self)
|
||||
|
||||
cdef int get_batch_size(self)
|
||||
|
||||
cdef run(self, input_data)
|
||||
@@ -0,0 +1,136 @@
|
||||
from inference_engine cimport InferenceEngine
|
||||
import tensorrt as trt
|
||||
import pycuda.driver as cuda
|
||||
import pycuda.autoinit # required for automatically initialize CUDA, do not remove.
|
||||
import pynvml
|
||||
import numpy as np
|
||||
cimport constants_inf
|
||||
|
||||
|
||||
cdef class TensorRTEngine(InferenceEngine):
|
||||
def __init__(self, model_bytes: bytes, batch_size: int = 4, **kwargs):
|
||||
super().__init__(model_bytes, batch_size)
|
||||
try:
|
||||
logger = trt.Logger(trt.Logger.WARNING)
|
||||
|
||||
runtime = trt.Runtime(logger)
|
||||
engine = runtime.deserialize_cuda_engine(model_bytes)
|
||||
|
||||
if engine is None:
|
||||
raise RuntimeError(f"Failed to load TensorRT engine from bytes")
|
||||
|
||||
self.context = engine.create_execution_context()
|
||||
|
||||
# input
|
||||
self.input_name = engine.get_tensor_name(0)
|
||||
engine_input_shape = engine.get_tensor_shape(self.input_name)
|
||||
if engine_input_shape[0] != -1:
|
||||
self.batch_size = engine_input_shape[0]
|
||||
else:
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.input_shape = [
|
||||
self.batch_size,
|
||||
engine_input_shape[1], # Channels (usually fixed at 3 for RGB)
|
||||
1280 if engine_input_shape[2] == -1 else engine_input_shape[2], # Height
|
||||
1280 if engine_input_shape[3] == -1 else engine_input_shape[3] # Width
|
||||
]
|
||||
self.context.set_input_shape(self.input_name, self.input_shape)
|
||||
input_size = trt.volume(self.input_shape) * np.dtype(np.float32).itemsize
|
||||
self.d_input = cuda.mem_alloc(input_size)
|
||||
|
||||
# output
|
||||
self.output_name = engine.get_tensor_name(1)
|
||||
engine_output_shape = tuple(engine.get_tensor_shape(self.output_name))
|
||||
self.output_shape = [
|
||||
self.batch_size,
|
||||
300 if engine_output_shape[1] == -1 else engine_output_shape[1], # max detections number
|
||||
6 if engine_output_shape[2] == -1 else engine_output_shape[2] # x1 y1 x2 y2 conf cls
|
||||
]
|
||||
self.h_output = cuda.pagelocked_empty(tuple(self.output_shape), dtype=np.float32)
|
||||
self.d_output = cuda.mem_alloc(self.h_output.nbytes)
|
||||
|
||||
self.stream = cuda.Stream()
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to initialize TensorRT engine: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def get_gpu_memory_bytes(int device_id):
|
||||
total_memory = None
|
||||
try:
|
||||
pynvml.nvmlInit()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
|
||||
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
total_memory = mem_info.total
|
||||
except pynvml.NVMLError:
|
||||
total_memory = None
|
||||
finally:
|
||||
try:
|
||||
pynvml.nvmlShutdown()
|
||||
except pynvml.NVMLError:
|
||||
pass
|
||||
return 2 * 1024 * 1024 * 1024 if total_memory is None else total_memory # default 2 Gb
|
||||
|
||||
@staticmethod
|
||||
def get_engine_filename(int device_id):
|
||||
try:
|
||||
device = cuda.Device(device_id)
|
||||
sm_count = device.multiprocessor_count
|
||||
cc_major, cc_minor = device.compute_capability()
|
||||
return f"azaion.cc_{cc_major}.{cc_minor}_sm_{sm_count}.engine"
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def convert_from_onnx(bytes onnx_model):
|
||||
workspace_bytes = int(TensorRTEngine.get_gpu_memory_bytes(0) * 0.9)
|
||||
|
||||
explicit_batch_flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
trt_logger = trt.Logger(trt.Logger.WARNING)
|
||||
|
||||
with trt.Builder(trt_logger) as builder, \
|
||||
builder.create_network(explicit_batch_flag) as network, \
|
||||
trt.OnnxParser(network, trt_logger) as parser, \
|
||||
builder.create_builder_config() as config:
|
||||
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_bytes)
|
||||
|
||||
if not parser.parse(onnx_model):
|
||||
return None
|
||||
|
||||
if builder.platform_has_fast_fp16:
|
||||
constants_inf.log(<str>'Converting to supported fp16')
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
else:
|
||||
constants_inf.log(<str>'Converting to supported fp32. (fp16 is not supported)')
|
||||
plan = builder.build_serialized_network(network, config)
|
||||
|
||||
if plan is None:
|
||||
constants_inf.logerror(<str>'Conversion failed.')
|
||||
return None
|
||||
constants_inf.log('conversion done!')
|
||||
return bytes(plan)
|
||||
|
||||
cdef tuple get_input_shape(self):
|
||||
return self.input_shape[2], self.input_shape[3]
|
||||
|
||||
cdef int get_batch_size(self):
|
||||
return self.batch_size
|
||||
|
||||
cdef run(self, input_data):
|
||||
try:
|
||||
cuda.memcpy_htod_async(self.d_input, input_data, self.stream)
|
||||
self.context.set_tensor_address(self.input_name, int(self.d_input)) # input buffer
|
||||
self.context.set_tensor_address(self.output_name, int(self.d_output)) # output buffer
|
||||
|
||||
self.context.execute_async_v3(stream_handle=self.stream.handle)
|
||||
self.stream.synchronize()
|
||||
|
||||
# Fix: Remove the stream parameter from memcpy_dtoh
|
||||
cuda.memcpy_dtoh(self.h_output, self.d_output)
|
||||
output = self.h_output.reshape(self.output_shape)
|
||||
return [output]
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to run TensorRT inference: {str(e)}")
|
||||
Reference in New Issue
Block a user