mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 10:36:32 +00:00
Add AIAvailabilityStatus and AIRecognitionConfig classes for AI model management
- Introduced `AIAvailabilityStatus` class to manage the availability status of AI models, including methods for setting status and logging messages. - Added `AIRecognitionConfig` class to encapsulate configuration parameters for AI recognition, with a static method for creating instances from dictionaries. - Implemented enums for AI availability states to enhance clarity and maintainability. - Updated related Cython files to support the new classes and ensure proper type handling. These changes aim to improve the structure and functionality of the AI model management system, facilitating better status tracking and configuration handling.
This commit is contained in:
@@ -0,0 +1,106 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from annotation cimport Detection
|
||||
|
||||
cdef class InferenceEngine:
|
||||
def __init__(self, model_bytes: bytes, max_batch_size: int = 8, **kwargs):
|
||||
self.max_batch_size = max_batch_size
|
||||
self.engine_name = <str>kwargs.get('engine_name', "onnx")
|
||||
|
||||
@staticmethod
|
||||
def get_engine_filename():
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_source_filename():
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def convert_from_source(bytes source_bytes):
|
||||
return source_bytes
|
||||
|
||||
cdef tuple get_input_shape(self):
|
||||
raise NotImplementedError("Subclass must implement get_input_shape")
|
||||
|
||||
cdef run(self, input_data):
|
||||
raise NotImplementedError("Subclass must implement run")
|
||||
|
||||
cdef preprocess(self, list frames):
|
||||
cdef int h, w
|
||||
h, w = self.get_input_shape()
|
||||
blobs = [cv2.dnn.blobFromImage(frame,
|
||||
scalefactor=1.0 / 255.0,
|
||||
size=(w, h),
|
||||
mean=(0, 0, 0),
|
||||
swapRB=True,
|
||||
crop=False)
|
||||
for frame in frames]
|
||||
return np.vstack(blobs)
|
||||
|
||||
cdef list postprocess(self, output, object ai_config):
|
||||
cdef list[Detection] detections
|
||||
cdef int ann_index
|
||||
cdef float x1, y1, x2, y2, conf
|
||||
cdef int class_id
|
||||
cdef list results = []
|
||||
cdef int h, w
|
||||
h, w = self.get_input_shape()
|
||||
|
||||
for ann_index in range(len(output[0])):
|
||||
detections = []
|
||||
for det in output[0][ann_index]:
|
||||
if det[4] == 0:
|
||||
break
|
||||
x1 = det[0] / w
|
||||
y1 = det[1] / h
|
||||
x2 = det[2] / w
|
||||
y2 = det[3] / h
|
||||
conf = round(det[4], 2)
|
||||
class_id = int(det[5])
|
||||
|
||||
x = (x1 + x2) / 2
|
||||
y = (y1 + y2) / 2
|
||||
bw = x2 - x1
|
||||
bh = y2 - y1
|
||||
if conf >= ai_config.probability_threshold:
|
||||
detections.append(Detection(x, y, bw, bh, class_id, conf))
|
||||
filtered = self.remove_overlapping(detections, ai_config.tracking_intersection_threshold)
|
||||
results.append(filtered)
|
||||
return results
|
||||
|
||||
cdef list remove_overlapping(self, list[Detection] detections, float threshold):
|
||||
cdef Detection det1, det2
|
||||
filtered_output = []
|
||||
filtered_out_indexes = []
|
||||
|
||||
for det1_index in range(len(detections)):
|
||||
if det1_index in filtered_out_indexes:
|
||||
continue
|
||||
det1 = detections[det1_index]
|
||||
res = det1_index
|
||||
for det2_index in range(det1_index + 1, len(detections)):
|
||||
det2 = detections[det2_index]
|
||||
if det1.overlaps(det2, threshold):
|
||||
if det1.confidence > det2.confidence or (
|
||||
det1.confidence == det2.confidence and det1.cls < det2.cls):
|
||||
filtered_out_indexes.append(det2_index)
|
||||
else:
|
||||
filtered_out_indexes.append(res)
|
||||
res = det2_index
|
||||
filtered_output.append(detections[res])
|
||||
filtered_out_indexes.append(res)
|
||||
return filtered_output
|
||||
|
||||
cpdef list process_frames(self, list frames, object ai_config):
|
||||
cdef int effective_batch = min(self.max_batch_size, ai_config.model_batch_size)
|
||||
if effective_batch < 1:
|
||||
effective_batch = 1
|
||||
cdef list all_detections = []
|
||||
cdef int i
|
||||
for i in range(0, len(frames), effective_batch):
|
||||
chunk = frames[i:i + effective_batch]
|
||||
input_blob = self.preprocess(chunk)
|
||||
raw_output = self.run(input_blob)
|
||||
batch_dets = self.postprocess(raw_output, ai_config)
|
||||
all_detections.extend(batch_dets)
|
||||
return all_detections
|
||||
Reference in New Issue
Block a user