mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 08:56:32 +00:00
2149cd6c08
- Dockerfile.jetson: JetPack 6.x L4T base image (aarch64), TensorRT and PyCUDA from apt - requirements-jetson.txt: derived from requirements.txt, no pip tensorrt/pycuda - docker-compose.jetson.yml: runtime: nvidia for NVIDIA Container Runtime - tensorrt_engine.pyx: convert_from_source accepts optional calib_cache_path; INT8 used when cache present, FP16 fallback; get_engine_filename encodes precision suffix to avoid engine cache confusion - inference.pyx: init_ai tries INT8 engine then FP16 on lookup; downloads calibration cache before conversion thread; passes cache path through to convert_from_source - constants_inf: add INT8_CALIB_CACHE_FILE constant - Unit tests for AC-3 (INT8 flag set when cache provided) and AC-4 (FP16 when no cache) Made-with: Cursor
98 lines
3.4 KiB
Python
98 lines
3.4 KiB
Python
import os
|
|
import tempfile
|
|
from unittest.mock import MagicMock, call, patch
|
|
|
|
import pytest
|
|
|
|
try:
|
|
import tensorrt # noqa: F401
|
|
import pycuda.driver # noqa: F401
|
|
HAS_TENSORRT = True
|
|
except ImportError:
|
|
HAS_TENSORRT = False
|
|
|
|
requires_tensorrt = pytest.mark.skipif(
|
|
not HAS_TENSORRT,
|
|
reason="TensorRT and PyCUDA required (GPU / Jetson environment)",
|
|
)
|
|
|
|
|
|
def _make_mock_trt():
|
|
mock_trt = MagicMock()
|
|
mock_trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH = 0
|
|
mock_trt.Logger.WARNING = "WARNING"
|
|
mock_trt.MemoryPoolType.WORKSPACE = "WORKSPACE"
|
|
mock_trt.BuilderFlag.INT8 = "INT8"
|
|
mock_trt.BuilderFlag.FP16 = "FP16"
|
|
|
|
mock_builder = MagicMock()
|
|
mock_builder.platform_has_fast_fp16 = True
|
|
mock_config = MagicMock()
|
|
mock_network = MagicMock()
|
|
mock_parser = MagicMock()
|
|
mock_parser.parse.return_value = True
|
|
mock_input = MagicMock()
|
|
mock_input.shape = [1, 3, 640, 640]
|
|
mock_input.name = "images"
|
|
mock_network.get_input.return_value = mock_input
|
|
mock_builder.create_network.return_value.__enter__ = MagicMock(return_value=mock_network)
|
|
mock_builder.create_network.return_value.__exit__ = MagicMock(return_value=False)
|
|
mock_builder.create_builder_config.return_value.__enter__ = MagicMock(return_value=mock_config)
|
|
mock_builder.create_builder_config.return_value.__exit__ = MagicMock(return_value=False)
|
|
mock_builder.__enter__ = MagicMock(return_value=mock_builder)
|
|
mock_builder.__exit__ = MagicMock(return_value=False)
|
|
mock_trt.Builder.return_value = mock_builder
|
|
|
|
mock_onnx_parser = MagicMock()
|
|
mock_onnx_parser.__enter__ = MagicMock(return_value=mock_parser)
|
|
mock_onnx_parser.__exit__ = MagicMock(return_value=False)
|
|
mock_trt.OnnxParser.return_value = mock_onnx_parser
|
|
|
|
mock_trt.IInt8EntropyCalibrator2 = object
|
|
mock_builder.build_serialized_network.return_value = b"engine_bytes"
|
|
|
|
return mock_trt, mock_builder, mock_config
|
|
|
|
|
|
@requires_tensorrt
|
|
def test_convert_from_source_uses_int8_when_cache_provided():
|
|
# Arrange
|
|
from engines.tensorrt_engine import TensorRTEngine
|
|
import engines.tensorrt_engine as trt_mod
|
|
|
|
mock_trt, mock_builder, mock_config = _make_mock_trt()
|
|
with tempfile.NamedTemporaryFile(suffix=".cache", delete=False) as f:
|
|
f.write(b"calibration_cache_data")
|
|
cache_path = f.name
|
|
|
|
try:
|
|
with patch.object(trt_mod, "trt", mock_trt), \
|
|
patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3):
|
|
# Act
|
|
TensorRTEngine.convert_from_source(b"onnx_model", cache_path)
|
|
|
|
# Assert
|
|
mock_config.set_flag.assert_any_call("INT8")
|
|
assert mock_config.int8_calibrator is not None
|
|
finally:
|
|
os.unlink(cache_path)
|
|
|
|
|
|
@requires_tensorrt
|
|
def test_convert_from_source_uses_fp16_when_no_cache():
|
|
# Arrange
|
|
from engines.tensorrt_engine import TensorRTEngine
|
|
import engines.tensorrt_engine as trt_mod
|
|
|
|
mock_trt, mock_builder, mock_config = _make_mock_trt()
|
|
|
|
with patch.object(trt_mod, "trt", mock_trt), \
|
|
patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3):
|
|
# Act
|
|
TensorRTEngine.convert_from_source(b"onnx_model", None)
|
|
|
|
# Assert
|
|
mock_config.set_flag.assert_any_call("FP16")
|
|
int8_calls = [c for c in mock_config.set_flag.call_args_list if c == call("INT8")]
|
|
assert len(int8_calls) == 0
|