mirror of
https://github.com/azaion/detections.git
synced 2026-06-21 07:01:09 +00:00
Changed to use additional calibration parameters
ci/woodpecker/push/02-build-push Pipeline was successful
ci/woodpecker/push/02-build-push Pipeline was successful
This commit is contained in:
@@ -30,6 +30,8 @@ def parse_args():
|
||||
parser.add_argument("--output", default="azaion.int8_calib.cache")
|
||||
parser.add_argument("--input-size", type=int, default=1280, help="Model input H=W (default 1280)")
|
||||
parser.add_argument("--num-samples", type=int, default=500)
|
||||
parser.add_argument("--workspace-gb", type=float, default=4.0)
|
||||
parser.add_argument("--no-fp16", action="store_true", help="Do not enable FP16 fallback during INT8 calibration")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -66,7 +68,7 @@ def main():
|
||||
if not images:
|
||||
print(f"No images found in {args.images_dir}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
print(f"Using {len(images)} calibration images")
|
||||
print(f"Using {len(images)} calibration images", flush=True)
|
||||
|
||||
H = W = args.input_size
|
||||
|
||||
@@ -102,7 +104,7 @@ def main():
|
||||
from engines.onnx_tensorrt_compat import prepare_for_tensorrt
|
||||
|
||||
onnx_data = prepare_for_tensorrt(onnx_data)
|
||||
print("Prepared ONNX model for TensorRT static Jetson build")
|
||||
print("Prepared ONNX model for TensorRT static Jetson build", flush=True)
|
||||
except Exception as e:
|
||||
print(f"WARNING: ONNX TensorRT compatibility preparation failed: {e}", file=sys.stderr)
|
||||
logger = trt.Logger(trt.Logger.INFO)
|
||||
@@ -114,9 +116,11 @@ def main():
|
||||
trt.OnnxParser(network, logger) as parser,
|
||||
builder.create_builder_config() as config,
|
||||
):
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 * 1024 ** 3)
|
||||
config.set_memory_pool_limit(
|
||||
trt.MemoryPoolType.WORKSPACE, int(args.workspace_gb * 1024 ** 3)
|
||||
)
|
||||
config.set_flag(trt.BuilderFlag.INT8)
|
||||
if builder.platform_has_fast_fp16:
|
||||
if not args.no_fp16 and builder.platform_has_fast_fp16:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
|
||||
calibrator = _ImageCalibrator()
|
||||
@@ -135,7 +139,7 @@ def main():
|
||||
profile.set_shape(inp.name, (1, C, H, W), (1, C, H, W), (1, C, H, W))
|
||||
config.add_optimization_profile(profile)
|
||||
|
||||
print("Building TensorRT engine with INT8 calibration (several minutes on Jetson)…")
|
||||
print("Building TensorRT engine with INT8 calibration (several minutes on Jetson)...", flush=True)
|
||||
plan = builder.build_serialized_network(network, config)
|
||||
if plan is None:
|
||||
print("Engine build failed", file=sys.stderr)
|
||||
|
||||
Reference in New Issue
Block a user