Files
gps-denied-onboard/f03_flight_database.py
T
Denys Zaitsev d7e1066c60 Initial commit
2026-04-03 23:25:54 +03:00

584 lines
24 KiB
Python

import logging
import threading
import time
from datetime import datetime
from typing import List, Optional, Dict, Any, Callable
from pydantic import BaseModel, Field
from abc import ABC, abstractmethod
from sqlalchemy import create_engine, Column, String, Float, Boolean, DateTime, Integer, JSON, ForeignKey, Text
from sqlalchemy.orm import declarative_base, sessionmaker, Session, relationship
from sqlalchemy.exc import IntegrityError
from sqlalchemy.pool import StaticPool
from sqlalchemy import event
from f02_1_flight_lifecycle_manager import Flight, Waypoint, GPSPoint, CameraParameters, Geofences, Polygon, FlightState
logger = logging.getLogger(__name__)
# --- Data Models ---
class FrameResult(BaseModel):
frame_id: int
gps_center: GPSPoint
altitude: Optional[float] = None
heading: float
confidence: float
refined: bool = False
timestamp: datetime
updated_at: datetime = Field(default_factory=datetime.utcnow)
class HeadingRecord(BaseModel):
frame_id: int
heading: float
timestamp: datetime
class BatchResult(BaseModel):
success: bool
updated_count: int
failed_ids: List[str]
class ChunkHandle(BaseModel):
chunk_id: str
start_frame_id: int
end_frame_id: Optional[int] = None
frames: List[int] = []
is_active: bool = True
has_anchor: bool = False
anchor_frame_id: Optional[int] = None
anchor_gps: Optional[GPSPoint] = None
matching_status: str = 'unanchored'
# --- SQLAlchemy ORM Models ---
Base = declarative_base()
class SQLFlight(Base):
__tablename__ = 'flights'
id = Column(String(36), primary_key=True)
name = Column(String(255), nullable=False)
description = Column(Text, default="")
start_lat = Column(Float, nullable=False)
start_lon = Column(Float, nullable=False)
altitude = Column(Float, nullable=False)
camera_params = Column(JSON, nullable=False)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow)
class SQLWaypoint(Base):
__tablename__ = 'waypoints'
id = Column(String(36), primary_key=True)
flight_id = Column(String(36), ForeignKey('flights.id', ondelete='CASCADE'), nullable=False)
lat = Column(Float, nullable=False)
lon = Column(Float, nullable=False)
altitude = Column(Float)
confidence = Column(Float, nullable=False)
timestamp = Column(DateTime, nullable=False)
refined = Column(Boolean, default=False)
class SQLGeofence(Base):
__tablename__ = 'geofences'
id = Column(String(36), primary_key=True)
flight_id = Column(String(36), ForeignKey('flights.id', ondelete='CASCADE'), nullable=False)
nw_lat = Column(Float, nullable=False)
nw_lon = Column(Float, nullable=False)
se_lat = Column(Float, nullable=False)
se_lon = Column(Float, nullable=False)
class SQLFlightState(Base):
__tablename__ = 'flight_state'
flight_id = Column(String(36), ForeignKey('flights.id', ondelete='CASCADE'), primary_key=True)
status = Column(String(50), nullable=False)
frames_processed = Column(Integer, default=0)
frames_total = Column(Integer, default=0)
current_frame = Column(Integer)
blocked = Column(Boolean, default=False)
search_grid_size = Column(Integer)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow)
class SQLFrameResult(Base):
__tablename__ = 'frame_results'
id = Column(String(72), primary_key=True) # Composite key representation: {flight_id}_{frame_id}
flight_id = Column(String(36), ForeignKey('flights.id', ondelete='CASCADE'), nullable=False)
frame_id = Column(Integer, nullable=False)
gps_lat = Column(Float)
gps_lon = Column(Float)
altitude = Column(Float)
heading = Column(Float)
confidence = Column(Float)
refined = Column(Boolean, default=False)
timestamp = Column(DateTime)
updated_at = Column(DateTime, default=datetime.utcnow)
class SQLHeadingHistory(Base):
__tablename__ = 'heading_history'
id = Column(String(72), primary_key=True) # {flight_id}_{frame_id}
flight_id = Column(String(36), ForeignKey('flights.id', ondelete='CASCADE'), nullable=False)
frame_id = Column(Integer, nullable=False)
heading = Column(Float, nullable=False)
timestamp = Column(DateTime, nullable=False)
class SQLFlightImage(Base):
__tablename__ = 'flight_images'
id = Column(String(72), primary_key=True) # {flight_id}_{frame_id}
flight_id = Column(String(36), ForeignKey('flights.id', ondelete='CASCADE'), nullable=False)
frame_id = Column(Integer, nullable=False)
file_path = Column(String(500), nullable=False)
metadata_json = Column(JSON)
uploaded_at = Column(DateTime, default=datetime.utcnow)
class SQLChunk(Base):
__tablename__ = 'chunks'
chunk_id = Column(String(36), primary_key=True)
flight_id = Column(String(36), ForeignKey('flights.id', ondelete='CASCADE'), nullable=False)
start_frame_id = Column(Integer, nullable=False)
end_frame_id = Column(Integer)
frames = Column(JSON, nullable=False)
is_active = Column(Boolean, default=True)
has_anchor = Column(Boolean, default=False)
anchor_frame_id = Column(Integer)
anchor_lat = Column(Float)
anchor_lon = Column(Float)
matching_status = Column(String(50), default='unanchored')
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow)
# --- Implementation ---
class FlightDatabase:
"""
Provides transactional CRUD operations and state persistence over SQLAlchemy.
Supports connection pooling and thread-safe batch transactions.
"""
def __init__(self, db_url: str = "sqlite:///:memory:"):
connect_args = {"check_same_thread": False} if db_url.startswith("sqlite") else {}
if db_url == "sqlite:///:memory:":
self.engine = create_engine(db_url, connect_args=connect_args, poolclass=StaticPool)
else:
self.engine = create_engine(db_url, connect_args=connect_args)
# Enable foreign key constraints for SQLite
if db_url.startswith("sqlite"):
@event.listens_for(self.engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
Base.metadata.create_all(self.engine)
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
# Thread-local storage to coordinate active transactions
self._local = threading.local()
def _get_session(self) -> Session:
if getattr(self._local, 'in_transaction', False):
return self._local.session
return self.SessionLocal()
def _close_session_if_needed(self, session: Session):
if not getattr(self._local, 'in_transaction', False):
session.commit()
session.close()
def _rollback_if_needed(self, session: Session):
if not getattr(self._local, 'in_transaction', False):
session.rollback()
session.close()
def _get_connection(self) -> Session:
"""Alias for _get_session to map to 03.01 spec naming conventions."""
return self._get_session()
def _release_connection(self, conn: Session):
"""Alias to release connection back to the pool."""
self._close_session_if_needed(conn)
def _execute_with_retry(self, operation: Callable, retries: int = 3) -> Any:
"""Executes a database operation with automatic retry on transient errors."""
last_exception = None
for attempt in range(retries):
try:
return operation()
except Exception as e:
last_exception = e
time.sleep(0.1 * (2 ** attempt)) # Exponential backoff
raise last_exception
def _serialize_camera_params(self, params: CameraParameters) -> dict:
return params.model_dump()
def _deserialize_camera_params(self, jsonb: dict) -> CameraParameters:
return CameraParameters(**jsonb)
def _serialize_metadata(self, metadata: Dict) -> dict:
return metadata
def _deserialize_metadata(self, jsonb: dict) -> Dict:
return jsonb if jsonb else {}
def _serialize_chunk_frames(self, frames: List[int]) -> list:
return frames
def _deserialize_chunk_frames(self, jsonb: list) -> List[int]:
return jsonb if jsonb else []
def _build_flight_from_row(self, row: SQLFlight) -> Flight:
return Flight(
flight_id=row.id, flight_name=row.name,
start_gps=GPSPoint(lat=row.start_lat, lon=row.start_lon),
altitude_m=row.altitude, camera_params=self._deserialize_camera_params(row.camera_params)
)
def _build_waypoint_from_row(self, row: SQLWaypoint) -> Waypoint:
return Waypoint(
id=row.id, lat=row.lat, lon=row.lon, altitude=row.altitude,
confidence=row.confidence, timestamp=row.timestamp, refined=row.refined
)
def _build_filter_query(self, query: Any, filters: Dict[str, Any]) -> Any:
if filters:
if "name" in filters:
query = query.filter(SQLFlight.name.like(filters["name"]))
if "status" in filters:
query = query.join(SQLFlightState).filter(SQLFlightState.status == filters["status"])
return query
def _build_flight_state_from_row(self, row: SQLFlightState) -> FlightState:
return FlightState(
flight_id=row.flight_id, state=row.status,
processed_images=row.frames_processed, total_images=row.frames_total
)
def _build_frame_result_from_row(self, row: SQLFrameResult) -> FrameResult:
return FrameResult(
frame_id=row.frame_id, gps_center=GPSPoint(lat=row.gps_lat, lon=row.gps_lon),
altitude=row.altitude, heading=row.heading, confidence=row.confidence,
refined=row.refined, timestamp=row.timestamp, updated_at=row.updated_at
)
def _build_heading_record_from_row(self, row: SQLHeadingHistory) -> HeadingRecord:
return HeadingRecord(frame_id=row.frame_id, heading=row.heading, timestamp=row.timestamp)
def _build_chunk_handle_from_row(self, row: SQLChunk) -> ChunkHandle:
gps = GPSPoint(lat=row.anchor_lat, lon=row.anchor_lon) if row.anchor_lat is not None and row.anchor_lon is not None else None
return ChunkHandle(
chunk_id=row.chunk_id, start_frame_id=row.start_frame_id, end_frame_id=row.end_frame_id,
frames=self._deserialize_chunk_frames(row.frames), is_active=row.is_active, has_anchor=row.has_anchor,
anchor_frame_id=row.anchor_frame_id, anchor_gps=gps, matching_status=row.matching_status
)
def _upsert_flight_state(self, state: FlightState) -> bool:
session = self._get_connection()
try:
state_obj = SQLFlightState(
flight_id=state.flight_id, status=state.state,
frames_processed=state.processed_images, frames_total=state.total_images
)
session.merge(state_obj)
self._release_connection(session)
return True
except Exception:
self._rollback_if_needed(session)
return False
def _upsert_frame_result(self, flight_id: str, result: FrameResult) -> bool:
session = self._get_connection()
try:
fr = SQLFrameResult(
id=f"{flight_id}_{result.frame_id}", flight_id=flight_id, frame_id=result.frame_id,
gps_lat=result.gps_center.lat, gps_lon=result.gps_center.lon, altitude=result.altitude,
heading=result.heading, confidence=result.confidence, refined=result.refined,
timestamp=result.timestamp, updated_at=result.updated_at
)
session.merge(fr)
self._release_connection(session)
return True
except Exception:
self._rollback_if_needed(session)
return False
def _upsert_chunk_state(self, flight_id: str, chunk: ChunkHandle) -> bool:
session = self._get_connection()
try:
anchor_lat = chunk.anchor_gps.lat if chunk.anchor_gps else None
anchor_lon = chunk.anchor_gps.lon if chunk.anchor_gps else None
c = SQLChunk(
chunk_id=chunk.chunk_id, flight_id=flight_id, start_frame_id=chunk.start_frame_id,
end_frame_id=chunk.end_frame_id, frames=self._serialize_chunk_frames(chunk.frames), is_active=chunk.is_active,
has_anchor=chunk.has_anchor, anchor_frame_id=chunk.anchor_frame_id,
anchor_lat=anchor_lat, anchor_lon=anchor_lon, matching_status=chunk.matching_status
)
session.merge(c)
self._release_connection(session)
return True
except Exception:
self._rollback_if_needed(session)
return False
# --- Transaction Support ---
def execute_transaction(self, operations: List[Callable[[], None]]) -> bool:
session = self.SessionLocal()
self._local.session = session
self._local.in_transaction = True
try:
for op in operations:
op()
session.commit()
return True
except Exception as e:
session.rollback()
logger.error(f"Transaction failed: {e}")
return False
finally:
self._local.in_transaction = False
self._local.session = None
session.close()
# --- Flight Operations ---
def insert_flight(self, flight: Flight) -> str:
def _do_insert():
session = self._get_connection()
try:
sql_flight = SQLFlight(
id=flight.flight_id, name=flight.flight_name, description=flight.flight_name,
start_lat=flight.start_gps.lat, start_lon=flight.start_gps.lon,
altitude=flight.altitude_m, camera_params=self._serialize_camera_params(flight.camera_params),
created_at=flight.created_at, updated_at=flight.updated_at
)
session.add(sql_flight)
self._release_connection(session)
return flight.flight_id
except IntegrityError as e:
self._rollback_if_needed(session)
raise ValueError(f"Duplicate flight or integrity error: {e}")
except Exception as e:
self._rollback_if_needed(session)
raise e
return self._execute_with_retry(_do_insert)
def update_flight(self, flight: Flight) -> bool:
session = self._get_connection()
try:
sql_flight = session.query(SQLFlight).filter_by(id=flight.flight_id).first()
if not sql_flight:
self._release_connection(session)
return False
sql_flight.name = flight.flight_name
sql_flight.updated_at = datetime.utcnow()
self._release_connection(session)
return True
except Exception:
self._rollback_if_needed(session)
return False
def query_flights(self, filters: Dict[str, Any], limit: int, offset: int = 0) -> List[Flight]:
session = self._get_connection()
query = session.query(SQLFlight)
query = self._build_filter_query(query, filters)
sql_flights = query.offset(offset).limit(limit).all()
flights = [self._build_flight_from_row(f) for f in sql_flights]
self._release_connection(session)
return flights
def get_flight_by_id(self, flight_id: str) -> Optional[Flight]:
session = self._get_connection()
f = session.query(SQLFlight).filter_by(id=flight_id).first()
if not f:
self._release_connection(session)
return None
flight = self._build_flight_from_row(f)
self._release_connection(session)
return flight
def delete_flight(self, flight_id: str) -> bool:
session = self._get_connection()
try:
sql_flight = session.query(SQLFlight).filter_by(id=flight_id).first()
if not sql_flight:
self._release_connection(session)
return False
session.delete(sql_flight) # Cascade handles related rows
self._release_connection(session)
return True
except Exception:
self._rollback_if_needed(session)
return False
# --- Waypoint Operations ---
def get_waypoints(self, flight_id: str, limit: Optional[int] = None) -> List[Waypoint]:
session = self._get_connection()
query = session.query(SQLWaypoint).filter_by(flight_id=flight_id).order_by(SQLWaypoint.timestamp)
if limit:
query = query.limit(limit)
wps = [self._build_waypoint_from_row(w) for w in query.all()]
self._release_connection(session)
return wps
def insert_waypoint(self, flight_id: str, waypoint: Waypoint) -> str:
session = self._get_connection()
try:
sql_wp = SQLWaypoint(
id=waypoint.id, flight_id=flight_id, lat=waypoint.lat, lon=waypoint.lon,
altitude=waypoint.altitude, confidence=waypoint.confidence,
timestamp=waypoint.timestamp, refined=waypoint.refined
)
session.merge(sql_wp)
self._release_connection(session)
return waypoint.id
except Exception as e:
self._rollback_if_needed(session)
raise ValueError(f"Failed to insert waypoint: {e}")
def update_waypoint(self, flight_id: str, waypoint_id: str, waypoint: Waypoint) -> bool:
session = self._get_connection()
try:
wp = session.query(SQLWaypoint).filter_by(id=waypoint_id, flight_id=flight_id).first()
if not wp:
self._release_connection(session)
return False
wp.lat, wp.lon = waypoint.lat, waypoint.lon
wp.altitude, wp.confidence = waypoint.altitude, waypoint.confidence
wp.refined = waypoint.refined
self._release_connection(session)
return True
except Exception:
self._rollback_if_needed(session)
return False
def batch_update_waypoints(self, flight_id: str, waypoints: List[Waypoint]) -> BatchResult:
failed = []
def do_update():
for wp in waypoints:
success = self.update_waypoint(flight_id, wp.id, wp)
if not success: failed.append(wp.id)
success = self.execute_transaction([do_update])
if not success:
return BatchResult(success=False, updated_count=0, failed_ids=[w.id for w in waypoints])
return BatchResult(success=len(failed) == 0, updated_count=len(waypoints) - len(failed), failed_ids=failed)
# --- Flight State & auxiliary persistence ---
def save_flight_state(self, flight_state: FlightState) -> bool:
return self._execute_with_retry(lambda: self._upsert_flight_state(flight_state))
def load_flight_state(self, flight_id: str) -> Optional[FlightState]:
session = self._get_connection()
s = session.query(SQLFlightState).filter_by(flight_id=flight_id).first()
result = self._build_flight_state_from_row(s) if s else None
self._release_connection(session)
return result
def query_processing_history(self, filters: Dict[str, Any]) -> List[FlightState]:
session = self._get_connection()
query = session.query(SQLFlightState)
if filters:
if "status" in filters:
query = query.filter(SQLFlightState.status == filters["status"])
if "created_after" in filters:
query = query.filter(SQLFlightState.created_at >= filters["created_after"])
if "created_before" in filters:
query = query.filter(SQLFlightState.created_at <= filters["created_before"])
results = [self._build_flight_state_from_row(r) for r in query.all()]
self._release_connection(session)
return results
def save_frame_result(self, flight_id: str, frame_result: FrameResult) -> bool:
return self._execute_with_retry(lambda: self._upsert_frame_result(flight_id, frame_result))
def get_frame_results(self, flight_id: str) -> List[FrameResult]:
session = self._get_connection()
results = session.query(SQLFrameResult).filter_by(flight_id=flight_id).order_by(SQLFrameResult.frame_id).all()
parsed = [self._build_frame_result_from_row(r) for r in results]
self._release_connection(session)
return parsed
def save_heading(self, flight_id: str, frame_id: int, heading: float, timestamp: datetime) -> bool:
def _do_save():
session = self._get_connection()
try:
obj = SQLHeadingHistory(id=f"{flight_id}_{frame_id}", flight_id=flight_id, frame_id=frame_id, heading=heading, timestamp=timestamp)
session.merge(obj)
self._release_connection(session)
return True
except Exception:
self._rollback_if_needed(session)
return False
return self._execute_with_retry(_do_save)
def get_heading_history(self, flight_id: str, last_n: Optional[int] = None) -> List[HeadingRecord]:
session = self._get_connection()
query = session.query(SQLHeadingHistory).filter_by(flight_id=flight_id).order_by(SQLHeadingHistory.frame_id.desc())
if last_n: query = query.limit(last_n)
results = [self._build_heading_record_from_row(r) for r in query.all()]
self._release_connection(session)
return results
def get_latest_heading(self, flight_id: str) -> Optional[float]:
session = self._get_connection()
h = session.query(SQLHeadingHistory).filter_by(flight_id=flight_id).order_by(SQLHeadingHistory.frame_id.desc()).first()
result = h.heading if h else None
self._release_connection(session)
return result
def save_image_metadata(self, flight_id: str, frame_id: int, file_path: str, metadata: Dict) -> bool:
def _do_save():
session = self._get_connection()
try:
img = SQLFlightImage(id=f"{flight_id}_{frame_id}", flight_id=flight_id, frame_id=frame_id, file_path=file_path, metadata_json=self._serialize_metadata(metadata))
session.merge(img)
self._release_connection(session)
return True
except Exception:
self._rollback_if_needed(session)
return False
return self._execute_with_retry(_do_save)
def get_image_path(self, flight_id: str, frame_id: int) -> Optional[str]:
session = self._get_connection()
img = session.query(SQLFlightImage).filter_by(flight_id=flight_id, frame_id=frame_id).first()
result = img.file_path if img else None
self._release_connection(session)
return result
def get_image_metadata(self, flight_id: str, frame_id: int) -> Optional[Dict]:
session = self._get_connection()
img = session.query(SQLFlightImage).filter_by(flight_id=flight_id, frame_id=frame_id).first()
result = self._deserialize_metadata(img.metadata_json) if img else None
self._release_connection(session)
return result
def save_chunk_state(self, flight_id: str, chunk: ChunkHandle) -> bool:
return self._execute_with_retry(lambda: self._upsert_chunk_state(flight_id, chunk))
def load_chunk_states(self, flight_id: str) -> List[ChunkHandle]:
session = self._get_connection()
sql_chunks = session.query(SQLChunk).filter_by(flight_id=flight_id).all()
handles = [self._build_chunk_handle_from_row(c) for c in sql_chunks]
self._release_connection(session)
return handles
def delete_chunk_state(self, flight_id: str, chunk_id: str) -> bool:
session = self._get_connection()
try:
chunk = session.query(SQLChunk).filter_by(flight_id=flight_id, chunk_id=chunk_id).first()
if not chunk:
self._release_connection(session)
return False
session.delete(chunk)
self._release_connection(session)
return True
except Exception:
self._rollback_if_needed(session)
return False