mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 19:26:36 +00:00
upload model to cdn and api
switch to yolov11
This commit is contained in:
@@ -1,6 +1,33 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from ultralytics import YOLO
|
||||
import yaml
|
||||
|
||||
|
||||
class Predictor(ABC):
|
||||
@abstractmethod
|
||||
def predict(self, frame):
|
||||
pass
|
||||
|
||||
|
||||
class OnnxPredictor(Predictor):
|
||||
def __init__(self):
|
||||
self.model = YOLO('azaion.onnx')
|
||||
self.model.task = 'detect'
|
||||
with open('data.yaml', 'r') as f:
|
||||
data_yaml = yaml.safe_load(f)
|
||||
class_names = data_yaml['names']
|
||||
|
||||
names = self.model.names
|
||||
|
||||
def predict(self, frame):
|
||||
results = self.model.track(frame, persist=True, tracker='bytetrack.yaml')
|
||||
return results[0].plot()
|
||||
|
||||
|
||||
class YoloPredictor(Predictor):
|
||||
def __init__(self):
|
||||
self.model = YOLO('azaion.pt')
|
||||
|
||||
def predict(self, frame):
|
||||
results = self.model.track(frame, persist=True, tracker='bytetrack.yaml')
|
||||
return results[0].plot()
|
||||
|
||||
Reference in New Issue
Block a user