"""Repository (DAO) implementing IFlightDatabase operations.""" from __future__ import annotations from datetime import datetime, timezone from typing import Any from sqlalchemy import delete, select, update from sqlalchemy.ext.asyncio import AsyncSession from gps_denied.db.models import ( ChunkRow, FlightRow, FlightStateRow, FrameResultRow, GeofenceRow, HeadingRow, ImageRow, WaypointRow, ) class FlightRepository: """Async repository wrapping all DB operations for flights.""" def __init__(self, session: AsyncSession) -> None: self._s = session # ── Flight CRUD ─────────────────────────────────────────────────── async def insert_flight( self, *, name: str, description: str, start_lat: float, start_lon: float, altitude: float, camera_params: dict, flight_id: str | None = None, ) -> FlightRow: row = FlightRow( name=name, description=description, start_lat=start_lat, start_lon=start_lon, altitude=altitude, camera_params=camera_params, ) if flight_id: row.id = flight_id self._s.add(row) await self._s.flush() # Create initial flight state state = FlightStateRow(flight_id=row.id) self._s.add(state) await self._s.flush() return row async def get_flight(self, flight_id: str) -> FlightRow | None: return await self._s.get(FlightRow, flight_id) async def list_flights( self, *, status: str | None = None, limit: int = 50, offset: int = 0, ) -> list[FlightRow]: stmt = select(FlightRow).offset(offset).limit(limit).order_by(FlightRow.created_at.desc()) if status: stmt = stmt.join(FlightStateRow).where(FlightStateRow.status == status) result = await self._s.execute(stmt) return list(result.scalars().all()) async def update_flight(self, flight_id: str, **kwargs: Any) -> bool: kwargs["updated_at"] = datetime.now(tz=timezone.utc) stmt = update(FlightRow).where(FlightRow.id == flight_id).values(**kwargs) result = await self._s.execute(stmt) return result.rowcount > 0 # type: ignore[union-attr] async def delete_flight(self, flight_id: str) -> bool: stmt = delete(FlightRow).where(FlightRow.id == flight_id) result = await self._s.execute(stmt) return result.rowcount > 0 # type: ignore[union-attr] # ── Waypoints ───────────────────────────────────────────────────── async def insert_waypoint( self, flight_id: str, *, lat: float, lon: float, altitude: float | None = None, confidence: float = 0.0, refined: bool = False, waypoint_id: str | None = None, ) -> WaypointRow: row = WaypointRow( flight_id=flight_id, lat=lat, lon=lon, altitude=altitude, confidence=confidence, refined=refined, ) if waypoint_id: row.id = waypoint_id self._s.add(row) await self._s.flush() return row async def get_waypoints(self, flight_id: str, *, limit: int | None = None) -> list[WaypointRow]: stmt = ( select(WaypointRow) .where(WaypointRow.flight_id == flight_id) .order_by(WaypointRow.timestamp) ) if limit: stmt = stmt.limit(limit) result = await self._s.execute(stmt) return list(result.scalars().all()) async def update_waypoint(self, flight_id: str, waypoint_id: str, **kwargs: Any) -> bool: stmt = ( update(WaypointRow) .where(WaypointRow.id == waypoint_id, WaypointRow.flight_id == flight_id) .values(**kwargs) ) result = await self._s.execute(stmt) return result.rowcount > 0 # type: ignore[union-attr] # ── Flight State ────────────────────────────────────────────────── async def save_flight_state(self, flight_id: str, **kwargs: Any) -> bool: kwargs["updated_at"] = datetime.now(tz=timezone.utc) stmt = update(FlightStateRow).where(FlightStateRow.flight_id == flight_id).values(**kwargs) result = await self._s.execute(stmt) return result.rowcount > 0 # type: ignore[union-attr] async def load_flight_state(self, flight_id: str) -> FlightStateRow | None: return await self._s.get(FlightStateRow, flight_id) # ── Frame Results ───────────────────────────────────────────────── async def save_frame_result( self, flight_id: str, *, frame_id: int, gps_lat: float, gps_lon: float, altitude: float = 0.0, heading: float = 0.0, confidence: float = 0.0, refined: bool = False, ) -> FrameResultRow: row = FrameResultRow( flight_id=flight_id, frame_id=frame_id, gps_lat=gps_lat, gps_lon=gps_lon, altitude=altitude, heading=heading, confidence=confidence, refined=refined, ) self._s.add(row) await self._s.flush() return row async def get_frame_results(self, flight_id: str) -> list[FrameResultRow]: stmt = ( select(FrameResultRow) .where(FrameResultRow.flight_id == flight_id) .order_by(FrameResultRow.frame_id) ) result = await self._s.execute(stmt) return list(result.scalars().all()) # ── Heading History ─────────────────────────────────────────────── async def save_heading( self, flight_id: str, *, frame_id: int, heading: float ) -> HeadingRow: row = HeadingRow(flight_id=flight_id, frame_id=frame_id, heading=heading) self._s.add(row) await self._s.flush() return row async def get_heading_history( self, flight_id: str, *, last_n: int | None = None ) -> list[HeadingRow]: stmt = ( select(HeadingRow) .where(HeadingRow.flight_id == flight_id) .order_by(HeadingRow.timestamp.desc()) ) if last_n: stmt = stmt.limit(last_n) result = await self._s.execute(stmt) return list(result.scalars().all()) async def get_latest_heading(self, flight_id: str) -> float | None: rows = await self.get_heading_history(flight_id, last_n=1) return rows[0].heading if rows else None # ── Images ──────────────────────────────────────────────────────── async def save_image_metadata( self, flight_id: str, *, frame_id: int, file_path: str, metadata: dict | None = None ) -> ImageRow: row = ImageRow( flight_id=flight_id, frame_id=frame_id, file_path=file_path, metadata_json=metadata ) self._s.add(row) await self._s.flush() return row async def get_image_path(self, flight_id: str, frame_id: int) -> str | None: stmt = select(ImageRow.file_path).where( ImageRow.flight_id == flight_id, ImageRow.frame_id == frame_id ) result = await self._s.execute(stmt) return result.scalar_one_or_none() # ── Chunks ──────────────────────────────────────────────────────── async def save_chunk( self, flight_id: str, *, chunk_id: str | None = None, start_frame_id: int, end_frame_id: int | None = None, frames: list[int] | None = None, is_active: bool = True, has_anchor: bool = False, anchor_frame_id: int | None = None, anchor_lat: float | None = None, anchor_lon: float | None = None, matching_status: str = "pending", ) -> ChunkRow: row = ChunkRow( flight_id=flight_id, start_frame_id=start_frame_id, end_frame_id=end_frame_id, frames=frames or [], is_active=is_active, has_anchor=has_anchor, anchor_frame_id=anchor_frame_id, anchor_lat=anchor_lat, anchor_lon=anchor_lon, matching_status=matching_status, ) if chunk_id: row.chunk_id = chunk_id self._s.add(row) await self._s.flush() return row async def load_chunks(self, flight_id: str) -> list[ChunkRow]: stmt = select(ChunkRow).where(ChunkRow.flight_id == flight_id) result = await self._s.execute(stmt) return list(result.scalars().all()) async def delete_chunk(self, flight_id: str, chunk_id: str) -> bool: stmt = delete(ChunkRow).where( ChunkRow.chunk_id == chunk_id, ChunkRow.flight_id == flight_id ) result = await self._s.execute(stmt) return result.rowcount > 0 # type: ignore[union-attr] # ── Geofences ───────────────────────────────────────────────────── async def insert_geofence( self, flight_id: str, *, nw_lat: float, nw_lon: float, se_lat: float, se_lon: float, ) -> GeofenceRow: row = GeofenceRow( flight_id=flight_id, nw_lat=nw_lat, nw_lon=nw_lon, se_lat=se_lat, se_lon=se_lon, ) self._s.add(row) await self._s.flush() return row