mirror of
https://github.com/azaion/annotations.git
synced 2026-04-22 22:56:29 +00:00
25 lines
1.1 KiB
Cython
25 lines
1.1 KiB
Cython
from inference_engine cimport InferenceEngine
|
|
import onnxruntime as onnx
|
|
|
|
cdef class OnnxEngine(InferenceEngine):
|
|
def __init__(self, model_bytes: bytes, batch_size: int = 1, **kwargs):
|
|
super().__init__(model_bytes, batch_size)
|
|
|
|
self.session = onnx.InferenceSession(model_bytes, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
|
|
self.model_inputs = self.session.get_inputs()
|
|
self.input_name = self.model_inputs[0].name
|
|
self.input_shape = self.model_inputs[0].shape
|
|
self.batch_size = self.input_shape[0] if self.input_shape[0] != -1 else batch_size
|
|
print(f'AI detection model input: {self.model_inputs} {self.input_shape}')
|
|
model_meta = self.session.get_modelmeta()
|
|
print("Metadata:", model_meta.custom_metadata_map)
|
|
|
|
cdef tuple get_input_shape(self):
|
|
shape = self.input_shape
|
|
return shape[2], shape[3]
|
|
|
|
cdef int get_batch_size(self):
|
|
return self.batch_size
|
|
|
|
cpdef run(self, input_data):
|
|
return self.session.run(None, {self.input_name: input_data}) |