import json import sys import time from enum import Enum from os.path import join, dirname import cv2 import numpy as np import onnxruntime as onnx class Detection: def __init__(self, x, y, w, h, cls, confidence): self.x = x self.y = y self.w = w self.h = h self.cls = cls self.confidence = confidence def overlaps(self, det2): overlap_x = 0.5 * (self.w + det2.w) - abs(self.x - det2.x) overlap_y = 0.5 * (self.h + det2.h) - abs(self.y - det2.y) overlap_area = max(0, overlap_x) * max(0, overlap_y) min_area = min(self.w * self.h, det2.w * det2.h) return overlap_area / min_area > 0.6 class Annotation: def __init__(self, frame, image_bytes, time, detections: list[Detection]): self.frame = frame self.image = image_bytes self.time = time self.detections = detections if detections is not None else [] class WeatherMode(Enum): Norm = 0 Wint = 20 Night = 40 class AnnotationClass: def __init__(self, id, name, color): self.id = id self.name = name self.color = color color_str = color.lstrip('#') self.opencv_color = (int(color_str[4:6], 16), int(color_str[2:4], 16), int(color_str[0:2], 16)) @staticmethod def read_json(): classes_path = join(dirname(dirname(__file__)), 'classes.json') with open(classes_path, 'r', encoding='utf-8') as f: j = json.loads(f.read()) annotations_dict = {} for mode in WeatherMode: for cl in j: id = mode.value + cl['Id'] name = cl['Name'] if mode.value == 0 else f'{cl["Name"]}({mode.name})' annotations_dict[id] = AnnotationClass(id, name, cl['Color']) return annotations_dict @property def color_tuple(self): color = self.color[3:] lv = len(color) xx = range(0, lv, lv // 3) return tuple(int(color[i:i + lv // 3], 16) for i in xx) class Inference: def __init__(self, onnx_model, batch_size, confidence_thres, iou_thres): self.onnx_model = onnx_model self.batch_size = batch_size self.confidence_thres = confidence_thres self.iou_thres = iou_thres self.model_width = None self.model_height = None self.classes = AnnotationClass.read_json() def draw(self, annotation: Annotation): img = annotation.frame img_height, img_width = img.shape[:2] for d in annotation.detections: x1 = int(img_width * (d.x - d.w / 2)) y1 = int(img_height * (d.y - d.h / 2)) x2 = int(x1 + img_width * d.w) y2 = int(y1 + img_height * d.h) color = self.classes[d.cls].opencv_color cv2.rectangle(img, (x1, y1), (x2, y2), color, 2) label = f"{self.classes[d.cls].name}: {d.confidence:.2f}" (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10 cv2.rectangle( img, (x1, label_y - label_height), (x1 + label_width, label_y + label_height), color, cv2.FILLED ) cv2.putText(img, label, (x1, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA) cv2.imshow('Video', img) def preprocess(self, frames): blobs = [cv2.dnn.blobFromImage(frame, scalefactor=1.0 / 255.0, size=(self.model_width, self.model_height), mean=(0, 0, 0), swapRB=True, crop=False) for frame in frames] return np.vstack(blobs) def postprocess(self, batch_frames, batch_timestamps, output): anns = [] for i in range(len(output[0])): frame = batch_frames[i] timestamp = batch_timestamps[i] detections = [] for det in output[0][i]: if det[4] == 0: # if confidence is 0 then valid points are over. break x1 = max(0, det[0] / self.model_width) y1 = max(0, det[1] / self.model_height) x2 = min(1, det[2] / self.model_width) y2 = min(1, det[3] / self.model_height) conf = round(det[4],2) class_id = int(det[5]) x = (x1 + x2) / 2 y = (y1 + y2) / 2 w = x2 - x1 h = y2 - y1 detections.append(Detection(x, y, w, h, class_id, conf)) filtered_detections = self.remove_overlapping_detections(detections) if len(filtered_detections) > 0: _, image = cv2.imencode('.jpg', frame) image_bytes = image.tobytes() annotation = Annotation(frame, image_bytes, timestamp, filtered_detections) anns.append(annotation) return anns def process(self, video): session = onnx.InferenceSession(self.onnx_model, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) model_inputs = session.get_inputs() input_name = model_inputs[0].name input_shape = model_inputs[0].shape self.model_width = input_shape[2] self.model_height = input_shape[3] frame_count = 0 batch_frames = [] batch_timestamps = [] v_input = cv2.VideoCapture(video) while v_input.isOpened(): ret, frame = v_input.read() if not ret or frame is None: break frame_count += 1 if frame_count % 4 == 0: batch_frames.append(frame) batch_timestamps.append(int(v_input.get(cv2.CAP_PROP_POS_MSEC))) if len(batch_frames) == self.batch_size: input_blob = self.preprocess(batch_frames) outputs = session.run(None, {input_name: input_blob}) annotations = self.postprocess(batch_frames, batch_timestamps, outputs) for annotation in annotations: self.draw(annotation) print(f'video: {annotation.time/1000:.3f}s') if cv2.waitKey(1) & 0xFF == ord('q'): break batch_frames.clear() batch_timestamps.clear() def remove_overlapping_detections(self, detections): filtered_output = [] filtered_out_indexes = [] for det1_index in range(len(detections)): if det1_index in filtered_out_indexes: continue det1 = detections[det1_index] res = det1_index for det2_index in range(det1_index + 1, len(detections)): det2 = detections[det2_index] if det1.overlaps(det2): if det1.confidence > det2.confidence or (det1.confidence == det2.confidence and det1.cls < det2.cls): # det1 has higher confidence or lower class_id filtered_out_indexes.append(det2_index) else: filtered_out_indexes.append(res) res = det2_index filtered_output.append(detections[res]) filtered_out_indexes.append(res) return filtered_output def overlap_tests(self): detections = [ Detection(10, 10, 200, 200, 0, 0.5), Detection(10, 10, 200, 200, 0, 0.6), Detection(10, 10, 200, 200, 0, 0.4), Detection(10, 10, 200, 200, 0, 0.8), Detection(10, 10, 200, 200, 0, 0.3), ] result = self.remove_overlapping_detections(detections) detections = [ Detection(10, 10, 100, 100, 0, 0.5), Detection(50, 50, 120, 110, 0, 0.6) ] result2 = self.remove_overlapping_detections(detections) pass if __name__ == "__main__": model = 'azaion-2024-10-26.onnx' input_video = 'ForAI_test.mp4' inf = Inference(model, batch_size=2, confidence_thres=0.5, iou_thres=0.35) # inf.overlap_tests() inf.process(input_video) cv2.waitKey(0)