mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 09:26:32 +00:00
fc57d677b4
- Updated various Cython files to explicitly cast types, enhancing type safety and readability. - Adjusted the `engine_name` property in `InferenceEngine` and its subclasses to be set directly in the constructor. - Modified the `request` method in `_SessionWithBase` to accept `*args` for better flexibility. - Ensured proper type casting for return values in methods across multiple classes, including `Inference`, `CoreMLEngine`, and `TensorRTEngine`. These changes aim to streamline the codebase and improve maintainability by enforcing consistent type usage.
53 lines
1.3 KiB
Python
53 lines
1.3 KiB
Python
import platform
|
|
import sys
|
|
|
|
|
|
def _check_tensor_gpu_index():
|
|
try:
|
|
import pynvml
|
|
pynvml.nvmlInit()
|
|
device_count = pynvml.nvmlDeviceGetCount()
|
|
if device_count == 0:
|
|
return -1
|
|
for i in range(device_count):
|
|
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
|
|
major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
|
if major > 6 or (major == 6 and minor >= 1):
|
|
return i
|
|
return -1
|
|
except Exception:
|
|
return -1
|
|
finally:
|
|
try:
|
|
import pynvml
|
|
pynvml.nvmlShutdown()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def _is_apple_silicon():
|
|
if sys.platform != "darwin" or platform.machine() != "arm64":
|
|
return False
|
|
try:
|
|
import coremltools
|
|
return True
|
|
except ImportError:
|
|
return False
|
|
|
|
|
|
tensor_gpu_index = _check_tensor_gpu_index()
|
|
|
|
|
|
def _select_engine_class():
|
|
if tensor_gpu_index > -1:
|
|
from engines.tensorrt_engine import TensorRTEngine # pyright: ignore[reportMissingImports]
|
|
return TensorRTEngine
|
|
if _is_apple_silicon():
|
|
from engines.coreml_engine import CoreMLEngine
|
|
return CoreMLEngine
|
|
from engines.onnx_engine import OnnxEngine
|
|
return OnnxEngine
|
|
|
|
|
|
EngineClass = _select_engine_class()
|