mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 05:26:32 +00:00
7d897df380
Fix dynamic ONNX input Update docs with correct file name for tests
75 lines
3.2 KiB
Cython
75 lines
3.2 KiB
Cython
from engines.inference_engine cimport InferenceEngine
|
|
import onnxruntime as onnx
|
|
cimport constants_inf
|
|
|
|
import ast
|
|
import os
|
|
|
|
def _select_providers():
|
|
available = set(onnx.get_available_providers())
|
|
skip_coreml = os.environ.get("SKIP_COREML", "").lower() in ("1", "true", "yes")
|
|
preferred = ["CoreMLExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"]
|
|
if skip_coreml:
|
|
preferred = [p for p in preferred if p != "CoreMLExecutionProvider"]
|
|
selected = [p for p in preferred if p in available]
|
|
return selected or ["CPUExecutionProvider"]
|
|
|
|
cdef class OnnxEngine(InferenceEngine):
|
|
def __init__(self, model_bytes: bytes, max_batch_size: int = 8, **kwargs):
|
|
InferenceEngine.__init__(self, model_bytes, max_batch_size)
|
|
|
|
providers = _select_providers()
|
|
constants_inf.log(<str>f'ONNX providers: {providers}')
|
|
self.session = onnx.InferenceSession(model_bytes, providers=providers)
|
|
self.model_inputs = self.session.get_inputs()
|
|
self.input_name = self.model_inputs[0].name
|
|
self.input_shape = self.model_inputs[0].shape
|
|
if isinstance(self.input_shape[0], int) and self.input_shape[0] > 0:
|
|
self.max_batch_size = self.input_shape[0]
|
|
constants_inf.log(f'AI detection model input: {self.model_inputs} {self.input_shape}')
|
|
model_meta = self.session.get_modelmeta()
|
|
constants_inf.log(f"Metadata: {model_meta.custom_metadata_map}")
|
|
|
|
self._resolved_input_hw = self._resolve_input_hw(model_meta.custom_metadata_map)
|
|
|
|
self._cpu_session = None
|
|
if any("CoreML" in p for p in self.session.get_providers()):
|
|
constants_inf.log(<str>'CoreML active — creating CPU fallback session')
|
|
self._cpu_session = onnx.InferenceSession(
|
|
model_bytes, providers=["CPUExecutionProvider"])
|
|
|
|
cdef tuple _resolve_input_hw(self, object metadata):
|
|
cdef object h = self.input_shape[2] if len(self.input_shape) > 2 else None
|
|
cdef object w = self.input_shape[3] if len(self.input_shape) > 3 else None
|
|
cdef int resolved_h
|
|
cdef int resolved_w
|
|
|
|
if isinstance(h, int) and h > 0 and isinstance(w, int) and w > 0:
|
|
return <tuple>(h, w)
|
|
|
|
try:
|
|
imgsz = metadata.get("imgsz") if metadata is not None else None
|
|
if imgsz:
|
|
parsed = ast.literal_eval(imgsz)
|
|
if isinstance(parsed, (list, tuple)) and len(parsed) == 2:
|
|
resolved_h = int(parsed[0])
|
|
resolved_w = int(parsed[1])
|
|
if resolved_h > 0 and resolved_w > 0:
|
|
return <tuple>(resolved_h, resolved_w)
|
|
except Exception:
|
|
pass
|
|
|
|
# Dynamic ONNX models are expected to use the project's canonical 1280x1280 input.
|
|
return <tuple>(1280, 1280)
|
|
|
|
cdef tuple get_input_shape(self):
|
|
return self._resolved_input_hw
|
|
|
|
cdef run(self, input_data):
|
|
try:
|
|
return self.session.run(None, {self.input_name: input_data})
|
|
except Exception:
|
|
if self._cpu_session is not None:
|
|
return self._cpu_session.run(None, {self.input_name: input_data})
|
|
raise
|