add onnx inference

This commit is contained in:
Alex Bezdieniezhnykh
2025-03-24 21:48:22 +02:00
parent 3113d59a3a
commit 7d99e377f1
7 changed files with 242 additions and 89 deletions
+1 -1
View File
@@ -24,7 +24,7 @@ class AnnotationClass:
for mode in WeatherMode: for mode in WeatherMode:
for cl in j: for cl in j:
id = mode.value + cl['Id'] id = mode.value + cl['Id']
name = cl['Name'] if mode.value == 0 else f'{cl['Name']}({mode.name})' name = cl['Name'] if mode.value == 0 else f'{cl["Name"]}({mode.name})'
annotations_dict[id] = AnnotationClass(id, name, cl['Color']) annotations_dict[id] = AnnotationClass(id, name, cl['Color'])
return annotations_dict return annotations_dict
+3 -2
View File
@@ -6,7 +6,6 @@ from ultralytics import YOLO
def export_rknn(model_path): def export_rknn(model_path):
# model_onnx = export_onnx(model_path)
model = YOLO(model_path) model = YOLO(model_path)
model.export(format="rknn", name="rk3588", simplify=True) model.export(format="rknn", name="rk3588", simplify=True)
model_stem = Path(model_path).stem model_stem = Path(model_path).stem
@@ -25,10 +24,12 @@ def export_onnx(model_path):
nms=True) nms=True)
return Path(model_path).stem + '.onnx' return Path(model_path).stem + '.onnx'
def show_model(model: str = None): def show_model(model: str = None):
netron.start(model) netron.start(model)
if __name__ == '__main__': if __name__ == '__main__':
show_model('azaion_2025-03-10.rknn') export_onnx('azaion-2024-10-26.pt')
show_model('azaion-2024-10-26.onnx')
# export_rknn('azaion_2025-03-10.pt') # export_rknn('azaion_2025-03-10.pt')
+1
View File
@@ -15,4 +15,5 @@ pyyaml
boto3 boto3
msgpack msgpack
rstream rstream
onnxruntime-gpu
netron netron
+227
View File
@@ -0,0 +1,227 @@
import json
import sys
import time
from enum import Enum
from os.path import join, dirname
import cv2
import numpy as np
import onnxruntime as onnx
class Detection:
def __init__(self, x, y, w, h, cls, confidence):
self.x = x
self.y = y
self.w = w
self.h = h
self.cls = cls
self.confidence = confidence
def overlaps(self, det2):
overlap_x = 0.5 * (self.w + det2.w) - abs(self.x - det2.x)
overlap_y = 0.5 * (self.h + det2.h) - abs(self.y - det2.y)
overlap_area = max(0, overlap_x) * max(0, overlap_y)
min_area = min(self.w * self.h, det2.w * det2.h)
return overlap_area / min_area > 0.6
class Annotation:
def __init__(self, frame, image_bytes, time, detections: list[Detection]):
self.frame = frame
self.image = image_bytes
self.time = time
self.detections = detections if detections is not None else []
class WeatherMode(Enum):
Norm = 0
Wint = 20
Night = 40
class AnnotationClass:
def __init__(self, id, name, color):
self.id = id
self.name = name
self.color = color
color_str = color.lstrip('#')
self.opencv_color = (int(color_str[4:6], 16), int(color_str[2:4], 16), int(color_str[0:2], 16))
@staticmethod
def read_json():
classes_path = join(dirname(dirname(__file__)), 'classes.json')
with open(classes_path, 'r', encoding='utf-8') as f:
j = json.loads(f.read())
annotations_dict = {}
for mode in WeatherMode:
for cl in j:
id = mode.value + cl['Id']
name = cl['Name'] if mode.value == 0 else f'{cl["Name"]}({mode.name})'
annotations_dict[id] = AnnotationClass(id, name, cl['Color'])
return annotations_dict
@property
def color_tuple(self):
color = self.color[3:]
lv = len(color)
xx = range(0, lv, lv // 3)
return tuple(int(color[i:i + lv // 3], 16) for i in xx)
class Inference:
def __init__(self, onnx_model, batch_size, confidence_thres, iou_thres):
self.onnx_model = onnx_model
self.batch_size = batch_size
self.confidence_thres = confidence_thres
self.iou_thres = iou_thres
self.model_width = None
self.model_height = None
self.classes = AnnotationClass.read_json()
def draw(self, annotation: Annotation):
img = annotation.frame
img_height, img_width = img.shape[:2]
for d in annotation.detections:
x1 = int(img_width * (d.x - d.w / 2))
y1 = int(img_height * (d.y - d.h / 2))
x2 = int(x1 + img_width * d.w)
y2 = int(y1 + img_height * d.h)
color = self.classes[d.cls].opencv_color
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
label = f"{self.classes[d.cls].name}: {d.confidence:.2f}"
(label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
cv2.rectangle(
img, (x1, label_y - label_height), (x1 + label_width, label_y + label_height), color, cv2.FILLED
)
cv2.putText(img, label, (x1, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
cv2.imshow('Video', img)
def preprocess(self, frames):
blobs = [cv2.dnn.blobFromImage(frame,
scalefactor=1.0 / 255.0,
size=(self.model_width, self.model_height),
mean=(0, 0, 0),
swapRB=True,
crop=False)
for frame in frames]
return np.vstack(blobs)
def postprocess(self, batch_frames, batch_timestamps, output):
anns = []
for i in range(len(output[0])):
frame = batch_frames[i]
timestamp = batch_timestamps[i]
detections = []
for det in output[0][i]:
if det[4] == 0: # if confidence is 0 then valid points are over.
break
x1 = max(0, det[0] / self.model_width)
y1 = max(0, det[1] / self.model_height)
x2 = min(1, det[2] / self.model_width)
y2 = min(1, det[3] / self.model_height)
conf = round(det[4],2)
class_id = int(det[5])
x = (x1 + x2) / 2
y = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1
detections.append(Detection(x, y, w, h, class_id, conf))
filtered_detections = self.remove_overlapping_detections(detections)
if len(filtered_detections) > 0:
_, image = cv2.imencode('.jpg', frame)
image_bytes = image.tobytes()
annotation = Annotation(frame, image_bytes, timestamp, filtered_detections)
anns.append(annotation)
return anns
def process(self, video):
session = onnx.InferenceSession(self.onnx_model, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
model_inputs = session.get_inputs()
input_name = model_inputs[0].name
input_shape = model_inputs[0].shape
self.model_width = input_shape[2]
self.model_height = input_shape[3]
frame_count = 0
batch_frames = []
batch_timestamps = []
v_input = cv2.VideoCapture(video)
while v_input.isOpened():
ret, frame = v_input.read()
if not ret or frame is None:
break
frame_count += 1
if frame_count % 4 == 0:
batch_frames.append(frame)
batch_timestamps.append(int(v_input.get(cv2.CAP_PROP_POS_MSEC)))
if len(batch_frames) == self.batch_size:
input_blob = self.preprocess(batch_frames)
outputs = session.run(None, {input_name: input_blob})
annotations = self.postprocess(batch_frames, batch_timestamps, outputs)
for annotation in annotations:
self.draw(annotation)
print(f'video: {annotation.time/1000:.3f}s')
if cv2.waitKey(1) & 0xFF == ord('q'):
break
batch_frames.clear()
batch_timestamps.clear()
def remove_overlapping_detections(self, detections):
filtered_output = []
filtered_out_indexes = []
for det1_index in range(len(detections)):
if det1_index in filtered_out_indexes:
continue
det1 = detections[det1_index]
res = det1_index
for det2_index in range(det1_index + 1, len(detections)):
det2 = detections[det2_index]
if det1.overlaps(det2):
if det1.confidence > det2.confidence or (det1.confidence == det2.confidence and det1.cls < det2.cls): # det1 has higher confidence or lower class_id
filtered_out_indexes.append(det2_index)
else:
filtered_out_indexes.append(res)
res = det2_index
filtered_output.append(detections[res])
filtered_out_indexes.append(res)
return filtered_output
def overlap_tests(self):
detections = [
Detection(10, 10, 200, 200, 0, 0.5),
Detection(10, 10, 200, 200, 0, 0.6),
Detection(10, 10, 200, 200, 0, 0.4),
Detection(10, 10, 200, 200, 0, 0.8),
Detection(10, 10, 200, 200, 0, 0.3),
]
result = self.remove_overlapping_detections(detections)
detections = [
Detection(10, 10, 100, 100, 0, 0.5),
Detection(50, 50, 120, 110, 0, 0.6)
]
result2 = self.remove_overlapping_detections(detections)
pass
if __name__ == "__main__":
model = 'azaion-2024-10-26.onnx'
input_video = 'ForAI_test.mp4'
inf = Inference(model, batch_size=2, confidence_thres=0.5, iou_thres=0.35)
# inf.overlap_tests()
inf.process(input_video)
cv2.waitKey(0)
-33
View File
@@ -1,33 +0,0 @@
from abc import ABC, abstractmethod
from ultralytics import YOLO
import yaml
class Predictor(ABC):
@abstractmethod
def predict(self, frame):
pass
class OnnxPredictor(Predictor):
def __init__(self):
self.model = YOLO('azaion.onnx')
self.model.task = 'detect'
with open('data.yaml', 'r') as f:
data_yaml = yaml.safe_load(f)
class_names = data_yaml['names']
names = self.model.names
def predict(self, frame):
results = self.model.track(frame, persist=True, tracker='bytetrack.yaml')
return results[0].plot()
class YoloPredictor(Predictor):
def __init__(self):
self.model = YOLO('azaion.pt')
def predict(self, frame):
results = self.model.track(frame, persist=True, tracker='bytetrack.yaml')
return results[0].plot()
-44
View File
@@ -1,44 +0,0 @@
import sys
from pathlib import Path
from ultralytics import YOLO
# from vidgear.gears import CamGear
import cv2
from time import sleep
from predictor import OnnxPredictor, YoloPredictor
# video_url = 'https://www.youtube.com/watch?v=d1n2fDOSo8c'
# stream = CamGear(source=video_url, stream_mode=True, logging=True).start()
write_output = False
predictor = YoloPredictor()
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
input_name = 'ForAI_test.mp4'
output_name = Path(input_name).stem + '_recognised.mp4'
v_input = cv2.VideoCapture(input_name)
if write_output:
v_output = cv2.VideoWriter(output_name, fourcc, 20.0, (640, 480))
while v_input.isOpened():
ret, frame = v_input.read()
if frame is None:
break
frame_detected = predictor.predict(frame)
frame_detected = cv2.resize(frame_detected, (640, 480))
cv2.imshow('Video', frame_detected)
sleep(0.01)
if write_output:
v_output.write(frame_detected)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
v_input.release()
if write_output:
v_output.release()
cv2.destroyAllWindows()
+10 -9
View File
@@ -237,9 +237,9 @@ def validate(model_path):
pass pass
def upload_model(model_path: str): def upload_model(model_path: str, size_small_in_kb: int=3):
model = YOLO(model_path) # model = YOLO(model_path)
model.export(format="onnx", imgsz=1280, nms=True, batch=4) # model.export(format="onnx", imgsz=1280, nms=True, batch=4)
onnx_model = path.dirname(model_path) + Path(model_path).stem + '.onnx' onnx_model = path.dirname(model_path) + Path(model_path).stem + '.onnx'
with open(onnx_model, 'rb') as f_in: with open(onnx_model, 'rb') as f_in:
@@ -248,7 +248,7 @@ def upload_model(model_path: str):
key = Security.get_model_encryption_key() key = Security.get_model_encryption_key()
onnx_encrypted = Security.encrypt_to(onnx_bytes, key) onnx_encrypted = Security.encrypt_to(onnx_bytes, key)
part1_size = min(10 * 1024, int(0.9 * len(onnx_encrypted))) part1_size = min(size_small_in_kb * 1024, int(0.9 * len(onnx_encrypted)))
onnx_part_small = onnx_encrypted[:part1_size] # slice bytes for part1 onnx_part_small = onnx_encrypted[:part1_size] # slice bytes for part1
onnx_part_big = onnx_encrypted[part1_size:] onnx_part_big = onnx_encrypted[part1_size:]
@@ -264,8 +264,9 @@ def upload_model(model_path: str):
api.upload_file('azaion.onnx.small', onnx_part_small) api.upload_file('azaion.onnx.small', onnx_part_small)
if __name__ == '__main__': if __name__ == '__main__':
model_path = train_dataset(from_scratch=True) # model_path = train_dataset(from_scratch=True)
validate(path.join('runs', 'detect', 'train7', 'weights', 'best.pt')) # validate(path.join('runs', 'detect', 'train7', 'weights', 'best.pt'))
form_data_sample(500) # form_data_sample(500)
convert2rknn() # convert2rknn()
upload_model('azaion.pt')
upload_model('azaion-2024-10-26.onnx')