mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 22:36:32 +00:00
33 lines
980 B
Python
33 lines
980 B
Python
def _check_tensor_gpu_index():
|
|
try:
|
|
import pynvml
|
|
pynvml.nvmlInit()
|
|
device_count = pynvml.nvmlDeviceGetCount()
|
|
if device_count == 0:
|
|
return -1
|
|
for i in range(device_count):
|
|
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
|
|
major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
|
if major > 6 or (major == 6 and minor >= 1):
|
|
return i
|
|
return -1
|
|
except Exception:
|
|
return -1
|
|
finally:
|
|
try:
|
|
import pynvml
|
|
pynvml.nvmlShutdown()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
tensor_gpu_index = _check_tensor_gpu_index()
|
|
|
|
|
|
def create_engine(model_bytes: bytes, batch_size: int = 1):
|
|
if tensor_gpu_index > -1:
|
|
from engines.tensorrt_engine import TensorRTEngine
|
|
return TensorRTEngine(model_bytes, batch_size)
|
|
from engines.onnx_engine import OnnxEngine
|
|
return OnnxEngine(model_bytes, batch_size)
|