mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 05:26:32 +00:00
50 lines
2.1 KiB
Cython
50 lines
2.1 KiB
Cython
from engines.inference_engine cimport InferenceEngine
|
|
import onnxruntime as onnx
|
|
cimport constants_inf
|
|
|
|
import os
|
|
|
|
def _select_providers():
|
|
available = set(onnx.get_available_providers())
|
|
skip_coreml = os.environ.get("SKIP_COREML", "").lower() in ("1", "true", "yes")
|
|
preferred = ["CoreMLExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"]
|
|
if skip_coreml:
|
|
preferred = [p for p in preferred if p != "CoreMLExecutionProvider"]
|
|
selected = [p for p in preferred if p in available]
|
|
return selected or ["CPUExecutionProvider"]
|
|
|
|
cdef class OnnxEngine(InferenceEngine):
|
|
def __init__(self, model_bytes: bytes, batch_size: int = 1, **kwargs):
|
|
super().__init__(model_bytes, batch_size)
|
|
|
|
providers = _select_providers()
|
|
constants_inf.log(<str>f'ONNX providers: {providers}')
|
|
self.session = onnx.InferenceSession(model_bytes, providers=providers)
|
|
self.model_inputs = self.session.get_inputs()
|
|
self.input_name = self.model_inputs[0].name
|
|
self.input_shape = self.model_inputs[0].shape
|
|
self.batch_size = self.input_shape[0] if self.input_shape[0] != -1 else batch_size
|
|
constants_inf.log(f'AI detection model input: {self.model_inputs} {self.input_shape}')
|
|
model_meta = self.session.get_modelmeta()
|
|
constants_inf.log(f"Metadata: {model_meta.custom_metadata_map}")
|
|
|
|
self._cpu_session = None
|
|
if any("CoreML" in p for p in self.session.get_providers()):
|
|
constants_inf.log(<str>'CoreML active — creating CPU fallback session')
|
|
self._cpu_session = onnx.InferenceSession(
|
|
model_bytes, providers=["CPUExecutionProvider"])
|
|
|
|
cdef tuple get_input_shape(self):
|
|
shape = self.input_shape
|
|
return shape[2], shape[3]
|
|
|
|
cdef int get_batch_size(self):
|
|
return self.batch_size
|
|
|
|
cdef run(self, input_data):
|
|
try:
|
|
return self.session.run(None, {self.input_name: input_data})
|
|
except Exception:
|
|
if self._cpu_session is not None:
|
|
return self._cpu_session.run(None, {self.input_name: input_data})
|
|
raise |