mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 22:36:36 +00:00
139 lines
5.6 KiB
Python
139 lines
5.6 KiB
Python
import cv2
|
|
import numpy as np
|
|
from inference.dto import Annotation, Detection, AnnotationClass
|
|
from inference.onnx_engine import InferenceEngine
|
|
|
|
|
|
class Inference:
|
|
def __init__(self, engine: InferenceEngine, confidence_threshold, iou_threshold):
|
|
self.engine = engine
|
|
self.confidence_threshold = confidence_threshold
|
|
self.iou_threshold = iou_threshold
|
|
self.batch_size = engine.get_batch_size()
|
|
|
|
self.model_height, self.model_width = engine.get_input_shape()
|
|
self.classes = AnnotationClass.read_json()
|
|
|
|
def draw(self, annotation: Annotation):
|
|
img = annotation.frame
|
|
img_height, img_width = img.shape[:2]
|
|
for d in annotation.detections:
|
|
x1 = int(img_width * (d.x - d.w / 2))
|
|
y1 = int(img_height * (d.y - d.h / 2))
|
|
x2 = int(x1 + img_width * d.w)
|
|
y2 = int(y1 + img_height * d.h)
|
|
|
|
color = self.classes[d.cls].opencv_color
|
|
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
|
|
label = f"{self.classes[d.cls].name}: {d.confidence:.2f}"
|
|
(label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
|
|
|
label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
|
|
|
|
cv2.rectangle(
|
|
img, (x1, label_y - label_height), (x1 + label_width, label_y + label_height), color, cv2.FILLED
|
|
)
|
|
cv2.putText(img, label, (x1, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
|
|
cv2.imshow('Video', img)
|
|
|
|
def preprocess(self, frames):
|
|
blobs = [cv2.dnn.blobFromImage(frame,
|
|
scalefactor=1.0 / 255.0,
|
|
size=(self.model_width, self.model_height),
|
|
mean=(0, 0, 0),
|
|
swapRB=True,
|
|
crop=False)
|
|
for frame in frames]
|
|
return np.vstack(blobs)
|
|
|
|
def postprocess(self, batch_frames, batch_timestamps, output):
|
|
anns = []
|
|
for i in range(len(output[0])):
|
|
frame = batch_frames[i]
|
|
timestamp = batch_timestamps[i]
|
|
detections = []
|
|
for det in output[0][i]:
|
|
if det[4] == 0:
|
|
break
|
|
if det[4] < self.confidence_threshold:
|
|
continue
|
|
|
|
x1 = max(0, det[0] / self.model_width)
|
|
y1 = max(0, det[1] / self.model_height)
|
|
x2 = min(1, det[2] / self.model_width)
|
|
y2 = min(1, det[3] / self.model_height)
|
|
conf = round(det[4], 2)
|
|
class_id = int(det[5])
|
|
|
|
x = (x1 + x2) / 2
|
|
y = (y1 + y2) / 2
|
|
w = x2 - x1
|
|
h = y2 - y1
|
|
detections.append(Detection(x, y, w, h, class_id, conf))
|
|
|
|
filtered_detections = self.remove_overlapping_detections(detections)
|
|
|
|
# if len(filtered_detections) > 0:
|
|
# _, image = cv2.imencode('.jpg', frame)
|
|
# image_bytes = image.tobytes()
|
|
annotation = Annotation(frame, timestamp, filtered_detections)
|
|
anns.append(annotation)
|
|
return anns
|
|
|
|
def process(self, video):
|
|
frame_count = 0
|
|
batch_frames = []
|
|
batch_timestamps = []
|
|
v_input = cv2.VideoCapture(video)
|
|
while v_input.isOpened():
|
|
ret, frame = v_input.read()
|
|
if not ret or frame is None:
|
|
break
|
|
|
|
frame_count += 1
|
|
if frame_count % 4 == 0:
|
|
batch_frames.append(frame)
|
|
batch_timestamps.append(int(v_input.get(cv2.CAP_PROP_POS_MSEC)))
|
|
|
|
if len(batch_frames) == self.batch_size:
|
|
input_blob = self.preprocess(batch_frames)
|
|
outputs = self.engine.run(input_blob)
|
|
annotations = self.postprocess(batch_frames, batch_timestamps, outputs)
|
|
for annotation in annotations:
|
|
self.draw(annotation)
|
|
print(f'video: {annotation.time / 1000:.3f}s')
|
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
|
break
|
|
batch_frames.clear()
|
|
batch_timestamps.clear()
|
|
|
|
if len(batch_frames) > 0:
|
|
input_blob = self.preprocess(batch_frames)
|
|
outputs = self.engine.run(input_blob)
|
|
annotations = self.postprocess(batch_frames, batch_timestamps, outputs)
|
|
for annotation in annotations:
|
|
self.draw(annotation)
|
|
print(f'video: {annotation.time / 1000:.3f}s')
|
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
|
break
|
|
|
|
def remove_overlapping_detections(self, detections):
|
|
filtered_output = []
|
|
filtered_out_indexes = []
|
|
|
|
for det1_index in range(len(detections)):
|
|
if det1_index in filtered_out_indexes:
|
|
continue
|
|
det1 = detections[det1_index]
|
|
res = det1_index
|
|
for det2_index in range(det1_index + 1, len(detections)):
|
|
det2 = detections[det2_index]
|
|
if det1.overlaps(det2, self.iou_threshold):
|
|
if det1.confidence > det2.confidence or (det1.confidence == det2.confidence and det1.cls < det2.cls):
|
|
filtered_out_indexes.append(det2_index)
|
|
else:
|
|
filtered_out_indexes.append(res)
|
|
res = det2_index
|
|
filtered_output.append(detections[res])
|
|
filtered_out_indexes.append(res)
|
|
return filtered_output |