diff --git a/src/engines/tensorrt_engine.pyx b/src/engines/tensorrt_engine.pyx index 56883a2..b1dba55 100644 --- a/src/engines/tensorrt_engine.pyx +++ b/src/engines/tensorrt_engine.pyx @@ -158,6 +158,8 @@ cdef class TensorRTEngine(InferenceEngine): constants_inf.log('Converting to INT8 with calibration cache') calibrator = _CacheCalibrator(calib_cache_path) config.set_flag(trt.BuilderFlag.INT8) + if builder.platform_has_fast_fp16: + config.set_flag(trt.BuilderFlag.FP16) config.int8_calibrator = calibrator elif builder.platform_has_fast_fp16: constants_inf.log('Converting to supported fp16') diff --git a/tests/test_az180_jetson_int8.py b/tests/test_az180_jetson_int8.py index d9186d4..a67faad 100644 --- a/tests/test_az180_jetson_int8.py +++ b/tests/test_az180_jetson_int8.py @@ -73,6 +73,7 @@ def test_convert_from_source_uses_int8_when_cache_provided(): # Assert mock_config.set_flag.assert_any_call("INT8") + mock_config.set_flag.assert_any_call("FP16") assert mock_config.int8_calibrator is not None finally: os.unlink(cache_path)