mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 23:36:36 +00:00
fix tensor rt engine
This commit is contained in:
committed by
Alex Bezdieniezhnykh
parent
5b89a21b36
commit
06a23525a6
@@ -22,15 +22,19 @@ class InferenceEngine(abc.ABC):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class OnnxEngine(InferenceEngine):
|
||||
def __init__(self, model_path: str, batch_size: int = 1, **kwargs):
|
||||
self.model_path = model_path
|
||||
def __init__(self, model_bytes, batch_size: int = 1, **kwargs):
|
||||
self.batch_size = batch_size
|
||||
self.session = onnx.InferenceSession(model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
|
||||
self.session = onnx.InferenceSession(model_bytes, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
|
||||
self.model_inputs = self.session.get_inputs()
|
||||
self.input_name = self.model_inputs[0].name
|
||||
self.input_shape = self.model_inputs[0].shape
|
||||
if self.input_shape[0] != -1:
|
||||
self.batch_size = self.input_shape[0]
|
||||
model_meta = self.session.get_modelmeta()
|
||||
print("Metadata:", model_meta.custom_metadata_map)
|
||||
self.class_names = eval(model_meta.custom_metadata_map["names"])
|
||||
pass
|
||||
|
||||
def get_input_shape(self) -> Tuple[int, int]:
|
||||
shape = self.input_shape
|
||||
|
||||
Reference in New Issue
Block a user