"""AZ-297 — C7 inference runtime Protocol + DTO + error + factory conformance. Covers all 8 ACs of AZ-297 plus NFR-perf-factory and NFR-reliability-error-family. The factory ACs (AC-4 / AC-5) substitute fake strategy modules at ``sys.modules`` boundaries so the test never touches TensorRT, ONNX Runtime, or PyTorch. """ from __future__ import annotations import dataclasses import re import sys import time import types from pathlib import Path import pytest from gps_denied_onboard.components.c7_inference import ( BuildConfig, C7InferenceConfig, CalibrationCacheError, EngineBuildError, EngineCacheEntry, EngineDeserializeError, EngineHandle, EngineHashMismatchError, EngineSchemaMismatchError, EngineSidecarMissingError, InferenceError, InferenceRuntime, OptimizationProfile, OutOfMemoryError, PrecisionMode, TelemetryUnavailableError, ThermalState, ) from gps_denied_onboard.components.c7_inference import ( RuntimeError as C7RuntimeError, ) from gps_denied_onboard.components.c7_inference.config import KNOWN_RUNTIMES from gps_denied_onboard.config.schema import Config, ConfigError from gps_denied_onboard.runtime_root.errors import RuntimeNotAvailableError from gps_denied_onboard.runtime_root.inference_factory import ( build_inference_runtime, ) _CONTRACT_PATH = ( Path(__file__).resolve().parents[3] / "_docs/02_document/contracts/c7_inference/inference_runtime_protocol.md" ) _STRATEGY_MODULES: dict[str, tuple[str, str, str]] = { "tensorrt": ( "gps_denied_onboard.components.c7_inference.tensorrt_runtime", "TensorrtRuntime", "BUILD_TENSORRT_RUNTIME", ), "onnx_trt_ep": ( "gps_denied_onboard.components.c7_inference.onnx_trt_ep_runtime", "OnnxTrtEpRuntime", "BUILD_ONNX_TRT_EP_RUNTIME", ), "pytorch_fp16": ( "gps_denied_onboard.components.c7_inference.pytorch_fp16_runtime", "PytorchFp16Runtime", "BUILD_PYTORCH_FP16_RUNTIME", ), } # ---------------------------------------------------------------------- # Fakes that structurally satisfy the InferenceRuntime Protocol. class _FullInferenceRuntime: def __init__(self, config: Config) -> None: self.config = config self._label = config.components["c7_inference"].runtime def compile_engine(self, model_path, build_config): raise NotImplementedError def deserialize_engine(self, entry): raise NotImplementedError def infer(self, handle, inputs): raise NotImplementedError def release_engine(self, handle): return None def thermal_state(self): return ThermalState( cpu_temp_c=None, gpu_temp_c=None, thermal_throttle_active=False, measured_clock_mhz=None, measured_at_ns=0, is_telemetry_available=False, ) def current_runtime_label(self): return self._label class _PartialInferenceRuntime: def compile_engine(self, model_path, build_config): raise NotImplementedError def deserialize_engine(self, entry): raise NotImplementedError def infer(self, handle, inputs): raise NotImplementedError def release_engine(self, handle): return None def thermal_state(self): raise NotImplementedError def _config_with_runtime(runtime: str) -> Config: return Config.with_blocks( c7_inference=C7InferenceConfig(runtime=runtime) ) def _install_fake_strategy(runtime_label: str) -> type: module_name, class_name, _flag = _STRATEGY_MODULES[runtime_label] class _FakeStrategy(_FullInferenceRuntime): pass _FakeStrategy.__name__ = class_name module = types.ModuleType(module_name) setattr(module, class_name, _FakeStrategy) sys.modules[module_name] = module return _FakeStrategy @pytest.fixture def strategy_module_cleanup(): """Pop every fake strategy module before/after each factory test.""" for module_name, _, _ in _STRATEGY_MODULES.values(): sys.modules.pop(module_name, None) yield for module_name, _, _ in _STRATEGY_MODULES.values(): sys.modules.pop(module_name, None) # ---------------------------------------------------------------------- # AC-1: Protocol is conformance-checkable. def test_ac1_inference_runtime_conformance_full() -> None: instance = _FullInferenceRuntime(_config_with_runtime("pytorch_fp16")) assert isinstance(instance, InferenceRuntime) def test_ac1_inference_runtime_conformance_partial_missing_label() -> None: assert not isinstance(_PartialInferenceRuntime(), InferenceRuntime) # ---------------------------------------------------------------------- # AC-2: frozen DTOs reject mutation. @pytest.mark.parametrize( "dto, field_name, new_value", [ ( BuildConfig( precision=PrecisionMode.FP16, workspace_mb=512, calibration_dataset=None, optimization_profiles=(), ), "precision", PrecisionMode.INT8, ), ( EngineCacheEntry( engine_path=Path("/var/lib/x.engine"), sha256_hex="a" * 64, sm=87, jp="6.2", trt="10.3", precision=PrecisionMode.FP16, extras={}, ), "sha256_hex", "b" * 64, ), ( ThermalState( cpu_temp_c=40.0, gpu_temp_c=45.0, thermal_throttle_active=False, measured_clock_mhz=918, measured_at_ns=1_000_000, is_telemetry_available=True, ), "thermal_throttle_active", True, ), ( OptimizationProfile( input_name="input", min_shape=(1, 3, 224, 224), opt_shape=(1, 3, 384, 384), max_shape=(1, 3, 512, 512), ), "input_name", "renamed", ), ], ) def test_ac2_frozen_dtos_reject_mutation(dto, field_name: str, new_value) -> None: original_value = getattr(dto, field_name) with pytest.raises(dataclasses.FrozenInstanceError): setattr(dto, field_name, new_value) assert getattr(dto, field_name) == original_value # ---------------------------------------------------------------------- # AC-3: error hierarchy catchable as a single family. @pytest.mark.parametrize( "exc_factory", [ EngineBuildError, EngineDeserializeError, EngineHashMismatchError, EngineSchemaMismatchError, EngineSidecarMissingError, CalibrationCacheError, InferenceError, OutOfMemoryError, TelemetryUnavailableError, ], ) def test_ac3_all_runtime_errors_caught_as_family(exc_factory) -> None: with pytest.raises(C7RuntimeError): raise exc_factory("boom") def test_ac3_unrelated_exception_not_caught_as_family() -> None: with pytest.raises(ValueError): try: raise ValueError("not us") except C7RuntimeError: pytest.fail("ValueError must not be caught as c7 RuntimeError") def test_ac3_runtime_not_available_outside_family() -> None: with pytest.raises(RuntimeNotAvailableError): try: raise RuntimeNotAvailableError("composition-time") except C7RuntimeError: pytest.fail( "RuntimeNotAvailableError is a composition-root error and " "MUST NOT be in the c7 runtime family" ) # ---------------------------------------------------------------------- # AC-4 + AC-5: factory honours config + BUILD flag gate. @pytest.mark.parametrize("runtime", sorted(_STRATEGY_MODULES)) def test_ac4_build_inference_runtime_returns_protocol_impl( monkeypatch, strategy_module_cleanup, runtime ) -> None: _, _, flag = _STRATEGY_MODULES[runtime] monkeypatch.setenv(flag, "ON") fake_cls = _install_fake_strategy(runtime) config = _config_with_runtime(runtime) instance = build_inference_runtime(config) assert isinstance(instance, fake_cls) assert isinstance(instance, InferenceRuntime) @pytest.mark.parametrize("runtime", sorted(_STRATEGY_MODULES)) def test_ac5_build_inference_runtime_flag_off_no_import( monkeypatch, strategy_module_cleanup, runtime ) -> None: module_name, _, flag = _STRATEGY_MODULES[runtime] monkeypatch.delenv(flag, raising=False) config = _config_with_runtime(runtime) with pytest.raises(RuntimeNotAvailableError) as exc_info: build_inference_runtime(config) assert runtime in str(exc_info.value) assert flag in str(exc_info.value) assert module_name not in sys.modules @pytest.mark.parametrize( "runtime", sorted(rt for rt in _STRATEGY_MODULES if rt != "pytorch_fp16"), ) def test_ac5_build_inference_runtime_flag_on_but_module_missing( monkeypatch, strategy_module_cleanup, runtime ) -> None: """``BUILD_*=ON`` but the strategy module hasn't been written yet. ``pytorch_fp16`` is excluded because AZ-300 shipped its concrete module — the corresponding case is covered by ``test_pytorch_fp16_runtime.test_ac1_protocol_conformance`` which constructs the real strategy. The TRT / ORT runtimes (AZ-298 / AZ-299) remain pending; this test still guards their factory path. """ _, _, flag = _STRATEGY_MODULES[runtime] monkeypatch.setenv(flag, "ON") config = _config_with_runtime(runtime) with pytest.raises(RuntimeNotAvailableError) as exc_info: build_inference_runtime(config) assert runtime in str(exc_info.value) # ---------------------------------------------------------------------- # AC-6: unknown runtime label rejected at config load. @pytest.mark.parametrize( "bad_label", ["tensorflow_lite", "onnx", "trt", "TENSORRT", ""], ) def test_ac6_unknown_runtime_rejected_at_config_load(bad_label: str) -> None: with pytest.raises(ConfigError) as exc_info: C7InferenceConfig(runtime=bad_label) msg = str(exc_info.value) assert bad_label in msg or "runtime" in msg for valid in KNOWN_RUNTIMES: assert valid in msg # ---------------------------------------------------------------------- # AC-7: current_runtime_label() matches config exactly. @pytest.mark.parametrize("runtime", sorted(_STRATEGY_MODULES)) def test_ac7_current_runtime_label_matches_config( monkeypatch, strategy_module_cleanup, runtime ) -> None: _, _, flag = _STRATEGY_MODULES[runtime] monkeypatch.setenv(flag, "ON") _install_fake_strategy(runtime) config = _config_with_runtime(runtime) instance = build_inference_runtime(config) assert instance.current_runtime_label() == runtime assert instance.current_runtime_label() == config.components["c7_inference"].runtime # ---------------------------------------------------------------------- # AC-8: contract file matches Protocol shape. _METHOD_TABLE_RE = re.compile(r"^\|\s*`(?P[a-z_][a-z0-9_]*)`\s*\|", re.MULTILINE) def _methods_from_contract() -> set[str]: text = _CONTRACT_PATH.read_text(encoding="utf-8") surface_start = text.index("### Protocol surface") next_section = text.find("\n### ", surface_start + len("### Protocol surface")) section = text[surface_start:next_section] if next_section != -1 else text[surface_start:] return {m.group("name") for m in _METHOD_TABLE_RE.finditer(section)} def _protocol_methods(proto: type) -> set[str]: return { name for name in dir(proto) if not name.startswith("_") and callable(getattr(proto, name)) } def test_ac8_contract_methods_match_protocol() -> None: contract_methods = _methods_from_contract() protocol_methods = _protocol_methods(InferenceRuntime) missing_in_protocol = contract_methods - protocol_methods missing_in_contract = protocol_methods - contract_methods assert not missing_in_protocol, ( "Methods declared in inference_runtime_protocol.md Shape section " f"but missing from the Protocol: {sorted(missing_in_protocol)}" ) assert not missing_in_contract, ( "Methods present on the Protocol but missing from the contract " f"Shape section: {sorted(missing_in_contract)}" ) def test_ac8_contract_lists_all_nine_error_subtypes() -> None: text = _CONTRACT_PATH.read_text(encoding="utf-8") expected = { "EngineBuildError", "EngineDeserializeError", "EngineHashMismatchError", "EngineSchemaMismatchError", "EngineSidecarMissingError", "CalibrationCacheError", "InferenceError", "OutOfMemoryError", "TelemetryUnavailableError", } for name in expected: assert name in text, ( f"Contract file is missing the documented error subtype {name!r}" ) # ---------------------------------------------------------------------- # NFRs. @pytest.mark.parametrize( "exc_type", [ EngineBuildError, EngineDeserializeError, EngineHashMismatchError, EngineSchemaMismatchError, EngineSidecarMissingError, CalibrationCacheError, InferenceError, OutOfMemoryError, TelemetryUnavailableError, ], ) def test_nfr_reliability_all_runtime_errors_subclass_family(exc_type) -> None: assert issubclass(exc_type, C7RuntimeError) def test_nfr_reliability_runtime_not_available_not_in_family() -> None: assert not issubclass(RuntimeNotAvailableError, C7RuntimeError) def test_nfr_perf_factory_under_200ms_p99( monkeypatch, strategy_module_cleanup ) -> None: """Factory p99 ≤ 200 ms across 1000 calls (NFR-perf-factory).""" runtime = "pytorch_fp16" _, _, flag = _STRATEGY_MODULES[runtime] monkeypatch.setenv(flag, "ON") _install_fake_strategy(runtime) config = _config_with_runtime(runtime) durations_ms: list[float] = [] for _ in range(1000): t0 = time.perf_counter() build_inference_runtime(config) durations_ms.append((time.perf_counter() - t0) * 1000.0) durations_ms.sort() p99 = durations_ms[int(0.99 * len(durations_ms))] assert p99 <= 200.0, ( f"build_inference_runtime() p99={p99:.3f} ms exceeds 200 ms NFR" ) # ---------------------------------------------------------------------- # Surface coverage. def test_engine_handle_is_class_not_protocol() -> None: """C7 EngineHandle is an opaque class — not a runtime_checkable Protocol. Distinguishes it from the LightGlue ``_types.manifests.EngineHandle`` Protocol (intentional dual-name design; see manifests.py docstring). """ assert isinstance(EngineHandle, type) assert not hasattr(EngineHandle, "_is_runtime_protocol") def test_c7_config_thermal_poll_hz_validation() -> None: with pytest.raises(ConfigError): C7InferenceConfig(thermal_poll_hz=0.0) with pytest.raises(ConfigError): C7InferenceConfig(thermal_poll_hz=-1.0) def test_c7_config_engine_cache_dir_validation() -> None: with pytest.raises(ConfigError): C7InferenceConfig(engine_cache_dir="") def test_precision_mode_enum_surface() -> None: assert {v.value for v in PrecisionMode} == {"fp16", "int8", "mixed"} def test_thermal_state_invariant_i6_default_safe() -> None: """When telemetry is unavailable, throttle MUST be False (Invariant I-6).""" ts = ThermalState( cpu_temp_c=None, gpu_temp_c=None, thermal_throttle_active=False, measured_clock_mhz=None, measured_at_ns=0, is_telemetry_available=False, ) assert ts.thermal_throttle_active is False assert ts.is_telemetry_available is False