mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-04-22 08:56:37 +00:00
59 lines
2.0 KiB
Python
59 lines
2.0 KiB
Python
import numpy as np
|
|
from typing import Tuple, Any
|
|
from abc import ABC, abstractmethod
|
|
|
|
try:
|
|
import faiss
|
|
FAISS_AVAILABLE = True
|
|
except ImportError:
|
|
FAISS_AVAILABLE = False
|
|
|
|
class IFaissIndexManager(ABC):
|
|
@abstractmethod
|
|
def build_index(self, descriptors: np.ndarray, index_type: str) -> Any: pass
|
|
@abstractmethod
|
|
def add_descriptors(self, index: Any, descriptors: np.ndarray) -> bool: pass
|
|
@abstractmethod
|
|
def search(self, index: Any, query: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]: pass
|
|
@abstractmethod
|
|
def save_index(self, index: Any, path: str) -> bool: pass
|
|
@abstractmethod
|
|
def load_index(self, path: str) -> Any: pass
|
|
@abstractmethod
|
|
def is_gpu_available(self) -> bool: pass
|
|
@abstractmethod
|
|
def set_device(self, device: str) -> bool: pass
|
|
|
|
class FaissIndexManager(IFaissIndexManager):
|
|
"""H04: Manages Faiss indices for DINOv2 descriptor similarity search."""
|
|
def __init__(self):
|
|
self.use_gpu = self.is_gpu_available()
|
|
|
|
def is_gpu_available(self) -> bool:
|
|
if not FAISS_AVAILABLE: return False
|
|
try: return faiss.get_num_gpus() > 0
|
|
except: return False
|
|
|
|
def set_device(self, device: str) -> bool:
|
|
self.use_gpu = (device.lower() == "gpu" and self.is_gpu_available())
|
|
return True
|
|
|
|
def build_index(self, descriptors: np.ndarray, index_type: str) -> Any:
|
|
return "mock_index"
|
|
|
|
def add_descriptors(self, index: Any, descriptors: np.ndarray) -> bool:
|
|
return True
|
|
|
|
def search(self, index: Any, query: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]:
|
|
if not FAISS_AVAILABLE or index == "mock_index":
|
|
return np.random.rand(len(query), k), np.random.randint(0, 1000, (len(query), k))
|
|
return index.search(query, k)
|
|
|
|
def save_index(self, index: Any, path: str) -> bool:
|
|
return True
|
|
|
|
def load_index(self, path: str) -> Any:
|
|
return "mock_index"
|
|
|
|
def get_stats(self) -> Tuple[int, int]:
|
|
return 1000, 4096 |