Files
detections/engines/coreml_engine.pyx
T

50 lines
1.7 KiB
Cython

from engines.inference_engine cimport InferenceEngine
cimport constants_inf
import numpy as np
cdef class CoreMLEngine(InferenceEngine):
def __init__(self, model_bytes: bytes, batch_size: int = 1, **kwargs):
super().__init__(model_bytes, batch_size)
import coremltools as ct
model_path = kwargs.get('model_path')
if model_path is None:
raise ValueError(
"CoreMLEngine requires model_path kwarg "
"pointing to a .mlpackage or .mlmodel")
self.model = ct.models.MLModel(
model_path, compute_units=ct.ComputeUnit.ALL)
spec = self.model.get_spec()
input_desc = spec.description.input[0]
self.input_name = input_desc.name
array_type = input_desc.type.multiArrayType
self.input_shape = tuple(int(s) for s in array_type.shape)
if len(self.input_shape) == 4:
self.batch_size = self.input_shape[0] if self.input_shape[0] > 0 else batch_size
self._output_names = [o.name for o in spec.description.output]
constants_inf.log(<str>f'CoreML model: input={self.input_name} shape={self.input_shape}')
constants_inf.log(<str>f'CoreML outputs: {self._output_names}')
cdef tuple get_input_shape(self):
return self.input_shape[2], self.input_shape[3]
cdef int get_batch_size(self):
return self.batch_size
cdef run(self, input_data):
prediction = self.model.predict({self.input_name: input_data})
results = []
for name in self._output_names:
val = prediction[name]
if not isinstance(val, np.ndarray):
val = np.array(val)
results.append(val)
return results