mirror of
https://github.com/azaion/annotations.git
synced 2026-04-22 21:46:30 +00:00
62623b7123
rewrite zmq to DEALER and ROUTER add GET_USER command to get CurrentUser from Python all auth is on the python side inference run and validate annotations on python
126 lines
4.7 KiB
Cython
126 lines
4.7 KiB
Cython
import ai_config
|
|
import msgpack
|
|
from ultralytics import YOLO
|
|
import mimetypes
|
|
import cv2
|
|
from ultralytics.engine.results import Boxes
|
|
from remote_command cimport RemoteCommand
|
|
from annotation cimport Detection, Annotation
|
|
from secure_model cimport SecureModelLoader
|
|
from ai_config cimport AIRecognitionConfig
|
|
|
|
cdef class Inference:
|
|
def __init__(self, model_bytes, on_annotation):
|
|
loader = SecureModelLoader()
|
|
model_path = loader.load_model(model_bytes)
|
|
self.model = YOLO(<str>model_path)
|
|
self.on_annotation = on_annotation
|
|
|
|
cdef bint is_video(self, str filepath):
|
|
mime_type, _ = mimetypes.guess_type(<str>filepath)
|
|
return mime_type and mime_type.startswith("video")
|
|
|
|
cdef run_inference(self, RemoteCommand cmd, int batch_size=8):
|
|
print('run inference..')
|
|
|
|
if self.is_video(cmd.filename):
|
|
return self._process_video(cmd, batch_size)
|
|
else:
|
|
return self._process_image(cmd)
|
|
|
|
cdef _process_video(self, RemoteCommand cmd, int batch_size):
|
|
frame_count = 0
|
|
batch_frame = []
|
|
v_input = cv2.VideoCapture(<str>cmd.filename)
|
|
self.ai_config = AIRecognitionConfig.from_msgpack(cmd.data)
|
|
|
|
while v_input.isOpened():
|
|
ret, frame = v_input.read()
|
|
ms = v_input.get(cv2.CAP_PROP_POS_MSEC)
|
|
if not ret or frame is None:
|
|
break
|
|
|
|
frame_count += 1
|
|
if frame_count % self.ai_config.frame_period_recognition == 0:
|
|
batch_frame.append((frame, ms))
|
|
|
|
if len(batch_frame) == batch_size:
|
|
frames = list(map(lambda x: x[0], batch_frame))
|
|
results = self.model.track(frames, persist=True)
|
|
|
|
for frame, res in zip(batch_frame, results):
|
|
annotation = self.frame_to_annotation(int(frame[1]), frame[0], res.boxes)
|
|
|
|
if self.is_valid_annotation(<Annotation>annotation):
|
|
self._previous_annotation = annotation
|
|
self.on_annotation(cmd, annotation)
|
|
batch_frame.clear()
|
|
|
|
v_input.release()
|
|
|
|
cdef _process_image(self, RemoteCommand cmd):
|
|
frame = cv2.imread(<str>cmd.filename)
|
|
res = self.model.track(frame)
|
|
annotation = self.frame_to_annotation(0, frame, res[0].boxes)
|
|
self.on_annotation(cmd, annotation)
|
|
|
|
cdef frame_to_annotation(self, long time, frame, boxes: Boxes):
|
|
detections = []
|
|
for box in boxes:
|
|
b = box.xywhn[0].cpu().numpy()
|
|
cls = int(box.cls[0].cpu().numpy().item())
|
|
confidence = box.conf[0].cpu().numpy().item()
|
|
det = Detection(<double> b[0], <double> b[1], <double> b[2], <double> b[3], cls, confidence)
|
|
detections.append(det)
|
|
_, encoded_image = cv2.imencode('.jpg', frame)
|
|
image_bytes = encoded_image.tobytes()
|
|
return Annotation(image_bytes, time, detections)
|
|
|
|
cdef bint is_valid_annotation(self, Annotation annotation):
|
|
# No detections, invalid
|
|
if not annotation.detections:
|
|
return False
|
|
|
|
# First valid annotation, always accept
|
|
if self._previous_annotation is None:
|
|
return True
|
|
|
|
# Enough time has passed since last annotation
|
|
if annotation.time >= self._previous_annotation.time + <long>(self.ai_config.frame_recognition_seconds * 1000):
|
|
return True
|
|
|
|
# More objects detected than before
|
|
if len(annotation.detections) > len(self._previous_annotation.detections):
|
|
return True
|
|
|
|
cdef:
|
|
Detection current_det, prev_det
|
|
double dx, dy, distance_sq, min_distance_sq
|
|
Detection closest_det
|
|
|
|
# Check each detection against previous frame
|
|
for current_det in annotation.detections:
|
|
min_distance_sq = 1e18 # Initialize with large value
|
|
closest_det = None
|
|
|
|
# Find closest detection in previous frame
|
|
for prev_det in self._previous_annotation.detections:
|
|
dx = current_det.x - prev_det.x
|
|
dy = current_det.y - prev_det.y
|
|
distance_sq = dx * dx + dy * dy
|
|
|
|
if distance_sq < min_distance_sq:
|
|
min_distance_sq = distance_sq
|
|
closest_det = prev_det
|
|
|
|
# Check if beyond tracking distance
|
|
if min_distance_sq > self.ai_config.tracking_distance_confidence:
|
|
return True
|
|
|
|
# Check probability increase
|
|
if current_det.confidence >= closest_det.confidence + self.ai_config.tracking_probability_increase:
|
|
return True
|
|
|
|
# No validation criteria met
|
|
return False
|