Files
ai-training/tests/yolo_predictor.py
T

21 lines
557 B
Python

import cv2
import numpy as np
import yaml
from predictor import Predictor
from ultralytics import YOLO
class YOLOPredictor(Predictor):
def __init__(self):
self.model = YOLO('/azaion/models/azaion.onnx')
self.model.task = 'detect'
with open('data.yaml', 'r') as f:
data_yaml = yaml.safe_load(f)
class_names = data_yaml['names']
names = self.model.names
def predict(self, frame):
results = self.model.track(frame, persist=True, tracker='bytetrack.yaml')
return results[0].plot()