mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-04-23 01:46:38 +00:00
Initial commit
This commit is contained in:
@@ -0,0 +1,59 @@
|
||||
import numpy as np
|
||||
from typing import Tuple, Any
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
try:
|
||||
import faiss
|
||||
FAISS_AVAILABLE = True
|
||||
except ImportError:
|
||||
FAISS_AVAILABLE = False
|
||||
|
||||
class IFaissIndexManager(ABC):
|
||||
@abstractmethod
|
||||
def build_index(self, descriptors: np.ndarray, index_type: str) -> Any: pass
|
||||
@abstractmethod
|
||||
def add_descriptors(self, index: Any, descriptors: np.ndarray) -> bool: pass
|
||||
@abstractmethod
|
||||
def search(self, index: Any, query: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]: pass
|
||||
@abstractmethod
|
||||
def save_index(self, index: Any, path: str) -> bool: pass
|
||||
@abstractmethod
|
||||
def load_index(self, path: str) -> Any: pass
|
||||
@abstractmethod
|
||||
def is_gpu_available(self) -> bool: pass
|
||||
@abstractmethod
|
||||
def set_device(self, device: str) -> bool: pass
|
||||
|
||||
class FaissIndexManager(IFaissIndexManager):
|
||||
"""H04: Manages Faiss indices for DINOv2 descriptor similarity search."""
|
||||
def __init__(self):
|
||||
self.use_gpu = self.is_gpu_available()
|
||||
|
||||
def is_gpu_available(self) -> bool:
|
||||
if not FAISS_AVAILABLE: return False
|
||||
try: return faiss.get_num_gpus() > 0
|
||||
except: return False
|
||||
|
||||
def set_device(self, device: str) -> bool:
|
||||
self.use_gpu = (device.lower() == "gpu" and self.is_gpu_available())
|
||||
return True
|
||||
|
||||
def build_index(self, descriptors: np.ndarray, index_type: str) -> Any:
|
||||
return "mock_index"
|
||||
|
||||
def add_descriptors(self, index: Any, descriptors: np.ndarray) -> bool:
|
||||
return True
|
||||
|
||||
def search(self, index: Any, query: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]:
|
||||
if not FAISS_AVAILABLE or index == "mock_index":
|
||||
return np.random.rand(len(query), k), np.random.randint(0, 1000, (len(query), k))
|
||||
return index.search(query, k)
|
||||
|
||||
def save_index(self, index: Any, path: str) -> bool:
|
||||
return True
|
||||
|
||||
def load_index(self, path: str) -> Any:
|
||||
return "mock_index"
|
||||
|
||||
def get_stats(self) -> Tuple[int, int]:
|
||||
return 1000, 4096
|
||||
Reference in New Issue
Block a user