import logging import threading import time from datetime import datetime from typing import List, Optional, Dict, Any, Callable from pydantic import BaseModel, Field from abc import ABC, abstractmethod from sqlalchemy import create_engine, Column, String, Float, Boolean, DateTime, Integer, JSON, ForeignKey, Text from sqlalchemy.orm import declarative_base, sessionmaker, Session, relationship from sqlalchemy.exc import IntegrityError from sqlalchemy.pool import StaticPool from sqlalchemy import event from f02_1_flight_lifecycle_manager import Flight, Waypoint, GPSPoint, CameraParameters, Geofences, Polygon, FlightState logger = logging.getLogger(__name__) # --- Data Models --- class FrameResult(BaseModel): frame_id: int gps_center: GPSPoint altitude: Optional[float] = None heading: float confidence: float refined: bool = False timestamp: datetime updated_at: datetime = Field(default_factory=datetime.utcnow) class HeadingRecord(BaseModel): frame_id: int heading: float timestamp: datetime class BatchResult(BaseModel): success: bool updated_count: int failed_ids: List[str] class ChunkHandle(BaseModel): chunk_id: str start_frame_id: int end_frame_id: Optional[int] = None frames: List[int] = [] is_active: bool = True has_anchor: bool = False anchor_frame_id: Optional[int] = None anchor_gps: Optional[GPSPoint] = None matching_status: str = 'unanchored' # --- SQLAlchemy ORM Models --- Base = declarative_base() class SQLFlight(Base): __tablename__ = 'flights' id = Column(String(36), primary_key=True) name = Column(String(255), nullable=False) description = Column(Text, default="") start_lat = Column(Float, nullable=False) start_lon = Column(Float, nullable=False) altitude = Column(Float, nullable=False) camera_params = Column(JSON, nullable=False) created_at = Column(DateTime, default=datetime.utcnow) updated_at = Column(DateTime, default=datetime.utcnow) class SQLWaypoint(Base): __tablename__ = 'waypoints' id = Column(String(36), primary_key=True) flight_id = Column(String(36), ForeignKey('flights.id', ondelete='CASCADE'), nullable=False) lat = Column(Float, nullable=False) lon = Column(Float, nullable=False) altitude = Column(Float) confidence = Column(Float, nullable=False) timestamp = Column(DateTime, nullable=False) refined = Column(Boolean, default=False) class SQLGeofence(Base): __tablename__ = 'geofences' id = Column(String(36), primary_key=True) flight_id = Column(String(36), ForeignKey('flights.id', ondelete='CASCADE'), nullable=False) nw_lat = Column(Float, nullable=False) nw_lon = Column(Float, nullable=False) se_lat = Column(Float, nullable=False) se_lon = Column(Float, nullable=False) class SQLFlightState(Base): __tablename__ = 'flight_state' flight_id = Column(String(36), ForeignKey('flights.id', ondelete='CASCADE'), primary_key=True) status = Column(String(50), nullable=False) frames_processed = Column(Integer, default=0) frames_total = Column(Integer, default=0) current_frame = Column(Integer) blocked = Column(Boolean, default=False) search_grid_size = Column(Integer) created_at = Column(DateTime, default=datetime.utcnow) updated_at = Column(DateTime, default=datetime.utcnow) class SQLFrameResult(Base): __tablename__ = 'frame_results' id = Column(String(72), primary_key=True) # Composite key representation: {flight_id}_{frame_id} flight_id = Column(String(36), ForeignKey('flights.id', ondelete='CASCADE'), nullable=False) frame_id = Column(Integer, nullable=False) gps_lat = Column(Float) gps_lon = Column(Float) altitude = Column(Float) heading = Column(Float) confidence = Column(Float) refined = Column(Boolean, default=False) timestamp = Column(DateTime) updated_at = Column(DateTime, default=datetime.utcnow) class SQLHeadingHistory(Base): __tablename__ = 'heading_history' id = Column(String(72), primary_key=True) # {flight_id}_{frame_id} flight_id = Column(String(36), ForeignKey('flights.id', ondelete='CASCADE'), nullable=False) frame_id = Column(Integer, nullable=False) heading = Column(Float, nullable=False) timestamp = Column(DateTime, nullable=False) class SQLFlightImage(Base): __tablename__ = 'flight_images' id = Column(String(72), primary_key=True) # {flight_id}_{frame_id} flight_id = Column(String(36), ForeignKey('flights.id', ondelete='CASCADE'), nullable=False) frame_id = Column(Integer, nullable=False) file_path = Column(String(500), nullable=False) metadata_json = Column(JSON) uploaded_at = Column(DateTime, default=datetime.utcnow) class SQLChunk(Base): __tablename__ = 'chunks' chunk_id = Column(String(36), primary_key=True) flight_id = Column(String(36), ForeignKey('flights.id', ondelete='CASCADE'), nullable=False) start_frame_id = Column(Integer, nullable=False) end_frame_id = Column(Integer) frames = Column(JSON, nullable=False) is_active = Column(Boolean, default=True) has_anchor = Column(Boolean, default=False) anchor_frame_id = Column(Integer) anchor_lat = Column(Float) anchor_lon = Column(Float) matching_status = Column(String(50), default='unanchored') created_at = Column(DateTime, default=datetime.utcnow) updated_at = Column(DateTime, default=datetime.utcnow) # --- Implementation --- class FlightDatabase: """ Provides transactional CRUD operations and state persistence over SQLAlchemy. Supports connection pooling and thread-safe batch transactions. """ def __init__(self, db_url: str = "sqlite:///:memory:"): connect_args = {"check_same_thread": False} if db_url.startswith("sqlite") else {} if db_url == "sqlite:///:memory:": self.engine = create_engine(db_url, connect_args=connect_args, poolclass=StaticPool) else: self.engine = create_engine(db_url, connect_args=connect_args) # Enable foreign key constraints for SQLite if db_url.startswith("sqlite"): @event.listens_for(self.engine, "connect") def set_sqlite_pragma(dbapi_connection, connection_record): cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() Base.metadata.create_all(self.engine) self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine) # Thread-local storage to coordinate active transactions self._local = threading.local() def _get_session(self) -> Session: if getattr(self._local, 'in_transaction', False): return self._local.session return self.SessionLocal() def _close_session_if_needed(self, session: Session): if not getattr(self._local, 'in_transaction', False): session.commit() session.close() def _rollback_if_needed(self, session: Session): if not getattr(self._local, 'in_transaction', False): session.rollback() session.close() def _get_connection(self) -> Session: """Alias for _get_session to map to 03.01 spec naming conventions.""" return self._get_session() def _release_connection(self, conn: Session): """Alias to release connection back to the pool.""" self._close_session_if_needed(conn) def _execute_with_retry(self, operation: Callable, retries: int = 3) -> Any: """Executes a database operation with automatic retry on transient errors.""" last_exception = None for attempt in range(retries): try: return operation() except Exception as e: last_exception = e time.sleep(0.1 * (2 ** attempt)) # Exponential backoff raise last_exception def _serialize_camera_params(self, params: CameraParameters) -> dict: return params.model_dump() def _deserialize_camera_params(self, jsonb: dict) -> CameraParameters: return CameraParameters(**jsonb) def _serialize_metadata(self, metadata: Dict) -> dict: return metadata def _deserialize_metadata(self, jsonb: dict) -> Dict: return jsonb if jsonb else {} def _serialize_chunk_frames(self, frames: List[int]) -> list: return frames def _deserialize_chunk_frames(self, jsonb: list) -> List[int]: return jsonb if jsonb else [] def _build_flight_from_row(self, row: SQLFlight) -> Flight: return Flight( flight_id=row.id, flight_name=row.name, start_gps=GPSPoint(lat=row.start_lat, lon=row.start_lon), altitude_m=row.altitude, camera_params=self._deserialize_camera_params(row.camera_params) ) def _build_waypoint_from_row(self, row: SQLWaypoint) -> Waypoint: return Waypoint( id=row.id, lat=row.lat, lon=row.lon, altitude=row.altitude, confidence=row.confidence, timestamp=row.timestamp, refined=row.refined ) def _build_filter_query(self, query: Any, filters: Dict[str, Any]) -> Any: if filters: if "name" in filters: query = query.filter(SQLFlight.name.like(filters["name"])) if "status" in filters: query = query.join(SQLFlightState).filter(SQLFlightState.status == filters["status"]) return query def _build_flight_state_from_row(self, row: SQLFlightState) -> FlightState: return FlightState( flight_id=row.flight_id, state=row.status, processed_images=row.frames_processed, total_images=row.frames_total ) def _build_frame_result_from_row(self, row: SQLFrameResult) -> FrameResult: return FrameResult( frame_id=row.frame_id, gps_center=GPSPoint(lat=row.gps_lat, lon=row.gps_lon), altitude=row.altitude, heading=row.heading, confidence=row.confidence, refined=row.refined, timestamp=row.timestamp, updated_at=row.updated_at ) def _build_heading_record_from_row(self, row: SQLHeadingHistory) -> HeadingRecord: return HeadingRecord(frame_id=row.frame_id, heading=row.heading, timestamp=row.timestamp) def _build_chunk_handle_from_row(self, row: SQLChunk) -> ChunkHandle: gps = GPSPoint(lat=row.anchor_lat, lon=row.anchor_lon) if row.anchor_lat is not None and row.anchor_lon is not None else None return ChunkHandle( chunk_id=row.chunk_id, start_frame_id=row.start_frame_id, end_frame_id=row.end_frame_id, frames=self._deserialize_chunk_frames(row.frames), is_active=row.is_active, has_anchor=row.has_anchor, anchor_frame_id=row.anchor_frame_id, anchor_gps=gps, matching_status=row.matching_status ) def _upsert_flight_state(self, state: FlightState) -> bool: session = self._get_connection() try: state_obj = SQLFlightState( flight_id=state.flight_id, status=state.state, frames_processed=state.processed_images, frames_total=state.total_images ) session.merge(state_obj) self._release_connection(session) return True except Exception: self._rollback_if_needed(session) return False def _upsert_frame_result(self, flight_id: str, result: FrameResult) -> bool: session = self._get_connection() try: fr = SQLFrameResult( id=f"{flight_id}_{result.frame_id}", flight_id=flight_id, frame_id=result.frame_id, gps_lat=result.gps_center.lat, gps_lon=result.gps_center.lon, altitude=result.altitude, heading=result.heading, confidence=result.confidence, refined=result.refined, timestamp=result.timestamp, updated_at=result.updated_at ) session.merge(fr) self._release_connection(session) return True except Exception: self._rollback_if_needed(session) return False def _upsert_chunk_state(self, flight_id: str, chunk: ChunkHandle) -> bool: session = self._get_connection() try: anchor_lat = chunk.anchor_gps.lat if chunk.anchor_gps else None anchor_lon = chunk.anchor_gps.lon if chunk.anchor_gps else None c = SQLChunk( chunk_id=chunk.chunk_id, flight_id=flight_id, start_frame_id=chunk.start_frame_id, end_frame_id=chunk.end_frame_id, frames=self._serialize_chunk_frames(chunk.frames), is_active=chunk.is_active, has_anchor=chunk.has_anchor, anchor_frame_id=chunk.anchor_frame_id, anchor_lat=anchor_lat, anchor_lon=anchor_lon, matching_status=chunk.matching_status ) session.merge(c) self._release_connection(session) return True except Exception: self._rollback_if_needed(session) return False # --- Transaction Support --- def execute_transaction(self, operations: List[Callable[[], None]]) -> bool: session = self.SessionLocal() self._local.session = session self._local.in_transaction = True try: for op in operations: op() session.commit() return True except Exception as e: session.rollback() logger.error(f"Transaction failed: {e}") return False finally: self._local.in_transaction = False self._local.session = None session.close() # --- Flight Operations --- def insert_flight(self, flight: Flight) -> str: def _do_insert(): session = self._get_connection() try: sql_flight = SQLFlight( id=flight.flight_id, name=flight.flight_name, description=flight.flight_name, start_lat=flight.start_gps.lat, start_lon=flight.start_gps.lon, altitude=flight.altitude_m, camera_params=self._serialize_camera_params(flight.camera_params), created_at=flight.created_at, updated_at=flight.updated_at ) session.add(sql_flight) self._release_connection(session) return flight.flight_id except IntegrityError as e: self._rollback_if_needed(session) raise ValueError(f"Duplicate flight or integrity error: {e}") except Exception as e: self._rollback_if_needed(session) raise e return self._execute_with_retry(_do_insert) def update_flight(self, flight: Flight) -> bool: session = self._get_connection() try: sql_flight = session.query(SQLFlight).filter_by(id=flight.flight_id).first() if not sql_flight: self._release_connection(session) return False sql_flight.name = flight.flight_name sql_flight.updated_at = datetime.utcnow() self._release_connection(session) return True except Exception: self._rollback_if_needed(session) return False def query_flights(self, filters: Dict[str, Any], limit: int, offset: int = 0) -> List[Flight]: session = self._get_connection() query = session.query(SQLFlight) query = self._build_filter_query(query, filters) sql_flights = query.offset(offset).limit(limit).all() flights = [self._build_flight_from_row(f) for f in sql_flights] self._release_connection(session) return flights def get_flight_by_id(self, flight_id: str) -> Optional[Flight]: session = self._get_connection() f = session.query(SQLFlight).filter_by(id=flight_id).first() if not f: self._release_connection(session) return None flight = self._build_flight_from_row(f) self._release_connection(session) return flight def delete_flight(self, flight_id: str) -> bool: session = self._get_connection() try: sql_flight = session.query(SQLFlight).filter_by(id=flight_id).first() if not sql_flight: self._release_connection(session) return False session.delete(sql_flight) # Cascade handles related rows self._release_connection(session) return True except Exception: self._rollback_if_needed(session) return False # --- Waypoint Operations --- def get_waypoints(self, flight_id: str, limit: Optional[int] = None) -> List[Waypoint]: session = self._get_connection() query = session.query(SQLWaypoint).filter_by(flight_id=flight_id).order_by(SQLWaypoint.timestamp) if limit: query = query.limit(limit) wps = [self._build_waypoint_from_row(w) for w in query.all()] self._release_connection(session) return wps def insert_waypoint(self, flight_id: str, waypoint: Waypoint) -> str: session = self._get_connection() try: sql_wp = SQLWaypoint( id=waypoint.id, flight_id=flight_id, lat=waypoint.lat, lon=waypoint.lon, altitude=waypoint.altitude, confidence=waypoint.confidence, timestamp=waypoint.timestamp, refined=waypoint.refined ) session.merge(sql_wp) self._release_connection(session) return waypoint.id except Exception as e: self._rollback_if_needed(session) raise ValueError(f"Failed to insert waypoint: {e}") def update_waypoint(self, flight_id: str, waypoint_id: str, waypoint: Waypoint) -> bool: session = self._get_connection() try: wp = session.query(SQLWaypoint).filter_by(id=waypoint_id, flight_id=flight_id).first() if not wp: self._release_connection(session) return False wp.lat, wp.lon = waypoint.lat, waypoint.lon wp.altitude, wp.confidence = waypoint.altitude, waypoint.confidence wp.refined = waypoint.refined self._release_connection(session) return True except Exception: self._rollback_if_needed(session) return False def batch_update_waypoints(self, flight_id: str, waypoints: List[Waypoint]) -> BatchResult: failed = [] def do_update(): for wp in waypoints: success = self.update_waypoint(flight_id, wp.id, wp) if not success: failed.append(wp.id) success = self.execute_transaction([do_update]) if not success: return BatchResult(success=False, updated_count=0, failed_ids=[w.id for w in waypoints]) return BatchResult(success=len(failed) == 0, updated_count=len(waypoints) - len(failed), failed_ids=failed) # --- Flight State & auxiliary persistence --- def save_flight_state(self, flight_state: FlightState) -> bool: return self._execute_with_retry(lambda: self._upsert_flight_state(flight_state)) def load_flight_state(self, flight_id: str) -> Optional[FlightState]: session = self._get_connection() s = session.query(SQLFlightState).filter_by(flight_id=flight_id).first() result = self._build_flight_state_from_row(s) if s else None self._release_connection(session) return result def query_processing_history(self, filters: Dict[str, Any]) -> List[FlightState]: session = self._get_connection() query = session.query(SQLFlightState) if filters: if "status" in filters: query = query.filter(SQLFlightState.status == filters["status"]) if "created_after" in filters: query = query.filter(SQLFlightState.created_at >= filters["created_after"]) if "created_before" in filters: query = query.filter(SQLFlightState.created_at <= filters["created_before"]) results = [self._build_flight_state_from_row(r) for r in query.all()] self._release_connection(session) return results def save_frame_result(self, flight_id: str, frame_result: FrameResult) -> bool: return self._execute_with_retry(lambda: self._upsert_frame_result(flight_id, frame_result)) def get_frame_results(self, flight_id: str) -> List[FrameResult]: session = self._get_connection() results = session.query(SQLFrameResult).filter_by(flight_id=flight_id).order_by(SQLFrameResult.frame_id).all() parsed = [self._build_frame_result_from_row(r) for r in results] self._release_connection(session) return parsed def save_heading(self, flight_id: str, frame_id: int, heading: float, timestamp: datetime) -> bool: def _do_save(): session = self._get_connection() try: obj = SQLHeadingHistory(id=f"{flight_id}_{frame_id}", flight_id=flight_id, frame_id=frame_id, heading=heading, timestamp=timestamp) session.merge(obj) self._release_connection(session) return True except Exception: self._rollback_if_needed(session) return False return self._execute_with_retry(_do_save) def get_heading_history(self, flight_id: str, last_n: Optional[int] = None) -> List[HeadingRecord]: session = self._get_connection() query = session.query(SQLHeadingHistory).filter_by(flight_id=flight_id).order_by(SQLHeadingHistory.frame_id.desc()) if last_n: query = query.limit(last_n) results = [self._build_heading_record_from_row(r) for r in query.all()] self._release_connection(session) return results def get_latest_heading(self, flight_id: str) -> Optional[float]: session = self._get_connection() h = session.query(SQLHeadingHistory).filter_by(flight_id=flight_id).order_by(SQLHeadingHistory.frame_id.desc()).first() result = h.heading if h else None self._release_connection(session) return result def save_image_metadata(self, flight_id: str, frame_id: int, file_path: str, metadata: Dict) -> bool: def _do_save(): session = self._get_connection() try: img = SQLFlightImage(id=f"{flight_id}_{frame_id}", flight_id=flight_id, frame_id=frame_id, file_path=file_path, metadata_json=self._serialize_metadata(metadata)) session.merge(img) self._release_connection(session) return True except Exception: self._rollback_if_needed(session) return False return self._execute_with_retry(_do_save) def get_image_path(self, flight_id: str, frame_id: int) -> Optional[str]: session = self._get_connection() img = session.query(SQLFlightImage).filter_by(flight_id=flight_id, frame_id=frame_id).first() result = img.file_path if img else None self._release_connection(session) return result def get_image_metadata(self, flight_id: str, frame_id: int) -> Optional[Dict]: session = self._get_connection() img = session.query(SQLFlightImage).filter_by(flight_id=flight_id, frame_id=frame_id).first() result = self._deserialize_metadata(img.metadata_json) if img else None self._release_connection(session) return result def save_chunk_state(self, flight_id: str, chunk: ChunkHandle) -> bool: return self._execute_with_retry(lambda: self._upsert_chunk_state(flight_id, chunk)) def load_chunk_states(self, flight_id: str) -> List[ChunkHandle]: session = self._get_connection() sql_chunks = session.query(SQLChunk).filter_by(flight_id=flight_id).all() handles = [self._build_chunk_handle_from_row(c) for c in sql_chunks] self._release_connection(session) return handles def delete_chunk_state(self, flight_id: str, chunk_id: str) -> bool: session = self._get_connection() try: chunk = session.query(SQLChunk).filter_by(flight_id=flight_id, chunk_id=chunk_id).first() if not chunk: self._release_connection(session) return False session.delete(chunk) self._release_connection(session) return True except Exception: self._rollback_if_needed(session) return False