mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 22:26:33 +00:00
Refactor inference engine and task management: Remove obsolete inference engine and ONNX engine files, update inference processing to utilize batch handling, and enhance task management structure in documentation. Adjust paths for task specifications to align with new directory organization.
This commit is contained in:
+34
-45
@@ -8,45 +8,10 @@ cimport constants_inf
|
||||
from ai_availability_status cimport AIAvailabilityEnum, AIAvailabilityStatus
|
||||
from annotation cimport Detection, Annotation
|
||||
from ai_config cimport AIRecognitionConfig
|
||||
import pynvml
|
||||
from threading import Thread
|
||||
|
||||
cdef int tensor_gpu_index
|
||||
|
||||
cdef int check_tensor_gpu_index():
|
||||
try:
|
||||
pynvml.nvmlInit()
|
||||
deviceCount = pynvml.nvmlDeviceGetCount()
|
||||
|
||||
if deviceCount == 0:
|
||||
constants_inf.logerror(<str>'No NVIDIA GPUs found.')
|
||||
return -1
|
||||
|
||||
for i in range(deviceCount):
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
|
||||
major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
||||
|
||||
if major > 6 or (major == 6 and minor >= 1):
|
||||
constants_inf.log(<str>'found NVIDIA GPU!')
|
||||
return i
|
||||
|
||||
constants_inf.logerror(<str>'NVIDIA GPU doesnt support TensorRT!')
|
||||
return -1
|
||||
|
||||
except pynvml.NVMLError:
|
||||
return -1
|
||||
finally:
|
||||
try:
|
||||
pynvml.nvmlShutdown()
|
||||
except:
|
||||
constants_inf.logerror(<str>'Failed to shutdown pynvml cause probably no NVIDIA GPU')
|
||||
pass
|
||||
|
||||
tensor_gpu_index = check_tensor_gpu_index()
|
||||
from engines import tensor_gpu_index, create_engine
|
||||
if tensor_gpu_index > -1:
|
||||
from tensorrt_engine import TensorRTEngine
|
||||
else:
|
||||
from onnx_engine import OnnxEngine
|
||||
from engines.tensorrt_engine import TensorRTEngine
|
||||
|
||||
|
||||
|
||||
@@ -67,6 +32,10 @@ cdef class Inference:
|
||||
self._converted_model_bytes = None
|
||||
self.init_ai()
|
||||
|
||||
@property
|
||||
def is_engine_ready(self):
|
||||
return self.engine is not None
|
||||
|
||||
|
||||
cdef bytes get_onnx_engine_bytes(self):
|
||||
models_dir = constants_inf.MODELS_FOLDER
|
||||
@@ -134,7 +103,7 @@ cdef class Inference:
|
||||
thread.start()
|
||||
return
|
||||
else:
|
||||
self.engine = OnnxEngine(<bytes>self.get_onnx_engine_bytes())
|
||||
self.engine = create_engine(<bytes>self.get_onnx_engine_bytes())
|
||||
self.is_building_engine = False
|
||||
|
||||
self.model_height, self.model_width = self.engine.get_input_shape()
|
||||
@@ -264,7 +233,9 @@ cdef class Inference:
|
||||
if frame is None:
|
||||
raise ValueError("Invalid image data")
|
||||
|
||||
input_blob = self.preprocess([frame])
|
||||
cdef int batch_size = self.engine.get_batch_size()
|
||||
frames = [frame] * batch_size
|
||||
input_blob = self.preprocess(frames)
|
||||
outputs = self.engine.run(input_blob)
|
||||
list_detections = self.postprocess(outputs, ai_config)
|
||||
if list_detections:
|
||||
@@ -273,14 +244,21 @@ cdef class Inference:
|
||||
|
||||
cdef _process_video(self, AIRecognitionConfig ai_config, str video_name):
|
||||
cdef int frame_count = 0
|
||||
cdef int batch_count = 0
|
||||
cdef list batch_frames = []
|
||||
cdef list[int] batch_timestamps = []
|
||||
cdef Annotation annotation
|
||||
self._previous_annotation = None
|
||||
|
||||
|
||||
v_input = cv2.VideoCapture(<str>video_name)
|
||||
if not v_input.isOpened():
|
||||
constants_inf.logerror(<str>f'Failed to open video: {video_name}')
|
||||
return
|
||||
total_frames = int(v_input.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
fps = v_input.get(cv2.CAP_PROP_FPS)
|
||||
width = int(v_input.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(v_input.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
constants_inf.log(<str>f'Video: {total_frames} frames, {fps:.1f} fps, {width}x{height}')
|
||||
while v_input.isOpened() and not self.stop_signal:
|
||||
ret, frame = v_input.read()
|
||||
if not ret or frame is None:
|
||||
@@ -292,11 +270,16 @@ cdef class Inference:
|
||||
batch_timestamps.append(int(v_input.get(cv2.CAP_PROP_POS_MSEC)))
|
||||
|
||||
if len(batch_frames) == self.engine.get_batch_size():
|
||||
batch_count += 1
|
||||
constants_inf.log(<str>f'Video batch {batch_count}: frame {frame_count}/{total_frames} ({frame_count*100//total_frames}%)')
|
||||
input_blob = self.preprocess(batch_frames)
|
||||
|
||||
outputs = self.engine.run(input_blob)
|
||||
|
||||
list_detections = self.postprocess(outputs, ai_config)
|
||||
total_dets = sum(len(d) for d in list_detections)
|
||||
if total_dets > 0:
|
||||
constants_inf.log(<str>f'Video batch {batch_count}: {total_dets} detections from postprocess')
|
||||
for i in range(len(list_detections)):
|
||||
detections = list_detections[i]
|
||||
|
||||
@@ -304,15 +287,21 @@ cdef class Inference:
|
||||
name = f'{original_media_name}_{constants_inf.format_time(batch_timestamps[i])}'
|
||||
annotation = Annotation(name, original_media_name, batch_timestamps[i], detections)
|
||||
|
||||
if self.is_valid_video_annotation(annotation, ai_config):
|
||||
_, image = cv2.imencode('.jpg', batch_frames[i])
|
||||
annotation.image = image.tobytes()
|
||||
self._previous_annotation = annotation
|
||||
self.on_annotation(annotation, frame_count, total_frames)
|
||||
if detections:
|
||||
valid = self.is_valid_video_annotation(annotation, ai_config)
|
||||
constants_inf.log(<str>f'Video frame {name}: {len(detections)} dets, valid={valid}')
|
||||
if valid:
|
||||
_, image = cv2.imencode('.jpg', batch_frames[i])
|
||||
annotation.image = image.tobytes()
|
||||
self._previous_annotation = annotation
|
||||
self.on_annotation(annotation, frame_count, total_frames)
|
||||
else:
|
||||
self.is_valid_video_annotation(annotation, ai_config)
|
||||
|
||||
batch_frames.clear()
|
||||
batch_timestamps.clear()
|
||||
v_input.release()
|
||||
constants_inf.log(<str>f'Video done: {frame_count} frames read, {batch_count} batches processed')
|
||||
self.send_detection_status()
|
||||
|
||||
cdef on_annotation(self, Annotation annotation, int frame_count=0, int total_frames=0):
|
||||
|
||||
Reference in New Issue
Block a user