import os import tempfile from unittest.mock import MagicMock, call, patch import pytest try: import tensorrt # noqa: F401 import pycuda.driver # noqa: F401 HAS_TENSORRT = True except ImportError: HAS_TENSORRT = False requires_tensorrt = pytest.mark.skipif( not HAS_TENSORRT, reason="TensorRT and PyCUDA required (GPU / Jetson environment)", ) def _make_mock_trt(): mock_trt = MagicMock() mock_trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH = 0 mock_trt.Logger.WARNING = "WARNING" mock_trt.MemoryPoolType.WORKSPACE = "WORKSPACE" mock_trt.BuilderFlag.INT8 = "INT8" mock_trt.BuilderFlag.FP16 = "FP16" mock_builder = MagicMock() mock_builder.platform_has_fast_fp16 = True mock_config = MagicMock() mock_network = MagicMock() mock_parser = MagicMock() mock_parser.parse.return_value = True mock_input = MagicMock() mock_input.shape = [1, 3, 640, 640] mock_input.name = "images" mock_network.get_input.return_value = mock_input mock_builder.create_network.return_value.__enter__ = MagicMock(return_value=mock_network) mock_builder.create_network.return_value.__exit__ = MagicMock(return_value=False) mock_builder.create_builder_config.return_value.__enter__ = MagicMock(return_value=mock_config) mock_builder.create_builder_config.return_value.__exit__ = MagicMock(return_value=False) mock_builder.__enter__ = MagicMock(return_value=mock_builder) mock_builder.__exit__ = MagicMock(return_value=False) mock_trt.Builder.return_value = mock_builder mock_onnx_parser = MagicMock() mock_onnx_parser.__enter__ = MagicMock(return_value=mock_parser) mock_onnx_parser.__exit__ = MagicMock(return_value=False) mock_trt.OnnxParser.return_value = mock_onnx_parser mock_trt.IInt8EntropyCalibrator2 = object mock_builder.build_serialized_network.return_value = b"engine_bytes" return mock_trt, mock_builder, mock_config @requires_tensorrt def test_convert_from_source_uses_int8_when_cache_provided(): # Arrange from engines.tensorrt_engine import TensorRTEngine import engines.tensorrt_engine as trt_mod mock_trt, mock_builder, mock_config = _make_mock_trt() with tempfile.NamedTemporaryFile(suffix=".cache", delete=False) as f: f.write(b"calibration_cache_data") cache_path = f.name try: with patch.object(trt_mod, "trt", mock_trt), \ patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3): # Act TensorRTEngine.convert_from_source(b"onnx_model", cache_path) # Assert mock_config.set_flag.assert_any_call("INT8") mock_config.set_flag.assert_any_call("FP16") assert mock_config.int8_calibrator is not None finally: os.unlink(cache_path) @requires_tensorrt def test_convert_from_source_uses_fp16_when_no_cache(): # Arrange from engines.tensorrt_engine import TensorRTEngine import engines.tensorrt_engine as trt_mod mock_trt, mock_builder, mock_config = _make_mock_trt() with patch.object(trt_mod, "trt", mock_trt), \ patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3): # Act TensorRTEngine.convert_from_source(b"onnx_model", None) # Assert mock_config.set_flag.assert_any_call("FP16") int8_calls = [c for c in mock_config.set_flag.call_args_list if c == call("INT8")] assert len(int8_calls) == 0 @requires_tensorrt def test_trt_factory_build_from_source_uses_fp16(): # Arrange from engines.engine_factory import TensorRTEngineFactory from engines.tensorrt_engine import TensorRTEngine import engines.tensorrt_engine as trt_mod mock_trt, mock_builder, mock_config = _make_mock_trt() factory = TensorRTEngineFactory() with patch.object(trt_mod, "trt", mock_trt), \ patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3): # Act engine_bytes, filename = factory.build_from_source(b"onnx", MagicMock(), "models") # Assert assert engine_bytes == b"engine_bytes" assert filename is not None int8_calls = [c for c in mock_config.set_flag.call_args_list if c == call("INT8")] assert len(int8_calls) == 0 @requires_tensorrt def test_jetson_factory_build_from_source_uses_int8_when_cache_available(): # Arrange from engines.engine_factory import JetsonTensorRTEngineFactory from engines.tensorrt_engine import TensorRTEngine import engines.tensorrt_engine as trt_mod mock_trt, mock_builder, mock_config = _make_mock_trt() factory = JetsonTensorRTEngineFactory() loader = MagicMock() result = MagicMock() result.err = None result.data = b"calib_data" loader.load_big_small_resource.return_value = result with patch.object(trt_mod, "trt", mock_trt), \ patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3): # Act engine_bytes, filename = factory.build_from_source(b"onnx", loader, "models") # Assert assert engine_bytes == b"engine_bytes" assert "int8" in filename mock_config.set_flag.assert_any_call("INT8") @requires_tensorrt def test_jetson_factory_build_from_source_falls_back_to_fp16_when_no_cache(): # Arrange from engines.engine_factory import JetsonTensorRTEngineFactory from engines.tensorrt_engine import TensorRTEngine import engines.tensorrt_engine as trt_mod mock_trt, mock_builder, mock_config = _make_mock_trt() factory = JetsonTensorRTEngineFactory() loader = MagicMock() result = MagicMock() result.err = "not found" loader.load_big_small_resource.return_value = result with patch.object(trt_mod, "trt", mock_trt), \ patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3): # Act engine_bytes, filename = factory.build_from_source(b"onnx", loader, "models") # Assert assert engine_bytes == b"engine_bytes" int8_calls = [c for c in mock_config.set_flag.call_args_list if c == call("INT8")] assert len(int8_calls) == 0 @requires_tensorrt def test_jetson_factory_cleans_up_cache_tempfile_after_build(): # Arrange from engines.engine_factory import JetsonTensorRTEngineFactory from engines.tensorrt_engine import TensorRTEngine import engines.tensorrt_engine as trt_mod mock_trt, _, _ = _make_mock_trt() factory = JetsonTensorRTEngineFactory() loader = MagicMock() result = MagicMock() result.err = None result.data = b"calib_data" loader.load_big_small_resource.return_value = result written_paths = [] original_download = factory._download_calib_cache def tracking_download(lc, md): path = original_download(lc, md) if path: written_paths.append(path) return path with patch.object(trt_mod, "trt", mock_trt), \ patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3), \ patch.object(factory, "_download_calib_cache", side_effect=tracking_download): factory.build_from_source(b"onnx", loader, "models") # Assert: temp file was deleted after build for p in written_paths: assert not os.path.exists(p) def test_is_jetson_false_on_non_aarch64(): # Arrange import engines as eng with patch("engines.platform") as mock_platform, \ patch("engines.tensor_gpu_index", 0), \ patch("engines.os.path.isfile", return_value=True): mock_platform.machine.return_value = "x86_64" # Assert assert eng._is_jetson() is False def test_is_jetson_false_when_no_gpu(): # Arrange import engines as eng with patch("engines.platform") as mock_platform, \ patch("engines.tensor_gpu_index", -1), \ patch("engines.os.path.isfile", return_value=True): mock_platform.machine.return_value = "aarch64" # Assert assert eng._is_jetson() is False def test_is_jetson_false_when_no_tegra_release(): # Arrange import engines as eng with patch("engines.platform") as mock_platform, \ patch("engines.tensor_gpu_index", 0), \ patch("engines.os.path.isfile", return_value=False): mock_platform.machine.return_value = "aarch64" # Assert assert eng._is_jetson() is False