mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 11:06:32 +00:00
fix some cython code
This commit is contained in:
@@ -1,6 +1,10 @@
|
||||
from engines.inference_engine cimport InferenceEngine
|
||||
cimport constants_inf
|
||||
import numpy as np
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
|
||||
cdef class CoreMLEngine(InferenceEngine):
|
||||
@@ -11,9 +15,7 @@ cdef class CoreMLEngine(InferenceEngine):
|
||||
|
||||
model_path = kwargs.get('model_path')
|
||||
if model_path is None:
|
||||
raise ValueError(
|
||||
"CoreMLEngine requires model_path kwarg "
|
||||
"pointing to a .mlpackage or .mlmodel")
|
||||
model_path = self._extract_from_zip(model_bytes)
|
||||
|
||||
self.model = ct.models.MLModel(
|
||||
model_path, compute_units=ct.ComputeUnit.ALL)
|
||||
@@ -32,6 +34,57 @@ cdef class CoreMLEngine(InferenceEngine):
|
||||
constants_inf.log(<str>f'CoreML model: input={self.input_name} shape={self.input_shape}')
|
||||
constants_inf.log(<str>f'CoreML outputs: {self._output_names}')
|
||||
|
||||
@property
|
||||
def engine_name(self):
|
||||
return "coreml"
|
||||
|
||||
@staticmethod
|
||||
def get_engine_filename():
|
||||
return "azaion_coreml.zip"
|
||||
|
||||
@staticmethod
|
||||
def convert_from_onnx(bytes onnx_bytes):
|
||||
import coremltools as ct
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f:
|
||||
f.write(onnx_bytes)
|
||||
onnx_path = f.name
|
||||
|
||||
try:
|
||||
constants_inf.log(<str>'Converting ONNX to CoreML...')
|
||||
model = ct.convert(
|
||||
onnx_path,
|
||||
compute_units=ct.ComputeUnit.ALL,
|
||||
minimum_deployment_target=ct.target.macOS13,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pkg_path = os.path.join(tmpdir, "azaion.mlpackage")
|
||||
model.save(pkg_path)
|
||||
|
||||
buf = io.BytesIO()
|
||||
with zipfile.ZipFile(buf, 'w', zipfile.ZIP_DEFLATED) as zf:
|
||||
for root, dirs, files in os.walk(pkg_path):
|
||||
for fname in files:
|
||||
file_path = os.path.join(root, fname)
|
||||
arcname = os.path.relpath(file_path, tmpdir)
|
||||
zf.write(file_path, arcname)
|
||||
constants_inf.log(<str>'CoreML conversion done!')
|
||||
return buf.getvalue()
|
||||
finally:
|
||||
os.unlink(onnx_path)
|
||||
|
||||
@staticmethod
|
||||
def _extract_from_zip(model_bytes):
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
buf = io.BytesIO(model_bytes)
|
||||
with zipfile.ZipFile(buf, 'r') as zf:
|
||||
zf.extractall(tmpdir)
|
||||
for item in os.listdir(tmpdir):
|
||||
if item.endswith('.mlpackage') or item.endswith('.mlmodel'):
|
||||
return os.path.join(tmpdir, item)
|
||||
raise ValueError("No .mlpackage or .mlmodel found in zip")
|
||||
|
||||
cdef tuple get_input_shape(self):
|
||||
return self.input_shape[2], self.input_shape[3]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user