mirror of
https://github.com/azaion/gps-denied-desktop.git
synced 2026-04-22 07:06:37 +00:00
111 lines
3.9 KiB
Python
111 lines
3.9 KiB
Python
import cv2
|
|
import numpy as np
|
|
import onnxruntime as ort
|
|
import os
|
|
|
|
class LightGlueMatcher:
|
|
def __init__(self, onnx_model_path, max_dimension=512):
|
|
"""
|
|
Initializes the ONNX runtime session.
|
|
|
|
Args:
|
|
onnx_model_path (str): Path to the .onnx model file.
|
|
max_dimension (int): Maximum edge length for resizing.
|
|
"""
|
|
self.max_dim = max_dimension
|
|
|
|
if not os.path.exists(onnx_model_path):
|
|
raise FileNotFoundError(f"Model not found at: {onnx_model_path}")
|
|
|
|
# Setup ONNX Runtime
|
|
if 'CUDAExecutionProvider' in ort.get_available_providers():
|
|
self.providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
|
print("LightGlueMatcher: Using CUDA (GPU).")
|
|
else:
|
|
self.providers = ['CPUExecutionProvider']
|
|
print("LightGlueMatcher: Using CPU.")
|
|
|
|
sess_options = ort.SessionOptions()
|
|
self.session = ort.InferenceSession(onnx_model_path, sess_options, providers=self.providers)
|
|
|
|
# Cache input/output names
|
|
self.input_names = [inp.name for inp in self.session.get_inputs()]
|
|
self.output_names = [out.name for out in self.session.get_outputs()]
|
|
|
|
def _preprocess(self, img_input):
|
|
"""
|
|
Internal helper: Resize and normalize image for ONNX.
|
|
Handles both file paths and numpy arrays.
|
|
"""
|
|
# Load image if input is a path
|
|
if isinstance(img_input, str):
|
|
img_raw = cv2.imread(img_input, cv2.IMREAD_GRAYSCALE)
|
|
if img_raw is None:
|
|
raise FileNotFoundError(f"Could not load image at {img_input}")
|
|
elif isinstance(img_input, np.ndarray):
|
|
if len(img_input.shape) == 3:
|
|
img_raw = cv2.cvtColor(img_input, cv2.COLOR_BGR2GRAY)
|
|
else:
|
|
img_raw = img_input
|
|
else:
|
|
raise ValueError("Input must be an image path or numpy array.")
|
|
|
|
h, w = img_raw.shape
|
|
scale = self.max_dim / max(h, w)
|
|
|
|
# Resize logic
|
|
if scale < 1:
|
|
h_new, w_new = int(round(h * scale)), int(round(w * scale))
|
|
img_resized = cv2.resize(img_raw, (w_new, h_new), interpolation=cv2.INTER_AREA)
|
|
else:
|
|
scale = 1.0
|
|
img_resized = img_raw
|
|
|
|
# Normalize and reshape (1, 1, H, W)
|
|
img_normalized = img_resized.astype(np.float32) / 255.0
|
|
img_tensor = img_normalized[None, None]
|
|
|
|
return img_tensor, scale
|
|
|
|
def match(self, img0_input, img1_input):
|
|
"""
|
|
Matches two images and returns the Fundamental Matrix and Keypoints.
|
|
|
|
Returns:
|
|
F (np.ndarray): The 3x3 Fundamental Matrix (or None if not enough matches).
|
|
mkpts0 (np.ndarray): Matched keypoints in Image 0 (N, 2).
|
|
mkpts1 (np.ndarray): Matched keypoints in Image 1 (N, 2).
|
|
"""
|
|
# Preprocess
|
|
tensor0, scale0 = self._preprocess(img0_input)
|
|
tensor1, scale1 = self._preprocess(img1_input)
|
|
|
|
# Inference
|
|
inputs = {
|
|
self.input_names[0]: tensor0,
|
|
self.input_names[1]: tensor1
|
|
}
|
|
outputs = self.session.run(None, inputs)
|
|
outputs_dict = dict(zip(self.output_names, outputs))
|
|
|
|
# Extract Raw Matches
|
|
kpts0 = outputs_dict['kpts0'].squeeze(0)
|
|
kpts1 = outputs_dict['kpts1'].squeeze(0)
|
|
matches0 = outputs_dict['matches0']
|
|
|
|
# Filter and Rescale Keypoints
|
|
if len(matches0) < 8:
|
|
print("Not enough matches to compute matrix.")
|
|
return None, np.array([]), np.array([])
|
|
|
|
mkpts0 = kpts0[matches0[:, 0]] / scale0
|
|
mkpts1 = kpts1[matches0[:, 1]] / scale1
|
|
|
|
# Calculate Fundamental Matrix (F) using RANSAC
|
|
# This is usually "The Matrix" required for relative pose estimation
|
|
F, mask = cv2.findFundamentalMat(mkpts0, mkpts1, cv2.FM_RANSAC, 3, 0.99)
|
|
|
|
return F, mkpts0, mkpts1
|
|
|
|
|