mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 11:16:35 +00:00
add download big engine file to cdn manager
revise onnx export process fixes
This commit is contained in:
@@ -8,28 +8,29 @@ import numpy as np
|
||||
import tensorrt as trt
|
||||
import pycuda.driver as cuda
|
||||
from inference.onnx_engine import InferenceEngine
|
||||
import pycuda.autoinit # required for automatically initialize CUDA, do not remove.
|
||||
# required for automatically initialize CUDA, do not remove.
|
||||
import pycuda.autoinit
|
||||
import pynvml
|
||||
|
||||
# TODO: 2. Convert onnx model with 4 batch and make sure it is working
|
||||
|
||||
class TensorRTEngine(InferenceEngine):
|
||||
def __init__(self, model_bytes: bytes, batch_size: int = 4, **kwargs):
|
||||
self.batch_size = batch_size
|
||||
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
|
||||
|
||||
def __init__(self, model_bytes: bytes, **kwargs):
|
||||
try:
|
||||
logger = trt.Logger(trt.Logger.WARNING)
|
||||
# metadata_len = struct.unpack("<I", model_bytes[:4])[0]
|
||||
# try:
|
||||
# self.metadata = json.loads(model_bytes[4:4 + metadata_len])
|
||||
# self.class_names = self.metadata['names']
|
||||
# print(f"Model metadata: {json.dumps(self.metadata, indent=2)}")
|
||||
# except json.JSONDecodeError as err:
|
||||
# print(f"Failed to parse metadata")
|
||||
# return
|
||||
# engine_data = model_bytes[4 + metadata_len:]
|
||||
|
||||
metadata_len = struct.unpack("<I", model_bytes[:4])[0]
|
||||
try:
|
||||
self.metadata = json.loads(model_bytes[4:4 + metadata_len])
|
||||
self.class_names = self.metadata['names']
|
||||
print(f"Model metadata: {json.dumps(self.metadata, indent=2)}")
|
||||
except json.JSONDecodeError as err:
|
||||
print(f"Failed to parse metadata")
|
||||
return
|
||||
engine_data = model_bytes[4 + metadata_len:]
|
||||
|
||||
runtime = trt.Runtime(logger)
|
||||
self.engine = runtime.deserialize_cuda_engine(engine_data)
|
||||
runtime = trt.Runtime(self.TRT_LOGGER)
|
||||
self.engine = runtime.deserialize_cuda_engine(model_bytes)
|
||||
|
||||
if self.engine is None:
|
||||
raise RuntimeError(f"Failed to load TensorRT engine!")
|
||||
@@ -55,7 +56,7 @@ class TensorRTEngine(InferenceEngine):
|
||||
self.output_name = self.engine.get_tensor_name(1)
|
||||
engine_output_shape = tuple(self.engine.get_tensor_shape(self.output_name))
|
||||
self.output_shape = [
|
||||
batch_size if self.input_shape[0] == -1 else self.input_shape[0],
|
||||
4 if self.input_shape[0] == -1 else self.input_shape[0], # by default, batch size is 4
|
||||
300 if engine_output_shape[1] == -1 else engine_output_shape[1], # max detections number
|
||||
6 if engine_output_shape[2] == -1 else engine_output_shape[2] # x1 y1 x2 y2 conf cls
|
||||
]
|
||||
@@ -73,7 +74,61 @@ class TensorRTEngine(InferenceEngine):
|
||||
def get_batch_size(self) -> int:
|
||||
return self.batch_size
|
||||
|
||||
# In tensorrt_engine.py, modify the run method:
|
||||
@staticmethod
|
||||
def get_gpu_memory_bytes(device_id=0) -> int:
|
||||
total_memory = None
|
||||
try:
|
||||
pynvml.nvmlInit()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
|
||||
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
total_memory = mem_info.total
|
||||
except pynvml.NVMLError:
|
||||
total_memory = None
|
||||
finally:
|
||||
try:
|
||||
pynvml.nvmlShutdown()
|
||||
except pynvml.NVMLError:
|
||||
pass
|
||||
return 2 * 1024 * 1024 * 1024 if total_memory is None else total_memory # default 2 Gb
|
||||
|
||||
@staticmethod
|
||||
def get_engine_filename(device_id=0) -> str | None:
|
||||
try:
|
||||
device = cuda.Device(device_id)
|
||||
sm_count = device.multiprocessor_count
|
||||
cc_major, cc_minor = device.compute_capability()
|
||||
return f"azaion.cc_{cc_major}.{cc_minor}_sm_{sm_count}.engine"
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def convert_from_onnx(onnx_model: bytes) -> bytes | None:
|
||||
workspace_bytes = int(TensorRTEngine.get_gpu_memory_bytes() * 0.9)
|
||||
|
||||
explicit_batch_flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
|
||||
with trt.Builder(TensorRTEngine.TRT_LOGGER) as builder, \
|
||||
builder.create_network(explicit_batch_flag) as network, \
|
||||
trt.OnnxParser(network, TensorRTEngine.TRT_LOGGER) as parser, \
|
||||
builder.create_builder_config() as config:
|
||||
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_bytes)
|
||||
|
||||
if not parser.parse(onnx_model):
|
||||
return None
|
||||
|
||||
if builder.platform_has_fast_fp16:
|
||||
print('Converting to supported fp16')
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
else:
|
||||
print('Converting to supported fp32. (fp16 is not supported)')
|
||||
plan = builder.build_serialized_network(network, config)
|
||||
|
||||
if plan is None:
|
||||
print('Conversion failed.')
|
||||
return None
|
||||
|
||||
return bytes(plan)
|
||||
|
||||
def run(self, input_data: np.ndarray) -> List[np.ndarray]:
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user