Files
detections/src/engines/inference_engine.pyx
T
Oleksandr Bezdieniezhnykh 8ce40a9385 Add AIAvailabilityStatus and AIRecognitionConfig classes for AI model management
- Introduced `AIAvailabilityStatus` class to manage the availability status of AI models, including methods for setting status and logging messages.
- Added `AIRecognitionConfig` class to encapsulate configuration parameters for AI recognition, with a static method for creating instances from dictionaries.
- Implemented enums for AI availability states to enhance clarity and maintainability.
- Updated related Cython files to support the new classes and ensure proper type handling.

These changes aim to improve the structure and functionality of the AI model management system, facilitating better status tracking and configuration handling.
2026-03-31 05:49:51 +03:00

107 lines
3.9 KiB
Cython

import cv2
import numpy as np
from annotation cimport Detection
cdef class InferenceEngine:
def __init__(self, model_bytes: bytes, max_batch_size: int = 8, **kwargs):
self.max_batch_size = max_batch_size
self.engine_name = <str>kwargs.get('engine_name', "onnx")
@staticmethod
def get_engine_filename():
return None
@staticmethod
def get_source_filename():
return None
@staticmethod
def convert_from_source(bytes source_bytes):
return source_bytes
cdef tuple get_input_shape(self):
raise NotImplementedError("Subclass must implement get_input_shape")
cdef run(self, input_data):
raise NotImplementedError("Subclass must implement run")
cdef preprocess(self, list frames):
cdef int h, w
h, w = self.get_input_shape()
blobs = [cv2.dnn.blobFromImage(frame,
scalefactor=1.0 / 255.0,
size=(w, h),
mean=(0, 0, 0),
swapRB=True,
crop=False)
for frame in frames]
return np.vstack(blobs)
cdef list postprocess(self, output, object ai_config):
cdef list[Detection] detections
cdef int ann_index
cdef float x1, y1, x2, y2, conf
cdef int class_id
cdef list results = []
cdef int h, w
h, w = self.get_input_shape()
for ann_index in range(len(output[0])):
detections = []
for det in output[0][ann_index]:
if det[4] == 0:
break
x1 = det[0] / w
y1 = det[1] / h
x2 = det[2] / w
y2 = det[3] / h
conf = round(det[4], 2)
class_id = int(det[5])
x = (x1 + x2) / 2
y = (y1 + y2) / 2
bw = x2 - x1
bh = y2 - y1
if conf >= ai_config.probability_threshold:
detections.append(Detection(x, y, bw, bh, class_id, conf))
filtered = self.remove_overlapping(detections, ai_config.tracking_intersection_threshold)
results.append(filtered)
return results
cdef list remove_overlapping(self, list[Detection] detections, float threshold):
cdef Detection det1, det2
filtered_output = []
filtered_out_indexes = []
for det1_index in range(len(detections)):
if det1_index in filtered_out_indexes:
continue
det1 = detections[det1_index]
res = det1_index
for det2_index in range(det1_index + 1, len(detections)):
det2 = detections[det2_index]
if det1.overlaps(det2, threshold):
if det1.confidence > det2.confidence or (
det1.confidence == det2.confidence and det1.cls < det2.cls):
filtered_out_indexes.append(det2_index)
else:
filtered_out_indexes.append(res)
res = det2_index
filtered_output.append(detections[res])
filtered_out_indexes.append(res)
return filtered_output
cpdef list process_frames(self, list frames, object ai_config):
cdef int effective_batch = min(self.max_batch_size, ai_config.model_batch_size)
if effective_batch < 1:
effective_batch = 1
cdef list all_detections = []
cdef int i
for i in range(0, len(frames), effective_batch):
chunk = frames[i:i + effective_batch]
input_blob = self.preprocess(chunk)
raw_output = self.run(input_blob)
batch_dets = self.postprocess(raw_output, ai_config)
all_detections.extend(batch_dets)
return all_detections