mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 07:06:32 +00:00
103 lines
3.5 KiB
Cython
103 lines
3.5 KiB
Cython
from engines.inference_engine cimport InferenceEngine
|
|
cimport constants_inf
|
|
import numpy as np
|
|
import io
|
|
import os
|
|
import tempfile
|
|
import zipfile
|
|
|
|
|
|
cdef class CoreMLEngine(InferenceEngine):
|
|
|
|
def __init__(self, model_bytes: bytes, batch_size: int = 1, **kwargs):
|
|
super().__init__(model_bytes, batch_size)
|
|
import coremltools as ct
|
|
|
|
model_path = kwargs.get('model_path')
|
|
if model_path is None:
|
|
model_path = self._extract_from_zip(model_bytes)
|
|
|
|
self.model = ct.models.MLModel(
|
|
model_path, compute_units=ct.ComputeUnit.ALL)
|
|
spec = self.model.get_spec()
|
|
|
|
input_desc = spec.description.input[0]
|
|
self.input_name = input_desc.name
|
|
|
|
array_type = input_desc.type.multiArrayType
|
|
self.input_shape = tuple(int(s) for s in array_type.shape)
|
|
if len(self.input_shape) == 4:
|
|
self.batch_size = self.input_shape[0] if self.input_shape[0] > 0 else batch_size
|
|
|
|
self._output_names = [o.name for o in spec.description.output]
|
|
|
|
constants_inf.log(<str>f'CoreML model: input={self.input_name} shape={self.input_shape}')
|
|
constants_inf.log(<str>f'CoreML outputs: {self._output_names}')
|
|
|
|
@property
|
|
def engine_name(self):
|
|
return "coreml"
|
|
|
|
@staticmethod
|
|
def get_engine_filename():
|
|
return "azaion_coreml.zip"
|
|
|
|
@staticmethod
|
|
def convert_from_onnx(bytes onnx_bytes):
|
|
import coremltools as ct
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f:
|
|
f.write(onnx_bytes)
|
|
onnx_path = f.name
|
|
|
|
try:
|
|
constants_inf.log(<str>'Converting ONNX to CoreML...')
|
|
model = ct.convert(
|
|
onnx_path,
|
|
compute_units=ct.ComputeUnit.ALL,
|
|
minimum_deployment_target=ct.target.macOS13,
|
|
)
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
pkg_path = os.path.join(tmpdir, "azaion.mlpackage")
|
|
model.save(pkg_path)
|
|
|
|
buf = io.BytesIO()
|
|
with zipfile.ZipFile(buf, 'w', zipfile.ZIP_DEFLATED) as zf:
|
|
for root, dirs, files in os.walk(pkg_path):
|
|
for fname in files:
|
|
file_path = os.path.join(root, fname)
|
|
arcname = os.path.relpath(file_path, tmpdir)
|
|
zf.write(file_path, arcname)
|
|
constants_inf.log(<str>'CoreML conversion done!')
|
|
return buf.getvalue()
|
|
finally:
|
|
os.unlink(onnx_path)
|
|
|
|
@staticmethod
|
|
def _extract_from_zip(model_bytes):
|
|
tmpdir = tempfile.mkdtemp()
|
|
buf = io.BytesIO(model_bytes)
|
|
with zipfile.ZipFile(buf, 'r') as zf:
|
|
zf.extractall(tmpdir)
|
|
for item in os.listdir(tmpdir):
|
|
if item.endswith('.mlpackage') or item.endswith('.mlmodel'):
|
|
return os.path.join(tmpdir, item)
|
|
raise ValueError("No .mlpackage or .mlmodel found in zip")
|
|
|
|
cdef tuple get_input_shape(self):
|
|
return self.input_shape[2], self.input_shape[3]
|
|
|
|
cdef int get_batch_size(self):
|
|
return self.batch_size
|
|
|
|
cdef run(self, input_data):
|
|
prediction = self.model.predict({self.input_name: input_data})
|
|
results = []
|
|
for name in self._output_names:
|
|
val = prediction[name]
|
|
if not isinstance(val, np.ndarray):
|
|
val = np.array(val)
|
|
results.append(val)
|
|
return results
|