Initial commit

This commit is contained in:
Denys Zaitsev
2026-04-03 23:25:54 +03:00
parent 531a1301d5
commit d7e1066c60
3843 changed files with 1554468 additions and 0 deletions
+95
View File
@@ -0,0 +1,95 @@
import torch
import cv2
import numpy as np
import logging
from typing import Tuple, Optional, Dict, Any
from lightglue import LightGlue, SuperPoint
from lightglue.utils import rbd
logger = logging.getLogger(__name__)
class CrossViewGeolocator:
"""
Asynchronous Global Place Recognizer and Fine-Grained Matcher.
Finds absolute metric GPS anchors for unscaled UAV keyframes.
"""
def __init__(self, faiss_manager, device: str = "cuda"):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.faiss_manager = faiss_manager
logger.info("Initializing Global Place Recognition (DINOv2) & Fine Matcher (LightGlue)")
# Global Descriptor Model for fast Faiss retrieval
self.global_encoder = self._load_global_encoder()
# Local feature matcher for metric alignment
self.extractor = SuperPoint(max_num_keypoints=2048).eval().to(self.device)
self.matcher = LightGlue(features='superpoint', depth_confidence=0.9).eval().to(self.device)
# Simulates the local geospatial SQLite cache of pre-downloaded satellite tiles
self.satellite_cache = {}
def _load_global_encoder(self):
"""Loads a Foundation Model (like DINOv2) for viewpoint-invariant descriptors."""
if USE_MOCK_MODELS:
class MockEncoder:
def __call__(self, x):
return torch.randn(1, 384).to(x.device)
return MockEncoder()
else:
return torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(self.device)
def extract_global_descriptor(self, image: np.ndarray) -> np.ndarray:
"""Extracts a 1D vector signature resilient to seasonal/lighting changes."""
img_resized = cv2.resize(image, (224, 224))
tensor = torch.from_numpy(img_resized).float() / 255.0
# Adjust dimensions for PyTorch [B, C, H, W]
if len(tensor.shape) == 3:
tensor = tensor.permute(2, 0, 1).unsqueeze(0)
else:
tensor = tensor.unsqueeze(0).unsqueeze(0).repeat(1, 3, 1, 1)
tensor = tensor.to(self.device)
with torch.no_grad():
desc = self.global_encoder(tensor)
return desc.cpu().numpy()
def retrieve_and_match(self, uav_image: np.ndarray, index) -> Tuple[bool, Optional[np.ndarray], Optional[Dict[str, Any]]]:
"""Searches the Faiss Index and computes the precise 2D-to-2D geodetic alignment."""
# 1. Global Search (Coarse)
global_desc = self.extract_global_descriptor(uav_image)
distances, indices = self.faiss_manager.search(index, global_desc, k=3)
best_transform, best_inliers, best_sat_info = None, 0, None
# 2. Extract UAV features once (Fine)
uav_gray = cv2.cvtColor(uav_image, cv2.COLOR_BGR2GRAY) if len(uav_image.shape) == 3 else uav_image
uav_tensor = torch.from_numpy(uav_gray).float()[None, None, ...].to(self.device) / 255.0
with torch.no_grad():
uav_feats = self.extractor.extract(uav_tensor)
# 3. Fine-grained matching against top-K satellite tiles
for idx in indices[0]:
if idx not in self.satellite_cache: continue
sat_info = self.satellite_cache[idx]
sat_feats = sat_info['features']
with torch.no_grad():
matches = self.matcher({'image0': uav_feats, 'image1': sat_feats})
feats0, feats1, matches01 = [rbd(x) for x in [uav_feats, sat_feats, matches]]
kpts_uav = feats0['keypoints'][matches01['matches'][..., 0]].cpu().numpy()
kpts_sat = feats1['keypoints'][matches01['matches'][..., 1]].cpu().numpy()
if len(kpts_uav) > 15:
H, mask = cv2.findHomography(kpts_uav, kpts_sat, cv2.RANSAC, 5.0)
inliers = mask.sum() if mask is not None else 0
if inliers > best_inliers and inliers > 15:
best_inliers, best_transform, best_sat_info = inliers, H, sat_info
return (best_transform is not None), best_transform, best_sat_info