mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 23:06:36 +00:00
add onnx inference
This commit is contained in:
@@ -0,0 +1,227 @@
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from enum import Enum
|
||||
from os.path import join, dirname
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import onnxruntime as onnx
|
||||
|
||||
|
||||
class Detection:
|
||||
def __init__(self, x, y, w, h, cls, confidence):
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.w = w
|
||||
self.h = h
|
||||
self.cls = cls
|
||||
self.confidence = confidence
|
||||
|
||||
def overlaps(self, det2):
|
||||
overlap_x = 0.5 * (self.w + det2.w) - abs(self.x - det2.x)
|
||||
overlap_y = 0.5 * (self.h + det2.h) - abs(self.y - det2.y)
|
||||
overlap_area = max(0, overlap_x) * max(0, overlap_y)
|
||||
min_area = min(self.w * self.h, det2.w * det2.h)
|
||||
|
||||
return overlap_area / min_area > 0.6
|
||||
|
||||
|
||||
class Annotation:
|
||||
def __init__(self, frame, image_bytes, time, detections: list[Detection]):
|
||||
self.frame = frame
|
||||
self.image = image_bytes
|
||||
self.time = time
|
||||
self.detections = detections if detections is not None else []
|
||||
|
||||
|
||||
class WeatherMode(Enum):
|
||||
Norm = 0
|
||||
Wint = 20
|
||||
Night = 40
|
||||
|
||||
|
||||
class AnnotationClass:
|
||||
def __init__(self, id, name, color):
|
||||
self.id = id
|
||||
self.name = name
|
||||
self.color = color
|
||||
color_str = color.lstrip('#')
|
||||
self.opencv_color = (int(color_str[4:6], 16), int(color_str[2:4], 16), int(color_str[0:2], 16))
|
||||
|
||||
@staticmethod
|
||||
def read_json():
|
||||
classes_path = join(dirname(dirname(__file__)), 'classes.json')
|
||||
with open(classes_path, 'r', encoding='utf-8') as f:
|
||||
j = json.loads(f.read())
|
||||
annotations_dict = {}
|
||||
for mode in WeatherMode:
|
||||
for cl in j:
|
||||
id = mode.value + cl['Id']
|
||||
name = cl['Name'] if mode.value == 0 else f'{cl["Name"]}({mode.name})'
|
||||
annotations_dict[id] = AnnotationClass(id, name, cl['Color'])
|
||||
return annotations_dict
|
||||
|
||||
@property
|
||||
def color_tuple(self):
|
||||
color = self.color[3:]
|
||||
lv = len(color)
|
||||
xx = range(0, lv, lv // 3)
|
||||
return tuple(int(color[i:i + lv // 3], 16) for i in xx)
|
||||
|
||||
class Inference:
|
||||
def __init__(self, onnx_model, batch_size, confidence_thres, iou_thres):
|
||||
self.onnx_model = onnx_model
|
||||
self.batch_size = batch_size
|
||||
self.confidence_thres = confidence_thres
|
||||
self.iou_thres = iou_thres
|
||||
self.model_width = None
|
||||
self.model_height = None
|
||||
|
||||
self.classes = AnnotationClass.read_json()
|
||||
|
||||
def draw(self, annotation: Annotation):
|
||||
img = annotation.frame
|
||||
img_height, img_width = img.shape[:2]
|
||||
for d in annotation.detections:
|
||||
x1 = int(img_width * (d.x - d.w / 2))
|
||||
y1 = int(img_height * (d.y - d.h / 2))
|
||||
x2 = int(x1 + img_width * d.w)
|
||||
y2 = int(y1 + img_height * d.h)
|
||||
|
||||
color = self.classes[d.cls].opencv_color
|
||||
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
|
||||
label = f"{self.classes[d.cls].name}: {d.confidence:.2f}"
|
||||
(label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
||||
|
||||
label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
|
||||
|
||||
cv2.rectangle(
|
||||
img, (x1, label_y - label_height), (x1 + label_width, label_y + label_height), color, cv2.FILLED
|
||||
)
|
||||
cv2.putText(img, label, (x1, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
|
||||
cv2.imshow('Video', img)
|
||||
|
||||
def preprocess(self, frames):
|
||||
blobs = [cv2.dnn.blobFromImage(frame,
|
||||
scalefactor=1.0 / 255.0,
|
||||
size=(self.model_width, self.model_height),
|
||||
mean=(0, 0, 0),
|
||||
swapRB=True,
|
||||
crop=False)
|
||||
for frame in frames]
|
||||
return np.vstack(blobs)
|
||||
|
||||
def postprocess(self, batch_frames, batch_timestamps, output):
|
||||
anns = []
|
||||
for i in range(len(output[0])):
|
||||
frame = batch_frames[i]
|
||||
timestamp = batch_timestamps[i]
|
||||
detections = []
|
||||
for det in output[0][i]:
|
||||
if det[4] == 0: # if confidence is 0 then valid points are over.
|
||||
break
|
||||
x1 = max(0, det[0] / self.model_width)
|
||||
y1 = max(0, det[1] / self.model_height)
|
||||
x2 = min(1, det[2] / self.model_width)
|
||||
y2 = min(1, det[3] / self.model_height)
|
||||
conf = round(det[4],2)
|
||||
class_id = int(det[5])
|
||||
|
||||
x = (x1 + x2) / 2
|
||||
y = (y1 + y2) / 2
|
||||
w = x2 - x1
|
||||
h = y2 - y1
|
||||
detections.append(Detection(x, y, w, h, class_id, conf))
|
||||
|
||||
filtered_detections = self.remove_overlapping_detections(detections)
|
||||
|
||||
if len(filtered_detections) > 0:
|
||||
_, image = cv2.imencode('.jpg', frame)
|
||||
image_bytes = image.tobytes()
|
||||
annotation = Annotation(frame, image_bytes, timestamp, filtered_detections)
|
||||
anns.append(annotation)
|
||||
return anns
|
||||
|
||||
def process(self, video):
|
||||
session = onnx.InferenceSession(self.onnx_model, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
|
||||
model_inputs = session.get_inputs()
|
||||
input_name = model_inputs[0].name
|
||||
input_shape = model_inputs[0].shape
|
||||
self.model_width = input_shape[2]
|
||||
self.model_height = input_shape[3]
|
||||
|
||||
frame_count = 0
|
||||
batch_frames = []
|
||||
batch_timestamps = []
|
||||
v_input = cv2.VideoCapture(video)
|
||||
while v_input.isOpened():
|
||||
ret, frame = v_input.read()
|
||||
if not ret or frame is None:
|
||||
break
|
||||
|
||||
frame_count += 1
|
||||
if frame_count % 4 == 0:
|
||||
batch_frames.append(frame)
|
||||
batch_timestamps.append(int(v_input.get(cv2.CAP_PROP_POS_MSEC)))
|
||||
|
||||
if len(batch_frames) == self.batch_size:
|
||||
input_blob = self.preprocess(batch_frames)
|
||||
outputs = session.run(None, {input_name: input_blob})
|
||||
annotations = self.postprocess(batch_frames, batch_timestamps, outputs)
|
||||
for annotation in annotations:
|
||||
self.draw(annotation)
|
||||
print(f'video: {annotation.time/1000:.3f}s')
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
batch_frames.clear()
|
||||
batch_timestamps.clear()
|
||||
|
||||
def remove_overlapping_detections(self, detections):
|
||||
filtered_output = []
|
||||
filtered_out_indexes = []
|
||||
|
||||
for det1_index in range(len(detections)):
|
||||
if det1_index in filtered_out_indexes:
|
||||
continue
|
||||
det1 = detections[det1_index]
|
||||
res = det1_index
|
||||
for det2_index in range(det1_index + 1, len(detections)):
|
||||
det2 = detections[det2_index]
|
||||
if det1.overlaps(det2):
|
||||
if det1.confidence > det2.confidence or (det1.confidence == det2.confidence and det1.cls < det2.cls): # det1 has higher confidence or lower class_id
|
||||
filtered_out_indexes.append(det2_index)
|
||||
else:
|
||||
filtered_out_indexes.append(res)
|
||||
res = det2_index
|
||||
filtered_output.append(detections[res])
|
||||
filtered_out_indexes.append(res)
|
||||
return filtered_output
|
||||
|
||||
|
||||
def overlap_tests(self):
|
||||
detections = [
|
||||
Detection(10, 10, 200, 200, 0, 0.5),
|
||||
Detection(10, 10, 200, 200, 0, 0.6),
|
||||
Detection(10, 10, 200, 200, 0, 0.4),
|
||||
Detection(10, 10, 200, 200, 0, 0.8),
|
||||
Detection(10, 10, 200, 200, 0, 0.3),
|
||||
]
|
||||
result = self.remove_overlapping_detections(detections)
|
||||
|
||||
detections = [
|
||||
Detection(10, 10, 100, 100, 0, 0.5),
|
||||
Detection(50, 50, 120, 110, 0, 0.6)
|
||||
]
|
||||
result2 = self.remove_overlapping_detections(detections)
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = 'azaion-2024-10-26.onnx'
|
||||
input_video = 'ForAI_test.mp4'
|
||||
inf = Inference(model, batch_size=2, confidence_thres=0.5, iou_thres=0.35)
|
||||
# inf.overlap_tests()
|
||||
inf.process(input_video)
|
||||
|
||||
cv2.waitKey(0)
|
||||
Reference in New Issue
Block a user