mirror of
https://github.com/azaion/annotations.git
synced 2026-04-22 10:36:30 +00:00
use nms in the model itself, simplify and make postprocess faster.
make inference in batches, fix c# handling, add overlap handling
This commit is contained in:
+132
-81
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import mimetypes
|
||||
import time
|
||||
|
||||
@@ -5,6 +6,7 @@ import cv2
|
||||
import numpy as np
|
||||
import onnxruntime as onnx
|
||||
|
||||
cimport constants
|
||||
from remote_command cimport RemoteCommand
|
||||
from annotation cimport Detection, Annotation
|
||||
from ai_config cimport AIRecognitionConfig
|
||||
@@ -26,68 +28,117 @@ cdef class Inference:
|
||||
model_meta = self.session.get_modelmeta()
|
||||
print("Metadata:", model_meta.custom_metadata_map)
|
||||
|
||||
cdef preprocess(self, frame):
|
||||
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
img = cv2.resize(img, (self.model_width, self.model_height))
|
||||
image_data = np.array(img) / 255.0
|
||||
image_data = np.transpose(image_data, (2, 0, 1)) # Channel first
|
||||
image_data = np.expand_dims(image_data, axis=0).astype(np.float32)
|
||||
return image_data
|
||||
cdef preprocess(self, frames):
|
||||
blobs = [cv2.dnn.blobFromImage(frame,
|
||||
scalefactor=1.0 / 255.0,
|
||||
size=(self.model_width, self.model_height),
|
||||
mean=(0, 0, 0),
|
||||
swapRB=True,
|
||||
crop=False)
|
||||
for frame in frames]
|
||||
return np.vstack(blobs)
|
||||
|
||||
cdef postprocess(self, output, int img_width, int img_height):
|
||||
outputs = np.transpose(np.squeeze(output[0]))
|
||||
rows = outputs.shape[0]
|
||||
|
||||
boxes = []
|
||||
scores = []
|
||||
class_ids = []
|
||||
cdef postprocess(self, output):
|
||||
cdef list[Detection] detections = []
|
||||
cdef int ann_index
|
||||
cdef float x1, y1, x2, y2, conf, cx, cy, w, h
|
||||
cdef int class_id
|
||||
cdef list[list[Detection]] results = []
|
||||
|
||||
x_factor = img_width / self.model_width
|
||||
y_factor = img_height / self.model_height
|
||||
for ann_index in range(len(output[0])):
|
||||
detections.clear()
|
||||
for det in output[0][ann_index]:
|
||||
if det[4] == 0: # if confidence is 0 then valid points are over.
|
||||
break
|
||||
x1 = det[0] / self.model_width
|
||||
y1 = det[1] / self.model_height
|
||||
x2 = det[2] / self.model_width
|
||||
y2 = det[3] / self.model_height
|
||||
conf = round(det[4], 2)
|
||||
class_id = int(det[5])
|
||||
|
||||
for i in range(rows):
|
||||
classes_scores = outputs[i][4:]
|
||||
max_score = np.amax(classes_scores)
|
||||
x = (x1 + x2) / 2
|
||||
y = (y1 + y2) / 2
|
||||
w = x2 - x1
|
||||
h = y2 - y1
|
||||
detections.append(Detection(x, y, w, h, class_id, conf))
|
||||
filtered_detections = self.remove_overlapping_detections(detections)
|
||||
results.append(filtered_detections)
|
||||
return results
|
||||
|
||||
if max_score >= self.ai_config.probability_threshold:
|
||||
class_id = np.argmax(classes_scores)
|
||||
x, y, w, h = outputs[i][0], outputs[i][1], outputs[i][2], outputs[i][3]
|
||||
cdef remove_overlapping_detections(self, list[Detection] detections):
|
||||
cdef Detection det1, det2
|
||||
filtered_output = []
|
||||
filtered_out_indexes = []
|
||||
|
||||
left = int((x - w / 2) * x_factor)
|
||||
top = int((y - h / 2) * y_factor)
|
||||
width = int(w * x_factor)
|
||||
height = int(h * y_factor)
|
||||
|
||||
class_ids.append(class_id)
|
||||
scores.append(max_score)
|
||||
boxes.append([left, top, width, height])
|
||||
indices = cv2.dnn.NMSBoxes(boxes, scores, self.ai_config.probability_threshold, 0.45)
|
||||
detections = []
|
||||
for i in indices:
|
||||
x, y, w, h = boxes[i]
|
||||
detections.append(Detection(x, y, w, h, class_ids[i], scores[i]))
|
||||
return detections
|
||||
for det1_index in range(len(detections)):
|
||||
if det1_index in filtered_out_indexes:
|
||||
continue
|
||||
det1 = detections[det1_index]
|
||||
print(f'det1 size: {det1.w}, {det1.h}')
|
||||
res = det1_index
|
||||
for det2_index in range(det1_index + 1, len(detections)):
|
||||
det2 = detections[det2_index]
|
||||
print(f'det2 size: {det2.w}, {det2.h}')
|
||||
if det1.overlaps(det2):
|
||||
if det1.confidence > det2.confidence or (
|
||||
det1.confidence == det2.confidence and det1.cls < det2.cls): # det1 has higher confidence or lower class_id
|
||||
filtered_out_indexes.append(det2_index)
|
||||
else:
|
||||
filtered_out_indexes.append(res)
|
||||
res = det2_index
|
||||
filtered_output.append(detections[res])
|
||||
filtered_out_indexes.append(res)
|
||||
return filtered_output
|
||||
|
||||
cdef bint is_video(self, str filepath):
|
||||
mime_type, _ = mimetypes.guess_type(<str>filepath)
|
||||
return mime_type and mime_type.startswith("video")
|
||||
|
||||
cdef run_inference(self, RemoteCommand cmd, int batch_size=8):
|
||||
print('run inference..')
|
||||
cdef split_list_extend(self, lst, chunk_size):
|
||||
chunks = [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
||||
|
||||
# If the last chunk is smaller than the desired chunk_size, extend it by duplicating its last element.
|
||||
last_chunk = chunks[len(chunks) - 1]
|
||||
if len(last_chunk) < chunk_size:
|
||||
last_elem = last_chunk[len(last_chunk)-1]
|
||||
while len(last_chunk) < chunk_size:
|
||||
last_chunk.append(last_elem)
|
||||
return chunks
|
||||
|
||||
cdef run_inference(self, RemoteCommand cmd):
|
||||
cdef list[str] medias = json.loads(<str> cmd.filename)
|
||||
cdef list[str] videos = []
|
||||
cdef list[str] images = []
|
||||
|
||||
self.ai_config = AIRecognitionConfig.from_msgpack(cmd.data)
|
||||
self.stop_signal = False
|
||||
if self.is_video(cmd.filename):
|
||||
self._process_video(cmd, batch_size)
|
||||
else:
|
||||
self._process_image(cmd)
|
||||
|
||||
cdef _process_video(self, RemoteCommand cmd, int batch_size):
|
||||
frame_count = 0
|
||||
batch_frame = []
|
||||
for m in medias:
|
||||
if self.is_video(m):
|
||||
videos.append(m)
|
||||
else:
|
||||
images.append(m)
|
||||
|
||||
# images first, it's faster
|
||||
if len(images) > 0:
|
||||
for chunk in self.split_list_extend(images, constants.MODEL_BATCH_SIZE):
|
||||
print(f'run inference on {" ".join(chunk)}...')
|
||||
self._process_images(cmd, chunk)
|
||||
if len(videos) > 0:
|
||||
for v in videos:
|
||||
print(f'run inference on {v}...')
|
||||
self._process_video(cmd, v)
|
||||
|
||||
|
||||
cdef _process_video(self, RemoteCommand cmd, str video_name):
|
||||
cdef int frame_count = 0
|
||||
cdef list batch_frames = []
|
||||
cdef list[int] batch_timestamps = []
|
||||
self._previous_annotation = None
|
||||
self.start_video_time = time.time()
|
||||
|
||||
v_input = cv2.VideoCapture(<str>cmd.filename)
|
||||
v_input = cv2.VideoCapture(<str>video_name)
|
||||
while v_input.isOpened():
|
||||
ret, frame = v_input.read()
|
||||
if not ret or frame is None:
|
||||
@@ -95,45 +146,45 @@ cdef class Inference:
|
||||
|
||||
frame_count += 1
|
||||
if frame_count % self.ai_config.frame_period_recognition == 0:
|
||||
ms = int(v_input.get(cv2.CAP_PROP_POS_MSEC))
|
||||
annotation = self.detect_frame(frame, ms)
|
||||
if annotation is not None:
|
||||
self._previous_annotation = annotation
|
||||
self.on_annotation(annotation)
|
||||
batch_frames.append(frame)
|
||||
batch_timestamps.append(int(v_input.get(cv2.CAP_PROP_POS_MSEC)))
|
||||
|
||||
if len(batch_frames) == constants.MODEL_BATCH_SIZE:
|
||||
input_blob = self.preprocess(batch_frames)
|
||||
outputs = self.session.run(None, {self.model_input: input_blob})
|
||||
list_detections = self.postprocess(outputs)
|
||||
for i in range(len(list_detections)):
|
||||
detections = list_detections[i]
|
||||
annotation = Annotation(video_name, batch_timestamps[i], detections)
|
||||
if self.is_valid_annotation(annotation):
|
||||
_, image = cv2.imencode('.jpg', frame)
|
||||
annotation.image = image.tobytes()
|
||||
self.on_annotation(cmd, annotation)
|
||||
self._previous_annotation = annotation
|
||||
|
||||
batch_frames.clear()
|
||||
batch_timestamps.clear()
|
||||
v_input.release()
|
||||
|
||||
|
||||
cdef detect_frame(self, frame, long time):
|
||||
cdef Annotation annotation
|
||||
img_height, img_width = frame.shape[:2]
|
||||
|
||||
start_time = time.time()
|
||||
img_data = self.preprocess(frame)
|
||||
preprocess_time = time.time()
|
||||
outputs = self.session.run(None, {self.model_input: img_data})
|
||||
inference_time = time.time()
|
||||
detections = self.postprocess(outputs, img_width, img_height)
|
||||
postprocess_time = time.time()
|
||||
print(f'video time, ms: {time / 1000:.3f}. total time, s : {postprocess_time - self.start_video_time:.3f} '
|
||||
f'preprocess time: {preprocess_time - start_time:.3f}, inference time: {inference_time - preprocess_time:.3f},'
|
||||
f' postprocess time: {postprocess_time - inference_time:.3f}, total time: {postprocess_time - start_time:.3f}')
|
||||
if len(detections) > 0:
|
||||
annotation = Annotation(frame, time, detections)
|
||||
if self.is_valid_annotation(annotation):
|
||||
_, image = cv2.imencode('.jpg', frame)
|
||||
annotation.image = image.tobytes()
|
||||
return annotation
|
||||
return None
|
||||
|
||||
|
||||
cdef _process_image(self, RemoteCommand cmd):
|
||||
cdef _process_images(self, RemoteCommand cmd, list[str] image_paths):
|
||||
cdef list frames = []
|
||||
cdef list timestamps = []
|
||||
self._previous_annotation = None
|
||||
frame = cv2.imread(<str>cmd.filename)
|
||||
annotation = self.detect_frame(frame, 0)
|
||||
if annotation is None:
|
||||
_, image = cv2.imencode('.jpg', frame)
|
||||
annotation = Annotation(frame, time, [])
|
||||
for image in image_paths:
|
||||
frame = cv2.imread(image)
|
||||
frames.append(frame)
|
||||
timestamps.append(0)
|
||||
|
||||
input_blob = self.preprocess(frames)
|
||||
outputs = self.session.run(None, {self.model_input: input_blob})
|
||||
list_detections = self.postprocess(outputs)
|
||||
for i in range(len(list_detections)):
|
||||
detections = list_detections[i]
|
||||
annotation = Annotation(image_paths[i], timestamps[i], detections)
|
||||
_, image = cv2.imencode('.jpg', frames[i])
|
||||
annotation.image = image.tobytes()
|
||||
self.on_annotation(cmd, annotation)
|
||||
self.on_annotation(cmd, annotation)
|
||||
|
||||
|
||||
cdef stop(self):
|
||||
|
||||
Reference in New Issue
Block a user