import abc from typing import List, Tuple import numpy as np import onnxruntime as onnx class InferenceEngine(abc.ABC): @abc.abstractmethod def __init__(self, model_path: str, batch_size: int = 1, **kwargs): pass @abc.abstractmethod def get_input_shape(self) -> Tuple[int, int]: pass @abc.abstractmethod def get_batch_size(self) -> int: pass @abc.abstractmethod def run(self, input_data: np.ndarray) -> List[np.ndarray]: pass class OnnxEngine(InferenceEngine): def __init__(self, model_bytes, batch_size: int = 1, **kwargs): self.batch_size = batch_size self.session = onnx.InferenceSession(model_bytes, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) self.model_inputs = self.session.get_inputs() self.input_name = self.model_inputs[0].name self.input_shape = self.model_inputs[0].shape if self.input_shape[0] != -1: self.batch_size = self.input_shape[0] model_meta = self.session.get_modelmeta() print("Metadata:", model_meta.custom_metadata_map) self.class_names = eval(model_meta.custom_metadata_map["names"]) pass def get_input_shape(self) -> Tuple[int, int]: shape = self.input_shape return shape[2], shape[3] def get_batch_size(self) -> int: return self.batch_size def run(self, input_data: np.ndarray) -> List[np.ndarray]: return self.session.run(None, {self.input_name: input_data})