import cv2 import numpy as np import onnxruntime as ort import os class LightGlueMatcher: def __init__(self, onnx_model_path, max_dimension=512): """ Initializes the ONNX runtime session. Args: onnx_model_path (str): Path to the .onnx model file. max_dimension (int): Maximum edge length for resizing. """ self.max_dim = max_dimension if not os.path.exists(onnx_model_path): raise FileNotFoundError(f"Model not found at: {onnx_model_path}") # Setup ONNX Runtime if 'CUDAExecutionProvider' in ort.get_available_providers(): self.providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] print("LightGlueMatcher: Using CUDA (GPU).") else: self.providers = ['CPUExecutionProvider'] print("LightGlueMatcher: Using CPU.") sess_options = ort.SessionOptions() self.session = ort.InferenceSession(onnx_model_path, sess_options, providers=self.providers) # Cache input/output names self.input_names = [inp.name for inp in self.session.get_inputs()] self.output_names = [out.name for out in self.session.get_outputs()] def _preprocess(self, img_input): """ Internal helper: Resize and normalize image for ONNX. Handles both file paths and numpy arrays. """ # Load image if input is a path if isinstance(img_input, str): img_raw = cv2.imread(img_input, cv2.IMREAD_GRAYSCALE) if img_raw is None: raise FileNotFoundError(f"Could not load image at {img_input}") elif isinstance(img_input, np.ndarray): if len(img_input.shape) == 3: img_raw = cv2.cvtColor(img_input, cv2.COLOR_BGR2GRAY) else: img_raw = img_input else: raise ValueError("Input must be an image path or numpy array.") h, w = img_raw.shape scale = self.max_dim / max(h, w) # Resize logic if scale < 1: h_new, w_new = int(round(h * scale)), int(round(w * scale)) img_resized = cv2.resize(img_raw, (w_new, h_new), interpolation=cv2.INTER_AREA) else: scale = 1.0 img_resized = img_raw # Normalize and reshape (1, 1, H, W) img_normalized = img_resized.astype(np.float32) / 255.0 img_tensor = img_normalized[None, None] return img_tensor, scale def match(self, img0_input, img1_input): """ Matches two images and returns the Fundamental Matrix and Keypoints. Returns: F (np.ndarray): The 3x3 Fundamental Matrix (or None if not enough matches). mkpts0 (np.ndarray): Matched keypoints in Image 0 (N, 2). mkpts1 (np.ndarray): Matched keypoints in Image 1 (N, 2). """ # Preprocess tensor0, scale0 = self._preprocess(img0_input) tensor1, scale1 = self._preprocess(img1_input) # Inference inputs = { self.input_names[0]: tensor0, self.input_names[1]: tensor1 } outputs = self.session.run(None, inputs) outputs_dict = dict(zip(self.output_names, outputs)) # Extract Raw Matches kpts0 = outputs_dict['kpts0'].squeeze(0) kpts1 = outputs_dict['kpts1'].squeeze(0) matches0 = outputs_dict['matches0'] # Filter and Rescale Keypoints if len(matches0) < 8: print("Not enough matches to compute matrix.") return None, np.array([]), np.array([]) mkpts0 = kpts0[matches0[:, 0]] / scale0 mkpts1 = kpts1[matches0[:, 1]] / scale1 # Calculate Fundamental Matrix (F) using RANSAC # This is usually "The Matrix" required for relative pose estimation F, mask = cv2.findFundamentalMat(mkpts0, mkpts1, cv2.FM_RANSAC, 3, 0.99) return F, mkpts0, mkpts1