mirror of
https://github.com/azaion/detections.git
synced 2026-06-23 14:31:09 +00:00
This commit is contained in:
@@ -114,13 +114,21 @@ cdef class TensorRTEngine(InferenceEngine):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def convert_from_source(bytes onnx_model, str calib_cache_path=None):
|
||||
def convert_from_source(bytes onnx_model, str calib_cache_path=None, bint force_static_input=False):
|
||||
gpu_mem = TensorRTEngine.get_gpu_memory_bytes(0)
|
||||
workspace_bytes = int(gpu_mem * 0.9)
|
||||
|
||||
explicit_batch_flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
trt_logger = trt.Logger(trt.Logger.WARNING)
|
||||
|
||||
if force_static_input:
|
||||
try:
|
||||
from engines.onnx_tensorrt_compat import prepare_for_tensorrt
|
||||
onnx_model = prepare_for_tensorrt(onnx_model)
|
||||
constants_inf.log(<str>'Prepared ONNX model for TensorRT static Jetson build')
|
||||
except Exception as e:
|
||||
constants_inf.logerror(<str>f'ONNX TensorRT compatibility preparation failed: {str(e)}')
|
||||
|
||||
with trt.Builder(trt_logger) as builder, \
|
||||
builder.create_network(explicit_batch_flag) as network, \
|
||||
trt.OnnxParser(network, trt_logger) as parser, \
|
||||
@@ -129,6 +137,8 @@ cdef class TensorRTEngine(InferenceEngine):
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_bytes)
|
||||
|
||||
if not parser.parse(onnx_model):
|
||||
for i in range(parser.num_errors):
|
||||
constants_inf.logerror(<str>f'TensorRT ONNX parser error: {parser.get_error(i)}')
|
||||
return None
|
||||
|
||||
input_tensor = network.get_input(0)
|
||||
@@ -137,7 +147,9 @@ cdef class TensorRTEngine(InferenceEngine):
|
||||
H = max(shape[2], 1280) if shape[2] != -1 else 1280
|
||||
W = max(shape[3], 1280) if shape[3] != -1 else 1280
|
||||
|
||||
if shape[0] == -1:
|
||||
if force_static_input:
|
||||
input_tensor.shape = (1, C, H, W)
|
||||
elif shape[0] == -1 or shape[2] == -1 or shape[3] == -1:
|
||||
max_batch = TensorRTEngine.calculate_max_batch_size(gpu_mem, H, W)
|
||||
profile = builder.create_optimization_profile()
|
||||
profile.set_shape(
|
||||
|
||||
Reference in New Issue
Block a user