mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 08:56:35 +00:00
5b89a21b36
add inference with possibility to have different
43 lines
1.2 KiB
Python
43 lines
1.2 KiB
Python
import abc
|
|
from typing import List, Tuple
|
|
import numpy as np
|
|
import onnxruntime as onnx
|
|
|
|
|
|
class InferenceEngine(abc.ABC):
|
|
@abc.abstractmethod
|
|
def __init__(self, model_path: str, batch_size: int = 1, **kwargs):
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def get_input_shape(self) -> Tuple[int, int]:
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def get_batch_size(self) -> int:
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def run(self, input_data: np.ndarray) -> List[np.ndarray]:
|
|
pass
|
|
|
|
|
|
|
|
class OnnxEngine(InferenceEngine):
|
|
def __init__(self, model_path: str, batch_size: int = 1, **kwargs):
|
|
self.model_path = model_path
|
|
self.batch_size = batch_size
|
|
self.session = onnx.InferenceSession(model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
|
|
self.model_inputs = self.session.get_inputs()
|
|
self.input_name = self.model_inputs[0].name
|
|
self.input_shape = self.model_inputs[0].shape
|
|
|
|
def get_input_shape(self) -> Tuple[int, int]:
|
|
shape = self.input_shape
|
|
return shape[2], shape[3]
|
|
|
|
def get_batch_size(self) -> int:
|
|
return self.batch_size
|
|
|
|
def run(self, input_data: np.ndarray) -> List[np.ndarray]:
|
|
return self.session.run(None, {self.input_name: input_data}) |