#!/usr/bin/env python3 """ Generate an INT8 calibration cache for TensorRT on Jetson. Run this INSIDE the Jetson Docker container: docker compose -f docker-compose.demo-jetson.yml run --rm \ -v /path/to/images:/calibration \ detections \ python3 scripts/generate_int8_cache.py \ --images-dir /calibration \ --onnx /models/azaion.onnx \ --output /models/azaion.int8_calib.cache The cache file must be in the loader's models volume so the detections service can download it on startup via the Loader API. """ import argparse import sys from pathlib import Path import cv2 import numpy as np def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--images-dir", required=True, help="Directory with calibration images (JPG/PNG)") parser.add_argument("--onnx", required=True, help="Path to azaion.onnx") parser.add_argument("--output", default="azaion.int8_calib.cache") parser.add_argument("--input-size", type=int, default=1280, help="Model input H=W (default 1280)") parser.add_argument("--num-samples", type=int, default=500) parser.add_argument("--workspace-gb", type=float, default=4.0) parser.add_argument("--no-fp16", action="store_true", help="Do not enable FP16 fallback during INT8 calibration") parser.add_argument( "--softmax-fp32", action="store_true", help="Force TensorRT SoftMax layers to FP32 as a workaround for Jetson INT8 calibration failures", ) return parser.parse_args() def collect_images(images_dir: str, num_samples: int) -> list[Path]: root = Path(images_dir) images: list[Path] = [] for pattern in ("**/*.jpg", "**/*.jpeg", "**/*.png"): images += sorted(root.glob(pattern)) return images[:num_samples] def preprocess(path: Path, h: int, w: int) -> np.ndarray | None: img = cv2.imread(str(path)) if img is None: return None img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.resize(img, (w, h)) img = img.astype(np.float32) / 255.0 return np.ascontiguousarray(img.transpose(2, 0, 1)[np.newaxis]) # NCHW def main(): args = parse_args() try: import pycuda.autoinit # noqa: F401 import pycuda.driver as cuda import tensorrt as trt except ImportError as e: print(f"ERROR: {e}\nRun this script inside the Jetson Docker container.", file=sys.stderr) sys.exit(1) images = collect_images(args.images_dir, args.num_samples) if not images: print(f"No images found in {args.images_dir}", file=sys.stderr) sys.exit(1) print(f"Using {len(images)} calibration images", flush=True) H = W = args.input_size class _ImageCalibrator(trt.IInt8EntropyCalibrator2): def __init__(self): super().__init__() self._idx = 0 self._buf = cuda.mem_alloc(3 * H * W * 4) def get_batch_size(self): return 1 def get_batch(self, names): while self._idx < len(images): arr = preprocess(images[self._idx], H, W) self._idx += 1 if arr is None: continue cuda.memcpy_htod(self._buf, arr) return [int(self._buf)] return None def read_calibration_cache(self): return None def write_calibration_cache(self, cache): with open(args.output, "wb") as f: f.write(cache) print(f"Cache written → {args.output}") onnx_data = Path(args.onnx).read_bytes() try: from engines.onnx_tensorrt_compat import prepare_for_tensorrt onnx_data = prepare_for_tensorrt(onnx_data) print("Prepared ONNX model for TensorRT static Jetson build", flush=True) except Exception as e: print(f"WARNING: ONNX TensorRT compatibility preparation failed: {e}", file=sys.stderr) logger = trt.Logger(trt.Logger.INFO) explicit_batch = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) with ( trt.Builder(logger) as builder, builder.create_network(explicit_batch) as network, trt.OnnxParser(network, logger) as parser, builder.create_builder_config() as config, ): config.set_memory_pool_limit( trt.MemoryPoolType.WORKSPACE, int(args.workspace_gb * 1024 ** 3) ) config.set_flag(trt.BuilderFlag.INT8) if not args.no_fp16 and builder.platform_has_fast_fp16: config.set_flag(trt.BuilderFlag.FP16) calibrator = _ImageCalibrator() config.int8_calibrator = calibrator if not parser.parse(onnx_data): for i in range(parser.num_errors): print(parser.get_error(i), file=sys.stderr) sys.exit(1) if args.softmax_fp32: constrained = 0 for i in range(network.num_layers): layer = network.get_layer(i) if layer.type == trt.LayerType.SOFTMAX: layer.precision = trt.float32 for j in range(layer.num_outputs): layer.set_output_type(j, trt.float32) constrained += 1 if constrained: for flag_name in ("PREFER_PRECISION_CONSTRAINTS", "OBEY_PRECISION_CONSTRAINTS"): flag = getattr(trt.BuilderFlag, flag_name, None) if flag is not None: config.set_flag(flag) break print(f"Forced {constrained} SoftMax layers to FP32", flush=True) inp = network.get_input(0) shape = inp.shape C = shape[1] if shape[0] == -1: profile = builder.create_optimization_profile() profile.set_shape(inp.name, (1, C, H, W), (1, C, H, W), (1, C, H, W)) config.add_optimization_profile(profile) print("Building TensorRT engine with INT8 calibration (several minutes on Jetson)...", flush=True) plan = builder.build_serialized_network(network, config) if plan is None: print("Engine build failed", file=sys.stderr) sys.exit(1) print("Done. Upload the cache to the Loader before (re)starting the detections service.") if __name__ == "__main__": main()