Files
gps-denied-onboard/c03_visual_odometry.py
T
Denys Zaitsev d7e1066c60 Initial commit
2026-04-03 23:25:54 +03:00

119 lines
4.8 KiB
Python

import torch
import cv2
import numpy as np
from typing import Tuple, Optional
import logging
import os
USE_MOCK_MODELS = os.environ.get("USE_MOCK_MODELS", "0") == "1"
if USE_MOCK_MODELS:
class SuperPoint(torch.nn.Module):
def __init__(self, **kwargs): super().__init__()
def forward(self, x):
b, _, h, w = x.shape
kpts = torch.rand(b, 50, 2, device=x.device)
kpts[..., 0] *= w
kpts[..., 1] *= h
return {'keypoints': kpts, 'descriptors': torch.rand(b, 256, 50, device=x.device), 'scores': torch.rand(b, 50, device=x.device)}
class LightGlue(torch.nn.Module):
def __init__(self, **kwargs): super().__init__()
def forward(self, data):
b = data['image0']['keypoints'].shape[0]
matches = torch.stack([torch.arange(25), torch.arange(25)], dim=-1).unsqueeze(0).repeat(b, 1, 1).to(data['image0']['keypoints'].device)
return {'matches': matches, 'matching_scores': torch.rand(b, 25, device=data['image0']['keypoints'].device)}
def rbd(data):
return {k: v[0] for k, v in data.items()}
else:
# Requires: pip install lightglue
from lightglue import LightGlue, SuperPoint
from lightglue.utils import rbd
logger = logging.getLogger(__name__)
class VisualOdometryFrontEnd:
"""
Visual Odometry Front-End using SuperPoint and LightGlue.
Provides robust, unscaled relative frame-to-frame tracking.
"""
def __init__(self, device: str = "cuda", resize_max: int = 1536):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.resize_max = resize_max
logger.info(f"Initializing V-SLAM Front-End on {self.device}")
# Load SuperPoint and LightGlue
# LightGlue automatically leverages FlashAttention if available for faster inference
self.extractor = SuperPoint(max_num_keypoints=2048).eval().to(self.device)
self.matcher = LightGlue(features='superpoint', depth_confidence=0.9).eval().to(self.device)
self.last_image_data = None
self.last_frame_id = -1
self.camera_matrix = None
def set_camera_intrinsics(self, k_matrix: np.ndarray):
self.camera_matrix = k_matrix
def _preprocess_image(self, image: np.ndarray) -> torch.Tensor:
"""Aggressive downscaling of 6.2K image to LR for sub-5s tracking."""
h, w = image.shape[:2]
scale = self.resize_max / max(h, w)
if scale < 1.0:
new_size = (int(w * scale), int(h * scale))
image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA)
# Convert to grayscale if needed
if len(image.shape) == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Convert to float tensor [0, 1]
tensor = torch.from_numpy(image).float() / 255.0
return tensor[None, None, ...].to(self.device) # [B, C, H, W]
def process_frame(self, frame_id: int, image: np.ndarray) -> Tuple[bool, Optional[np.ndarray]]:
"""
Extracts features and matches against the previous frame to compute an unscaled 6-DoF pose.
"""
if self.camera_matrix is None:
logger.error("Camera intrinsics must be set before processing frames.")
return False, None
# 1. Preprocess & Extract Features
img_tensor = self._preprocess_image(image)
with torch.no_grad():
feats = self.extractor.extract(img_tensor)
if self.last_image_data is None:
self.last_image_data = feats
self.last_frame_id = frame_id
return True, np.eye(4) # Identity transform for the first frame
# 2. Adaptive Matching with LightGlue
with torch.no_grad():
matches01 = self.matcher({'image0': self.last_image_data, 'image1': feats})
feats0, feats1, matches01 = [rbd(x) for x in [self.last_image_data, feats, matches01]]
kpts0 = feats0['keypoints'][matches01['matches'][..., 0]].cpu().numpy()
kpts1 = feats1['keypoints'][matches01['matches'][..., 1]].cpu().numpy()
if len(kpts0) < 20:
logger.warning(f"Not enough matches ({len(kpts0)}) to compute pose for frame {frame_id}.")
return False, None
# 3. Compute Essential Matrix and Relative Pose (Unscaled SE(3))
E, mask = cv2.findEssentialMat(kpts1, kpts0, self.camera_matrix, method=cv2.RANSAC, prob=0.999, threshold=1.0)
if E is None or mask.sum() < 15:
return False, None
_, R, t, _ = cv2.recoverPose(E, kpts1, kpts0, self.camera_matrix, mask=mask)
transform = np.eye(4)
transform[:3, :3] = R
transform[:3, 3] = t.flatten()
self.last_image_data = feats
self.last_frame_id = frame_id
return True, transform