mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 18:36:35 +00:00
Refactor inference engine and task management: Remove obsolete inference engine and ONNX engine files, update inference processing to utilize batch handling, and enhance task management structure in documentation. Adjust paths for task specifications to align with new directory organization.
This commit is contained in:
@@ -0,0 +1,49 @@
|
||||
from engines.inference_engine cimport InferenceEngine
|
||||
cimport constants_inf
|
||||
import numpy as np
|
||||
|
||||
|
||||
cdef class CoreMLEngine(InferenceEngine):
|
||||
|
||||
def __init__(self, model_bytes: bytes, batch_size: int = 1, **kwargs):
|
||||
super().__init__(model_bytes, batch_size)
|
||||
import coremltools as ct
|
||||
|
||||
model_path = kwargs.get('model_path')
|
||||
if model_path is None:
|
||||
raise ValueError(
|
||||
"CoreMLEngine requires model_path kwarg "
|
||||
"pointing to a .mlpackage or .mlmodel")
|
||||
|
||||
self.model = ct.models.MLModel(
|
||||
model_path, compute_units=ct.ComputeUnit.ALL)
|
||||
spec = self.model.get_spec()
|
||||
|
||||
input_desc = spec.description.input[0]
|
||||
self.input_name = input_desc.name
|
||||
|
||||
array_type = input_desc.type.multiArrayType
|
||||
self.input_shape = tuple(int(s) for s in array_type.shape)
|
||||
if len(self.input_shape) == 4:
|
||||
self.batch_size = self.input_shape[0] if self.input_shape[0] > 0 else batch_size
|
||||
|
||||
self._output_names = [o.name for o in spec.description.output]
|
||||
|
||||
constants_inf.log(<str>f'CoreML model: input={self.input_name} shape={self.input_shape}')
|
||||
constants_inf.log(<str>f'CoreML outputs: {self._output_names}')
|
||||
|
||||
cdef tuple get_input_shape(self):
|
||||
return self.input_shape[2], self.input_shape[3]
|
||||
|
||||
cdef int get_batch_size(self):
|
||||
return self.batch_size
|
||||
|
||||
cdef run(self, input_data):
|
||||
prediction = self.model.predict({self.input_name: input_data})
|
||||
results = []
|
||||
for name in self._output_names:
|
||||
val = prediction[name]
|
||||
if not isinstance(val, np.ndarray):
|
||||
val = np.array(val)
|
||||
results.append(val)
|
||||
return results
|
||||
Reference in New Issue
Block a user