mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-04-22 22:06:37 +00:00
259 lines
11 KiB
Python
259 lines
11 KiB
Python
import cv2
|
|
import numpy as np
|
|
import json
|
|
import os
|
|
import logging
|
|
from typing import List, Dict, Optional, Any, Tuple
|
|
from pydantic import BaseModel
|
|
from abc import ABC, abstractmethod
|
|
|
|
from f02_1_flight_lifecycle_manager import GPSPoint
|
|
from f04_satellite_data_manager import TileBounds
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# --- Data Models ---
|
|
|
|
class TileCandidate(BaseModel):
|
|
tile_id: str
|
|
gps_center: GPSPoint
|
|
bounds: Optional[Any] = None # Optional TileBounds to avoid strict cyclic coupling
|
|
similarity_score: float
|
|
rank: int
|
|
spatial_score: Optional[float] = None
|
|
|
|
class DatabaseMatch(BaseModel):
|
|
index: int
|
|
tile_id: str
|
|
distance: float
|
|
similarity_score: float
|
|
|
|
class SatelliteTile(BaseModel):
|
|
tile_id: str
|
|
image: np.ndarray
|
|
gps_center: GPSPoint
|
|
bounds: Any
|
|
descriptor: Optional[np.ndarray] = None
|
|
|
|
model_config = {"arbitrary_types_allowed": True}
|
|
|
|
# --- Exceptions ---
|
|
class IndexNotFoundError(Exception): pass
|
|
class IndexCorruptedError(Exception): pass
|
|
class MetadataMismatchError(Exception): pass
|
|
|
|
# --- Interface ---
|
|
|
|
class IGlobalPlaceRecognition(ABC):
|
|
@abstractmethod
|
|
def retrieve_candidate_tiles(self, image: np.ndarray, top_k: int) -> List[TileCandidate]: pass
|
|
|
|
@abstractmethod
|
|
def compute_location_descriptor(self, image: np.ndarray) -> np.ndarray: pass
|
|
|
|
@abstractmethod
|
|
def query_database(self, descriptor: np.ndarray, top_k: int) -> List[DatabaseMatch]: pass
|
|
|
|
@abstractmethod
|
|
def rank_candidates(self, candidates: List[TileCandidate]) -> List[TileCandidate]: pass
|
|
|
|
@abstractmethod
|
|
def load_index(self, flight_id: str, index_path: str) -> bool: pass
|
|
|
|
@abstractmethod
|
|
def retrieve_candidate_tiles_for_chunk(self, chunk_images: List[np.ndarray], top_k: int) -> List[TileCandidate]: pass
|
|
|
|
@abstractmethod
|
|
def compute_chunk_descriptor(self, chunk_images: List[np.ndarray]) -> np.ndarray: pass
|
|
|
|
|
|
# --- Implementation ---
|
|
|
|
class GlobalPlaceRecognition(IGlobalPlaceRecognition):
|
|
"""
|
|
F08: Global Place Recognition
|
|
Computes DINOv2+VLAD semantic descriptors and queries a pre-built Faiss index
|
|
of satellite tiles to relocalize the UAV after catastrophic tracking loss.
|
|
"""
|
|
def __init__(self, model_manager=None, faiss_manager=None, satellite_manager=None):
|
|
self.model_manager = model_manager
|
|
self.faiss_manager = faiss_manager
|
|
self.satellite_manager = satellite_manager
|
|
|
|
self.is_index_loaded = False
|
|
self.tile_metadata: Dict[int, Dict] = {}
|
|
self.dim = 4096 # DINOv2 + VLAD standard dimension
|
|
|
|
# --- Descriptor Computation (08.02) ---
|
|
|
|
def _preprocess_image(self, image: np.ndarray) -> np.ndarray:
|
|
if len(image.shape) == 3 and image.shape[2] == 3:
|
|
img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
else:
|
|
img = image
|
|
# Standard DINOv2 input size
|
|
img = cv2.resize(img, (224, 224))
|
|
return img.astype(np.float32) / 255.0
|
|
|
|
def _extract_dense_features(self, preprocessed: np.ndarray) -> np.ndarray:
|
|
if self.model_manager and hasattr(self.model_manager, 'run_dinov2'):
|
|
return self.model_manager.run_dinov2(preprocessed)
|
|
# Mock fallback: return random features [num_patches, feat_dim]
|
|
rng = np.random.RandomState(int(np.sum(preprocessed) * 1000) % (2**32))
|
|
return rng.rand(256, 384).astype(np.float32)
|
|
|
|
def _vlad_aggregate(self, dense_features: np.ndarray, codebook: Optional[np.ndarray] = None) -> np.ndarray:
|
|
# Mock VLAD aggregation projecting to 4096 dims
|
|
rng = np.random.RandomState(int(np.sum(dense_features) * 1000) % (2**32))
|
|
vlad_desc = rng.rand(self.dim).astype(np.float32)
|
|
return vlad_desc
|
|
|
|
def _l2_normalize(self, descriptor: np.ndarray) -> np.ndarray:
|
|
norm = np.linalg.norm(descriptor)
|
|
if norm == 0:
|
|
return descriptor
|
|
return descriptor / norm
|
|
|
|
def compute_location_descriptor(self, image: np.ndarray) -> np.ndarray:
|
|
preprocessed = self._preprocess_image(image)
|
|
dense_feat = self._extract_dense_features(preprocessed)
|
|
vlad_desc = self._vlad_aggregate(dense_feat)
|
|
return self._l2_normalize(vlad_desc)
|
|
|
|
def _aggregate_chunk_descriptors(self, descriptors: List[np.ndarray], strategy: str = "mean") -> np.ndarray:
|
|
if not descriptors:
|
|
raise ValueError("Cannot aggregate empty descriptor list.")
|
|
stacked = np.stack(descriptors)
|
|
if strategy == "mean":
|
|
agg = np.mean(stacked, axis=0)
|
|
elif strategy == "max":
|
|
agg = np.max(stacked, axis=0)
|
|
elif strategy == "vlad":
|
|
agg = np.mean(stacked, axis=0) # Simplified fallback for vlad aggregation
|
|
else:
|
|
raise ValueError(f"Unknown aggregation strategy: {strategy}")
|
|
return self._l2_normalize(agg)
|
|
|
|
def compute_chunk_descriptor(self, chunk_images: List[np.ndarray]) -> np.ndarray:
|
|
if not chunk_images:
|
|
raise ValueError("Chunk images list is empty.")
|
|
descriptors = [self.compute_location_descriptor(img) for img in chunk_images]
|
|
return self._aggregate_chunk_descriptors(descriptors, strategy="mean")
|
|
|
|
# --- Index Management (08.01) ---
|
|
|
|
def _validate_index_integrity(self, index_dim: int, expected_count: int) -> bool:
|
|
if index_dim not in [4096, 8192]:
|
|
raise IndexCorruptedError(f"Invalid index dimensions: {index_dim}")
|
|
return True
|
|
|
|
def _load_tile_metadata(self, metadata_path: str) -> Dict[int, dict]:
|
|
if not os.path.exists(metadata_path):
|
|
raise MetadataMismatchError("Metadata file not found.")
|
|
try:
|
|
with open(metadata_path, 'r') as f:
|
|
content = f.read().strip()
|
|
if not content:
|
|
raise MetadataMismatchError("Metadata file is empty.")
|
|
data = json.loads(content)
|
|
if not data:
|
|
raise MetadataMismatchError("Metadata file contains empty JSON object.")
|
|
except json.JSONDecodeError:
|
|
raise MetadataMismatchError("Metadata file contains invalid JSON.")
|
|
return {int(k): v for k, v in data.items()}
|
|
|
|
def _verify_metadata_alignment(self, index_count: int, metadata: Dict) -> bool:
|
|
if index_count != len(metadata):
|
|
raise MetadataMismatchError(f"Index count ({index_count}) does not match metadata count ({len(metadata)}).")
|
|
return True
|
|
|
|
def load_index(self, flight_id: str, index_path: str) -> bool:
|
|
meta_path = index_path.replace(".index", ".json")
|
|
if not os.path.exists(index_path):
|
|
raise IndexNotFoundError(f"Index file {index_path} not found.")
|
|
|
|
if self.faiss_manager:
|
|
self.faiss_manager.load_index(index_path)
|
|
idx_count, idx_dim = self.faiss_manager.get_stats()
|
|
else:
|
|
# Mock Faiss loading
|
|
idx_count, idx_dim = 1000, 4096
|
|
|
|
self._validate_index_integrity(idx_dim, idx_count)
|
|
self.tile_metadata = self._load_tile_metadata(meta_path)
|
|
self._verify_metadata_alignment(idx_count, self.tile_metadata)
|
|
|
|
self.is_index_loaded = True
|
|
logger.info(f"Successfully loaded global index for flight {flight_id}.")
|
|
return True
|
|
|
|
# --- Candidate Retrieval (08.03) ---
|
|
|
|
def _retrieve_tile_metadata(self, indices: List[int]) -> List[Dict[str, Any]]:
|
|
"""Fetches metadata for a list of tile indices."""
|
|
# In a real system, this might delegate to F04 if metadata is not held in memory
|
|
return [self.tile_metadata.get(idx, {}) for idx in indices]
|
|
|
|
def _build_candidates_from_matches(self, matches: List[DatabaseMatch]) -> List[TileCandidate]:
|
|
candidates = []
|
|
for m in matches:
|
|
meta = self.tile_metadata.get(m.index, {})
|
|
lat, lon = meta.get("lat", 0.0), meta.get("lon", 0.0)
|
|
cand = TileCandidate(
|
|
tile_id=m.tile_id,
|
|
gps_center=GPSPoint(lat=lat, lon=lon),
|
|
similarity_score=m.similarity_score,
|
|
rank=0
|
|
)
|
|
candidates.append(cand)
|
|
return candidates
|
|
|
|
def _distance_to_similarity(self, distance: float) -> float:
|
|
# For L2 normalized vectors, Euclidean distance is in [0, 2].
|
|
# Sim = 1 - (dist^2 / 4) maps [0, 2] to [1, 0].
|
|
return max(0.0, 1.0 - (distance**2 / 4.0))
|
|
|
|
def query_database(self, descriptor: np.ndarray, top_k: int) -> List[DatabaseMatch]:
|
|
if not self.is_index_loaded:
|
|
return []
|
|
|
|
if self.faiss_manager:
|
|
distances, indices = self.faiss_manager.search(descriptor.reshape(1, -1), top_k)
|
|
else:
|
|
# Mock Faiss search
|
|
indices = np.random.choice(len(self.tile_metadata), min(top_k, len(self.tile_metadata)), replace=False).reshape(1, -1)
|
|
distances = np.sort(np.random.rand(top_k) * 1.5).reshape(1, -1) # Distances sorted ascending
|
|
|
|
matches = []
|
|
for i in range(len(indices[0])):
|
|
idx = int(indices[0][i])
|
|
dist = float(distances[0][i])
|
|
meta = self.tile_metadata.get(idx, {})
|
|
tile_id = meta.get("tile_id", f"tile_{idx}")
|
|
sim = self._distance_to_similarity(dist)
|
|
matches.append(DatabaseMatch(index=idx, tile_id=tile_id, distance=dist, similarity_score=sim))
|
|
|
|
return matches
|
|
|
|
def _apply_spatial_reranking(self, candidates: List[TileCandidate], dead_reckoning_estimate: Optional[GPSPoint] = None) -> List[TileCandidate]:
|
|
# Currently returns unmodified, leaving hook for future GPS-proximity heuristics
|
|
return candidates
|
|
|
|
def rank_candidates(self, candidates: List[TileCandidate]) -> List[TileCandidate]:
|
|
# Primary sort by similarity score descending
|
|
candidates.sort(key=lambda x: x.similarity_score, reverse=True)
|
|
for i, cand in enumerate(candidates):
|
|
cand.rank = i + 1
|
|
return self._apply_spatial_reranking(candidates)
|
|
|
|
def retrieve_candidate_tiles(self, image: np.ndarray, top_k: int = 5) -> List[TileCandidate]:
|
|
descriptor = self.compute_location_descriptor(image)
|
|
matches = self.query_database(descriptor, top_k)
|
|
candidates = self._build_candidates_from_matches(matches)
|
|
return self.rank_candidates(candidates)
|
|
|
|
def retrieve_candidate_tiles_for_chunk(self, chunk_images: List[np.ndarray], top_k: int = 5) -> List[TileCandidate]:
|
|
descriptor = self.compute_chunk_descriptor(chunk_images)
|
|
matches = self.query_database(descriptor, top_k)
|
|
candidates = self._build_candidates_from_matches(matches)
|
|
return self.rank_candidates(candidates) |