mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 09:06:31 +00:00
Add AIAvailabilityStatus and AIRecognitionConfig classes for AI model management
- Introduced `AIAvailabilityStatus` class to manage the availability status of AI models, including methods for setting status and logging messages. - Added `AIRecognitionConfig` class to encapsulate configuration parameters for AI recognition, with a static method for creating instances from dictionaries. - Implemented enums for AI availability states to enhance clarity and maintainability. - Updated related Cython files to support the new classes and ensure proper type handling. These changes aim to improve the structure and functionality of the AI model management system, facilitating better status tracking and configuration handling.
This commit is contained in:
@@ -0,0 +1,48 @@
|
||||
from engines.inference_engine cimport InferenceEngine
|
||||
import onnxruntime as onnx
|
||||
cimport constants_inf
|
||||
|
||||
import os
|
||||
|
||||
def _select_providers():
|
||||
available = set(onnx.get_available_providers())
|
||||
skip_coreml = os.environ.get("SKIP_COREML", "").lower() in ("1", "true", "yes")
|
||||
preferred = ["CoreMLExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
if skip_coreml:
|
||||
preferred = [p for p in preferred if p != "CoreMLExecutionProvider"]
|
||||
selected = [p for p in preferred if p in available]
|
||||
return selected or ["CPUExecutionProvider"]
|
||||
|
||||
cdef class OnnxEngine(InferenceEngine):
|
||||
def __init__(self, model_bytes: bytes, max_batch_size: int = 8, **kwargs):
|
||||
InferenceEngine.__init__(self, model_bytes, max_batch_size)
|
||||
|
||||
providers = _select_providers()
|
||||
constants_inf.log(<str>f'ONNX providers: {providers}')
|
||||
self.session = onnx.InferenceSession(model_bytes, providers=providers)
|
||||
self.model_inputs = self.session.get_inputs()
|
||||
self.input_name = self.model_inputs[0].name
|
||||
self.input_shape = self.model_inputs[0].shape
|
||||
if self.input_shape[0] not in (-1, None, "N"):
|
||||
self.max_batch_size = self.input_shape[0]
|
||||
constants_inf.log(f'AI detection model input: {self.model_inputs} {self.input_shape}')
|
||||
model_meta = self.session.get_modelmeta()
|
||||
constants_inf.log(f"Metadata: {model_meta.custom_metadata_map}")
|
||||
|
||||
self._cpu_session = None
|
||||
if any("CoreML" in p for p in self.session.get_providers()):
|
||||
constants_inf.log(<str>'CoreML active — creating CPU fallback session')
|
||||
self._cpu_session = onnx.InferenceSession(
|
||||
model_bytes, providers=["CPUExecutionProvider"])
|
||||
|
||||
cdef tuple get_input_shape(self):
|
||||
shape = self.input_shape
|
||||
return <tuple>(shape[2], shape[3])
|
||||
|
||||
cdef run(self, input_data):
|
||||
try:
|
||||
return self.session.run(None, {self.input_name: input_data})
|
||||
except Exception:
|
||||
if self._cpu_session is not None:
|
||||
return self._cpu_session.run(None, {self.input_name: input_data})
|
||||
raise
|
||||
Reference in New Issue
Block a user