mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 09:06:31 +00:00
Fixed dynamic ONNX input
Fix dynamic ONNX input Update docs with correct file name for tests
This commit is contained in:
@@ -2,6 +2,7 @@ from engines.inference_engine cimport InferenceEngine
|
||||
import onnxruntime as onnx
|
||||
cimport constants_inf
|
||||
|
||||
import ast
|
||||
import os
|
||||
|
||||
def _select_providers():
|
||||
@@ -29,15 +30,40 @@ cdef class OnnxEngine(InferenceEngine):
|
||||
model_meta = self.session.get_modelmeta()
|
||||
constants_inf.log(f"Metadata: {model_meta.custom_metadata_map}")
|
||||
|
||||
self._resolved_input_hw = self._resolve_input_hw(model_meta.custom_metadata_map)
|
||||
|
||||
self._cpu_session = None
|
||||
if any("CoreML" in p for p in self.session.get_providers()):
|
||||
constants_inf.log(<str>'CoreML active — creating CPU fallback session')
|
||||
self._cpu_session = onnx.InferenceSession(
|
||||
model_bytes, providers=["CPUExecutionProvider"])
|
||||
|
||||
cdef tuple _resolve_input_hw(self, object metadata):
|
||||
cdef object h = self.input_shape[2] if len(self.input_shape) > 2 else None
|
||||
cdef object w = self.input_shape[3] if len(self.input_shape) > 3 else None
|
||||
cdef int resolved_h
|
||||
cdef int resolved_w
|
||||
|
||||
if isinstance(h, int) and h > 0 and isinstance(w, int) and w > 0:
|
||||
return <tuple>(h, w)
|
||||
|
||||
try:
|
||||
imgsz = metadata.get("imgsz") if metadata is not None else None
|
||||
if imgsz:
|
||||
parsed = ast.literal_eval(imgsz)
|
||||
if isinstance(parsed, (list, tuple)) and len(parsed) == 2:
|
||||
resolved_h = int(parsed[0])
|
||||
resolved_w = int(parsed[1])
|
||||
if resolved_h > 0 and resolved_w > 0:
|
||||
return <tuple>(resolved_h, resolved_w)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Dynamic ONNX models are expected to use the project's canonical 1280x1280 input.
|
||||
return <tuple>(1280, 1280)
|
||||
|
||||
cdef tuple get_input_shape(self):
|
||||
shape = self.input_shape
|
||||
return <tuple>(shape[2], shape[3])
|
||||
return self._resolved_input_hw
|
||||
|
||||
cdef run(self, input_data):
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user