mirror of
https://github.com/azaion/annotations.git
synced 2026-04-22 16:46:31 +00:00
refactor external clients
put model batch size as parameter in config
This commit is contained in:
@@ -9,6 +9,7 @@ cdef class AIRecognitionConfig:
|
||||
|
||||
cdef public bytes file_data
|
||||
cdef public list[str] paths
|
||||
cdef public int model_batch_size
|
||||
|
||||
@staticmethod
|
||||
cdef from_msgpack(bytes data)
|
||||
@@ -11,7 +11,8 @@ cdef class AIRecognitionConfig:
|
||||
tracking_intersection_threshold,
|
||||
|
||||
file_data,
|
||||
paths
|
||||
paths,
|
||||
model_batch_size
|
||||
):
|
||||
self.frame_period_recognition = frame_period_recognition
|
||||
self.frame_recognition_seconds = frame_recognition_seconds
|
||||
@@ -23,26 +24,29 @@ cdef class AIRecognitionConfig:
|
||||
|
||||
self.file_data = file_data
|
||||
self.paths = paths
|
||||
self.model_batch_size = model_batch_size
|
||||
|
||||
def __str__(self):
|
||||
return (f'frame_seconds : {self.frame_recognition_seconds}, distance_confidence : {self.tracking_distance_confidence}, '
|
||||
f'probability_increase : {self.tracking_probability_increase}, '
|
||||
f'intersection_threshold : {self.tracking_intersection_threshold}, '
|
||||
f'frame_period_recognition : {self.frame_period_recognition}, '
|
||||
f'paths: {self.paths}')
|
||||
f'paths: {self.paths}, '
|
||||
f'model_batch_size: {self.model_batch_size}')
|
||||
|
||||
@staticmethod
|
||||
cdef from_msgpack(bytes data):
|
||||
unpacked = unpackb(data, strict_map_key=False)
|
||||
return AIRecognitionConfig(
|
||||
unpacked.get("FramePeriodRecognition", 0),
|
||||
unpacked.get("FrameRecognitionSeconds", 0.0),
|
||||
unpacked.get("ProbabilityThreshold", 0.0),
|
||||
unpacked.get("f_pr", 0),
|
||||
unpacked.get("f_rs", 0.0),
|
||||
unpacked.get("pt", 0.0),
|
||||
|
||||
unpacked.get("TrackingDistanceConfidence", 0.0),
|
||||
unpacked.get("TrackingProbabilityIncrease", 0.0),
|
||||
unpacked.get("TrackingIntersectionThreshold", 0.0),
|
||||
unpacked.get("t_dc", 0.0),
|
||||
unpacked.get("t_pi", 0.0),
|
||||
unpacked.get("t_it", 0.0),
|
||||
|
||||
unpacked.get("Data", b''),
|
||||
unpacked.get("Paths", []),
|
||||
unpacked.get("d", b''),
|
||||
unpacked.get("p", []),
|
||||
unpacked.get("m_bs")
|
||||
)
|
||||
@@ -1 +1 @@
|
||||
zmq_port: 5128
|
||||
zmq_port: 5127
|
||||
@@ -10,7 +10,6 @@ cdef str AI_MODEL_FILE_BIG # AI Model file (BIG part)
|
||||
cdef str AI_MODEL_FILE_SMALL # AI Model file (small part)
|
||||
|
||||
cdef bytes DONE_SIGNAL
|
||||
cdef int MODEL_BATCH_SIZE
|
||||
|
||||
|
||||
cdef log(str log_message, bytes client_id=*)
|
||||
@@ -12,7 +12,6 @@ cdef str AI_MODEL_FILE_BIG = "azaion.onnx.big"
|
||||
cdef str AI_MODEL_FILE_SMALL = "azaion.onnx.small"
|
||||
|
||||
cdef bytes DONE_SIGNAL = b"DONE"
|
||||
cdef int MODEL_BATCH_SIZE = 4
|
||||
|
||||
cdef log(str log_message, bytes client_id=None):
|
||||
local_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
|
||||
|
||||
@@ -132,7 +132,7 @@ cdef class Inference:
|
||||
images.append(m)
|
||||
# images first, it's faster
|
||||
if len(images) > 0:
|
||||
for chunk in self.split_list_extend(images, constants.MODEL_BATCH_SIZE):
|
||||
for chunk in self.split_list_extend(images, ai_config.model_batch_size):
|
||||
print(f'run inference on {" ".join(chunk)}...')
|
||||
self._process_images(cmd, ai_config, chunk)
|
||||
if len(videos) > 0:
|
||||
@@ -158,7 +158,7 @@ cdef class Inference:
|
||||
batch_frames.append(frame)
|
||||
batch_timestamps.append(int(v_input.get(cv2.CAP_PROP_POS_MSEC)))
|
||||
|
||||
if len(batch_frames) == constants.MODEL_BATCH_SIZE:
|
||||
if len(batch_frames) == ai_config.model_batch_size:
|
||||
input_blob = self.preprocess(batch_frames)
|
||||
outputs = self.session.run(None, {self.model_input: input_blob})
|
||||
list_detections = self.postprocess(outputs, ai_config)
|
||||
|
||||
@@ -7,4 +7,5 @@ psutil
|
||||
msgpack
|
||||
pyjwt
|
||||
zmq
|
||||
requests
|
||||
requests
|
||||
pyyaml
|
||||
Reference in New Issue
Block a user