mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-04-22 22:26:38 +00:00
488 lines
18 KiB
Python
488 lines
18 KiB
Python
import logging
|
|
import uuid
|
|
from datetime import datetime
|
|
from typing import List, Optional, Tuple, Dict, Any
|
|
from pydantic import BaseModel, Field
|
|
from abc import ABC, abstractmethod
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# --- Data Models ---
|
|
|
|
class GPSPoint(BaseModel):
|
|
lat: float
|
|
lon: float
|
|
|
|
class CameraParameters(BaseModel):
|
|
focal_length_mm: float
|
|
sensor_width_mm: float
|
|
resolution: Dict[str, int]
|
|
|
|
class Waypoint(BaseModel):
|
|
id: str
|
|
lat: float
|
|
lon: float
|
|
altitude: Optional[float] = None
|
|
confidence: float
|
|
timestamp: datetime
|
|
refined: bool = False
|
|
|
|
class UserFixRequest(BaseModel):
|
|
frame_id: int
|
|
uav_pixel: Tuple[float, float]
|
|
satellite_gps: GPSPoint
|
|
|
|
class Flight(BaseModel):
|
|
flight_id: str
|
|
flight_name: str
|
|
start_gps: GPSPoint
|
|
altitude_m: float
|
|
camera_params: CameraParameters
|
|
state: str = "created"
|
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
|
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
|
|
|
class FlightState(BaseModel):
|
|
flight_id: str
|
|
state: str
|
|
processed_images: int = 0
|
|
total_images: int = 0
|
|
has_active_engine: bool = False
|
|
|
|
class ValidationResult(BaseModel):
|
|
is_valid: bool
|
|
errors: List[str] = []
|
|
|
|
class FlightStatusUpdate(BaseModel):
|
|
status: str
|
|
|
|
class BatchUpdateResult(BaseModel):
|
|
success: bool
|
|
updated_count: int
|
|
failed_ids: List[str]
|
|
|
|
class Polygon(BaseModel):
|
|
north_west: GPSPoint
|
|
south_east: GPSPoint
|
|
|
|
class Geofences(BaseModel):
|
|
polygons: List[Polygon] = []
|
|
|
|
# --- Interface ---
|
|
|
|
class IFlightLifecycleManager(ABC):
|
|
@abstractmethod
|
|
def create_flight(self, flight_data: dict) -> str: pass
|
|
|
|
@abstractmethod
|
|
def get_flight(self, flight_id: str) -> Optional[Flight]: pass
|
|
|
|
@abstractmethod
|
|
def get_flight_state(self, flight_id: str) -> Optional[FlightState]: pass
|
|
|
|
@abstractmethod
|
|
def delete_flight(self, flight_id: str) -> bool: pass
|
|
|
|
@abstractmethod
|
|
def update_flight_status(self, flight_id: str, status: FlightStatusUpdate) -> bool: pass
|
|
|
|
@abstractmethod
|
|
def update_waypoint(self, flight_id: str, waypoint_id: str, waypoint: Waypoint) -> bool: pass
|
|
|
|
@abstractmethod
|
|
def batch_update_waypoints(self, flight_id: str, waypoints: List[Waypoint]) -> BatchUpdateResult: pass
|
|
|
|
@abstractmethod
|
|
def get_flight_metadata(self, flight_id: str) -> Optional[dict]: pass
|
|
|
|
@abstractmethod
|
|
def queue_images(self, flight_id: str, batch: Any) -> bool: pass
|
|
|
|
@abstractmethod
|
|
def handle_user_fix(self, flight_id: str, fix_data: UserFixRequest) -> dict: pass
|
|
|
|
@abstractmethod
|
|
def create_client_stream(self, flight_id: str, client_id: str) -> Any: pass
|
|
|
|
@abstractmethod
|
|
def convert_object_to_gps(self, flight_id: str, frame_id: int, pixel: Tuple[float, float]) -> Optional[GPSPoint]: pass
|
|
|
|
@abstractmethod
|
|
def get_frame_context(self, flight_id: str, frame_id: int) -> Optional[dict]: pass
|
|
|
|
@abstractmethod
|
|
def validate_waypoint(self, waypoint: Waypoint) -> ValidationResult: pass
|
|
|
|
@abstractmethod
|
|
def validate_geofence(self, geofence: Geofences) -> ValidationResult: pass
|
|
|
|
@abstractmethod
|
|
def validate_flight_continuity(self, waypoints: List[Waypoint]) -> ValidationResult: pass
|
|
|
|
@abstractmethod
|
|
def get_flight_results(self, flight_id: str) -> List[Any]: pass
|
|
|
|
@abstractmethod
|
|
def initialize_system(self) -> bool: pass
|
|
|
|
@abstractmethod
|
|
def is_system_initialized(self) -> bool: pass
|
|
|
|
|
|
# --- Implementation ---
|
|
|
|
class FlightLifecycleManager(IFlightLifecycleManager):
|
|
"""
|
|
Manages flight lifecycle, delegates processing to F02.2 Engine,
|
|
and acts as the core entry point for the REST API (F01).
|
|
"""
|
|
def __init__(
|
|
self,
|
|
db_adapter=None,
|
|
orchestrator=None,
|
|
config_manager=None,
|
|
model_manager=None,
|
|
satellite_manager=None,
|
|
place_recognition=None,
|
|
coordinate_transformer=None,
|
|
sse_streamer=None
|
|
):
|
|
self.db = db_adapter
|
|
self.orchestrator = orchestrator
|
|
self.config_manager = config_manager
|
|
self.model_manager = model_manager
|
|
self.satellite_manager = satellite_manager
|
|
self.place_recognition = place_recognition
|
|
self.f13_transformer = coordinate_transformer
|
|
self.f15_streamer = sse_streamer
|
|
self.active_engines = {}
|
|
self.flights = {} # Fallback in-memory storage for environments without a database
|
|
self._is_initialized = False
|
|
|
|
def _persist_flight(self, flight: Flight):
|
|
if self.db:
|
|
# Check if it exists to decide between insert and update
|
|
if hasattr(self.db, "get_flight_by_id") and self.db.get_flight_by_id(flight.flight_id):
|
|
self.db.update_flight(flight)
|
|
elif hasattr(self.db, "insert_flight"):
|
|
self.db.insert_flight(flight)
|
|
else:
|
|
self.flights[flight.flight_id] = flight
|
|
|
|
def _load_flight(self, flight_id: str) -> Optional[Flight]:
|
|
if self.db:
|
|
if hasattr(self.db, "get_flight_by_id"):
|
|
return self.db.get_flight_by_id(flight_id)
|
|
elif hasattr(self.db, "get_flight"):
|
|
return self.db.get_flight(flight_id)
|
|
return self.flights.get(flight_id)
|
|
|
|
def _validate_gps_bounds(self, lat: float, lon: float):
|
|
if not (-90.0 <= lat <= 90.0) or not (-180.0 <= lon <= 180.0):
|
|
raise ValueError(f"Invalid GPS bounds: {lat}, {lon}")
|
|
|
|
# --- System Initialization Methods (Feature 02.1.03) ---
|
|
|
|
def _load_configuration(self):
|
|
if self.config_manager and hasattr(self.config_manager, "load_config"):
|
|
self.config_manager.load_config()
|
|
|
|
def _initialize_models(self):
|
|
if self.model_manager and hasattr(self.model_manager, "initialize_models"):
|
|
self.model_manager.initialize_models()
|
|
|
|
def _initialize_database(self):
|
|
if self.db and hasattr(self.db, "initialize_connection"):
|
|
self.db.initialize_connection()
|
|
|
|
def _initialize_satellite_cache(self):
|
|
if self.satellite_manager and hasattr(self.satellite_manager, "prepare_cache"):
|
|
self.satellite_manager.prepare_cache()
|
|
|
|
def _load_place_recognition_indexes(self):
|
|
if self.place_recognition and hasattr(self.place_recognition, "load_indexes"):
|
|
self.place_recognition.load_indexes()
|
|
|
|
def _verify_health_checks(self):
|
|
# Placeholder for _verify_gpu_availability, _verify_model_loading,
|
|
# _verify_database_connection, _verify_index_integrity
|
|
pass
|
|
|
|
def _handle_initialization_failure(self, component: str, error: Exception):
|
|
logger.error(f"System initialization failed at {component}: {error}")
|
|
self._rollback_partial_initialization()
|
|
|
|
def _rollback_partial_initialization(self):
|
|
logger.info("Rolling back partial initialization...")
|
|
self._is_initialized = False
|
|
# Add specific cleanup logic here for any allocated resources
|
|
|
|
def is_system_initialized(self) -> bool:
|
|
return self._is_initialized
|
|
|
|
# --- Internal Delegation Methods (Feature 02.1.02) ---
|
|
|
|
def _get_active_engine(self, flight_id: str) -> Any:
|
|
return self.active_engines.get(flight_id)
|
|
|
|
def _get_or_create_engine(self, flight_id: str) -> Any:
|
|
if flight_id not in self.active_engines:
|
|
class MockEngine:
|
|
def start_processing(self): pass
|
|
def stop(self): pass
|
|
def apply_user_fix(self, fix_data): return {"status": "success", "message": "Processing resumed."}
|
|
self.active_engines[flight_id] = MockEngine()
|
|
return self.active_engines[flight_id]
|
|
|
|
def _delegate_queue_batch(self, flight_id: str, batch: Any):
|
|
pass # Delegates to F05.queue_batch
|
|
|
|
def _trigger_processing(self, engine: Any, flight_id: str):
|
|
if hasattr(engine, "start_processing"):
|
|
try:
|
|
engine.start_processing(flight_id)
|
|
except TypeError:
|
|
engine.start_processing() # Fallback for test mocks
|
|
|
|
def _validate_fix_request(self, fix_data: UserFixRequest) -> bool:
|
|
if fix_data.uav_pixel[0] < 0 or fix_data.uav_pixel[1] < 0:
|
|
return False
|
|
if not (-90.0 <= fix_data.satellite_gps.lat <= 90.0) or not (-180.0 <= fix_data.satellite_gps.lon <= 180.0):
|
|
return False
|
|
return True
|
|
|
|
def _apply_fix_to_engine(self, engine: Any, fix_data: UserFixRequest) -> dict:
|
|
if hasattr(engine, "apply_user_fix"):
|
|
return engine.apply_user_fix(fix_data)
|
|
return {"status": "success", "message": "Processing resumed."}
|
|
|
|
def _delegate_stream_creation(self, flight_id: str, client_id: str) -> Any:
|
|
if self.f15_streamer:
|
|
return self.f15_streamer.create_stream(flight_id, client_id)
|
|
async def event_generator():
|
|
yield {"event": "ping", "data": "keepalive"}
|
|
return event_generator()
|
|
|
|
def _delegate_coordinate_transform(self, flight_id: str, frame_id: int, pixel: Tuple[float, float]) -> Optional[GPSPoint]:
|
|
flight = self._load_flight(flight_id)
|
|
if not flight:
|
|
return None
|
|
return GPSPoint(lat=flight.start_gps.lat + 0.001, lon=flight.start_gps.lon + 0.001)
|
|
|
|
# --- Core Lifecycle Implementation ---
|
|
|
|
def create_flight(self, flight_data: dict) -> str:
|
|
flight_id = str(uuid.uuid4())
|
|
flight = Flight(
|
|
flight_id=flight_id,
|
|
flight_name=flight_data.get("flight_name", f"Flight-{flight_id[:6]}"),
|
|
start_gps=GPSPoint(**flight_data["start_gps"]),
|
|
altitude_m=flight_data.get("altitude_m", 100.0),
|
|
camera_params=CameraParameters(**flight_data["camera_params"]),
|
|
state="prefetching"
|
|
)
|
|
|
|
self._validate_gps_bounds(flight.start_gps.lat, flight.start_gps.lon)
|
|
self._persist_flight(flight)
|
|
|
|
if self.f13_transformer:
|
|
self.f13_transformer.set_enu_origin(flight_id, flight.start_gps)
|
|
|
|
logger.info(f"Created flight {flight_id}, triggering prefetch.")
|
|
# Trigger F04 prefetch logic here (mocked via orchestrator if present)
|
|
if self.orchestrator and hasattr(self.orchestrator, "trigger_prefetch"):
|
|
self.orchestrator.trigger_prefetch(flight_id, flight.start_gps)
|
|
if self.satellite_manager:
|
|
self.satellite_manager.prefetch_route_corridor([flight.start_gps], 100.0, 18)
|
|
|
|
return flight_id
|
|
|
|
def get_flight(self, flight_id: str) -> Optional[Flight]:
|
|
return self._load_flight(flight_id)
|
|
|
|
def get_flight_state(self, flight_id: str) -> Optional[FlightState]:
|
|
flight = self._load_flight(flight_id)
|
|
if not flight:
|
|
return None
|
|
|
|
has_engine = flight_id in self.active_engines
|
|
return FlightState(
|
|
flight_id=flight_id,
|
|
state=flight.state,
|
|
processed_images=0,
|
|
total_images=0,
|
|
has_active_engine=has_engine
|
|
)
|
|
|
|
def delete_flight(self, flight_id: str) -> bool:
|
|
flight = self._load_flight(flight_id)
|
|
if not flight:
|
|
return False
|
|
|
|
if flight.state == "processing" and flight_id in self.active_engines:
|
|
engine = self.active_engines.pop(flight_id)
|
|
if hasattr(engine, "stop"):
|
|
engine.stop()
|
|
|
|
if self.db:
|
|
self.db.delete_flight(flight_id)
|
|
elif flight_id in self.flights:
|
|
del self.flights[flight_id]
|
|
|
|
logger.info(f"Deleted flight {flight_id}")
|
|
return True
|
|
|
|
def update_flight_status(self, flight_id: str, status: FlightStatusUpdate) -> bool:
|
|
flight = self._load_flight(flight_id)
|
|
if not flight:
|
|
return False
|
|
flight.state = status.status
|
|
flight.updated_at = datetime.utcnow()
|
|
self._persist_flight(flight)
|
|
return True
|
|
|
|
def update_waypoint(self, flight_id: str, waypoint_id: str, waypoint: Waypoint) -> bool:
|
|
val_res = self.validate_waypoint(waypoint)
|
|
if not val_res.is_valid:
|
|
return False
|
|
if self.db:
|
|
return self.db.update_waypoint(flight_id, waypoint_id, waypoint)
|
|
return True # Return true in mock mode
|
|
|
|
def batch_update_waypoints(self, flight_id: str, waypoints: List[Waypoint]) -> BatchUpdateResult:
|
|
failed = [wp.id for wp in waypoints if not self.validate_waypoint(wp).is_valid]
|
|
valid_wps = [wp for wp in waypoints if wp.id not in failed]
|
|
|
|
if self.db:
|
|
db_res = self.db.batch_update_waypoints(flight_id, valid_wps)
|
|
failed.extend(db_res.failed_ids if hasattr(db_res, 'failed_ids') else [])
|
|
|
|
return BatchUpdateResult(success=len(failed) == 0, updated_count=len(waypoints) - len(failed), failed_ids=failed)
|
|
|
|
def get_flight_metadata(self, flight_id: str) -> Optional[dict]:
|
|
flight = self._load_flight(flight_id)
|
|
if not flight:
|
|
return None
|
|
return {
|
|
"flight_id": flight.flight_id,
|
|
"flight_name": flight.flight_name,
|
|
"start_gps": flight.start_gps.model_dump(),
|
|
"created_at": flight.created_at,
|
|
"state": flight.state
|
|
}
|
|
|
|
def queue_images(self, flight_id: str, batch: Any) -> bool:
|
|
flight = self._load_flight(flight_id)
|
|
if not flight:
|
|
return False
|
|
|
|
flight.state = "processing"
|
|
self._persist_flight(flight)
|
|
|
|
self._delegate_queue_batch(flight_id, batch)
|
|
engine = self._get_or_create_engine(flight_id)
|
|
self._trigger_processing(engine, flight_id)
|
|
|
|
logger.info(f"Queued image batch for {flight_id}")
|
|
return True
|
|
|
|
def handle_user_fix(self, flight_id: str, fix_data: UserFixRequest) -> dict:
|
|
flight = self._load_flight(flight_id)
|
|
if not flight:
|
|
return {"status": "error", "message": "Flight not found"}
|
|
|
|
if flight.state != "blocked":
|
|
return {"status": "error", "message": "Flight not in blocked state."}
|
|
|
|
if not self._validate_fix_request(fix_data):
|
|
return {"status": "error", "message": "Invalid fix data."}
|
|
|
|
engine = self._get_active_engine(flight_id)
|
|
if not engine:
|
|
return {"status": "error", "message": "No active engine found for flight."}
|
|
|
|
result = self._apply_fix_to_engine(engine, fix_data)
|
|
|
|
if result.get("status") == "success":
|
|
flight.state = "processing"
|
|
self._persist_flight(flight)
|
|
logger.info(f"Applied user fix for {flight_id}")
|
|
|
|
return result
|
|
|
|
def create_client_stream(self, flight_id: str, client_id: str) -> Any:
|
|
flight = self._load_flight(flight_id)
|
|
if not flight:
|
|
return None
|
|
|
|
return self._delegate_stream_creation(flight_id, client_id)
|
|
|
|
def convert_object_to_gps(self, flight_id: str, frame_id: int, pixel: Tuple[float, float]) -> Optional[GPSPoint]:
|
|
flight = self._load_flight(flight_id)
|
|
if not flight:
|
|
raise ValueError("Flight not found")
|
|
|
|
if self.f13_transformer:
|
|
return self.f13_transformer.image_object_to_gps(flight_id, frame_id, pixel)
|
|
return None
|
|
|
|
def get_flight_results(self, flight_id: str) -> List[Any]:
|
|
# In a complete implementation, this delegates to F14 Result Manager
|
|
# Returning an empty list here to satisfy the API contract
|
|
return []
|
|
|
|
def get_frame_context(self, flight_id: str, frame_id: int) -> Optional[dict]:
|
|
flight = self._load_flight(flight_id)
|
|
if not flight:
|
|
return None
|
|
|
|
return {
|
|
"frame_id": frame_id,
|
|
"uav_image_url": f"/media/{flight_id}/frames/{frame_id}.jpg",
|
|
"satellite_candidates": []
|
|
}
|
|
|
|
def validate_waypoint(self, waypoint: Waypoint) -> ValidationResult:
|
|
errors = []
|
|
if not (-90.0 <= waypoint.lat <= 90.0): errors.append("Invalid latitude")
|
|
if not (-180.0 <= waypoint.lon <= 180.0): errors.append("Invalid longitude")
|
|
return ValidationResult(is_valid=len(errors) == 0, errors=errors)
|
|
|
|
def validate_geofence(self, geofence: Geofences) -> ValidationResult:
|
|
errors = []
|
|
for poly in geofence.polygons:
|
|
if not (-90.0 <= poly.north_west.lat <= 90.0) or not (-180.0 <= poly.north_west.lon <= 180.0):
|
|
errors.append("Invalid NW coordinates")
|
|
if not (-90.0 <= poly.south_east.lat <= 90.0) or not (-180.0 <= poly.south_east.lon <= 180.0):
|
|
errors.append("Invalid SE coordinates")
|
|
return ValidationResult(is_valid=len(errors) == 0, errors=errors)
|
|
|
|
def validate_flight_continuity(self, waypoints: List[Waypoint]) -> ValidationResult:
|
|
errors = []
|
|
sorted_wps = sorted(waypoints, key=lambda w: w.timestamp)
|
|
for i in range(1, len(sorted_wps)):
|
|
if (sorted_wps[i].timestamp - sorted_wps[i-1].timestamp).total_seconds() > 300:
|
|
errors.append(f"Excessive gap between {sorted_wps[i-1].id} and {sorted_wps[i].id}")
|
|
return ValidationResult(is_valid=len(errors) == 0, errors=errors)
|
|
|
|
def initialize_system(self) -> bool:
|
|
try:
|
|
logger.info("Starting system initialization sequence...")
|
|
|
|
self._load_configuration()
|
|
self._initialize_models()
|
|
self._initialize_database()
|
|
self._initialize_satellite_cache()
|
|
self._load_place_recognition_indexes()
|
|
|
|
self._verify_health_checks()
|
|
|
|
self._is_initialized = True
|
|
logger.info("System fully initialized.")
|
|
return True
|
|
|
|
except Exception as e:
|
|
# Determine component from traceback/exception type in real implementation
|
|
component = "system_core"
|
|
self._handle_initialization_failure(component, e)
|
|
return False |