"""AZ-349 — :class:`AdHoPRefiner` AC-1..AC-11 coverage. The AdHoP TRT engine is exercised via a programmable :class:`_FakeInferenceRuntime` that returns canned ``correspondences`` arrays (or raises) per call. The :class:`RansacFilter` is replaced by a programmable stub that returns canned :class:`RansacResult` instances. The fake clock is a monotonic counter that lets AC-7 / AC-8 walk through 60 s windows deterministically. """ from __future__ import annotations import logging from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path from typing import Final import numpy as np import pytest from gps_denied_onboard._types.matcher import ( CandidateMatchSet, MatchResult, ) from gps_denied_onboard._types.nav import NavCameraFrame from gps_denied_onboard.components.c3_5_adhop import ( C3_5RefinerConfig, ConditionalRefiner, ) from gps_denied_onboard.components.c3_5_adhop.adhop_refiner import ( AdHoPRefiner, create as create_adhop, ) from gps_denied_onboard.components.c3_5_adhop.errors import ( RefinerBackboneError, RefinerConfigError, ) from gps_denied_onboard.components.c3_matcher import C3MatcherConfig from gps_denied_onboard.config.schema import Config from gps_denied_onboard.fdr_client import EnqueueResult, FdrRecord from gps_denied_onboard.helpers.ransac_filter import RansacResult from gps_denied_onboard.runtime_root.refiner_factory import build_refiner_strategy _ONE_SECOND_NS: Final[int] = 1_000_000_000 # ---------------------------------------------------------------------- # Test doubles. @dataclass class _FakeClock: """Monotonic clock — advances 1 ns per ``monotonic_ns`` call.""" _t: int = 1_700_000_000_000_000_000 def monotonic_ns(self) -> int: self._t += 1 return self._t def time_ns(self) -> int: return self._t def sleep_until_ns(self, target_ns: int) -> None: # noqa: ARG002 return None def advance(self, delta_ns: int) -> None: self._t += int(delta_ns) class _CapturingFdrClient: def __init__(self) -> None: self.records: list[FdrRecord] = [] @property def producer_id(self) -> str: return "c3_5_adhop.test" def enqueue(self, record: FdrRecord) -> str: self.records.append(record) return EnqueueResult.OK def by_kind(self, kind: str) -> list[FdrRecord]: return [r for r in self.records if r.kind == kind] class _ProgrammableInferenceRuntime: def __init__(self) -> None: self._queue: list[object] = [] self.calls: int = 0 def queue_refined(self, refined: np.ndarray) -> None: self._queue.append({"correspondences": refined.astype(np.float32)}) def queue_error(self, exc: BaseException) -> None: self._queue.append(exc) def queue_bad_output(self, payload: object) -> None: self._queue.append(payload) def current_runtime_label(self) -> str: return "tensorrt" def compile_engine(self, model_path: Path, build_config) -> object: # noqa: ARG002 return _DummyEngineEntry( engine_path=model_path, sha256_hex="0" * 64, sm=87, jp="6.0", trt="10.3", precision=build_config.precision, extras={"model_name": "adhop"}, ) def deserialize_engine(self, entry) -> object: # noqa: ARG002 return object() def release_engine(self, handle) -> None: # noqa: ARG002 return None def thermal_state(self): raise NotImplementedError def infer(self, handle, inputs): # noqa: ARG002 self.calls += 1 result = self._queue.pop(0) if isinstance(result, BaseException): raise result return result @dataclass(frozen=True) class _DummyEngineEntry: engine_path: Path sha256_hex: str sm: int jp: str trt: str precision: object extras: dict[str, str] class _ProgrammableRansacFilter: def __init__(self) -> None: self._queue: list[RansacResult] = [] self.calls: int = 0 def queue(self, *, inliers: int, residual: float, outliers: int = 0) -> None: if inliers > 0: arr = np.tile( np.array([10.0, 20.0, 30.0, 40.0], dtype=np.float32), (inliers, 1) ) else: arr = np.zeros((0, 4), dtype=np.float32) self._queue.append( RansacResult( inlier_correspondences=arr, inlier_count=int(inliers), outlier_count=int(outliers), median_residual_px=float(residual), ) ) def filter_correspondences( self, corr, threshold, min_inl ): # noqa: ARG002 self.calls += 1 return self._queue.pop(0) # ---------------------------------------------------------------------- # Builders. def _make_frame(frame_id: int = 7) -> NavCameraFrame: return NavCameraFrame( frame_id=frame_id, timestamp=datetime.now(tz=timezone.utc), image=np.zeros((16, 16, 3), dtype=np.uint8), camera_calibration_id="cam0", ) def _make_candidate(*, inliers: int = 120, residual: float = 5.0) -> CandidateMatchSet: return CandidateMatchSet( tile_id=(18, 49.9, 36.3), inlier_count=inliers, inlier_correspondences=np.full((inliers, 4), 0.25, dtype=np.float32), ransac_outlier_count=3, per_candidate_residual_px=residual, ) def _make_match_result( *, frame_id: int = 7, reprojection_residual: float = 5.0, inliers: int = 120, refinement_label: str = "passthrough", ) -> MatchResult: cand = _make_candidate(inliers=inliers, residual=reprojection_residual) return MatchResult( frame_id=frame_id, per_candidate=(cand,), best_candidate_idx=0, reprojection_residual_px=reprojection_residual, matched_at=1_000_000_000, matcher_label="disk_lightglue", candidates_input=3, candidates_dropped=2, refinement_label=refinement_label, ) def _build_refiner( *, inference: _ProgrammableInferenceRuntime, ransac: _ProgrammableRansacFilter, fdr_client: _CapturingFdrClient | None, clock: _FakeClock, invocation_rate_warn_threshold: float = 0.25, ratelimited_warn_window_ns: int = 60 * _ONE_SECOND_NS, ransac_threshold_px: float = 3.0, min_inliers_threshold: int = 60, logger: logging.Logger | None = None, ) -> AdHoPRefiner: return AdHoPRefiner( inference_runtime=inference, engine_handle=object(), ransac_filter=ransac, invocation_rate_warn_threshold=invocation_rate_warn_threshold, ratelimited_warn_window_ns=ratelimited_warn_window_ns, ransac_threshold_px=ransac_threshold_px, min_inliers_threshold=min_inliers_threshold, clock=clock, fdr_client=fdr_client, logger=logger or logging.getLogger("test.c3_5_adhop"), ) # ---------------------------------------------------------------------- # AC-1: Protocol conformance. def test_ac1_protocol_conformance() -> None: refiner = _build_refiner( inference=_ProgrammableInferenceRuntime(), ransac=_ProgrammableRansacFilter(), fdr_client=None, clock=_FakeClock(), ) assert isinstance(refiner, ConditionalRefiner) # ---------------------------------------------------------------------- # AC-2: Gate inclusive semantics — `<=` is passthrough. def test_ac2_gate_inclusive_equality_is_passthrough() -> None: # Arrange fdr = _CapturingFdrClient() refiner = _build_refiner( inference=_ProgrammableInferenceRuntime(), ransac=_ProgrammableRansacFilter(), fdr_client=fdr, clock=_FakeClock(), ) mr = _make_match_result(reprojection_residual=2.5) # Act out = refiner.refine_if_needed(_make_frame(), mr, residual_threshold_px=2.5) # Assert assert out is mr assert refiner.was_invoked() is False def test_ac2_gate_above_threshold_invokes() -> None: # Arrange inf = _ProgrammableInferenceRuntime() ransac = _ProgrammableRansacFilter() fdr = _CapturingFdrClient() refined_corr = np.full((100, 4), 1.0, dtype=np.float32) inf.queue_refined(refined_corr) ransac.queue(inliers=100, residual=1.2) refiner = _build_refiner( inference=inf, ransac=ransac, fdr_client=fdr, clock=_FakeClock() ) mr = _make_match_result(reprojection_residual=2.5 + 1e-6) # Act out = refiner.refine_if_needed(_make_frame(), mr, residual_threshold_px=2.5) # Assert assert refiner.was_invoked() is True assert out.refinement_label == "adhop" # ---------------------------------------------------------------------- # AC-3: Successful refinement enriches MatchResult. def test_ac3_successful_refinement_enriches_match_result() -> None: # Arrange inf = _ProgrammableInferenceRuntime() ransac = _ProgrammableRansacFilter() fdr = _CapturingFdrClient() refined_corr = np.full((100, 4), 7.7, dtype=np.float32) inf.queue_refined(refined_corr) ransac.queue(inliers=100, residual=1.2) refiner = _build_refiner( inference=inf, ransac=ransac, fdr_client=fdr, clock=_FakeClock() ) mr = _make_match_result(reprojection_residual=5.0) # Act out = refiner.refine_if_needed(_make_frame(), mr, residual_threshold_px=2.5) # Assert assert out.refinement_label == "adhop" assert out.reprojection_residual_px == pytest.approx(1.2) assert out.refinement_added_latency_ms > 0 assert refiner.was_invoked() is True in_corr = mr.per_candidate[0].inlier_correspondences out_corr = out.per_candidate[0].inlier_correspondences assert not np.array_equal(in_corr, out_corr) # ---------------------------------------------------------------------- # AC-4: Passthrough fall-through on RefinerBackboneError. def test_ac4_refiner_backbone_error_falls_through(caplog) -> None: # Arrange — engine raises RuntimeError → maps to RefinerBackboneError. inf = _ProgrammableInferenceRuntime() ransac = _ProgrammableRansacFilter() fdr = _CapturingFdrClient() inf.queue_error(RuntimeError("simulated TRT failure")) refiner = _build_refiner( inference=inf, ransac=ransac, fdr_client=fdr, clock=_FakeClock() ) mr = _make_match_result(reprojection_residual=5.0) # Act with caplog.at_level(logging.ERROR): out = refiner.refine_if_needed(_make_frame(), mr, residual_threshold_px=2.5) # Assert assert out is mr assert out.refinement_label == "passthrough" assert refiner.was_invoked() is True error_records = [ r for r in caplog.records if r.message == "c3_5.refiner.backbone_error" ] assert len(error_records) == 1 frame_dones = fdr.by_kind("refiner.frame_done") assert len(frame_dones) == 1 assert frame_dones[0].payload.get("error") is True def test_ac4_explicit_refiner_backbone_error_falls_through() -> None: # Arrange — engine raises explicit RefinerBackboneError. inf = _ProgrammableInferenceRuntime() ransac = _ProgrammableRansacFilter() fdr = _CapturingFdrClient() inf.queue_error(RefinerBackboneError("explicit")) refiner = _build_refiner( inference=inf, ransac=ransac, fdr_client=fdr, clock=_FakeClock() ) mr = _make_match_result(reprojection_residual=5.0) # Act out = refiner.refine_if_needed(_make_frame(), mr, residual_threshold_px=2.5) # Assert assert out is mr assert refiner.was_invoked() is True # ---------------------------------------------------------------------- # AC-5: Other exception types re-raise. def test_ac5_memory_error_propagates() -> None: # Arrange inf = _ProgrammableInferenceRuntime() ransac = _ProgrammableRansacFilter() fdr = _CapturingFdrClient() inf.queue_error(MemoryError("simulated OOM")) refiner = _build_refiner( inference=inf, ransac=ransac, fdr_client=fdr, clock=_FakeClock() ) mr = _make_match_result(reprojection_residual=5.0) # Act & Assert with pytest.raises(MemoryError): refiner.refine_if_needed(_make_frame(), mr, residual_threshold_px=2.5) # ---------------------------------------------------------------------- # AC-6: Gate-decided passthrough — bit-identical correspondences. def test_ac6_gate_passthrough_correspondences_identity_preserved() -> None: # Arrange fdr = _CapturingFdrClient() refiner = _build_refiner( inference=_ProgrammableInferenceRuntime(), ransac=_ProgrammableRansacFilter(), fdr_client=fdr, clock=_FakeClock(), ) mr = _make_match_result(reprojection_residual=1.0) # Act out = refiner.refine_if_needed(_make_frame(), mr, residual_threshold_px=2.5) # Assert assert out is mr for in_c, out_c in zip(mr.per_candidate, out.per_candidate, strict=True): assert out_c.inlier_correspondences is in_c.inlier_correspondences assert out.refinement_label == "passthrough" # ---------------------------------------------------------------------- # AC-7: _invocation_window rate accuracy. def test_ac7_invocation_window_rate_accuracy() -> None: # Arrange — 30 frames at 3 Hz: 10 invoked, 20 passthrough. clock = _FakeClock() inf = _ProgrammableInferenceRuntime() ransac = _ProgrammableRansacFilter() fdr = _CapturingFdrClient() for _ in range(10): inf.queue_refined(np.full((50, 4), 0.5, dtype=np.float32)) ransac.queue(inliers=50, residual=1.5) refiner = _build_refiner( inference=inf, ransac=ransac, fdr_client=fdr, clock=clock ) mr_invoke = _make_match_result(reprojection_residual=5.0) mr_passthrough = _make_match_result(reprojection_residual=1.0) # Act — interleave (3 passthrough + 1 invoke) × 10 to land 10/30 for cycle in range(10): for _ in range(2): refiner.refine_if_needed( _make_frame(), mr_passthrough, residual_threshold_px=2.5 ) clock.advance(int(_ONE_SECOND_NS / 3)) refiner.refine_if_needed(_make_frame(), mr_invoke, residual_threshold_px=2.5) clock.advance(int(_ONE_SECOND_NS / 3)) # Assert rate = refiner._invocation_rate() # noqa: SLF001 assert rate == pytest.approx(10 / 30, abs=0.01) # ---------------------------------------------------------------------- # AC-8: Invocation-rate WARN is rate-limited. def test_ac8_invocation_rate_warn_rate_limited(caplog) -> None: # Arrange — high rate (every frame invoked) → trigger WARN. clock = _FakeClock() inf = _ProgrammableInferenceRuntime() ransac = _ProgrammableRansacFilter() fdr = _CapturingFdrClient() for _ in range(20): inf.queue_refined(np.full((50, 4), 0.5, dtype=np.float32)) ransac.queue(inliers=50, residual=1.5) refiner = _build_refiner( inference=inf, ransac=ransac, fdr_client=fdr, clock=clock, invocation_rate_warn_threshold=0.25, ratelimited_warn_window_ns=10 * _ONE_SECOND_NS, logger=logging.getLogger("test.c3_5_adhop.warn"), ) mr_invoke = _make_match_result(reprojection_residual=5.0) # Act — 20 invoked frames within 5 seconds → rate = 1.0 ≫ 0.25 with caplog.at_level(logging.WARNING, logger="test.c3_5_adhop.warn"): for _ in range(20): refiner.refine_if_needed( _make_frame(), mr_invoke, residual_threshold_px=2.5 ) clock.advance(int(_ONE_SECOND_NS / 4)) # Assert warns = [ r for r in caplog.records if r.message == "c3_5.refiner.invocation_rate_high" ] assert len(warns) == 1 def test_ac8_warn_re_fires_after_window_expires(caplog) -> None: # Arrange — high rate sustained beyond the rate-limit window. clock = _FakeClock() inf = _ProgrammableInferenceRuntime() ransac = _ProgrammableRansacFilter() fdr = _CapturingFdrClient() for _ in range(40): inf.queue_refined(np.full((50, 4), 0.5, dtype=np.float32)) ransac.queue(inliers=50, residual=1.5) refiner = _build_refiner( inference=inf, ransac=ransac, fdr_client=fdr, clock=clock, invocation_rate_warn_threshold=0.25, ratelimited_warn_window_ns=1 * _ONE_SECOND_NS, logger=logging.getLogger("test.c3_5_adhop.warn2"), ) mr_invoke = _make_match_result(reprojection_residual=5.0) # Act — 40 invokes spaced 5 s apart so the rate-limit window expires. with caplog.at_level(logging.WARNING, logger="test.c3_5_adhop.warn2"): for _ in range(40): refiner.refine_if_needed( _make_frame(), mr_invoke, residual_threshold_px=2.5 ) clock.advance(5 * _ONE_SECOND_NS) # Assert — many warns, but bounded by N frames / N seconds. warns = [ r for r in caplog.records if r.message == "c3_5.refiner.invocation_rate_high" ] assert len(warns) >= 2 # ---------------------------------------------------------------------- # AC-9: was_invoked() three-state semantics. def test_ac9_was_invoked_gate_passthrough_false() -> None: refiner = _build_refiner( inference=_ProgrammableInferenceRuntime(), ransac=_ProgrammableRansacFilter(), fdr_client=None, clock=_FakeClock(), ) refiner.refine_if_needed( _make_frame(), _make_match_result(reprojection_residual=1.0), 2.5 ) assert refiner.was_invoked() is False def test_ac9_was_invoked_adhop_success_true() -> None: inf = _ProgrammableInferenceRuntime() ransac = _ProgrammableRansacFilter() inf.queue_refined(np.full((50, 4), 0.5, dtype=np.float32)) ransac.queue(inliers=50, residual=1.5) refiner = _build_refiner( inference=inf, ransac=ransac, fdr_client=None, clock=_FakeClock() ) refiner.refine_if_needed( _make_frame(), _make_match_result(reprojection_residual=5.0), 2.5 ) assert refiner.was_invoked() is True def test_ac9_was_invoked_fallthrough_true() -> None: inf = _ProgrammableInferenceRuntime() ransac = _ProgrammableRansacFilter() inf.queue_error(RuntimeError("trt error")) refiner = _build_refiner( inference=inf, ransac=ransac, fdr_client=None, clock=_FakeClock() ) refiner.refine_if_needed( _make_frame(), _make_match_result(reprojection_residual=5.0), 2.5 ) assert refiner.was_invoked() is True # ---------------------------------------------------------------------- # AC-10: Composition-root wiring + identity-shared RansacFilter. def test_ac10_factory_wires_adhop_strategy(caplog) -> None: # Arrange inf = _ProgrammableInferenceRuntime() ransac = _ProgrammableRansacFilter() fdr = _CapturingFdrClient() clock = _FakeClock() weights_path = Path("/tmp/adhop.engine") config = Config.with_blocks( c3_5_adhop=C3_5RefinerConfig( strategy="adhop", residual_threshold_px=2.5, adhop_weights_path=weights_path ), c3_matcher=C3MatcherConfig( strategy="disk_lightglue", min_inliers_threshold=60, disk_weights_path=Path("/tmp/disk.engine"), aliked_weights_path=Path("/tmp/aliked.engine"), xfeat_weights_path=Path("/tmp/xfeat.engine"), ), ) # Act with caplog.at_level(logging.INFO, logger="gps_denied_onboard.c3_5_adhop"): instance = build_refiner_strategy( config, ransac_filter=ransac, inference_runtime=inf, clock=clock, fdr_client=fdr, ) # Assert assert isinstance(instance, AdHoPRefiner) assert isinstance(instance, ConditionalRefiner) assert instance._ransac_filter is ransac # noqa: SLF001 assert ( len([r for r in caplog.records if r.message == "c3_5.refiner.ready"]) == 1 ) def test_ac10_factory_missing_weights_rejects() -> None: # Arrange inf = _ProgrammableInferenceRuntime() ransac = _ProgrammableRansacFilter() config = Config.with_blocks( c3_5_adhop=C3_5RefinerConfig(strategy="adhop") ) # Act & Assert with pytest.raises(RefinerConfigError): build_refiner_strategy( config, ransac_filter=ransac, inference_runtime=inf ) def test_ac10_create_init_rejects_invalid_thresholds() -> None: inf = _ProgrammableInferenceRuntime() ransac = _ProgrammableRansacFilter() clock = _FakeClock() with pytest.raises(RefinerConfigError): AdHoPRefiner( inference_runtime=inf, engine_handle=object(), ransac_filter=ransac, invocation_rate_warn_threshold=0.0, ratelimited_warn_window_ns=_ONE_SECOND_NS, ransac_threshold_px=3.0, min_inliers_threshold=60, clock=clock, fdr_client=None, logger=logging.getLogger("test"), ) with pytest.raises(RefinerConfigError): AdHoPRefiner( inference_runtime=inf, engine_handle=object(), ransac_filter=ransac, invocation_rate_warn_threshold=0.5, ratelimited_warn_window_ns=_ONE_SECOND_NS, ransac_threshold_px=-1.0, min_inliers_threshold=60, clock=clock, fdr_client=None, logger=logging.getLogger("test"), ) # ---------------------------------------------------------------------- # AC-11: FDR refiner.frame_done emitted on every call. def test_ac11_fdr_emitted_on_gate_passthrough() -> None: fdr = _CapturingFdrClient() refiner = _build_refiner( inference=_ProgrammableInferenceRuntime(), ransac=_ProgrammableRansacFilter(), fdr_client=fdr, clock=_FakeClock(), ) refiner.refine_if_needed( _make_frame(), _make_match_result(reprojection_residual=1.0), 2.5 ) records = fdr.by_kind("refiner.frame_done") assert len(records) == 1 payload = records[0].payload assert payload["was_invoked"] is False assert payload["refinement_label"] == "passthrough" assert payload["refinement_added_latency_ms"] == 0.0 for field in ( "frame_id", "pre_residual_px", "post_residual_px", "inlier_count_before", "inlier_count_after", ): assert field in payload def test_ac11_fdr_emitted_on_adhop_success() -> None: inf = _ProgrammableInferenceRuntime() ransac = _ProgrammableRansacFilter() fdr = _CapturingFdrClient() inf.queue_refined(np.full((100, 4), 0.5, dtype=np.float32)) ransac.queue(inliers=100, residual=1.1) refiner = _build_refiner( inference=inf, ransac=ransac, fdr_client=fdr, clock=_FakeClock() ) refiner.refine_if_needed( _make_frame(), _make_match_result(reprojection_residual=5.0), 2.5 ) records = fdr.by_kind("refiner.frame_done") assert len(records) == 1 payload = records[0].payload assert payload["was_invoked"] is True assert payload["refinement_label"] == "adhop" assert payload["refinement_added_latency_ms"] > 0 assert payload["post_residual_px"] == pytest.approx(1.1) assert payload["inlier_count_after"] == 100 def test_ac11_fdr_emitted_on_fallthrough_with_error_flag() -> None: inf = _ProgrammableInferenceRuntime() ransac = _ProgrammableRansacFilter() fdr = _CapturingFdrClient() inf.queue_error(RuntimeError("trt failed")) refiner = _build_refiner( inference=inf, ransac=ransac, fdr_client=fdr, clock=_FakeClock() ) refiner.refine_if_needed( _make_frame(), _make_match_result(reprojection_residual=5.0), 2.5 ) records = fdr.by_kind("refiner.frame_done") assert len(records) == 1 payload = records[0].payload assert payload["was_invoked"] is True assert payload.get("error") is True assert payload["refinement_label"] == "passthrough" # ---------------------------------------------------------------------- # Extra safety: bad threshold raises ValueError. def test_extra_zero_threshold_raises_value_error() -> None: refiner = _build_refiner( inference=_ProgrammableInferenceRuntime(), ransac=_ProgrammableRansacFilter(), fdr_client=None, clock=_FakeClock(), ) with pytest.raises(ValueError): refiner.refine_if_needed( _make_frame(), _make_match_result(), residual_threshold_px=0.0 ) def test_extra_bad_refined_shape_falls_through() -> None: # Arrange — engine returns a (50,) array instead of (M, 4). inf = _ProgrammableInferenceRuntime() ransac = _ProgrammableRansacFilter() fdr = _CapturingFdrClient() inf.queue_bad_output({"correspondences": np.zeros((50,), dtype=np.float32)}) refiner = _build_refiner( inference=inf, ransac=ransac, fdr_client=fdr, clock=_FakeClock() ) mr = _make_match_result(reprojection_residual=5.0) # Act out = refiner.refine_if_needed(_make_frame(), mr, residual_threshold_px=2.5) # Assert assert out is mr error_records = fdr.by_kind("refiner.frame_done") assert len(error_records) == 1 assert error_records[0].payload.get("error") is True def test_extra_non_finite_refined_falls_through() -> None: # Arrange — engine returns NaN. inf = _ProgrammableInferenceRuntime() ransac = _ProgrammableRansacFilter() fdr = _CapturingFdrClient() refined = np.full((100, 4), np.nan, dtype=np.float32) inf.queue_bad_output({"correspondences": refined}) refiner = _build_refiner( inference=inf, ransac=ransac, fdr_client=fdr, clock=_FakeClock() ) mr = _make_match_result(reprojection_residual=5.0) # Act out = refiner.refine_if_needed(_make_frame(), mr, residual_threshold_px=2.5) # Assert assert out is mr error_records = fdr.by_kind("refiner.frame_done") assert len(error_records) == 1 assert error_records[0].payload.get("error") is True