mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 06:46:32 +00:00
Fixed dynamic ONNX input
Fix dynamic ONNX input Update docs with correct file name for tests
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
cimport constants_inf
|
||||
from loader_http_client cimport LoaderHttpClient, LoadResult
|
||||
|
||||
|
||||
@@ -42,7 +43,6 @@ class EngineFactory:
|
||||
|
||||
def build_and_cache(self, bytes source_bytes, LoaderHttpClient loader_client, str models_dir):
|
||||
cdef LoadResult res
|
||||
import constants_inf
|
||||
engine_bytes, engine_filename = self.build_from_source(source_bytes, loader_client, models_dir)
|
||||
res = loader_client.upload_big_small_resource(engine_bytes, engine_filename, models_dir)
|
||||
if res.err is not None:
|
||||
@@ -56,7 +56,6 @@ class OnnxEngineFactory(EngineFactory):
|
||||
return OnnxEngine(model_bytes)
|
||||
|
||||
def get_source_filename(self):
|
||||
import constants_inf
|
||||
return constants_inf.AI_ONNX_MODEL_FILE
|
||||
|
||||
|
||||
@@ -81,7 +80,6 @@ class TensorRTEngineFactory(EngineFactory):
|
||||
return TensorRTEngine.get_engine_filename()
|
||||
|
||||
def get_source_filename(self):
|
||||
import constants_inf
|
||||
return constants_inf.AI_ONNX_MODEL_FILE
|
||||
|
||||
def build_from_source(self, onnx_bytes, loader_client, models_dir):
|
||||
|
||||
@@ -8,6 +8,8 @@ cdef class OnnxEngine(InferenceEngine):
|
||||
cdef object model_inputs
|
||||
cdef str input_name
|
||||
cdef object input_shape
|
||||
cdef object _resolved_input_hw
|
||||
|
||||
cdef tuple _resolve_input_hw(self, object metadata)
|
||||
cdef tuple get_input_shape(self)
|
||||
cdef run(self, input_data)
|
||||
|
||||
@@ -2,6 +2,7 @@ from engines.inference_engine cimport InferenceEngine
|
||||
import onnxruntime as onnx
|
||||
cimport constants_inf
|
||||
|
||||
import ast
|
||||
import os
|
||||
|
||||
def _select_providers():
|
||||
@@ -29,15 +30,40 @@ cdef class OnnxEngine(InferenceEngine):
|
||||
model_meta = self.session.get_modelmeta()
|
||||
constants_inf.log(f"Metadata: {model_meta.custom_metadata_map}")
|
||||
|
||||
self._resolved_input_hw = self._resolve_input_hw(model_meta.custom_metadata_map)
|
||||
|
||||
self._cpu_session = None
|
||||
if any("CoreML" in p for p in self.session.get_providers()):
|
||||
constants_inf.log(<str>'CoreML active — creating CPU fallback session')
|
||||
self._cpu_session = onnx.InferenceSession(
|
||||
model_bytes, providers=["CPUExecutionProvider"])
|
||||
|
||||
cdef tuple _resolve_input_hw(self, object metadata):
|
||||
cdef object h = self.input_shape[2] if len(self.input_shape) > 2 else None
|
||||
cdef object w = self.input_shape[3] if len(self.input_shape) > 3 else None
|
||||
cdef int resolved_h
|
||||
cdef int resolved_w
|
||||
|
||||
if isinstance(h, int) and h > 0 and isinstance(w, int) and w > 0:
|
||||
return <tuple>(h, w)
|
||||
|
||||
try:
|
||||
imgsz = metadata.get("imgsz") if metadata is not None else None
|
||||
if imgsz:
|
||||
parsed = ast.literal_eval(imgsz)
|
||||
if isinstance(parsed, (list, tuple)) and len(parsed) == 2:
|
||||
resolved_h = int(parsed[0])
|
||||
resolved_w = int(parsed[1])
|
||||
if resolved_h > 0 and resolved_w > 0:
|
||||
return <tuple>(resolved_h, resolved_w)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Dynamic ONNX models are expected to use the project's canonical 1280x1280 input.
|
||||
return <tuple>(1280, 1280)
|
||||
|
||||
cdef tuple get_input_shape(self):
|
||||
shape = self.input_shape
|
||||
return <tuple>(shape[2], shape[3])
|
||||
return self._resolved_input_hw
|
||||
|
||||
cdef run(self, input_data):
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user