mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 21:26:32 +00:00
fix some cython code
This commit is contained in:
+23
-3
@@ -1,3 +1,7 @@
|
||||
import platform
|
||||
import sys
|
||||
|
||||
|
||||
def _check_tensor_gpu_index():
|
||||
try:
|
||||
import pynvml
|
||||
@@ -21,12 +25,28 @@ def _check_tensor_gpu_index():
|
||||
pass
|
||||
|
||||
|
||||
def _is_apple_silicon():
|
||||
if sys.platform != "darwin" or platform.machine() != "arm64":
|
||||
return False
|
||||
try:
|
||||
import coremltools
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
tensor_gpu_index = _check_tensor_gpu_index()
|
||||
|
||||
|
||||
def create_engine(model_bytes: bytes, batch_size: int = 1):
|
||||
def _select_engine_class():
|
||||
if tensor_gpu_index > -1:
|
||||
from engines.tensorrt_engine import TensorRTEngine
|
||||
return TensorRTEngine(model_bytes, batch_size)
|
||||
return TensorRTEngine
|
||||
if _is_apple_silicon():
|
||||
from engines.coreml_engine import CoreMLEngine
|
||||
return CoreMLEngine
|
||||
from engines.onnx_engine import OnnxEngine
|
||||
return OnnxEngine(model_bytes, batch_size)
|
||||
return OnnxEngine
|
||||
|
||||
|
||||
EngineClass = _select_engine_class()
|
||||
|
||||
Reference in New Issue
Block a user