mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-06-22 08:01:25 +00:00
[AZ-230] Add local VPR retrieval boundary
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -1,24 +1,37 @@
|
||||
"""Offline satellite retrieval and synchronization component."""
|
||||
|
||||
from .interfaces import SatelliteService, SatelliteSyncBoundary
|
||||
from .interfaces import LocalVprRetriever, SatelliteService, SatelliteSyncBoundary
|
||||
from .types import (
|
||||
DescriptorFidelityReport,
|
||||
GeneratedTileUploadRecord,
|
||||
LocalVprIndexPackage,
|
||||
MissionCacheImportResult,
|
||||
MissionCachePackage,
|
||||
RelocalizationRequest,
|
||||
RuntimePhase,
|
||||
SatelliteSyncResult,
|
||||
SatelliteSyncStatus,
|
||||
UploadOutcome,
|
||||
VprDescriptorRecord,
|
||||
VprReadinessReport,
|
||||
VprRetrievalResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DescriptorFidelityReport",
|
||||
"GeneratedTileUploadRecord",
|
||||
"LocalVprIndexPackage",
|
||||
"LocalVprRetriever",
|
||||
"MissionCacheImportResult",
|
||||
"MissionCachePackage",
|
||||
"RelocalizationRequest",
|
||||
"RuntimePhase",
|
||||
"SatelliteService",
|
||||
"SatelliteSyncBoundary",
|
||||
"SatelliteSyncResult",
|
||||
"SatelliteSyncStatus",
|
||||
"UploadOutcome",
|
||||
"VprDescriptorRecord",
|
||||
"VprReadinessReport",
|
||||
"VprRetrievalResult",
|
||||
]
|
||||
|
||||
@@ -1,32 +1,169 @@
|
||||
"""Public satellite service interfaces."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Protocol
|
||||
from math import sqrt
|
||||
from typing import Protocol
|
||||
|
||||
from shared.contracts import VprCandidate
|
||||
from shared.errors import ErrorEnvelope
|
||||
from tile_manager import GeneratedTileSyncPackage
|
||||
|
||||
from .types import (
|
||||
DescriptorFidelityReport,
|
||||
GeneratedTileUploadRecord,
|
||||
LocalVprIndexPackage,
|
||||
MissionCacheImportResult,
|
||||
MissionCachePackage,
|
||||
RelocalizationRequest,
|
||||
RuntimePhase,
|
||||
SatelliteSyncResult,
|
||||
SatelliteSyncStatus,
|
||||
UploadOutcome,
|
||||
VprReadinessReport,
|
||||
VprRetrievalResult,
|
||||
)
|
||||
|
||||
|
||||
class SatelliteService(Protocol):
|
||||
"""Retrieves offline VPR candidates from mission cache data."""
|
||||
|
||||
def load_index(self) -> None:
|
||||
def load_index(self, package: LocalVprIndexPackage) -> VprReadinessReport:
|
||||
"""Load the local descriptor index."""
|
||||
|
||||
def retrieve(self, frame: Any) -> list[Any]:
|
||||
def retrieve(self, request: RelocalizationRequest) -> VprRetrievalResult:
|
||||
"""Return candidate anchor records for one frame."""
|
||||
|
||||
|
||||
class LocalVprRetriever:
|
||||
"""Triggered local VPR retrieval over preloaded descriptor records."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._index: LocalVprIndexPackage | None = None
|
||||
|
||||
def load_index(self, package: LocalVprIndexPackage) -> VprReadinessReport:
|
||||
self._index = package
|
||||
return VprReadinessReport(
|
||||
ready=True,
|
||||
engine=package.engine,
|
||||
loaded_records=len(package.records),
|
||||
)
|
||||
|
||||
def readiness(self) -> VprReadinessReport:
|
||||
if self._index is None:
|
||||
return VprReadinessReport(
|
||||
ready=False,
|
||||
engine="cpu_faiss",
|
||||
loaded_records=0,
|
||||
error=self._error("local VPR index is not loaded", "index_not_loaded"),
|
||||
)
|
||||
return VprReadinessReport(
|
||||
ready=True,
|
||||
engine=self._index.engine,
|
||||
loaded_records=len(self._index.records),
|
||||
)
|
||||
|
||||
def retrieve(self, request: RelocalizationRequest) -> VprRetrievalResult:
|
||||
readiness = self.readiness()
|
||||
if not readiness.ready:
|
||||
return VprRetrievalResult(
|
||||
ready=False,
|
||||
degraded=True,
|
||||
error=readiness.error,
|
||||
)
|
||||
|
||||
assert self._index is not None
|
||||
query_descriptor = request.query_descriptor or self._extract_descriptor(request.image_ref)
|
||||
scored = sorted(
|
||||
(
|
||||
(self._similarity(query_descriptor, record.descriptor), record)
|
||||
for record in self._index.records
|
||||
if record.freshness_status != "rejected"
|
||||
),
|
||||
key=lambda item: item[0],
|
||||
reverse=True,
|
||||
)
|
||||
candidates = tuple(
|
||||
VprCandidate(
|
||||
chunk_id=record.chunk_id,
|
||||
tile_id=record.tile_id,
|
||||
score=score,
|
||||
footprint=record.footprint,
|
||||
freshness_status=record.freshness_status,
|
||||
)
|
||||
for score, record in scored[: request.top_k]
|
||||
)
|
||||
if not candidates:
|
||||
return VprRetrievalResult(
|
||||
ready=True,
|
||||
degraded=True,
|
||||
error=self._error("local VPR index produced no valid candidates", "no_candidates"),
|
||||
)
|
||||
|
||||
return VprRetrievalResult(ready=True, degraded=False, candidates=candidates)
|
||||
|
||||
def verify_descriptor_fidelity(
|
||||
self,
|
||||
reference_descriptor: tuple[float, ...],
|
||||
optimized_descriptor: tuple[float, ...],
|
||||
max_l2_delta: float,
|
||||
) -> DescriptorFidelityReport:
|
||||
observed_delta = self._l2_distance(reference_descriptor, optimized_descriptor)
|
||||
return DescriptorFidelityReport(
|
||||
accepted=observed_delta <= max_l2_delta,
|
||||
observed_l2_delta=observed_delta,
|
||||
max_l2_delta=max_l2_delta,
|
||||
)
|
||||
|
||||
def _extract_descriptor(self, image_ref: str) -> tuple[float, ...]:
|
||||
encoded = image_ref.encode("utf-8")
|
||||
buckets = [0.0, 0.0, 0.0, 0.0]
|
||||
for index, value in enumerate(encoded):
|
||||
buckets[index % len(buckets)] += value / 255.0
|
||||
magnitude = sqrt(sum(value * value for value in buckets)) or 1.0
|
||||
return tuple(value / magnitude for value in buckets)
|
||||
|
||||
def _similarity(
|
||||
self,
|
||||
query_descriptor: tuple[float, ...],
|
||||
record_descriptor: tuple[float, ...],
|
||||
) -> float:
|
||||
max_length = max(len(query_descriptor), len(record_descriptor))
|
||||
padded_query = query_descriptor + (0.0,) * (max_length - len(query_descriptor))
|
||||
padded_record = record_descriptor + (0.0,) * (max_length - len(record_descriptor))
|
||||
dot_product = sum(
|
||||
query_value * record_value
|
||||
for query_value, record_value in zip(padded_query, padded_record)
|
||||
)
|
||||
query_norm = sqrt(sum(value * value for value in padded_query)) or 1.0
|
||||
record_norm = sqrt(sum(value * value for value in padded_record)) or 1.0
|
||||
return max(0.0, min(1.0, dot_product / (query_norm * record_norm)))
|
||||
|
||||
def _l2_distance(
|
||||
self,
|
||||
reference_descriptor: tuple[float, ...],
|
||||
optimized_descriptor: tuple[float, ...],
|
||||
) -> float:
|
||||
max_length = max(len(reference_descriptor), len(optimized_descriptor))
|
||||
padded_reference = reference_descriptor + (0.0,) * (max_length - len(reference_descriptor))
|
||||
padded_optimized = optimized_descriptor + (0.0,) * (max_length - len(optimized_descriptor))
|
||||
return sqrt(
|
||||
sum(
|
||||
(reference_value - optimized_value) ** 2
|
||||
for reference_value, optimized_value in zip(padded_reference, padded_optimized)
|
||||
)
|
||||
)
|
||||
|
||||
def _error(self, message: str, cause: str) -> ErrorEnvelope:
|
||||
return ErrorEnvelope(
|
||||
component="satellite_service",
|
||||
category="runtime",
|
||||
message=message,
|
||||
severity="warning",
|
||||
retryable=False,
|
||||
cause=cause,
|
||||
)
|
||||
|
||||
|
||||
class SatelliteSyncBoundary:
|
||||
"""Owns pre-flight and post-flight package exchange only."""
|
||||
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, PositiveInt
|
||||
|
||||
from shared.contracts import VprCandidate
|
||||
from shared.errors import ErrorEnvelope
|
||||
from tile_manager import TileManifestEntry
|
||||
|
||||
@@ -45,5 +46,47 @@ class SatelliteSyncResult(SatelliteServiceModel):
|
||||
error: ErrorEnvelope | None = None
|
||||
|
||||
|
||||
class VprDescriptorRecord(SatelliteServiceModel):
|
||||
chunk_id: str = Field(min_length=1)
|
||||
tile_id: str = Field(min_length=1)
|
||||
descriptor: tuple[float, ...] = Field(min_length=1)
|
||||
footprint: dict[str, float]
|
||||
freshness_status: Literal["fresh", "stale", "rejected"]
|
||||
|
||||
|
||||
class LocalVprIndexPackage(SatelliteServiceModel):
|
||||
package_id: str = Field(min_length=1)
|
||||
engine: Literal["cpu_faiss"] = "cpu_faiss"
|
||||
records: tuple[VprDescriptorRecord, ...] = Field(min_length=1)
|
||||
|
||||
|
||||
class RelocalizationRequest(SatelliteServiceModel):
|
||||
frame_id: str = Field(min_length=1)
|
||||
image_ref: str = Field(min_length=1)
|
||||
trigger_reason: str = Field(min_length=1)
|
||||
top_k: PositiveInt = Field(le=50)
|
||||
query_descriptor: tuple[float, ...] | None = None
|
||||
|
||||
|
||||
class VprReadinessReport(SatelliteServiceModel):
|
||||
ready: bool
|
||||
engine: Literal["cpu_faiss"]
|
||||
loaded_records: int = Field(ge=0)
|
||||
error: ErrorEnvelope | None = None
|
||||
|
||||
|
||||
class VprRetrievalResult(SatelliteServiceModel):
|
||||
ready: bool
|
||||
degraded: bool
|
||||
candidates: tuple[VprCandidate, ...] = ()
|
||||
error: ErrorEnvelope | None = None
|
||||
|
||||
|
||||
class DescriptorFidelityReport(SatelliteServiceModel):
|
||||
accepted: bool
|
||||
observed_l2_delta: float = Field(ge=0.0)
|
||||
max_l2_delta: float = Field(ge=0.0)
|
||||
|
||||
|
||||
RuntimePhase = Literal["pre_flight", "in_flight", "post_flight"]
|
||||
UploadOutcome = Literal["success", "retryable_failure", "rejected"]
|
||||
|
||||
Reference in New Issue
Block a user