mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 09:06:31 +00:00
Refactor inference engine and task management: Remove obsolete inference engine and ONNX engine files, update inference processing to utilize batch handling, and enhance task management structure in documentation. Adjust paths for task specifications to align with new directory organization.
This commit is contained in:
@@ -0,0 +1,50 @@
|
||||
from engines.inference_engine cimport InferenceEngine
|
||||
import onnxruntime as onnx
|
||||
cimport constants_inf
|
||||
|
||||
import os
|
||||
|
||||
def _select_providers():
|
||||
available = set(onnx.get_available_providers())
|
||||
skip_coreml = os.environ.get("SKIP_COREML", "").lower() in ("1", "true", "yes")
|
||||
preferred = ["CoreMLExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
if skip_coreml:
|
||||
preferred = [p for p in preferred if p != "CoreMLExecutionProvider"]
|
||||
selected = [p for p in preferred if p in available]
|
||||
return selected or ["CPUExecutionProvider"]
|
||||
|
||||
cdef class OnnxEngine(InferenceEngine):
|
||||
def __init__(self, model_bytes: bytes, batch_size: int = 1, **kwargs):
|
||||
super().__init__(model_bytes, batch_size)
|
||||
|
||||
providers = _select_providers()
|
||||
constants_inf.log(<str>f'ONNX providers: {providers}')
|
||||
self.session = onnx.InferenceSession(model_bytes, providers=providers)
|
||||
self.model_inputs = self.session.get_inputs()
|
||||
self.input_name = self.model_inputs[0].name
|
||||
self.input_shape = self.model_inputs[0].shape
|
||||
self.batch_size = self.input_shape[0] if self.input_shape[0] != -1 else batch_size
|
||||
constants_inf.log(f'AI detection model input: {self.model_inputs} {self.input_shape}')
|
||||
model_meta = self.session.get_modelmeta()
|
||||
constants_inf.log(f"Metadata: {model_meta.custom_metadata_map}")
|
||||
|
||||
self._cpu_session = None
|
||||
if any("CoreML" in p for p in self.session.get_providers()):
|
||||
constants_inf.log(<str>'CoreML active — creating CPU fallback session')
|
||||
self._cpu_session = onnx.InferenceSession(
|
||||
model_bytes, providers=["CPUExecutionProvider"])
|
||||
|
||||
cdef tuple get_input_shape(self):
|
||||
shape = self.input_shape
|
||||
return shape[2], shape[3]
|
||||
|
||||
cdef int get_batch_size(self):
|
||||
return self.batch_size
|
||||
|
||||
cdef run(self, input_data):
|
||||
try:
|
||||
return self.session.run(None, {self.input_name: input_data})
|
||||
except Exception:
|
||||
if self._cpu_session is not None:
|
||||
return self._cpu_session.run(None, {self.input_name: input_data})
|
||||
raise
|
||||
Reference in New Issue
Block a user