mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-04-23 08:36:37 +00:00
127 lines
4.8 KiB
Python
127 lines
4.8 KiB
Python
import pytest
|
|
import os
|
|
import time
|
|
import tempfile
|
|
import shutil
|
|
from unittest.mock import Mock, patch
|
|
|
|
from f16_model_manager import ModelManager, TensorRTInferenceEngine, ONNXInferenceEngine
|
|
|
|
@pytest.fixture
|
|
def models_dir():
|
|
temp_dir = tempfile.mkdtemp()
|
|
yield temp_dir
|
|
shutil.rmtree(temp_dir)
|
|
|
|
@pytest.fixture
|
|
def mm(models_dir):
|
|
return ModelManager(models_dir=models_dir)
|
|
|
|
class TestModelManager:
|
|
|
|
# --- 16.01 Feature: Model Lifecycle Management ---
|
|
|
|
def test_load_model_tensorrt_mock(self, mm, models_dir):
|
|
# Create dummy file to bypass file existence checks
|
|
dummy_path = os.path.join(models_dir, "superpoint.engine")
|
|
with open(dummy_path, "wb") as f: f.write(b"data")
|
|
|
|
assert mm.load_model("SuperPoint", "tensorrt", dummy_path) is True
|
|
engine = mm.get_inference_engine("SuperPoint")
|
|
assert engine is not None
|
|
assert engine.format == "tensorrt"
|
|
|
|
def test_load_model_onnx_mock(self, mm, models_dir):
|
|
dummy_path = os.path.join(models_dir, "lightglue.onnx")
|
|
with open(dummy_path, "wb") as f: f.write(b"data")
|
|
|
|
assert mm.load_model("LightGlue", "onnx", dummy_path) is True
|
|
engine = mm.get_inference_engine("LightGlue")
|
|
assert isinstance(engine, ONNXInferenceEngine)
|
|
|
|
def test_load_already_cached_model(self, mm, models_dir):
|
|
dummy_path = os.path.join(models_dir, "dinov2.onnx")
|
|
with open(dummy_path, "wb") as f: f.write(b"data")
|
|
|
|
mm.load_model("DINOv2", "onnx", dummy_path)
|
|
|
|
# Second load should return True immediately (cache hit)
|
|
with patch.object(mm, 'warmup_model') as mock_warmup:
|
|
assert mm.load_model("DINOv2", "onnx", dummy_path) is True
|
|
mock_warmup.assert_not_called()
|
|
|
|
def test_load_invalid_format(self, mm):
|
|
assert mm.load_model("SuperPoint", "pytorch_pth") is False
|
|
|
|
def test_warmup_model(self, mm, models_dir):
|
|
dummy_path = os.path.join(models_dir, "litesam.onnx")
|
|
with open(dummy_path, "wb") as f: f.write(b"data")
|
|
|
|
mm.load_model("LiteSAM", "onnx", dummy_path)
|
|
|
|
# Replace engine's infer with a mock to track calls
|
|
engine = mm.get_inference_engine("LiteSAM")
|
|
engine.infer = Mock()
|
|
|
|
assert mm.warmup_model("LiteSAM") is True
|
|
assert engine.infer.call_count == 3
|
|
|
|
def test_warmup_unloaded_model(self, mm):
|
|
assert mm.warmup_model("MissingModel") is False
|
|
|
|
# --- 16.02 Feature: Inference Engine Provisioning ---
|
|
|
|
def test_get_inference_engine_missing(self, mm):
|
|
assert mm.get_inference_engine("SuperPoint") is None
|
|
|
|
def test_optimize_to_tensorrt(self, mm, models_dir):
|
|
onnx_path = os.path.join(models_dir, "test.onnx")
|
|
with open(onnx_path, "wb") as f: f.write(b"data")
|
|
|
|
trt_path = mm.optimize_to_tensorrt("TestModel", onnx_path)
|
|
|
|
assert trt_path != ""
|
|
assert os.path.exists(trt_path)
|
|
assert trt_path.endswith(".engine")
|
|
|
|
def test_optimize_to_tensorrt_missing_source(self, mm):
|
|
assert mm.optimize_to_tensorrt("Missing", "does_not_exist.onnx") == ""
|
|
|
|
def test_fallback_to_onnx_triggered_on_trt_failure(self, mm, models_dir):
|
|
# Ensure TensorRT file doesn't exist, which triggers fallback
|
|
onnx_path = os.path.join(models_dir, "superpoint.onnx")
|
|
with open(onnx_path, "wb") as f: f.write(b"data")
|
|
|
|
# Explicitly mock TRT_AVAILABLE as False
|
|
import f16_model_manager
|
|
f16_model_manager.TRT_AVAILABLE = False
|
|
|
|
# Load TRT -> Fails -> Falls back to ONNX
|
|
assert mm.load_model("SuperPoint", "tensorrt") is True
|
|
|
|
engine = mm.get_inference_engine("SuperPoint")
|
|
assert engine.format == "onnx"
|
|
|
|
def test_inference_performance_latency_mock(self, mm, models_dir):
|
|
dummy_path = os.path.join(models_dir, "test.onnx")
|
|
with open(dummy_path, "wb") as f: f.write(b"data")
|
|
mm.load_model("TestPerf", "onnx", dummy_path)
|
|
|
|
engine = mm.get_inference_engine("TestPerf")
|
|
start = time.time()
|
|
engine.infer()
|
|
duration = time.time() - start
|
|
|
|
# Simulated latency shouldn't exceed 100ms heavily in mock
|
|
assert duration < 0.100
|
|
|
|
def test_cold_start_all_models(self, mm, models_dir):
|
|
# Mock paths for all 4 models
|
|
for m in ["superpoint.engine", "lightglue.engine", "dinov2.engine", "litesam.engine"]:
|
|
with open(os.path.join(models_dir, m), "wb") as f: f.write(b"data")
|
|
|
|
assert mm.initialize_models() is True
|
|
assert mm.get_inference_engine("SuperPoint") is not None
|
|
assert mm.get_inference_engine("LightGlue") is not None
|
|
assert mm.get_inference_engine("DINOv2") is not None
|
|
assert mm.get_inference_engine("LiteSAM") is not None |