import os import tempfile from unittest.mock import MagicMock, call, patch import pytest try: import tensorrt # noqa: F401 import pycuda.driver # noqa: F401 HAS_TENSORRT = True except ImportError: HAS_TENSORRT = False requires_tensorrt = pytest.mark.skipif( not HAS_TENSORRT, reason="TensorRT and PyCUDA required (GPU / Jetson environment)", ) def _make_mock_trt(): mock_trt = MagicMock() mock_trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH = 0 mock_trt.Logger.WARNING = "WARNING" mock_trt.MemoryPoolType.WORKSPACE = "WORKSPACE" mock_trt.BuilderFlag.INT8 = "INT8" mock_trt.BuilderFlag.FP16 = "FP16" mock_builder = MagicMock() mock_builder.platform_has_fast_fp16 = True mock_config = MagicMock() mock_network = MagicMock() mock_parser = MagicMock() mock_parser.parse.return_value = True mock_input = MagicMock() mock_input.shape = [1, 3, 640, 640] mock_input.name = "images" mock_network.get_input.return_value = mock_input mock_builder.create_network.return_value.__enter__ = MagicMock(return_value=mock_network) mock_builder.create_network.return_value.__exit__ = MagicMock(return_value=False) mock_builder.create_builder_config.return_value.__enter__ = MagicMock(return_value=mock_config) mock_builder.create_builder_config.return_value.__exit__ = MagicMock(return_value=False) mock_builder.__enter__ = MagicMock(return_value=mock_builder) mock_builder.__exit__ = MagicMock(return_value=False) mock_trt.Builder.return_value = mock_builder mock_onnx_parser = MagicMock() mock_onnx_parser.__enter__ = MagicMock(return_value=mock_parser) mock_onnx_parser.__exit__ = MagicMock(return_value=False) mock_trt.OnnxParser.return_value = mock_onnx_parser mock_trt.IInt8EntropyCalibrator2 = object mock_builder.build_serialized_network.return_value = b"engine_bytes" return mock_trt, mock_builder, mock_config @requires_tensorrt def test_convert_from_source_uses_int8_when_cache_provided(): # Arrange from engines.tensorrt_engine import TensorRTEngine import engines.tensorrt_engine as trt_mod mock_trt, mock_builder, mock_config = _make_mock_trt() with tempfile.NamedTemporaryFile(suffix=".cache", delete=False) as f: f.write(b"calibration_cache_data") cache_path = f.name try: with patch.object(trt_mod, "trt", mock_trt), \ patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3): # Act TensorRTEngine.convert_from_source(b"onnx_model", cache_path) # Assert mock_config.set_flag.assert_any_call("INT8") assert mock_config.int8_calibrator is not None finally: os.unlink(cache_path) @requires_tensorrt def test_convert_from_source_uses_fp16_when_no_cache(): # Arrange from engines.tensorrt_engine import TensorRTEngine import engines.tensorrt_engine as trt_mod mock_trt, mock_builder, mock_config = _make_mock_trt() with patch.object(trt_mod, "trt", mock_trt), \ patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3): # Act TensorRTEngine.convert_from_source(b"onnx_model", None) # Assert mock_config.set_flag.assert_any_call("FP16") int8_calls = [c for c in mock_config.set_flag.call_args_list if c == call("INT8")] assert len(int8_calls) == 0