Files
ai-training/tests/onnx_inference.py
T
Alex Bezdieniezhnykh 7d99e377f1 add onnx inference
2025-03-24 21:48:22 +02:00

227 lines
8.2 KiB
Python

import json
import sys
import time
from enum import Enum
from os.path import join, dirname
import cv2
import numpy as np
import onnxruntime as onnx
class Detection:
def __init__(self, x, y, w, h, cls, confidence):
self.x = x
self.y = y
self.w = w
self.h = h
self.cls = cls
self.confidence = confidence
def overlaps(self, det2):
overlap_x = 0.5 * (self.w + det2.w) - abs(self.x - det2.x)
overlap_y = 0.5 * (self.h + det2.h) - abs(self.y - det2.y)
overlap_area = max(0, overlap_x) * max(0, overlap_y)
min_area = min(self.w * self.h, det2.w * det2.h)
return overlap_area / min_area > 0.6
class Annotation:
def __init__(self, frame, image_bytes, time, detections: list[Detection]):
self.frame = frame
self.image = image_bytes
self.time = time
self.detections = detections if detections is not None else []
class WeatherMode(Enum):
Norm = 0
Wint = 20
Night = 40
class AnnotationClass:
def __init__(self, id, name, color):
self.id = id
self.name = name
self.color = color
color_str = color.lstrip('#')
self.opencv_color = (int(color_str[4:6], 16), int(color_str[2:4], 16), int(color_str[0:2], 16))
@staticmethod
def read_json():
classes_path = join(dirname(dirname(__file__)), 'classes.json')
with open(classes_path, 'r', encoding='utf-8') as f:
j = json.loads(f.read())
annotations_dict = {}
for mode in WeatherMode:
for cl in j:
id = mode.value + cl['Id']
name = cl['Name'] if mode.value == 0 else f'{cl["Name"]}({mode.name})'
annotations_dict[id] = AnnotationClass(id, name, cl['Color'])
return annotations_dict
@property
def color_tuple(self):
color = self.color[3:]
lv = len(color)
xx = range(0, lv, lv // 3)
return tuple(int(color[i:i + lv // 3], 16) for i in xx)
class Inference:
def __init__(self, onnx_model, batch_size, confidence_thres, iou_thres):
self.onnx_model = onnx_model
self.batch_size = batch_size
self.confidence_thres = confidence_thres
self.iou_thres = iou_thres
self.model_width = None
self.model_height = None
self.classes = AnnotationClass.read_json()
def draw(self, annotation: Annotation):
img = annotation.frame
img_height, img_width = img.shape[:2]
for d in annotation.detections:
x1 = int(img_width * (d.x - d.w / 2))
y1 = int(img_height * (d.y - d.h / 2))
x2 = int(x1 + img_width * d.w)
y2 = int(y1 + img_height * d.h)
color = self.classes[d.cls].opencv_color
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
label = f"{self.classes[d.cls].name}: {d.confidence:.2f}"
(label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
cv2.rectangle(
img, (x1, label_y - label_height), (x1 + label_width, label_y + label_height), color, cv2.FILLED
)
cv2.putText(img, label, (x1, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
cv2.imshow('Video', img)
def preprocess(self, frames):
blobs = [cv2.dnn.blobFromImage(frame,
scalefactor=1.0 / 255.0,
size=(self.model_width, self.model_height),
mean=(0, 0, 0),
swapRB=True,
crop=False)
for frame in frames]
return np.vstack(blobs)
def postprocess(self, batch_frames, batch_timestamps, output):
anns = []
for i in range(len(output[0])):
frame = batch_frames[i]
timestamp = batch_timestamps[i]
detections = []
for det in output[0][i]:
if det[4] == 0: # if confidence is 0 then valid points are over.
break
x1 = max(0, det[0] / self.model_width)
y1 = max(0, det[1] / self.model_height)
x2 = min(1, det[2] / self.model_width)
y2 = min(1, det[3] / self.model_height)
conf = round(det[4],2)
class_id = int(det[5])
x = (x1 + x2) / 2
y = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1
detections.append(Detection(x, y, w, h, class_id, conf))
filtered_detections = self.remove_overlapping_detections(detections)
if len(filtered_detections) > 0:
_, image = cv2.imencode('.jpg', frame)
image_bytes = image.tobytes()
annotation = Annotation(frame, image_bytes, timestamp, filtered_detections)
anns.append(annotation)
return anns
def process(self, video):
session = onnx.InferenceSession(self.onnx_model, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
model_inputs = session.get_inputs()
input_name = model_inputs[0].name
input_shape = model_inputs[0].shape
self.model_width = input_shape[2]
self.model_height = input_shape[3]
frame_count = 0
batch_frames = []
batch_timestamps = []
v_input = cv2.VideoCapture(video)
while v_input.isOpened():
ret, frame = v_input.read()
if not ret or frame is None:
break
frame_count += 1
if frame_count % 4 == 0:
batch_frames.append(frame)
batch_timestamps.append(int(v_input.get(cv2.CAP_PROP_POS_MSEC)))
if len(batch_frames) == self.batch_size:
input_blob = self.preprocess(batch_frames)
outputs = session.run(None, {input_name: input_blob})
annotations = self.postprocess(batch_frames, batch_timestamps, outputs)
for annotation in annotations:
self.draw(annotation)
print(f'video: {annotation.time/1000:.3f}s')
if cv2.waitKey(1) & 0xFF == ord('q'):
break
batch_frames.clear()
batch_timestamps.clear()
def remove_overlapping_detections(self, detections):
filtered_output = []
filtered_out_indexes = []
for det1_index in range(len(detections)):
if det1_index in filtered_out_indexes:
continue
det1 = detections[det1_index]
res = det1_index
for det2_index in range(det1_index + 1, len(detections)):
det2 = detections[det2_index]
if det1.overlaps(det2):
if det1.confidence > det2.confidence or (det1.confidence == det2.confidence and det1.cls < det2.cls): # det1 has higher confidence or lower class_id
filtered_out_indexes.append(det2_index)
else:
filtered_out_indexes.append(res)
res = det2_index
filtered_output.append(detections[res])
filtered_out_indexes.append(res)
return filtered_output
def overlap_tests(self):
detections = [
Detection(10, 10, 200, 200, 0, 0.5),
Detection(10, 10, 200, 200, 0, 0.6),
Detection(10, 10, 200, 200, 0, 0.4),
Detection(10, 10, 200, 200, 0, 0.8),
Detection(10, 10, 200, 200, 0, 0.3),
]
result = self.remove_overlapping_detections(detections)
detections = [
Detection(10, 10, 100, 100, 0, 0.5),
Detection(50, 50, 120, 110, 0, 0.6)
]
result2 = self.remove_overlapping_detections(detections)
pass
if __name__ == "__main__":
model = 'azaion-2024-10-26.onnx'
input_video = 'ForAI_test.mp4'
inf = Inference(model, batch_size=2, confidence_thres=0.5, iou_thres=0.35)
# inf.overlap_tests()
inf.process(input_video)
cv2.waitKey(0)