Files
gps-denied-onboard/c04_cross_view_geolocator.py
T
Denys Zaitsev d7e1066c60 Initial commit
2026-04-03 23:25:54 +03:00

95 lines
4.1 KiB
Python

import torch
import cv2
import numpy as np
import logging
from typing import Tuple, Optional, Dict, Any
from lightglue import LightGlue, SuperPoint
from lightglue.utils import rbd
logger = logging.getLogger(__name__)
class CrossViewGeolocator:
"""
Asynchronous Global Place Recognizer and Fine-Grained Matcher.
Finds absolute metric GPS anchors for unscaled UAV keyframes.
"""
def __init__(self, faiss_manager, device: str = "cuda"):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.faiss_manager = faiss_manager
logger.info("Initializing Global Place Recognition (DINOv2) & Fine Matcher (LightGlue)")
# Global Descriptor Model for fast Faiss retrieval
self.global_encoder = self._load_global_encoder()
# Local feature matcher for metric alignment
self.extractor = SuperPoint(max_num_keypoints=2048).eval().to(self.device)
self.matcher = LightGlue(features='superpoint', depth_confidence=0.9).eval().to(self.device)
# Simulates the local geospatial SQLite cache of pre-downloaded satellite tiles
self.satellite_cache = {}
def _load_global_encoder(self):
"""Loads a Foundation Model (like DINOv2) for viewpoint-invariant descriptors."""
if USE_MOCK_MODELS:
class MockEncoder:
def __call__(self, x):
return torch.randn(1, 384).to(x.device)
return MockEncoder()
else:
return torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(self.device)
def extract_global_descriptor(self, image: np.ndarray) -> np.ndarray:
"""Extracts a 1D vector signature resilient to seasonal/lighting changes."""
img_resized = cv2.resize(image, (224, 224))
tensor = torch.from_numpy(img_resized).float() / 255.0
# Adjust dimensions for PyTorch [B, C, H, W]
if len(tensor.shape) == 3:
tensor = tensor.permute(2, 0, 1).unsqueeze(0)
else:
tensor = tensor.unsqueeze(0).unsqueeze(0).repeat(1, 3, 1, 1)
tensor = tensor.to(self.device)
with torch.no_grad():
desc = self.global_encoder(tensor)
return desc.cpu().numpy()
def retrieve_and_match(self, uav_image: np.ndarray, index) -> Tuple[bool, Optional[np.ndarray], Optional[Dict[str, Any]]]:
"""Searches the Faiss Index and computes the precise 2D-to-2D geodetic alignment."""
# 1. Global Search (Coarse)
global_desc = self.extract_global_descriptor(uav_image)
distances, indices = self.faiss_manager.search(index, global_desc, k=3)
best_transform, best_inliers, best_sat_info = None, 0, None
# 2. Extract UAV features once (Fine)
uav_gray = cv2.cvtColor(uav_image, cv2.COLOR_BGR2GRAY) if len(uav_image.shape) == 3 else uav_image
uav_tensor = torch.from_numpy(uav_gray).float()[None, None, ...].to(self.device) / 255.0
with torch.no_grad():
uav_feats = self.extractor.extract(uav_tensor)
# 3. Fine-grained matching against top-K satellite tiles
for idx in indices[0]:
if idx not in self.satellite_cache: continue
sat_info = self.satellite_cache[idx]
sat_feats = sat_info['features']
with torch.no_grad():
matches = self.matcher({'image0': uav_feats, 'image1': sat_feats})
feats0, feats1, matches01 = [rbd(x) for x in [uav_feats, sat_feats, matches]]
kpts_uav = feats0['keypoints'][matches01['matches'][..., 0]].cpu().numpy()
kpts_sat = feats1['keypoints'][matches01['matches'][..., 1]].cpu().numpy()
if len(kpts_uav) > 15:
H, mask = cv2.findHomography(kpts_uav, kpts_sat, cv2.RANSAC, 5.0)
inliers = mask.sum() if mask is not None else 0
if inliers > best_inliers and inliers > 15:
best_inliers, best_transform, best_sat_info = inliers, H, sat_info
return (best_transform is not None), best_transform, best_sat_info