mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 05:26:32 +00:00
50 lines
1.7 KiB
Cython
50 lines
1.7 KiB
Cython
from engines.inference_engine cimport InferenceEngine
|
|
cimport constants_inf
|
|
import numpy as np
|
|
|
|
|
|
cdef class CoreMLEngine(InferenceEngine):
|
|
|
|
def __init__(self, model_bytes: bytes, batch_size: int = 1, **kwargs):
|
|
super().__init__(model_bytes, batch_size)
|
|
import coremltools as ct
|
|
|
|
model_path = kwargs.get('model_path')
|
|
if model_path is None:
|
|
raise ValueError(
|
|
"CoreMLEngine requires model_path kwarg "
|
|
"pointing to a .mlpackage or .mlmodel")
|
|
|
|
self.model = ct.models.MLModel(
|
|
model_path, compute_units=ct.ComputeUnit.ALL)
|
|
spec = self.model.get_spec()
|
|
|
|
input_desc = spec.description.input[0]
|
|
self.input_name = input_desc.name
|
|
|
|
array_type = input_desc.type.multiArrayType
|
|
self.input_shape = tuple(int(s) for s in array_type.shape)
|
|
if len(self.input_shape) == 4:
|
|
self.batch_size = self.input_shape[0] if self.input_shape[0] > 0 else batch_size
|
|
|
|
self._output_names = [o.name for o in spec.description.output]
|
|
|
|
constants_inf.log(<str>f'CoreML model: input={self.input_name} shape={self.input_shape}')
|
|
constants_inf.log(<str>f'CoreML outputs: {self._output_names}')
|
|
|
|
cdef tuple get_input_shape(self):
|
|
return self.input_shape[2], self.input_shape[3]
|
|
|
|
cdef int get_batch_size(self):
|
|
return self.batch_size
|
|
|
|
cdef run(self, input_data):
|
|
prediction = self.model.predict({self.input_name: input_data})
|
|
results = []
|
|
for name in self._output_names:
|
|
val = prediction[name]
|
|
if not isinstance(val, np.ndarray):
|
|
val = np.array(val)
|
|
results.append(val)
|
|
return results
|