fix some cython code

This commit is contained in:
Oleksandr Bezdieniezhnykh
2026-03-29 21:18:18 +03:00
parent ad5530b9ef
commit 6269a7485c
32 changed files with 17108 additions and 2728 deletions
+56 -3
View File
@@ -1,6 +1,10 @@
from engines.inference_engine cimport InferenceEngine
cimport constants_inf
import numpy as np
import io
import os
import tempfile
import zipfile
cdef class CoreMLEngine(InferenceEngine):
@@ -11,9 +15,7 @@ cdef class CoreMLEngine(InferenceEngine):
model_path = kwargs.get('model_path')
if model_path is None:
raise ValueError(
"CoreMLEngine requires model_path kwarg "
"pointing to a .mlpackage or .mlmodel")
model_path = self._extract_from_zip(model_bytes)
self.model = ct.models.MLModel(
model_path, compute_units=ct.ComputeUnit.ALL)
@@ -32,6 +34,57 @@ cdef class CoreMLEngine(InferenceEngine):
constants_inf.log(<str>f'CoreML model: input={self.input_name} shape={self.input_shape}')
constants_inf.log(<str>f'CoreML outputs: {self._output_names}')
@property
def engine_name(self):
return "coreml"
@staticmethod
def get_engine_filename():
return "azaion_coreml.zip"
@staticmethod
def convert_from_onnx(bytes onnx_bytes):
import coremltools as ct
with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f:
f.write(onnx_bytes)
onnx_path = f.name
try:
constants_inf.log(<str>'Converting ONNX to CoreML...')
model = ct.convert(
onnx_path,
compute_units=ct.ComputeUnit.ALL,
minimum_deployment_target=ct.target.macOS13,
)
with tempfile.TemporaryDirectory() as tmpdir:
pkg_path = os.path.join(tmpdir, "azaion.mlpackage")
model.save(pkg_path)
buf = io.BytesIO()
with zipfile.ZipFile(buf, 'w', zipfile.ZIP_DEFLATED) as zf:
for root, dirs, files in os.walk(pkg_path):
for fname in files:
file_path = os.path.join(root, fname)
arcname = os.path.relpath(file_path, tmpdir)
zf.write(file_path, arcname)
constants_inf.log(<str>'CoreML conversion done!')
return buf.getvalue()
finally:
os.unlink(onnx_path)
@staticmethod
def _extract_from_zip(model_bytes):
tmpdir = tempfile.mkdtemp()
buf = io.BytesIO(model_bytes)
with zipfile.ZipFile(buf, 'r') as zf:
zf.extractall(tmpdir)
for item in os.listdir(tmpdir):
if item.endswith('.mlpackage') or item.endswith('.mlmodel'):
return os.path.join(tmpdir, item)
raise ValueError("No .mlpackage or .mlmodel found in zip")
cdef tuple get_input_shape(self):
return self.input_shape[2], self.input_shape[3]