from engines.inference_engine cimport InferenceEngine import onnxruntime as onnx cimport constants_inf import ast import os def _select_providers(): available = set(onnx.get_available_providers()) skip_coreml = os.environ.get("SKIP_COREML", "").lower() in ("1", "true", "yes") preferred = ["CoreMLExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"] if skip_coreml: preferred = [p for p in preferred if p != "CoreMLExecutionProvider"] selected = [p for p in preferred if p in available] return selected or ["CPUExecutionProvider"] cdef class OnnxEngine(InferenceEngine): def __init__(self, model_bytes: bytes, max_batch_size: int = 8, **kwargs): InferenceEngine.__init__(self, model_bytes, max_batch_size) providers = _select_providers() constants_inf.log(f'ONNX providers: {providers}') self.session = onnx.InferenceSession(model_bytes, providers=providers) self.model_inputs = self.session.get_inputs() self.input_name = self.model_inputs[0].name self.input_shape = self.model_inputs[0].shape if isinstance(self.input_shape[0], int) and self.input_shape[0] > 0: self.max_batch_size = self.input_shape[0] constants_inf.log(f'AI detection model input: {self.model_inputs} {self.input_shape}') model_meta = self.session.get_modelmeta() constants_inf.log(f"Metadata: {model_meta.custom_metadata_map}") self._resolved_input_hw = self._resolve_input_hw(model_meta.custom_metadata_map) self._cpu_session = None if any("CoreML" in p for p in self.session.get_providers()): constants_inf.log('CoreML active — creating CPU fallback session') self._cpu_session = onnx.InferenceSession( model_bytes, providers=["CPUExecutionProvider"]) cdef tuple _resolve_input_hw(self, object metadata): cdef object h = self.input_shape[2] if len(self.input_shape) > 2 else None cdef object w = self.input_shape[3] if len(self.input_shape) > 3 else None cdef int resolved_h cdef int resolved_w if isinstance(h, int) and h > 0 and isinstance(w, int) and w > 0: return (h, w) try: imgsz = metadata.get("imgsz") if metadata is not None else None if imgsz: parsed = ast.literal_eval(imgsz) if isinstance(parsed, (list, tuple)) and len(parsed) == 2: resolved_h = int(parsed[0]) resolved_w = int(parsed[1]) if resolved_h > 0 and resolved_w > 0: return (resolved_h, resolved_w) except Exception: pass # Dynamic ONNX models are expected to use the project's canonical 1280x1280 input. return (1280, 1280) cdef tuple get_input_shape(self): return self._resolved_input_hw cdef run(self, input_data): try: return self.session.run(None, {self.input_name: input_data}) except Exception: if self._cpu_session is not None: return self._cpu_session.run(None, {self.input_name: input_data}) raise