mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-04-23 01:16:38 +00:00
Initial commit
This commit is contained in:
@@ -0,0 +1,119 @@
|
||||
import torch
|
||||
import cv2
|
||||
import numpy as np
|
||||
from typing import Tuple, Optional
|
||||
import logging
|
||||
|
||||
import os
|
||||
|
||||
USE_MOCK_MODELS = os.environ.get("USE_MOCK_MODELS", "0") == "1"
|
||||
|
||||
if USE_MOCK_MODELS:
|
||||
class SuperPoint(torch.nn.Module):
|
||||
def __init__(self, **kwargs): super().__init__()
|
||||
def forward(self, x):
|
||||
b, _, h, w = x.shape
|
||||
kpts = torch.rand(b, 50, 2, device=x.device)
|
||||
kpts[..., 0] *= w
|
||||
kpts[..., 1] *= h
|
||||
return {'keypoints': kpts, 'descriptors': torch.rand(b, 256, 50, device=x.device), 'scores': torch.rand(b, 50, device=x.device)}
|
||||
class LightGlue(torch.nn.Module):
|
||||
def __init__(self, **kwargs): super().__init__()
|
||||
def forward(self, data):
|
||||
b = data['image0']['keypoints'].shape[0]
|
||||
matches = torch.stack([torch.arange(25), torch.arange(25)], dim=-1).unsqueeze(0).repeat(b, 1, 1).to(data['image0']['keypoints'].device)
|
||||
return {'matches': matches, 'matching_scores': torch.rand(b, 25, device=data['image0']['keypoints'].device)}
|
||||
def rbd(data):
|
||||
return {k: v[0] for k, v in data.items()}
|
||||
else:
|
||||
# Requires: pip install lightglue
|
||||
from lightglue import LightGlue, SuperPoint
|
||||
from lightglue.utils import rbd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class VisualOdometryFrontEnd:
|
||||
"""
|
||||
Visual Odometry Front-End using SuperPoint and LightGlue.
|
||||
Provides robust, unscaled relative frame-to-frame tracking.
|
||||
"""
|
||||
def __init__(self, device: str = "cuda", resize_max: int = 1536):
|
||||
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
self.resize_max = resize_max
|
||||
|
||||
logger.info(f"Initializing V-SLAM Front-End on {self.device}")
|
||||
|
||||
# Load SuperPoint and LightGlue
|
||||
# LightGlue automatically leverages FlashAttention if available for faster inference
|
||||
self.extractor = SuperPoint(max_num_keypoints=2048).eval().to(self.device)
|
||||
self.matcher = LightGlue(features='superpoint', depth_confidence=0.9).eval().to(self.device)
|
||||
|
||||
self.last_image_data = None
|
||||
self.last_frame_id = -1
|
||||
self.camera_matrix = None
|
||||
|
||||
def set_camera_intrinsics(self, k_matrix: np.ndarray):
|
||||
self.camera_matrix = k_matrix
|
||||
|
||||
def _preprocess_image(self, image: np.ndarray) -> torch.Tensor:
|
||||
"""Aggressive downscaling of 6.2K image to LR for sub-5s tracking."""
|
||||
h, w = image.shape[:2]
|
||||
scale = self.resize_max / max(h, w)
|
||||
|
||||
if scale < 1.0:
|
||||
new_size = (int(w * scale), int(h * scale))
|
||||
image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA)
|
||||
|
||||
# Convert to grayscale if needed
|
||||
if len(image.shape) == 3:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Convert to float tensor [0, 1]
|
||||
tensor = torch.from_numpy(image).float() / 255.0
|
||||
return tensor[None, None, ...].to(self.device) # [B, C, H, W]
|
||||
|
||||
def process_frame(self, frame_id: int, image: np.ndarray) -> Tuple[bool, Optional[np.ndarray]]:
|
||||
"""
|
||||
Extracts features and matches against the previous frame to compute an unscaled 6-DoF pose.
|
||||
"""
|
||||
if self.camera_matrix is None:
|
||||
logger.error("Camera intrinsics must be set before processing frames.")
|
||||
return False, None
|
||||
|
||||
# 1. Preprocess & Extract Features
|
||||
img_tensor = self._preprocess_image(image)
|
||||
with torch.no_grad():
|
||||
feats = self.extractor.extract(img_tensor)
|
||||
|
||||
if self.last_image_data is None:
|
||||
self.last_image_data = feats
|
||||
self.last_frame_id = frame_id
|
||||
return True, np.eye(4) # Identity transform for the first frame
|
||||
|
||||
# 2. Adaptive Matching with LightGlue
|
||||
with torch.no_grad():
|
||||
matches01 = self.matcher({'image0': self.last_image_data, 'image1': feats})
|
||||
|
||||
feats0, feats1, matches01 = [rbd(x) for x in [self.last_image_data, feats, matches01]]
|
||||
kpts0 = feats0['keypoints'][matches01['matches'][..., 0]].cpu().numpy()
|
||||
kpts1 = feats1['keypoints'][matches01['matches'][..., 1]].cpu().numpy()
|
||||
|
||||
if len(kpts0) < 20:
|
||||
logger.warning(f"Not enough matches ({len(kpts0)}) to compute pose for frame {frame_id}.")
|
||||
return False, None
|
||||
|
||||
# 3. Compute Essential Matrix and Relative Pose (Unscaled SE(3))
|
||||
E, mask = cv2.findEssentialMat(kpts1, kpts0, self.camera_matrix, method=cv2.RANSAC, prob=0.999, threshold=1.0)
|
||||
if E is None or mask.sum() < 15:
|
||||
return False, None
|
||||
|
||||
_, R, t, _ = cv2.recoverPose(E, kpts1, kpts0, self.camera_matrix, mask=mask)
|
||||
|
||||
transform = np.eye(4)
|
||||
transform[:3, :3] = R
|
||||
transform[:3, 3] = t.flatten()
|
||||
|
||||
self.last_image_data = feats
|
||||
self.last_frame_id = frame_id
|
||||
|
||||
return True, transform
|
||||
Reference in New Issue
Block a user