[AZ-232] Add safety anchor state machine

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Oleksandr Bezdieniezhnykh
2026-05-03 19:10:10 +03:00
parent 7819ae7a38
commit 9fb9e4a349
8 changed files with 388 additions and 7 deletions
+17
View File
@@ -1 +1,18 @@
"""Safety and anchor wrapper component."""
from .interfaces import LocalizationStateMachine, SafetyAnchorStateMachine
from .types import (
LocalizationSnapshot,
SafetyStateConfig,
TelemetryContext,
TileWriteEligibility,
)
__all__ = [
"LocalizationSnapshot",
"LocalizationStateMachine",
"SafetyAnchorStateMachine",
"SafetyStateConfig",
"TelemetryContext",
"TileWriteEligibility",
]
+141 -3
View File
@@ -1,13 +1,151 @@
"""Public localization state-machine interfaces."""
from typing import Any, Protocol
from typing import Protocol
from shared.contracts import AnchorDecision, PositionEstimate, VioStatePacket
from .types import (
LocalizationSnapshot,
SafetyStateConfig,
TelemetryContext,
TileWriteEligibility,
)
class LocalizationStateMachine(Protocol):
"""Coordinates VIO propagation and anchor promotion decisions."""
def update_vio(self, vio_state: Any) -> Any:
def update_vio(
self, vio_state: VioStatePacket, telemetry: TelemetryContext
) -> LocalizationSnapshot:
"""Update the state machine with a VIO state packet."""
def consider_anchor(self, anchor_decision: Any) -> Any:
def consider_anchor(self, anchor_decision: AnchorDecision) -> LocalizationSnapshot:
"""Evaluate a verified anchor decision."""
class SafetyAnchorStateMachine:
"""Owns authoritative source labels, covariance, and tile eligibility."""
def __init__(self, config: SafetyStateConfig | None = None) -> None:
self._config = config or SafetyStateConfig()
self._snapshot: LocalizationSnapshot | None = None
@property
def snapshot(self) -> LocalizationSnapshot | None:
return self._snapshot
def update_vio(
self,
vio_state: VioStatePacket,
telemetry: TelemetryContext,
) -> LocalizationSnapshot:
covariance_m = self._covariance_from_vio(vio_state)
estimate = PositionEstimate(
timestamp_ns=vio_state.timestamp_ns,
latitude_deg=telemetry.latitude_hint_deg,
longitude_deg=telemetry.longitude_hint_deg,
altitude_m=telemetry.altitude_m,
covariance_semimajor_m=covariance_m,
source_label="vo_extrapolated",
fix_type=3,
horizontal_accuracy_m=covariance_m,
anchor_age_ms=0,
)
self._snapshot = LocalizationSnapshot(
estimate=estimate,
mode="vo_extrapolated",
last_vio_state=vio_state,
)
return self._snapshot
def consider_anchor(self, anchor_decision: AnchorDecision) -> LocalizationSnapshot:
self._require_snapshot()
assert self._snapshot is not None
if not anchor_decision.accepted:
return self._snapshot
pose = anchor_decision.estimated_pose or {}
covariance_m = max(anchor_decision.mean_reprojection_error_px, 0.5)
estimate = PositionEstimate(
timestamp_ns=self._snapshot.estimate.timestamp_ns,
latitude_deg=float(pose.get("latitude_deg", self._snapshot.estimate.latitude_deg)),
longitude_deg=float(pose.get("longitude_deg", self._snapshot.estimate.longitude_deg)),
altitude_m=float(pose.get("altitude_m", self._snapshot.estimate.altitude_m)),
covariance_semimajor_m=covariance_m,
source_label="satellite_anchored",
fix_type=3,
horizontal_accuracy_m=covariance_m,
anchor_age_ms=0,
)
self._snapshot = LocalizationSnapshot(
estimate=estimate,
mode="satellite_anchored",
anchor_evidence=anchor_decision,
last_vio_state=self._snapshot.last_vio_state,
)
return self._snapshot
def propagate_blackout(self, timestamp_ns: int) -> LocalizationSnapshot:
self._require_snapshot()
assert self._snapshot is not None
previous = self._snapshot.estimate
covariance_m = previous.covariance_semimajor_m + self._config.dead_reckoning_growth_m
no_fix = covariance_m >= self._config.no_fix_covariance_threshold_m
source_label = "no_fix" if no_fix else "dead_reckoned"
fix_type = 0 if no_fix else 2
estimate = PositionEstimate(
timestamp_ns=timestamp_ns,
latitude_deg=previous.latitude_deg,
longitude_deg=previous.longitude_deg,
altitude_m=previous.altitude_m,
covariance_semimajor_m=covariance_m,
source_label=source_label,
fix_type=fix_type,
horizontal_accuracy_m=max(covariance_m, 999.0 if no_fix else covariance_m),
anchor_age_ms=previous.anchor_age_ms + 1_000,
)
self._snapshot = LocalizationSnapshot(
estimate=estimate,
mode=source_label,
anchor_evidence=self._snapshot.anchor_evidence,
last_vio_state=self._snapshot.last_vio_state,
)
return self._snapshot
def tile_write_eligibility(self) -> TileWriteEligibility:
self._require_snapshot()
assert self._snapshot is not None
estimate = self._snapshot.estimate
if estimate.source_label not in {"satellite_anchored", "vo_extrapolated"}:
return TileWriteEligibility(
eligible=False,
reason="untrusted_source_label",
estimate=estimate,
)
if estimate.covariance_semimajor_m > self._config.tile_write_covariance_max_m:
return TileWriteEligibility(
eligible=False,
reason="covariance_too_high",
estimate=estimate,
)
return TileWriteEligibility(
eligible=True,
reason="trusted_pose",
estimate=estimate,
)
def _covariance_from_vio(self, vio_state: VioStatePacket) -> float:
if not vio_state.covariance_hint:
return max(
self._config.vio_covariance_floor_m,
self._config.initial_covariance_m / max(vio_state.tracking_quality, 0.1),
)
diagonal = [
row[index] for index, row in enumerate(vio_state.covariance_hint) if index < len(row)
]
return max(self._config.vio_covariance_floor_m, max(diagonal, default=0.0))
def _require_snapshot(self) -> None:
if self._snapshot is None:
raise RuntimeError("safety state requires a VIO update before this operation")
+37 -3
View File
@@ -1,5 +1,39 @@
"""Public safety wrapper type aliases."""
"""Public safety wrapper models."""
from typing import Any
from typing import Literal
PositionEstimateLike = Any
from pydantic import BaseModel, ConfigDict, Field, NonNegativeFloat, NonNegativeInt
from shared.contracts import AnchorDecision, PositionEstimate, VioStatePacket
class SafetyWrapperModel(BaseModel):
model_config = ConfigDict(extra="forbid", frozen=True)
class TelemetryContext(SafetyWrapperModel):
timestamp_ns: NonNegativeInt
latitude_hint_deg: float = Field(ge=-90.0, le=90.0)
longitude_hint_deg: float = Field(ge=-180.0, le=180.0)
altitude_m: float
class SafetyStateConfig(SafetyWrapperModel):
initial_covariance_m: NonNegativeFloat = 2.0
vio_covariance_floor_m: NonNegativeFloat = 1.0
dead_reckoning_growth_m: NonNegativeFloat = 50.0
no_fix_covariance_threshold_m: NonNegativeFloat = 500.0
tile_write_covariance_max_m: NonNegativeFloat = 3.0
class LocalizationSnapshot(SafetyWrapperModel):
estimate: PositionEstimate
mode: Literal["satellite_anchored", "vo_extrapolated", "dead_reckoned", "no_fix"]
anchor_evidence: AnchorDecision | None = None
last_vio_state: VioStatePacket | None = None
class TileWriteEligibility(SafetyWrapperModel):
eligible: bool
reason: str = Field(min_length=1)
estimate: PositionEstimate