mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 07:06:32 +00:00
[AZ-180] Refactor detection event handling and improve SSE support
- Updated the detection image endpoint to require a channel ID for event streaming. - Introduced a new endpoint for streaming detection events, allowing clients to receive real-time updates. - Enhanced the internal buffering mechanism for detection events to manage multiple channels. - Refactored the inference module to support the new event handling structure. Made-with: Cursor
This commit is contained in:
@@ -419,7 +419,10 @@ class TestDetectVideoEndpoint:
|
||||
from fastapi.testclient import TestClient
|
||||
client = TestClient(main.app)
|
||||
token = _access_jwt()
|
||||
with patch.object(main, "get_inference", return_value=_CaptureInf()):
|
||||
with (
|
||||
patch.object(main, "JWT_SECRET", _TEST_JWT_SECRET),
|
||||
patch.object(main, "get_inference", return_value=_CaptureInf()),
|
||||
):
|
||||
# Act
|
||||
r = client.post(
|
||||
"/detect/video",
|
||||
|
||||
@@ -96,3 +96,149 @@ def test_convert_from_source_uses_fp16_when_no_cache():
|
||||
mock_config.set_flag.assert_any_call("FP16")
|
||||
int8_calls = [c for c in mock_config.set_flag.call_args_list if c == call("INT8")]
|
||||
assert len(int8_calls) == 0
|
||||
|
||||
|
||||
@requires_tensorrt
|
||||
def test_trt_factory_build_from_source_uses_fp16():
|
||||
# Arrange
|
||||
from engines.engine_factory import TensorRTEngineFactory
|
||||
from engines.tensorrt_engine import TensorRTEngine
|
||||
import engines.tensorrt_engine as trt_mod
|
||||
|
||||
mock_trt, mock_builder, mock_config = _make_mock_trt()
|
||||
factory = TensorRTEngineFactory()
|
||||
|
||||
with patch.object(trt_mod, "trt", mock_trt), \
|
||||
patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3):
|
||||
# Act
|
||||
engine_bytes, filename = factory.build_from_source(b"onnx", MagicMock(), "models")
|
||||
|
||||
# Assert
|
||||
assert engine_bytes == b"engine_bytes"
|
||||
assert filename is not None
|
||||
int8_calls = [c for c in mock_config.set_flag.call_args_list if c == call("INT8")]
|
||||
assert len(int8_calls) == 0
|
||||
|
||||
|
||||
@requires_tensorrt
|
||||
def test_jetson_factory_build_from_source_uses_int8_when_cache_available():
|
||||
# Arrange
|
||||
from engines.engine_factory import JetsonTensorRTEngineFactory
|
||||
from engines.tensorrt_engine import TensorRTEngine
|
||||
import engines.tensorrt_engine as trt_mod
|
||||
|
||||
mock_trt, mock_builder, mock_config = _make_mock_trt()
|
||||
factory = JetsonTensorRTEngineFactory()
|
||||
loader = MagicMock()
|
||||
result = MagicMock()
|
||||
result.err = None
|
||||
result.data = b"calib_data"
|
||||
loader.load_big_small_resource.return_value = result
|
||||
|
||||
with patch.object(trt_mod, "trt", mock_trt), \
|
||||
patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3):
|
||||
# Act
|
||||
engine_bytes, filename = factory.build_from_source(b"onnx", loader, "models")
|
||||
|
||||
# Assert
|
||||
assert engine_bytes == b"engine_bytes"
|
||||
assert "int8" in filename
|
||||
mock_config.set_flag.assert_any_call("INT8")
|
||||
|
||||
|
||||
@requires_tensorrt
|
||||
def test_jetson_factory_build_from_source_falls_back_to_fp16_when_no_cache():
|
||||
# Arrange
|
||||
from engines.engine_factory import JetsonTensorRTEngineFactory
|
||||
from engines.tensorrt_engine import TensorRTEngine
|
||||
import engines.tensorrt_engine as trt_mod
|
||||
|
||||
mock_trt, mock_builder, mock_config = _make_mock_trt()
|
||||
factory = JetsonTensorRTEngineFactory()
|
||||
loader = MagicMock()
|
||||
result = MagicMock()
|
||||
result.err = "not found"
|
||||
loader.load_big_small_resource.return_value = result
|
||||
|
||||
with patch.object(trt_mod, "trt", mock_trt), \
|
||||
patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3):
|
||||
# Act
|
||||
engine_bytes, filename = factory.build_from_source(b"onnx", loader, "models")
|
||||
|
||||
# Assert
|
||||
assert engine_bytes == b"engine_bytes"
|
||||
int8_calls = [c for c in mock_config.set_flag.call_args_list if c == call("INT8")]
|
||||
assert len(int8_calls) == 0
|
||||
|
||||
|
||||
@requires_tensorrt
|
||||
def test_jetson_factory_cleans_up_cache_tempfile_after_build():
|
||||
# Arrange
|
||||
from engines.engine_factory import JetsonTensorRTEngineFactory
|
||||
from engines.tensorrt_engine import TensorRTEngine
|
||||
import engines.tensorrt_engine as trt_mod
|
||||
|
||||
mock_trt, _, _ = _make_mock_trt()
|
||||
factory = JetsonTensorRTEngineFactory()
|
||||
loader = MagicMock()
|
||||
result = MagicMock()
|
||||
result.err = None
|
||||
result.data = b"calib_data"
|
||||
loader.load_big_small_resource.return_value = result
|
||||
|
||||
written_paths = []
|
||||
original_download = factory._download_calib_cache
|
||||
|
||||
def tracking_download(lc, md):
|
||||
path = original_download(lc, md)
|
||||
if path:
|
||||
written_paths.append(path)
|
||||
return path
|
||||
|
||||
with patch.object(trt_mod, "trt", mock_trt), \
|
||||
patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3), \
|
||||
patch.object(factory, "_download_calib_cache", side_effect=tracking_download):
|
||||
factory.build_from_source(b"onnx", loader, "models")
|
||||
|
||||
# Assert: temp file was deleted after build
|
||||
for p in written_paths:
|
||||
assert not os.path.exists(p)
|
||||
|
||||
|
||||
def test_is_jetson_false_on_non_aarch64():
|
||||
# Arrange
|
||||
import engines as eng
|
||||
|
||||
with patch("engines.platform") as mock_platform, \
|
||||
patch("engines.tensor_gpu_index", 0), \
|
||||
patch("engines.os.path.isfile", return_value=True):
|
||||
mock_platform.machine.return_value = "x86_64"
|
||||
|
||||
# Assert
|
||||
assert eng._is_jetson() is False
|
||||
|
||||
|
||||
def test_is_jetson_false_when_no_gpu():
|
||||
# Arrange
|
||||
import engines as eng
|
||||
|
||||
with patch("engines.platform") as mock_platform, \
|
||||
patch("engines.tensor_gpu_index", -1), \
|
||||
patch("engines.os.path.isfile", return_value=True):
|
||||
mock_platform.machine.return_value = "aarch64"
|
||||
|
||||
# Assert
|
||||
assert eng._is_jetson() is False
|
||||
|
||||
|
||||
def test_is_jetson_false_when_no_tegra_release():
|
||||
# Arrange
|
||||
import engines as eng
|
||||
|
||||
with patch("engines.platform") as mock_platform, \
|
||||
patch("engines.tensor_gpu_index", 0), \
|
||||
patch("engines.os.path.isfile", return_value=False):
|
||||
mock_platform.machine.return_value = "aarch64"
|
||||
|
||||
# Assert
|
||||
assert eng._is_jetson() is False
|
||||
|
||||
Reference in New Issue
Block a user