mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-04-22 08:56:37 +00:00
119 lines
4.8 KiB
Python
119 lines
4.8 KiB
Python
import torch
|
|
import cv2
|
|
import numpy as np
|
|
from typing import Tuple, Optional
|
|
import logging
|
|
|
|
import os
|
|
|
|
USE_MOCK_MODELS = os.environ.get("USE_MOCK_MODELS", "0") == "1"
|
|
|
|
if USE_MOCK_MODELS:
|
|
class SuperPoint(torch.nn.Module):
|
|
def __init__(self, **kwargs): super().__init__()
|
|
def forward(self, x):
|
|
b, _, h, w = x.shape
|
|
kpts = torch.rand(b, 50, 2, device=x.device)
|
|
kpts[..., 0] *= w
|
|
kpts[..., 1] *= h
|
|
return {'keypoints': kpts, 'descriptors': torch.rand(b, 256, 50, device=x.device), 'scores': torch.rand(b, 50, device=x.device)}
|
|
class LightGlue(torch.nn.Module):
|
|
def __init__(self, **kwargs): super().__init__()
|
|
def forward(self, data):
|
|
b = data['image0']['keypoints'].shape[0]
|
|
matches = torch.stack([torch.arange(25), torch.arange(25)], dim=-1).unsqueeze(0).repeat(b, 1, 1).to(data['image0']['keypoints'].device)
|
|
return {'matches': matches, 'matching_scores': torch.rand(b, 25, device=data['image0']['keypoints'].device)}
|
|
def rbd(data):
|
|
return {k: v[0] for k, v in data.items()}
|
|
else:
|
|
# Requires: pip install lightglue
|
|
from lightglue import LightGlue, SuperPoint
|
|
from lightglue.utils import rbd
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class VisualOdometryFrontEnd:
|
|
"""
|
|
Visual Odometry Front-End using SuperPoint and LightGlue.
|
|
Provides robust, unscaled relative frame-to-frame tracking.
|
|
"""
|
|
def __init__(self, device: str = "cuda", resize_max: int = 1536):
|
|
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
|
self.resize_max = resize_max
|
|
|
|
logger.info(f"Initializing V-SLAM Front-End on {self.device}")
|
|
|
|
# Load SuperPoint and LightGlue
|
|
# LightGlue automatically leverages FlashAttention if available for faster inference
|
|
self.extractor = SuperPoint(max_num_keypoints=2048).eval().to(self.device)
|
|
self.matcher = LightGlue(features='superpoint', depth_confidence=0.9).eval().to(self.device)
|
|
|
|
self.last_image_data = None
|
|
self.last_frame_id = -1
|
|
self.camera_matrix = None
|
|
|
|
def set_camera_intrinsics(self, k_matrix: np.ndarray):
|
|
self.camera_matrix = k_matrix
|
|
|
|
def _preprocess_image(self, image: np.ndarray) -> torch.Tensor:
|
|
"""Aggressive downscaling of 6.2K image to LR for sub-5s tracking."""
|
|
h, w = image.shape[:2]
|
|
scale = self.resize_max / max(h, w)
|
|
|
|
if scale < 1.0:
|
|
new_size = (int(w * scale), int(h * scale))
|
|
image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA)
|
|
|
|
# Convert to grayscale if needed
|
|
if len(image.shape) == 3:
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|
|
|
# Convert to float tensor [0, 1]
|
|
tensor = torch.from_numpy(image).float() / 255.0
|
|
return tensor[None, None, ...].to(self.device) # [B, C, H, W]
|
|
|
|
def process_frame(self, frame_id: int, image: np.ndarray) -> Tuple[bool, Optional[np.ndarray]]:
|
|
"""
|
|
Extracts features and matches against the previous frame to compute an unscaled 6-DoF pose.
|
|
"""
|
|
if self.camera_matrix is None:
|
|
logger.error("Camera intrinsics must be set before processing frames.")
|
|
return False, None
|
|
|
|
# 1. Preprocess & Extract Features
|
|
img_tensor = self._preprocess_image(image)
|
|
with torch.no_grad():
|
|
feats = self.extractor.extract(img_tensor)
|
|
|
|
if self.last_image_data is None:
|
|
self.last_image_data = feats
|
|
self.last_frame_id = frame_id
|
|
return True, np.eye(4) # Identity transform for the first frame
|
|
|
|
# 2. Adaptive Matching with LightGlue
|
|
with torch.no_grad():
|
|
matches01 = self.matcher({'image0': self.last_image_data, 'image1': feats})
|
|
|
|
feats0, feats1, matches01 = [rbd(x) for x in [self.last_image_data, feats, matches01]]
|
|
kpts0 = feats0['keypoints'][matches01['matches'][..., 0]].cpu().numpy()
|
|
kpts1 = feats1['keypoints'][matches01['matches'][..., 1]].cpu().numpy()
|
|
|
|
if len(kpts0) < 20:
|
|
logger.warning(f"Not enough matches ({len(kpts0)}) to compute pose for frame {frame_id}.")
|
|
return False, None
|
|
|
|
# 3. Compute Essential Matrix and Relative Pose (Unscaled SE(3))
|
|
E, mask = cv2.findEssentialMat(kpts1, kpts0, self.camera_matrix, method=cv2.RANSAC, prob=0.999, threshold=1.0)
|
|
if E is None or mask.sum() < 15:
|
|
return False, None
|
|
|
|
_, R, t, _ = cv2.recoverPose(E, kpts1, kpts0, self.camera_matrix, mask=mask)
|
|
|
|
transform = np.eye(4)
|
|
transform[:3, :3] = R
|
|
transform[:3, 3] = t.flatten()
|
|
|
|
self.last_image_data = feats
|
|
self.last_frame_id = frame_id
|
|
|
|
return True, transform |