mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-06-22 20:41:12 +00:00
a1185d0a28
Implement the three concrete C3 CrossDomainMatcher strategies plus the C3.5 production-default AdHoPRefiner. C3 (AZ-345/346/347): - DiskLightGlueMatcher + AlikedLightGlueMatcher share a single shared _pipeline.run_lightglue_pipeline orchestrator (decode -> query extract -> per-candidate loop -> RANSAC sort -> health update -> FDR emit) so the only per-backbone delta is the keypoint+descriptor extractor closure. ALIKED adds a create-time engine output-schema probe (AC-special-1). - XFeatMatcher owns its own per-candidate loop (single forward fuses extraction + matching); it re-uses the shared FDR emission helpers to keep telemetry byte-identical across strategies. lightglue_runtime parameter accepted by factory but discarded (AC-special-1). - All three consume the shared LightGlueRuntime / RansacFilter / RollingHealthWindow helpers; no helper forks. InferenceRuntimeCut consumer-side Protocol added per AZ-507. C3.5 (AZ-349): - AdHoPRefiner implements the <= conditional gate, runs the OrthoLoC AdHoP TRT engine over best-candidate correspondences, re-runs RANSAC on the perspective-preconditioned set, and emits an enriched MatchResult with refinement_label="adhop". - Invariant 4 passthrough fall-through: any RefinerBackboneError (TRT failure, OOM, NaN, bad shape) is caught, logged ERROR, FDR-emitted with error: true, and converted to passthrough that still counts against the rolling invocation-rate window. MemoryError and other non-listed exceptions propagate by design (AC-5 closed-set semantics). - Rolling 60-s invocation-rate window + rate-limited WARN log (configurable via ratelimited_warn_window_ns; default 60 s). Shared changes: - C3MatcherConfig + C3_5RefinerConfig extended with the new weights/threshold/window fields. - matcher_factory + refiner_factory optionally forward clock + fdr_client to the strategy's create(); backward-compatible. - fdr_client.records registers five new kinds: matcher.frame_done, matcher.backbone_error, matcher.insufficient_inliers, matcher.all_failed, refiner.frame_done. Tests: 66 new (43 C3 parametrised + 23 AdHoP) covering 47/47 ACs; focused suite green; full project test suite green except for one pre-existing flaky CLI cold-start timing test unrelated to this batch. Co-authored-by: Cursor <cursoragent@cursor.com>
774 lines
26 KiB
Python
774 lines
26 KiB
Python
"""AZ-349 — :class:`AdHoPRefiner` AC-1..AC-11 coverage.
|
||
|
||
The AdHoP TRT engine is exercised via a programmable
|
||
:class:`_FakeInferenceRuntime` that returns canned
|
||
``correspondences`` arrays (or raises) per call. The
|
||
:class:`RansacFilter` is replaced by a programmable stub that
|
||
returns canned :class:`RansacResult` instances. The fake clock is
|
||
a monotonic counter that lets AC-7 / AC-8 walk through 60 s
|
||
windows deterministically.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
from dataclasses import dataclass
|
||
from datetime import datetime, timezone
|
||
from pathlib import Path
|
||
from typing import Final
|
||
|
||
import numpy as np
|
||
import pytest
|
||
|
||
from gps_denied_onboard._types.matcher import (
|
||
CandidateMatchSet,
|
||
MatchResult,
|
||
)
|
||
from gps_denied_onboard._types.nav import NavCameraFrame
|
||
from gps_denied_onboard.components.c3_5_adhop import (
|
||
C3_5RefinerConfig,
|
||
ConditionalRefiner,
|
||
)
|
||
from gps_denied_onboard.components.c3_5_adhop.adhop_refiner import (
|
||
AdHoPRefiner,
|
||
create as create_adhop,
|
||
)
|
||
from gps_denied_onboard.components.c3_5_adhop.errors import (
|
||
RefinerBackboneError,
|
||
RefinerConfigError,
|
||
)
|
||
from gps_denied_onboard.components.c3_matcher import C3MatcherConfig
|
||
from gps_denied_onboard.config.schema import Config
|
||
from gps_denied_onboard.fdr_client import EnqueueResult, FdrRecord
|
||
from gps_denied_onboard.helpers.ransac_filter import RansacResult
|
||
from gps_denied_onboard.runtime_root.refiner_factory import build_refiner_strategy
|
||
|
||
|
||
_ONE_SECOND_NS: Final[int] = 1_000_000_000
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# Test doubles.
|
||
|
||
|
||
@dataclass
|
||
class _FakeClock:
|
||
"""Monotonic clock — advances 1 ns per ``monotonic_ns`` call."""
|
||
|
||
_t: int = 1_700_000_000_000_000_000
|
||
|
||
def monotonic_ns(self) -> int:
|
||
self._t += 1
|
||
return self._t
|
||
|
||
def time_ns(self) -> int:
|
||
return self._t
|
||
|
||
def sleep_until_ns(self, target_ns: int) -> None: # noqa: ARG002
|
||
return None
|
||
|
||
def advance(self, delta_ns: int) -> None:
|
||
self._t += int(delta_ns)
|
||
|
||
|
||
class _CapturingFdrClient:
|
||
def __init__(self) -> None:
|
||
self.records: list[FdrRecord] = []
|
||
|
||
@property
|
||
def producer_id(self) -> str:
|
||
return "c3_5_adhop.test"
|
||
|
||
def enqueue(self, record: FdrRecord) -> str:
|
||
self.records.append(record)
|
||
return EnqueueResult.OK
|
||
|
||
def by_kind(self, kind: str) -> list[FdrRecord]:
|
||
return [r for r in self.records if r.kind == kind]
|
||
|
||
|
||
class _ProgrammableInferenceRuntime:
|
||
def __init__(self) -> None:
|
||
self._queue: list[object] = []
|
||
self.calls: int = 0
|
||
|
||
def queue_refined(self, refined: np.ndarray) -> None:
|
||
self._queue.append({"correspondences": refined.astype(np.float32)})
|
||
|
||
def queue_error(self, exc: BaseException) -> None:
|
||
self._queue.append(exc)
|
||
|
||
def queue_bad_output(self, payload: object) -> None:
|
||
self._queue.append(payload)
|
||
|
||
def current_runtime_label(self) -> str:
|
||
return "tensorrt"
|
||
|
||
def compile_engine(self, model_path: Path, build_config) -> object: # noqa: ARG002
|
||
return _DummyEngineEntry(
|
||
engine_path=model_path,
|
||
sha256_hex="0" * 64,
|
||
sm=87,
|
||
jp="6.0",
|
||
trt="10.3",
|
||
precision=build_config.precision,
|
||
extras={"model_name": "adhop"},
|
||
)
|
||
|
||
def deserialize_engine(self, entry) -> object: # noqa: ARG002
|
||
return object()
|
||
|
||
def release_engine(self, handle) -> None: # noqa: ARG002
|
||
return None
|
||
|
||
def thermal_state(self):
|
||
raise NotImplementedError
|
||
|
||
def infer(self, handle, inputs): # noqa: ARG002
|
||
self.calls += 1
|
||
result = self._queue.pop(0)
|
||
if isinstance(result, BaseException):
|
||
raise result
|
||
return result
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class _DummyEngineEntry:
|
||
engine_path: Path
|
||
sha256_hex: str
|
||
sm: int
|
||
jp: str
|
||
trt: str
|
||
precision: object
|
||
extras: dict[str, str]
|
||
|
||
|
||
class _ProgrammableRansacFilter:
|
||
def __init__(self) -> None:
|
||
self._queue: list[RansacResult] = []
|
||
self.calls: int = 0
|
||
|
||
def queue(self, *, inliers: int, residual: float, outliers: int = 0) -> None:
|
||
if inliers > 0:
|
||
arr = np.tile(
|
||
np.array([10.0, 20.0, 30.0, 40.0], dtype=np.float32), (inliers, 1)
|
||
)
|
||
else:
|
||
arr = np.zeros((0, 4), dtype=np.float32)
|
||
self._queue.append(
|
||
RansacResult(
|
||
inlier_correspondences=arr,
|
||
inlier_count=int(inliers),
|
||
outlier_count=int(outliers),
|
||
median_residual_px=float(residual),
|
||
)
|
||
)
|
||
|
||
def filter_correspondences(
|
||
self, corr, threshold, min_inl
|
||
): # noqa: ARG002
|
||
self.calls += 1
|
||
return self._queue.pop(0)
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# Builders.
|
||
|
||
|
||
def _make_frame(frame_id: int = 7) -> NavCameraFrame:
|
||
return NavCameraFrame(
|
||
frame_id=frame_id,
|
||
timestamp=datetime.now(tz=timezone.utc),
|
||
image=np.zeros((16, 16, 3), dtype=np.uint8),
|
||
camera_calibration_id="cam0",
|
||
)
|
||
|
||
|
||
def _make_candidate(*, inliers: int = 120, residual: float = 5.0) -> CandidateMatchSet:
|
||
return CandidateMatchSet(
|
||
tile_id=(18, 49.9, 36.3),
|
||
inlier_count=inliers,
|
||
inlier_correspondences=np.full((inliers, 4), 0.25, dtype=np.float32),
|
||
ransac_outlier_count=3,
|
||
per_candidate_residual_px=residual,
|
||
)
|
||
|
||
|
||
def _make_match_result(
|
||
*,
|
||
frame_id: int = 7,
|
||
reprojection_residual: float = 5.0,
|
||
inliers: int = 120,
|
||
refinement_label: str = "passthrough",
|
||
) -> MatchResult:
|
||
cand = _make_candidate(inliers=inliers, residual=reprojection_residual)
|
||
return MatchResult(
|
||
frame_id=frame_id,
|
||
per_candidate=(cand,),
|
||
best_candidate_idx=0,
|
||
reprojection_residual_px=reprojection_residual,
|
||
matched_at=1_000_000_000,
|
||
matcher_label="disk_lightglue",
|
||
candidates_input=3,
|
||
candidates_dropped=2,
|
||
refinement_label=refinement_label,
|
||
)
|
||
|
||
|
||
def _build_refiner(
|
||
*,
|
||
inference: _ProgrammableInferenceRuntime,
|
||
ransac: _ProgrammableRansacFilter,
|
||
fdr_client: _CapturingFdrClient | None,
|
||
clock: _FakeClock,
|
||
invocation_rate_warn_threshold: float = 0.25,
|
||
ratelimited_warn_window_ns: int = 60 * _ONE_SECOND_NS,
|
||
ransac_threshold_px: float = 3.0,
|
||
min_inliers_threshold: int = 60,
|
||
logger: logging.Logger | None = None,
|
||
) -> AdHoPRefiner:
|
||
return AdHoPRefiner(
|
||
inference_runtime=inference,
|
||
engine_handle=object(),
|
||
ransac_filter=ransac,
|
||
invocation_rate_warn_threshold=invocation_rate_warn_threshold,
|
||
ratelimited_warn_window_ns=ratelimited_warn_window_ns,
|
||
ransac_threshold_px=ransac_threshold_px,
|
||
min_inliers_threshold=min_inliers_threshold,
|
||
clock=clock,
|
||
fdr_client=fdr_client,
|
||
logger=logger or logging.getLogger("test.c3_5_adhop"),
|
||
)
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# AC-1: Protocol conformance.
|
||
|
||
|
||
def test_ac1_protocol_conformance() -> None:
|
||
refiner = _build_refiner(
|
||
inference=_ProgrammableInferenceRuntime(),
|
||
ransac=_ProgrammableRansacFilter(),
|
||
fdr_client=None,
|
||
clock=_FakeClock(),
|
||
)
|
||
assert isinstance(refiner, ConditionalRefiner)
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# AC-2: Gate inclusive semantics — `<=` is passthrough.
|
||
|
||
|
||
def test_ac2_gate_inclusive_equality_is_passthrough() -> None:
|
||
# Arrange
|
||
fdr = _CapturingFdrClient()
|
||
refiner = _build_refiner(
|
||
inference=_ProgrammableInferenceRuntime(),
|
||
ransac=_ProgrammableRansacFilter(),
|
||
fdr_client=fdr,
|
||
clock=_FakeClock(),
|
||
)
|
||
mr = _make_match_result(reprojection_residual=2.5)
|
||
# Act
|
||
out = refiner.refine_if_needed(_make_frame(), mr, residual_threshold_px=2.5)
|
||
# Assert
|
||
assert out is mr
|
||
assert refiner.was_invoked() is False
|
||
|
||
|
||
def test_ac2_gate_above_threshold_invokes() -> None:
|
||
# Arrange
|
||
inf = _ProgrammableInferenceRuntime()
|
||
ransac = _ProgrammableRansacFilter()
|
||
fdr = _CapturingFdrClient()
|
||
refined_corr = np.full((100, 4), 1.0, dtype=np.float32)
|
||
inf.queue_refined(refined_corr)
|
||
ransac.queue(inliers=100, residual=1.2)
|
||
refiner = _build_refiner(
|
||
inference=inf, ransac=ransac, fdr_client=fdr, clock=_FakeClock()
|
||
)
|
||
mr = _make_match_result(reprojection_residual=2.5 + 1e-6)
|
||
# Act
|
||
out = refiner.refine_if_needed(_make_frame(), mr, residual_threshold_px=2.5)
|
||
# Assert
|
||
assert refiner.was_invoked() is True
|
||
assert out.refinement_label == "adhop"
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# AC-3: Successful refinement enriches MatchResult.
|
||
|
||
|
||
def test_ac3_successful_refinement_enriches_match_result() -> None:
|
||
# Arrange
|
||
inf = _ProgrammableInferenceRuntime()
|
||
ransac = _ProgrammableRansacFilter()
|
||
fdr = _CapturingFdrClient()
|
||
refined_corr = np.full((100, 4), 7.7, dtype=np.float32)
|
||
inf.queue_refined(refined_corr)
|
||
ransac.queue(inliers=100, residual=1.2)
|
||
refiner = _build_refiner(
|
||
inference=inf, ransac=ransac, fdr_client=fdr, clock=_FakeClock()
|
||
)
|
||
mr = _make_match_result(reprojection_residual=5.0)
|
||
# Act
|
||
out = refiner.refine_if_needed(_make_frame(), mr, residual_threshold_px=2.5)
|
||
# Assert
|
||
assert out.refinement_label == "adhop"
|
||
assert out.reprojection_residual_px == pytest.approx(1.2)
|
||
assert out.refinement_added_latency_ms > 0
|
||
assert refiner.was_invoked() is True
|
||
in_corr = mr.per_candidate[0].inlier_correspondences
|
||
out_corr = out.per_candidate[0].inlier_correspondences
|
||
assert not np.array_equal(in_corr, out_corr)
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# AC-4: Passthrough fall-through on RefinerBackboneError.
|
||
|
||
|
||
def test_ac4_refiner_backbone_error_falls_through(caplog) -> None:
|
||
# Arrange — engine raises RuntimeError → maps to RefinerBackboneError.
|
||
inf = _ProgrammableInferenceRuntime()
|
||
ransac = _ProgrammableRansacFilter()
|
||
fdr = _CapturingFdrClient()
|
||
inf.queue_error(RuntimeError("simulated TRT failure"))
|
||
refiner = _build_refiner(
|
||
inference=inf, ransac=ransac, fdr_client=fdr, clock=_FakeClock()
|
||
)
|
||
mr = _make_match_result(reprojection_residual=5.0)
|
||
# Act
|
||
with caplog.at_level(logging.ERROR):
|
||
out = refiner.refine_if_needed(_make_frame(), mr, residual_threshold_px=2.5)
|
||
# Assert
|
||
assert out is mr
|
||
assert out.refinement_label == "passthrough"
|
||
assert refiner.was_invoked() is True
|
||
error_records = [
|
||
r for r in caplog.records if r.message == "c3_5.refiner.backbone_error"
|
||
]
|
||
assert len(error_records) == 1
|
||
frame_dones = fdr.by_kind("refiner.frame_done")
|
||
assert len(frame_dones) == 1
|
||
assert frame_dones[0].payload.get("error") is True
|
||
|
||
|
||
def test_ac4_explicit_refiner_backbone_error_falls_through() -> None:
|
||
# Arrange — engine raises explicit RefinerBackboneError.
|
||
inf = _ProgrammableInferenceRuntime()
|
||
ransac = _ProgrammableRansacFilter()
|
||
fdr = _CapturingFdrClient()
|
||
inf.queue_error(RefinerBackboneError("explicit"))
|
||
refiner = _build_refiner(
|
||
inference=inf, ransac=ransac, fdr_client=fdr, clock=_FakeClock()
|
||
)
|
||
mr = _make_match_result(reprojection_residual=5.0)
|
||
# Act
|
||
out = refiner.refine_if_needed(_make_frame(), mr, residual_threshold_px=2.5)
|
||
# Assert
|
||
assert out is mr
|
||
assert refiner.was_invoked() is True
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# AC-5: Other exception types re-raise.
|
||
|
||
|
||
def test_ac5_memory_error_propagates() -> None:
|
||
# Arrange
|
||
inf = _ProgrammableInferenceRuntime()
|
||
ransac = _ProgrammableRansacFilter()
|
||
fdr = _CapturingFdrClient()
|
||
inf.queue_error(MemoryError("simulated OOM"))
|
||
refiner = _build_refiner(
|
||
inference=inf, ransac=ransac, fdr_client=fdr, clock=_FakeClock()
|
||
)
|
||
mr = _make_match_result(reprojection_residual=5.0)
|
||
# Act & Assert
|
||
with pytest.raises(MemoryError):
|
||
refiner.refine_if_needed(_make_frame(), mr, residual_threshold_px=2.5)
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# AC-6: Gate-decided passthrough — bit-identical correspondences.
|
||
|
||
|
||
def test_ac6_gate_passthrough_correspondences_identity_preserved() -> None:
|
||
# Arrange
|
||
fdr = _CapturingFdrClient()
|
||
refiner = _build_refiner(
|
||
inference=_ProgrammableInferenceRuntime(),
|
||
ransac=_ProgrammableRansacFilter(),
|
||
fdr_client=fdr,
|
||
clock=_FakeClock(),
|
||
)
|
||
mr = _make_match_result(reprojection_residual=1.0)
|
||
# Act
|
||
out = refiner.refine_if_needed(_make_frame(), mr, residual_threshold_px=2.5)
|
||
# Assert
|
||
assert out is mr
|
||
for in_c, out_c in zip(mr.per_candidate, out.per_candidate, strict=True):
|
||
assert out_c.inlier_correspondences is in_c.inlier_correspondences
|
||
assert out.refinement_label == "passthrough"
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# AC-7: _invocation_window rate accuracy.
|
||
|
||
|
||
def test_ac7_invocation_window_rate_accuracy() -> None:
|
||
# Arrange — 30 frames at 3 Hz: 10 invoked, 20 passthrough.
|
||
clock = _FakeClock()
|
||
inf = _ProgrammableInferenceRuntime()
|
||
ransac = _ProgrammableRansacFilter()
|
||
fdr = _CapturingFdrClient()
|
||
for _ in range(10):
|
||
inf.queue_refined(np.full((50, 4), 0.5, dtype=np.float32))
|
||
ransac.queue(inliers=50, residual=1.5)
|
||
refiner = _build_refiner(
|
||
inference=inf, ransac=ransac, fdr_client=fdr, clock=clock
|
||
)
|
||
mr_invoke = _make_match_result(reprojection_residual=5.0)
|
||
mr_passthrough = _make_match_result(reprojection_residual=1.0)
|
||
# Act — interleave (3 passthrough + 1 invoke) × 10 to land 10/30
|
||
for cycle in range(10):
|
||
for _ in range(2):
|
||
refiner.refine_if_needed(
|
||
_make_frame(), mr_passthrough, residual_threshold_px=2.5
|
||
)
|
||
clock.advance(int(_ONE_SECOND_NS / 3))
|
||
refiner.refine_if_needed(_make_frame(), mr_invoke, residual_threshold_px=2.5)
|
||
clock.advance(int(_ONE_SECOND_NS / 3))
|
||
# Assert
|
||
rate = refiner._invocation_rate() # noqa: SLF001
|
||
assert rate == pytest.approx(10 / 30, abs=0.01)
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# AC-8: Invocation-rate WARN is rate-limited.
|
||
|
||
|
||
def test_ac8_invocation_rate_warn_rate_limited(caplog) -> None:
|
||
# Arrange — high rate (every frame invoked) → trigger WARN.
|
||
clock = _FakeClock()
|
||
inf = _ProgrammableInferenceRuntime()
|
||
ransac = _ProgrammableRansacFilter()
|
||
fdr = _CapturingFdrClient()
|
||
for _ in range(20):
|
||
inf.queue_refined(np.full((50, 4), 0.5, dtype=np.float32))
|
||
ransac.queue(inliers=50, residual=1.5)
|
||
refiner = _build_refiner(
|
||
inference=inf,
|
||
ransac=ransac,
|
||
fdr_client=fdr,
|
||
clock=clock,
|
||
invocation_rate_warn_threshold=0.25,
|
||
ratelimited_warn_window_ns=10 * _ONE_SECOND_NS,
|
||
logger=logging.getLogger("test.c3_5_adhop.warn"),
|
||
)
|
||
mr_invoke = _make_match_result(reprojection_residual=5.0)
|
||
# Act — 20 invoked frames within 5 seconds → rate = 1.0 ≫ 0.25
|
||
with caplog.at_level(logging.WARNING, logger="test.c3_5_adhop.warn"):
|
||
for _ in range(20):
|
||
refiner.refine_if_needed(
|
||
_make_frame(), mr_invoke, residual_threshold_px=2.5
|
||
)
|
||
clock.advance(int(_ONE_SECOND_NS / 4))
|
||
# Assert
|
||
warns = [
|
||
r for r in caplog.records if r.message == "c3_5.refiner.invocation_rate_high"
|
||
]
|
||
assert len(warns) == 1
|
||
|
||
|
||
def test_ac8_warn_re_fires_after_window_expires(caplog) -> None:
|
||
# Arrange — high rate sustained beyond the rate-limit window.
|
||
clock = _FakeClock()
|
||
inf = _ProgrammableInferenceRuntime()
|
||
ransac = _ProgrammableRansacFilter()
|
||
fdr = _CapturingFdrClient()
|
||
for _ in range(40):
|
||
inf.queue_refined(np.full((50, 4), 0.5, dtype=np.float32))
|
||
ransac.queue(inliers=50, residual=1.5)
|
||
refiner = _build_refiner(
|
||
inference=inf,
|
||
ransac=ransac,
|
||
fdr_client=fdr,
|
||
clock=clock,
|
||
invocation_rate_warn_threshold=0.25,
|
||
ratelimited_warn_window_ns=1 * _ONE_SECOND_NS,
|
||
logger=logging.getLogger("test.c3_5_adhop.warn2"),
|
||
)
|
||
mr_invoke = _make_match_result(reprojection_residual=5.0)
|
||
# Act — 40 invokes spaced 5 s apart so the rate-limit window expires.
|
||
with caplog.at_level(logging.WARNING, logger="test.c3_5_adhop.warn2"):
|
||
for _ in range(40):
|
||
refiner.refine_if_needed(
|
||
_make_frame(), mr_invoke, residual_threshold_px=2.5
|
||
)
|
||
clock.advance(5 * _ONE_SECOND_NS)
|
||
# Assert — many warns, but bounded by N frames / N seconds.
|
||
warns = [
|
||
r for r in caplog.records if r.message == "c3_5.refiner.invocation_rate_high"
|
||
]
|
||
assert len(warns) >= 2
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# AC-9: was_invoked() three-state semantics.
|
||
|
||
|
||
def test_ac9_was_invoked_gate_passthrough_false() -> None:
|
||
refiner = _build_refiner(
|
||
inference=_ProgrammableInferenceRuntime(),
|
||
ransac=_ProgrammableRansacFilter(),
|
||
fdr_client=None,
|
||
clock=_FakeClock(),
|
||
)
|
||
refiner.refine_if_needed(
|
||
_make_frame(), _make_match_result(reprojection_residual=1.0), 2.5
|
||
)
|
||
assert refiner.was_invoked() is False
|
||
|
||
|
||
def test_ac9_was_invoked_adhop_success_true() -> None:
|
||
inf = _ProgrammableInferenceRuntime()
|
||
ransac = _ProgrammableRansacFilter()
|
||
inf.queue_refined(np.full((50, 4), 0.5, dtype=np.float32))
|
||
ransac.queue(inliers=50, residual=1.5)
|
||
refiner = _build_refiner(
|
||
inference=inf, ransac=ransac, fdr_client=None, clock=_FakeClock()
|
||
)
|
||
refiner.refine_if_needed(
|
||
_make_frame(), _make_match_result(reprojection_residual=5.0), 2.5
|
||
)
|
||
assert refiner.was_invoked() is True
|
||
|
||
|
||
def test_ac9_was_invoked_fallthrough_true() -> None:
|
||
inf = _ProgrammableInferenceRuntime()
|
||
ransac = _ProgrammableRansacFilter()
|
||
inf.queue_error(RuntimeError("trt error"))
|
||
refiner = _build_refiner(
|
||
inference=inf, ransac=ransac, fdr_client=None, clock=_FakeClock()
|
||
)
|
||
refiner.refine_if_needed(
|
||
_make_frame(), _make_match_result(reprojection_residual=5.0), 2.5
|
||
)
|
||
assert refiner.was_invoked() is True
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# AC-10: Composition-root wiring + identity-shared RansacFilter.
|
||
|
||
|
||
def test_ac10_factory_wires_adhop_strategy(caplog) -> None:
|
||
# Arrange
|
||
inf = _ProgrammableInferenceRuntime()
|
||
ransac = _ProgrammableRansacFilter()
|
||
fdr = _CapturingFdrClient()
|
||
clock = _FakeClock()
|
||
weights_path = Path("/tmp/adhop.engine")
|
||
config = Config.with_blocks(
|
||
c3_5_adhop=C3_5RefinerConfig(
|
||
strategy="adhop", residual_threshold_px=2.5, adhop_weights_path=weights_path
|
||
),
|
||
c3_matcher=C3MatcherConfig(
|
||
strategy="disk_lightglue",
|
||
min_inliers_threshold=60,
|
||
disk_weights_path=Path("/tmp/disk.engine"),
|
||
aliked_weights_path=Path("/tmp/aliked.engine"),
|
||
xfeat_weights_path=Path("/tmp/xfeat.engine"),
|
||
),
|
||
)
|
||
# Act
|
||
with caplog.at_level(logging.INFO, logger="gps_denied_onboard.c3_5_adhop"):
|
||
instance = build_refiner_strategy(
|
||
config,
|
||
ransac_filter=ransac,
|
||
inference_runtime=inf,
|
||
clock=clock,
|
||
fdr_client=fdr,
|
||
)
|
||
# Assert
|
||
assert isinstance(instance, AdHoPRefiner)
|
||
assert isinstance(instance, ConditionalRefiner)
|
||
assert instance._ransac_filter is ransac # noqa: SLF001
|
||
assert (
|
||
len([r for r in caplog.records if r.message == "c3_5.refiner.ready"]) == 1
|
||
)
|
||
|
||
|
||
def test_ac10_factory_missing_weights_rejects() -> None:
|
||
# Arrange
|
||
inf = _ProgrammableInferenceRuntime()
|
||
ransac = _ProgrammableRansacFilter()
|
||
config = Config.with_blocks(
|
||
c3_5_adhop=C3_5RefinerConfig(strategy="adhop")
|
||
)
|
||
# Act & Assert
|
||
with pytest.raises(RefinerConfigError):
|
||
build_refiner_strategy(
|
||
config, ransac_filter=ransac, inference_runtime=inf
|
||
)
|
||
|
||
|
||
def test_ac10_create_init_rejects_invalid_thresholds() -> None:
|
||
inf = _ProgrammableInferenceRuntime()
|
||
ransac = _ProgrammableRansacFilter()
|
||
clock = _FakeClock()
|
||
with pytest.raises(RefinerConfigError):
|
||
AdHoPRefiner(
|
||
inference_runtime=inf,
|
||
engine_handle=object(),
|
||
ransac_filter=ransac,
|
||
invocation_rate_warn_threshold=0.0,
|
||
ratelimited_warn_window_ns=_ONE_SECOND_NS,
|
||
ransac_threshold_px=3.0,
|
||
min_inliers_threshold=60,
|
||
clock=clock,
|
||
fdr_client=None,
|
||
logger=logging.getLogger("test"),
|
||
)
|
||
with pytest.raises(RefinerConfigError):
|
||
AdHoPRefiner(
|
||
inference_runtime=inf,
|
||
engine_handle=object(),
|
||
ransac_filter=ransac,
|
||
invocation_rate_warn_threshold=0.5,
|
||
ratelimited_warn_window_ns=_ONE_SECOND_NS,
|
||
ransac_threshold_px=-1.0,
|
||
min_inliers_threshold=60,
|
||
clock=clock,
|
||
fdr_client=None,
|
||
logger=logging.getLogger("test"),
|
||
)
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# AC-11: FDR refiner.frame_done emitted on every call.
|
||
|
||
|
||
def test_ac11_fdr_emitted_on_gate_passthrough() -> None:
|
||
fdr = _CapturingFdrClient()
|
||
refiner = _build_refiner(
|
||
inference=_ProgrammableInferenceRuntime(),
|
||
ransac=_ProgrammableRansacFilter(),
|
||
fdr_client=fdr,
|
||
clock=_FakeClock(),
|
||
)
|
||
refiner.refine_if_needed(
|
||
_make_frame(), _make_match_result(reprojection_residual=1.0), 2.5
|
||
)
|
||
records = fdr.by_kind("refiner.frame_done")
|
||
assert len(records) == 1
|
||
payload = records[0].payload
|
||
assert payload["was_invoked"] is False
|
||
assert payload["refinement_label"] == "passthrough"
|
||
assert payload["refinement_added_latency_ms"] == 0.0
|
||
for field in (
|
||
"frame_id",
|
||
"pre_residual_px",
|
||
"post_residual_px",
|
||
"inlier_count_before",
|
||
"inlier_count_after",
|
||
):
|
||
assert field in payload
|
||
|
||
|
||
def test_ac11_fdr_emitted_on_adhop_success() -> None:
|
||
inf = _ProgrammableInferenceRuntime()
|
||
ransac = _ProgrammableRansacFilter()
|
||
fdr = _CapturingFdrClient()
|
||
inf.queue_refined(np.full((100, 4), 0.5, dtype=np.float32))
|
||
ransac.queue(inliers=100, residual=1.1)
|
||
refiner = _build_refiner(
|
||
inference=inf, ransac=ransac, fdr_client=fdr, clock=_FakeClock()
|
||
)
|
||
refiner.refine_if_needed(
|
||
_make_frame(), _make_match_result(reprojection_residual=5.0), 2.5
|
||
)
|
||
records = fdr.by_kind("refiner.frame_done")
|
||
assert len(records) == 1
|
||
payload = records[0].payload
|
||
assert payload["was_invoked"] is True
|
||
assert payload["refinement_label"] == "adhop"
|
||
assert payload["refinement_added_latency_ms"] > 0
|
||
assert payload["post_residual_px"] == pytest.approx(1.1)
|
||
assert payload["inlier_count_after"] == 100
|
||
|
||
|
||
def test_ac11_fdr_emitted_on_fallthrough_with_error_flag() -> None:
|
||
inf = _ProgrammableInferenceRuntime()
|
||
ransac = _ProgrammableRansacFilter()
|
||
fdr = _CapturingFdrClient()
|
||
inf.queue_error(RuntimeError("trt failed"))
|
||
refiner = _build_refiner(
|
||
inference=inf, ransac=ransac, fdr_client=fdr, clock=_FakeClock()
|
||
)
|
||
refiner.refine_if_needed(
|
||
_make_frame(), _make_match_result(reprojection_residual=5.0), 2.5
|
||
)
|
||
records = fdr.by_kind("refiner.frame_done")
|
||
assert len(records) == 1
|
||
payload = records[0].payload
|
||
assert payload["was_invoked"] is True
|
||
assert payload.get("error") is True
|
||
assert payload["refinement_label"] == "passthrough"
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# Extra safety: bad threshold raises ValueError.
|
||
|
||
|
||
def test_extra_zero_threshold_raises_value_error() -> None:
|
||
refiner = _build_refiner(
|
||
inference=_ProgrammableInferenceRuntime(),
|
||
ransac=_ProgrammableRansacFilter(),
|
||
fdr_client=None,
|
||
clock=_FakeClock(),
|
||
)
|
||
with pytest.raises(ValueError):
|
||
refiner.refine_if_needed(
|
||
_make_frame(), _make_match_result(), residual_threshold_px=0.0
|
||
)
|
||
|
||
|
||
def test_extra_bad_refined_shape_falls_through() -> None:
|
||
# Arrange — engine returns a (50,) array instead of (M, 4).
|
||
inf = _ProgrammableInferenceRuntime()
|
||
ransac = _ProgrammableRansacFilter()
|
||
fdr = _CapturingFdrClient()
|
||
inf.queue_bad_output({"correspondences": np.zeros((50,), dtype=np.float32)})
|
||
refiner = _build_refiner(
|
||
inference=inf, ransac=ransac, fdr_client=fdr, clock=_FakeClock()
|
||
)
|
||
mr = _make_match_result(reprojection_residual=5.0)
|
||
# Act
|
||
out = refiner.refine_if_needed(_make_frame(), mr, residual_threshold_px=2.5)
|
||
# Assert
|
||
assert out is mr
|
||
error_records = fdr.by_kind("refiner.frame_done")
|
||
assert len(error_records) == 1
|
||
assert error_records[0].payload.get("error") is True
|
||
|
||
|
||
def test_extra_non_finite_refined_falls_through() -> None:
|
||
# Arrange — engine returns NaN.
|
||
inf = _ProgrammableInferenceRuntime()
|
||
ransac = _ProgrammableRansacFilter()
|
||
fdr = _CapturingFdrClient()
|
||
refined = np.full((100, 4), np.nan, dtype=np.float32)
|
||
inf.queue_bad_output({"correspondences": refined})
|
||
refiner = _build_refiner(
|
||
inference=inf, ransac=ransac, fdr_client=fdr, clock=_FakeClock()
|
||
)
|
||
mr = _make_match_result(reprojection_residual=5.0)
|
||
# Act
|
||
out = refiner.refine_if_needed(_make_frame(), mr, residual_threshold_px=2.5)
|
||
# Assert
|
||
assert out is mr
|
||
error_records = fdr.by_kind("refiner.frame_done")
|
||
assert len(error_records) == 1
|
||
assert error_records[0].payload.get("error") is True
|