From 1250542800e876c2e688a0bd810b8e56151fb7be Mon Sep 17 00:00:00 2001 From: Roman Meshko Date: Sat, 23 May 2026 22:11:57 +0300 Subject: [PATCH] Changed to use additional calibration parameters --- scripts/jetson/generate_int8_cache.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/scripts/jetson/generate_int8_cache.py b/scripts/jetson/generate_int8_cache.py index 622df9e..5506404 100644 --- a/scripts/jetson/generate_int8_cache.py +++ b/scripts/jetson/generate_int8_cache.py @@ -30,6 +30,8 @@ def parse_args(): parser.add_argument("--output", default="azaion.int8_calib.cache") parser.add_argument("--input-size", type=int, default=1280, help="Model input H=W (default 1280)") parser.add_argument("--num-samples", type=int, default=500) + parser.add_argument("--workspace-gb", type=float, default=4.0) + parser.add_argument("--no-fp16", action="store_true", help="Do not enable FP16 fallback during INT8 calibration") return parser.parse_args() @@ -66,7 +68,7 @@ def main(): if not images: print(f"No images found in {args.images_dir}", file=sys.stderr) sys.exit(1) - print(f"Using {len(images)} calibration images") + print(f"Using {len(images)} calibration images", flush=True) H = W = args.input_size @@ -102,7 +104,7 @@ def main(): from engines.onnx_tensorrt_compat import prepare_for_tensorrt onnx_data = prepare_for_tensorrt(onnx_data) - print("Prepared ONNX model for TensorRT static Jetson build") + print("Prepared ONNX model for TensorRT static Jetson build", flush=True) except Exception as e: print(f"WARNING: ONNX TensorRT compatibility preparation failed: {e}", file=sys.stderr) logger = trt.Logger(trt.Logger.INFO) @@ -114,9 +116,11 @@ def main(): trt.OnnxParser(network, logger) as parser, builder.create_builder_config() as config, ): - config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 * 1024 ** 3) + config.set_memory_pool_limit( + trt.MemoryPoolType.WORKSPACE, int(args.workspace_gb * 1024 ** 3) + ) config.set_flag(trt.BuilderFlag.INT8) - if builder.platform_has_fast_fp16: + if not args.no_fp16 and builder.platform_has_fast_fp16: config.set_flag(trt.BuilderFlag.FP16) calibrator = _ImageCalibrator() @@ -135,7 +139,7 @@ def main(): profile.set_shape(inp.name, (1, C, H, W), (1, C, H, W), (1, C, H, W)) config.add_optimization_profile(profile) - print("Building TensorRT engine with INT8 calibration (several minutes on Jetson)…") + print("Building TensorRT engine with INT8 calibration (several minutes on Jetson)...", flush=True) plan = builder.build_serialized_network(network, config) if plan is None: print("Engine build failed", file=sys.stderr)