diff --git a/scripts/jetson/generate_int8_cache.py b/scripts/jetson/generate_int8_cache.py index 5506404..5ec323e 100644 --- a/scripts/jetson/generate_int8_cache.py +++ b/scripts/jetson/generate_int8_cache.py @@ -32,6 +32,11 @@ def parse_args(): parser.add_argument("--num-samples", type=int, default=500) parser.add_argument("--workspace-gb", type=float, default=4.0) parser.add_argument("--no-fp16", action="store_true", help="Do not enable FP16 fallback during INT8 calibration") + parser.add_argument( + "--softmax-fp32", + action="store_true", + help="Force TensorRT SoftMax layers to FP32 as a workaround for Jetson INT8 calibration failures", + ) return parser.parse_args() @@ -131,6 +136,23 @@ def main(): print(parser.get_error(i), file=sys.stderr) sys.exit(1) + if args.softmax_fp32: + constrained = 0 + for i in range(network.num_layers): + layer = network.get_layer(i) + if layer.type == trt.LayerType.SOFTMAX: + layer.precision = trt.float32 + for j in range(layer.num_outputs): + layer.set_output_type(j, trt.float32) + constrained += 1 + if constrained: + for flag_name in ("PREFER_PRECISION_CONSTRAINTS", "OBEY_PRECISION_CONSTRAINTS"): + flag = getattr(trt.BuilderFlag, flag_name, None) + if flag is not None: + config.set_flag(flag) + break + print(f"Forced {constrained} SoftMax layers to FP32", flush=True) + inp = network.get_input(0) shape = inp.shape C = shape[1]