mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-04-23 03:26:38 +00:00
Initial commit
This commit is contained in:
@@ -0,0 +1,259 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from pydantic import BaseModel
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from f02_1_flight_lifecycle_manager import GPSPoint
|
||||
from f04_satellite_data_manager import TileBounds
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Data Models ---
|
||||
|
||||
class TileCandidate(BaseModel):
|
||||
tile_id: str
|
||||
gps_center: GPSPoint
|
||||
bounds: Optional[Any] = None # Optional TileBounds to avoid strict cyclic coupling
|
||||
similarity_score: float
|
||||
rank: int
|
||||
spatial_score: Optional[float] = None
|
||||
|
||||
class DatabaseMatch(BaseModel):
|
||||
index: int
|
||||
tile_id: str
|
||||
distance: float
|
||||
similarity_score: float
|
||||
|
||||
class SatelliteTile(BaseModel):
|
||||
tile_id: str
|
||||
image: np.ndarray
|
||||
gps_center: GPSPoint
|
||||
bounds: Any
|
||||
descriptor: Optional[np.ndarray] = None
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
# --- Exceptions ---
|
||||
class IndexNotFoundError(Exception): pass
|
||||
class IndexCorruptedError(Exception): pass
|
||||
class MetadataMismatchError(Exception): pass
|
||||
|
||||
# --- Interface ---
|
||||
|
||||
class IGlobalPlaceRecognition(ABC):
|
||||
@abstractmethod
|
||||
def retrieve_candidate_tiles(self, image: np.ndarray, top_k: int) -> List[TileCandidate]: pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_location_descriptor(self, image: np.ndarray) -> np.ndarray: pass
|
||||
|
||||
@abstractmethod
|
||||
def query_database(self, descriptor: np.ndarray, top_k: int) -> List[DatabaseMatch]: pass
|
||||
|
||||
@abstractmethod
|
||||
def rank_candidates(self, candidates: List[TileCandidate]) -> List[TileCandidate]: pass
|
||||
|
||||
@abstractmethod
|
||||
def load_index(self, flight_id: str, index_path: str) -> bool: pass
|
||||
|
||||
@abstractmethod
|
||||
def retrieve_candidate_tiles_for_chunk(self, chunk_images: List[np.ndarray], top_k: int) -> List[TileCandidate]: pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_chunk_descriptor(self, chunk_images: List[np.ndarray]) -> np.ndarray: pass
|
||||
|
||||
|
||||
# --- Implementation ---
|
||||
|
||||
class GlobalPlaceRecognition(IGlobalPlaceRecognition):
|
||||
"""
|
||||
F08: Global Place Recognition
|
||||
Computes DINOv2+VLAD semantic descriptors and queries a pre-built Faiss index
|
||||
of satellite tiles to relocalize the UAV after catastrophic tracking loss.
|
||||
"""
|
||||
def __init__(self, model_manager=None, faiss_manager=None, satellite_manager=None):
|
||||
self.model_manager = model_manager
|
||||
self.faiss_manager = faiss_manager
|
||||
self.satellite_manager = satellite_manager
|
||||
|
||||
self.is_index_loaded = False
|
||||
self.tile_metadata: Dict[int, Dict] = {}
|
||||
self.dim = 4096 # DINOv2 + VLAD standard dimension
|
||||
|
||||
# --- Descriptor Computation (08.02) ---
|
||||
|
||||
def _preprocess_image(self, image: np.ndarray) -> np.ndarray:
|
||||
if len(image.shape) == 3 and image.shape[2] == 3:
|
||||
img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
else:
|
||||
img = image
|
||||
# Standard DINOv2 input size
|
||||
img = cv2.resize(img, (224, 224))
|
||||
return img.astype(np.float32) / 255.0
|
||||
|
||||
def _extract_dense_features(self, preprocessed: np.ndarray) -> np.ndarray:
|
||||
if self.model_manager and hasattr(self.model_manager, 'run_dinov2'):
|
||||
return self.model_manager.run_dinov2(preprocessed)
|
||||
# Mock fallback: return random features [num_patches, feat_dim]
|
||||
rng = np.random.RandomState(int(np.sum(preprocessed) * 1000) % (2**32))
|
||||
return rng.rand(256, 384).astype(np.float32)
|
||||
|
||||
def _vlad_aggregate(self, dense_features: np.ndarray, codebook: Optional[np.ndarray] = None) -> np.ndarray:
|
||||
# Mock VLAD aggregation projecting to 4096 dims
|
||||
rng = np.random.RandomState(int(np.sum(dense_features) * 1000) % (2**32))
|
||||
vlad_desc = rng.rand(self.dim).astype(np.float32)
|
||||
return vlad_desc
|
||||
|
||||
def _l2_normalize(self, descriptor: np.ndarray) -> np.ndarray:
|
||||
norm = np.linalg.norm(descriptor)
|
||||
if norm == 0:
|
||||
return descriptor
|
||||
return descriptor / norm
|
||||
|
||||
def compute_location_descriptor(self, image: np.ndarray) -> np.ndarray:
|
||||
preprocessed = self._preprocess_image(image)
|
||||
dense_feat = self._extract_dense_features(preprocessed)
|
||||
vlad_desc = self._vlad_aggregate(dense_feat)
|
||||
return self._l2_normalize(vlad_desc)
|
||||
|
||||
def _aggregate_chunk_descriptors(self, descriptors: List[np.ndarray], strategy: str = "mean") -> np.ndarray:
|
||||
if not descriptors:
|
||||
raise ValueError("Cannot aggregate empty descriptor list.")
|
||||
stacked = np.stack(descriptors)
|
||||
if strategy == "mean":
|
||||
agg = np.mean(stacked, axis=0)
|
||||
elif strategy == "max":
|
||||
agg = np.max(stacked, axis=0)
|
||||
elif strategy == "vlad":
|
||||
agg = np.mean(stacked, axis=0) # Simplified fallback for vlad aggregation
|
||||
else:
|
||||
raise ValueError(f"Unknown aggregation strategy: {strategy}")
|
||||
return self._l2_normalize(agg)
|
||||
|
||||
def compute_chunk_descriptor(self, chunk_images: List[np.ndarray]) -> np.ndarray:
|
||||
if not chunk_images:
|
||||
raise ValueError("Chunk images list is empty.")
|
||||
descriptors = [self.compute_location_descriptor(img) for img in chunk_images]
|
||||
return self._aggregate_chunk_descriptors(descriptors, strategy="mean")
|
||||
|
||||
# --- Index Management (08.01) ---
|
||||
|
||||
def _validate_index_integrity(self, index_dim: int, expected_count: int) -> bool:
|
||||
if index_dim not in [4096, 8192]:
|
||||
raise IndexCorruptedError(f"Invalid index dimensions: {index_dim}")
|
||||
return True
|
||||
|
||||
def _load_tile_metadata(self, metadata_path: str) -> Dict[int, dict]:
|
||||
if not os.path.exists(metadata_path):
|
||||
raise MetadataMismatchError("Metadata file not found.")
|
||||
try:
|
||||
with open(metadata_path, 'r') as f:
|
||||
content = f.read().strip()
|
||||
if not content:
|
||||
raise MetadataMismatchError("Metadata file is empty.")
|
||||
data = json.loads(content)
|
||||
if not data:
|
||||
raise MetadataMismatchError("Metadata file contains empty JSON object.")
|
||||
except json.JSONDecodeError:
|
||||
raise MetadataMismatchError("Metadata file contains invalid JSON.")
|
||||
return {int(k): v for k, v in data.items()}
|
||||
|
||||
def _verify_metadata_alignment(self, index_count: int, metadata: Dict) -> bool:
|
||||
if index_count != len(metadata):
|
||||
raise MetadataMismatchError(f"Index count ({index_count}) does not match metadata count ({len(metadata)}).")
|
||||
return True
|
||||
|
||||
def load_index(self, flight_id: str, index_path: str) -> bool:
|
||||
meta_path = index_path.replace(".index", ".json")
|
||||
if not os.path.exists(index_path):
|
||||
raise IndexNotFoundError(f"Index file {index_path} not found.")
|
||||
|
||||
if self.faiss_manager:
|
||||
self.faiss_manager.load_index(index_path)
|
||||
idx_count, idx_dim = self.faiss_manager.get_stats()
|
||||
else:
|
||||
# Mock Faiss loading
|
||||
idx_count, idx_dim = 1000, 4096
|
||||
|
||||
self._validate_index_integrity(idx_dim, idx_count)
|
||||
self.tile_metadata = self._load_tile_metadata(meta_path)
|
||||
self._verify_metadata_alignment(idx_count, self.tile_metadata)
|
||||
|
||||
self.is_index_loaded = True
|
||||
logger.info(f"Successfully loaded global index for flight {flight_id}.")
|
||||
return True
|
||||
|
||||
# --- Candidate Retrieval (08.03) ---
|
||||
|
||||
def _retrieve_tile_metadata(self, indices: List[int]) -> List[Dict[str, Any]]:
|
||||
"""Fetches metadata for a list of tile indices."""
|
||||
# In a real system, this might delegate to F04 if metadata is not held in memory
|
||||
return [self.tile_metadata.get(idx, {}) for idx in indices]
|
||||
|
||||
def _build_candidates_from_matches(self, matches: List[DatabaseMatch]) -> List[TileCandidate]:
|
||||
candidates = []
|
||||
for m in matches:
|
||||
meta = self.tile_metadata.get(m.index, {})
|
||||
lat, lon = meta.get("lat", 0.0), meta.get("lon", 0.0)
|
||||
cand = TileCandidate(
|
||||
tile_id=m.tile_id,
|
||||
gps_center=GPSPoint(lat=lat, lon=lon),
|
||||
similarity_score=m.similarity_score,
|
||||
rank=0
|
||||
)
|
||||
candidates.append(cand)
|
||||
return candidates
|
||||
|
||||
def _distance_to_similarity(self, distance: float) -> float:
|
||||
# For L2 normalized vectors, Euclidean distance is in [0, 2].
|
||||
# Sim = 1 - (dist^2 / 4) maps [0, 2] to [1, 0].
|
||||
return max(0.0, 1.0 - (distance**2 / 4.0))
|
||||
|
||||
def query_database(self, descriptor: np.ndarray, top_k: int) -> List[DatabaseMatch]:
|
||||
if not self.is_index_loaded:
|
||||
return []
|
||||
|
||||
if self.faiss_manager:
|
||||
distances, indices = self.faiss_manager.search(descriptor.reshape(1, -1), top_k)
|
||||
else:
|
||||
# Mock Faiss search
|
||||
indices = np.random.choice(len(self.tile_metadata), min(top_k, len(self.tile_metadata)), replace=False).reshape(1, -1)
|
||||
distances = np.sort(np.random.rand(top_k) * 1.5).reshape(1, -1) # Distances sorted ascending
|
||||
|
||||
matches = []
|
||||
for i in range(len(indices[0])):
|
||||
idx = int(indices[0][i])
|
||||
dist = float(distances[0][i])
|
||||
meta = self.tile_metadata.get(idx, {})
|
||||
tile_id = meta.get("tile_id", f"tile_{idx}")
|
||||
sim = self._distance_to_similarity(dist)
|
||||
matches.append(DatabaseMatch(index=idx, tile_id=tile_id, distance=dist, similarity_score=sim))
|
||||
|
||||
return matches
|
||||
|
||||
def _apply_spatial_reranking(self, candidates: List[TileCandidate], dead_reckoning_estimate: Optional[GPSPoint] = None) -> List[TileCandidate]:
|
||||
# Currently returns unmodified, leaving hook for future GPS-proximity heuristics
|
||||
return candidates
|
||||
|
||||
def rank_candidates(self, candidates: List[TileCandidate]) -> List[TileCandidate]:
|
||||
# Primary sort by similarity score descending
|
||||
candidates.sort(key=lambda x: x.similarity_score, reverse=True)
|
||||
for i, cand in enumerate(candidates):
|
||||
cand.rank = i + 1
|
||||
return self._apply_spatial_reranking(candidates)
|
||||
|
||||
def retrieve_candidate_tiles(self, image: np.ndarray, top_k: int = 5) -> List[TileCandidate]:
|
||||
descriptor = self.compute_location_descriptor(image)
|
||||
matches = self.query_database(descriptor, top_k)
|
||||
candidates = self._build_candidates_from_matches(matches)
|
||||
return self.rank_candidates(candidates)
|
||||
|
||||
def retrieve_candidate_tiles_for_chunk(self, chunk_images: List[np.ndarray], top_k: int = 5) -> List[TileCandidate]:
|
||||
descriptor = self.compute_chunk_descriptor(chunk_images)
|
||||
matches = self.query_database(descriptor, top_k)
|
||||
candidates = self._build_candidates_from_matches(matches)
|
||||
return self.rank_candidates(candidates)
|
||||
Reference in New Issue
Block a user