mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-04-23 03:26:38 +00:00
Initial commit
This commit is contained in:
@@ -0,0 +1,95 @@
|
||||
import torch
|
||||
import cv2
|
||||
import numpy as np
|
||||
import logging
|
||||
from typing import Tuple, Optional, Dict, Any
|
||||
|
||||
from lightglue import LightGlue, SuperPoint
|
||||
from lightglue.utils import rbd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CrossViewGeolocator:
|
||||
"""
|
||||
Asynchronous Global Place Recognizer and Fine-Grained Matcher.
|
||||
Finds absolute metric GPS anchors for unscaled UAV keyframes.
|
||||
"""
|
||||
def __init__(self, faiss_manager, device: str = "cuda"):
|
||||
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
self.faiss_manager = faiss_manager
|
||||
|
||||
logger.info("Initializing Global Place Recognition (DINOv2) & Fine Matcher (LightGlue)")
|
||||
|
||||
# Global Descriptor Model for fast Faiss retrieval
|
||||
self.global_encoder = self._load_global_encoder()
|
||||
|
||||
# Local feature matcher for metric alignment
|
||||
self.extractor = SuperPoint(max_num_keypoints=2048).eval().to(self.device)
|
||||
self.matcher = LightGlue(features='superpoint', depth_confidence=0.9).eval().to(self.device)
|
||||
|
||||
# Simulates the local geospatial SQLite cache of pre-downloaded satellite tiles
|
||||
self.satellite_cache = {}
|
||||
|
||||
def _load_global_encoder(self):
|
||||
"""Loads a Foundation Model (like DINOv2) for viewpoint-invariant descriptors."""
|
||||
if USE_MOCK_MODELS:
|
||||
class MockEncoder:
|
||||
def __call__(self, x):
|
||||
return torch.randn(1, 384).to(x.device)
|
||||
return MockEncoder()
|
||||
else:
|
||||
return torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(self.device)
|
||||
|
||||
def extract_global_descriptor(self, image: np.ndarray) -> np.ndarray:
|
||||
"""Extracts a 1D vector signature resilient to seasonal/lighting changes."""
|
||||
img_resized = cv2.resize(image, (224, 224))
|
||||
tensor = torch.from_numpy(img_resized).float() / 255.0
|
||||
# Adjust dimensions for PyTorch [B, C, H, W]
|
||||
if len(tensor.shape) == 3:
|
||||
tensor = tensor.permute(2, 0, 1).unsqueeze(0)
|
||||
else:
|
||||
tensor = tensor.unsqueeze(0).unsqueeze(0).repeat(1, 3, 1, 1)
|
||||
|
||||
tensor = tensor.to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
desc = self.global_encoder(tensor)
|
||||
|
||||
return desc.cpu().numpy()
|
||||
|
||||
def retrieve_and_match(self, uav_image: np.ndarray, index) -> Tuple[bool, Optional[np.ndarray], Optional[Dict[str, Any]]]:
|
||||
"""Searches the Faiss Index and computes the precise 2D-to-2D geodetic alignment."""
|
||||
# 1. Global Search (Coarse)
|
||||
global_desc = self.extract_global_descriptor(uav_image)
|
||||
distances, indices = self.faiss_manager.search(index, global_desc, k=3)
|
||||
|
||||
best_transform, best_inliers, best_sat_info = None, 0, None
|
||||
|
||||
# 2. Extract UAV features once (Fine)
|
||||
uav_gray = cv2.cvtColor(uav_image, cv2.COLOR_BGR2GRAY) if len(uav_image.shape) == 3 else uav_image
|
||||
uav_tensor = torch.from_numpy(uav_gray).float()[None, None, ...].to(self.device) / 255.0
|
||||
|
||||
with torch.no_grad():
|
||||
uav_feats = self.extractor.extract(uav_tensor)
|
||||
|
||||
# 3. Fine-grained matching against top-K satellite tiles
|
||||
for idx in indices[0]:
|
||||
if idx not in self.satellite_cache: continue
|
||||
sat_info = self.satellite_cache[idx]
|
||||
sat_feats = sat_info['features']
|
||||
|
||||
with torch.no_grad():
|
||||
matches = self.matcher({'image0': uav_feats, 'image1': sat_feats})
|
||||
|
||||
feats0, feats1, matches01 = [rbd(x) for x in [uav_feats, sat_feats, matches]]
|
||||
kpts_uav = feats0['keypoints'][matches01['matches'][..., 0]].cpu().numpy()
|
||||
kpts_sat = feats1['keypoints'][matches01['matches'][..., 1]].cpu().numpy()
|
||||
|
||||
if len(kpts_uav) > 15:
|
||||
H, mask = cv2.findHomography(kpts_uav, kpts_sat, cv2.RANSAC, 5.0)
|
||||
inliers = mask.sum() if mask is not None else 0
|
||||
|
||||
if inliers > best_inliers and inliers > 15:
|
||||
best_inliers, best_transform, best_sat_info = inliers, H, sat_info
|
||||
|
||||
return (best_transform is not None), best_transform, best_sat_info
|
||||
Reference in New Issue
Block a user