Files
detections/src/engines/onnx_engine.pyx
T
Roman Meshko 7d897df380 Fixed dynamic ONNX input
Fix dynamic ONNX input
Update docs with correct file name for tests
2026-04-19 20:55:51 +03:00

75 lines
3.2 KiB
Cython

from engines.inference_engine cimport InferenceEngine
import onnxruntime as onnx
cimport constants_inf
import ast
import os
def _select_providers():
available = set(onnx.get_available_providers())
skip_coreml = os.environ.get("SKIP_COREML", "").lower() in ("1", "true", "yes")
preferred = ["CoreMLExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"]
if skip_coreml:
preferred = [p for p in preferred if p != "CoreMLExecutionProvider"]
selected = [p for p in preferred if p in available]
return selected or ["CPUExecutionProvider"]
cdef class OnnxEngine(InferenceEngine):
def __init__(self, model_bytes: bytes, max_batch_size: int = 8, **kwargs):
InferenceEngine.__init__(self, model_bytes, max_batch_size)
providers = _select_providers()
constants_inf.log(<str>f'ONNX providers: {providers}')
self.session = onnx.InferenceSession(model_bytes, providers=providers)
self.model_inputs = self.session.get_inputs()
self.input_name = self.model_inputs[0].name
self.input_shape = self.model_inputs[0].shape
if isinstance(self.input_shape[0], int) and self.input_shape[0] > 0:
self.max_batch_size = self.input_shape[0]
constants_inf.log(f'AI detection model input: {self.model_inputs} {self.input_shape}')
model_meta = self.session.get_modelmeta()
constants_inf.log(f"Metadata: {model_meta.custom_metadata_map}")
self._resolved_input_hw = self._resolve_input_hw(model_meta.custom_metadata_map)
self._cpu_session = None
if any("CoreML" in p for p in self.session.get_providers()):
constants_inf.log(<str>'CoreML active — creating CPU fallback session')
self._cpu_session = onnx.InferenceSession(
model_bytes, providers=["CPUExecutionProvider"])
cdef tuple _resolve_input_hw(self, object metadata):
cdef object h = self.input_shape[2] if len(self.input_shape) > 2 else None
cdef object w = self.input_shape[3] if len(self.input_shape) > 3 else None
cdef int resolved_h
cdef int resolved_w
if isinstance(h, int) and h > 0 and isinstance(w, int) and w > 0:
return <tuple>(h, w)
try:
imgsz = metadata.get("imgsz") if metadata is not None else None
if imgsz:
parsed = ast.literal_eval(imgsz)
if isinstance(parsed, (list, tuple)) and len(parsed) == 2:
resolved_h = int(parsed[0])
resolved_w = int(parsed[1])
if resolved_h > 0 and resolved_w > 0:
return <tuple>(resolved_h, resolved_w)
except Exception:
pass
# Dynamic ONNX models are expected to use the project's canonical 1280x1280 input.
return <tuple>(1280, 1280)
cdef tuple get_input_shape(self):
return self._resolved_input_hw
cdef run(self, input_data):
try:
return self.session.run(None, {self.input_name: input_data})
except Exception:
if self._cpu_session is not None:
return self._cpu_session.run(None, {self.input_name: input_data})
raise