import torch import cv2 import numpy as np from typing import Tuple, Optional import logging import os USE_MOCK_MODELS = os.environ.get("USE_MOCK_MODELS", "0") == "1" if USE_MOCK_MODELS: class SuperPoint(torch.nn.Module): def __init__(self, **kwargs): super().__init__() def forward(self, x): b, _, h, w = x.shape kpts = torch.rand(b, 50, 2, device=x.device) kpts[..., 0] *= w kpts[..., 1] *= h return {'keypoints': kpts, 'descriptors': torch.rand(b, 256, 50, device=x.device), 'scores': torch.rand(b, 50, device=x.device)} class LightGlue(torch.nn.Module): def __init__(self, **kwargs): super().__init__() def forward(self, data): b = data['image0']['keypoints'].shape[0] matches = torch.stack([torch.arange(25), torch.arange(25)], dim=-1).unsqueeze(0).repeat(b, 1, 1).to(data['image0']['keypoints'].device) return {'matches': matches, 'matching_scores': torch.rand(b, 25, device=data['image0']['keypoints'].device)} def rbd(data): return {k: v[0] for k, v in data.items()} else: # Requires: pip install lightglue from lightglue import LightGlue, SuperPoint from lightglue.utils import rbd logger = logging.getLogger(__name__) class VisualOdometryFrontEnd: """ Visual Odometry Front-End using SuperPoint and LightGlue. Provides robust, unscaled relative frame-to-frame tracking. """ def __init__(self, device: str = "cuda", resize_max: int = 1536): self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.resize_max = resize_max logger.info(f"Initializing V-SLAM Front-End on {self.device}") # Load SuperPoint and LightGlue # LightGlue automatically leverages FlashAttention if available for faster inference self.extractor = SuperPoint(max_num_keypoints=2048).eval().to(self.device) self.matcher = LightGlue(features='superpoint', depth_confidence=0.9).eval().to(self.device) self.last_image_data = None self.last_frame_id = -1 self.camera_matrix = None def set_camera_intrinsics(self, k_matrix: np.ndarray): self.camera_matrix = k_matrix def _preprocess_image(self, image: np.ndarray) -> torch.Tensor: """Aggressive downscaling of 6.2K image to LR for sub-5s tracking.""" h, w = image.shape[:2] scale = self.resize_max / max(h, w) if scale < 1.0: new_size = (int(w * scale), int(h * scale)) image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA) # Convert to grayscale if needed if len(image.shape) == 3: image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # Convert to float tensor [0, 1] tensor = torch.from_numpy(image).float() / 255.0 return tensor[None, None, ...].to(self.device) # [B, C, H, W] def process_frame(self, frame_id: int, image: np.ndarray) -> Tuple[bool, Optional[np.ndarray]]: """ Extracts features and matches against the previous frame to compute an unscaled 6-DoF pose. """ if self.camera_matrix is None: logger.error("Camera intrinsics must be set before processing frames.") return False, None # 1. Preprocess & Extract Features img_tensor = self._preprocess_image(image) with torch.no_grad(): feats = self.extractor.extract(img_tensor) if self.last_image_data is None: self.last_image_data = feats self.last_frame_id = frame_id return True, np.eye(4) # Identity transform for the first frame # 2. Adaptive Matching with LightGlue with torch.no_grad(): matches01 = self.matcher({'image0': self.last_image_data, 'image1': feats}) feats0, feats1, matches01 = [rbd(x) for x in [self.last_image_data, feats, matches01]] kpts0 = feats0['keypoints'][matches01['matches'][..., 0]].cpu().numpy() kpts1 = feats1['keypoints'][matches01['matches'][..., 1]].cpu().numpy() if len(kpts0) < 20: logger.warning(f"Not enough matches ({len(kpts0)}) to compute pose for frame {frame_id}.") return False, None # 3. Compute Essential Matrix and Relative Pose (Unscaled SE(3)) E, mask = cv2.findEssentialMat(kpts1, kpts0, self.camera_matrix, method=cv2.RANSAC, prob=0.999, threshold=1.0) if E is None or mask.sum() < 15: return False, None _, R, t, _ = cv2.recoverPose(E, kpts1, kpts0, self.camera_matrix, mask=mask) transform = np.eye(4) transform[:3, :3] = R transform[:3, 3] = t.flatten() self.last_image_data = feats self.last_frame_id = frame_id return True, transform