mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 22:36:32 +00:00
53 lines
1.3 KiB
Python
53 lines
1.3 KiB
Python
import platform
|
|
import sys
|
|
|
|
|
|
def _check_tensor_gpu_index():
|
|
try:
|
|
import pynvml
|
|
pynvml.nvmlInit()
|
|
device_count = pynvml.nvmlDeviceGetCount()
|
|
if device_count == 0:
|
|
return -1
|
|
for i in range(device_count):
|
|
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
|
|
major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
|
if major > 6 or (major == 6 and minor >= 1):
|
|
return i
|
|
return -1
|
|
except Exception:
|
|
return -1
|
|
finally:
|
|
try:
|
|
import pynvml
|
|
pynvml.nvmlShutdown()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def _is_apple_silicon():
|
|
if sys.platform != "darwin" or platform.machine() != "arm64":
|
|
return False
|
|
try:
|
|
import coremltools
|
|
return True
|
|
except ImportError:
|
|
return False
|
|
|
|
|
|
tensor_gpu_index = _check_tensor_gpu_index()
|
|
|
|
|
|
def _select_engine_class():
|
|
if tensor_gpu_index > -1:
|
|
from engines.tensorrt_engine import TensorRTEngine
|
|
return TensorRTEngine
|
|
if _is_apple_silicon():
|
|
from engines.coreml_engine import CoreMLEngine
|
|
return CoreMLEngine
|
|
from engines.onnx_engine import OnnxEngine
|
|
return OnnxEngine
|
|
|
|
|
|
EngineClass = _select_engine_class()
|