mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-04-22 17:46:43 +00:00
296 lines
10 KiB
Python
296 lines
10 KiB
Python
"""Repository (DAO) implementing IFlightDatabase operations."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
from sqlalchemy import delete, select, update
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from gps_denied.db.models import (
|
|
ChunkRow,
|
|
FlightRow,
|
|
FlightStateRow,
|
|
FrameResultRow,
|
|
GeofenceRow,
|
|
HeadingRow,
|
|
ImageRow,
|
|
WaypointRow,
|
|
)
|
|
|
|
|
|
class FlightRepository:
|
|
"""Async repository wrapping all DB operations for flights."""
|
|
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
self._s = session
|
|
|
|
# ── Flight CRUD ───────────────────────────────────────────────────
|
|
|
|
async def insert_flight(
|
|
self,
|
|
*,
|
|
name: str,
|
|
description: str,
|
|
start_lat: float,
|
|
start_lon: float,
|
|
altitude: float,
|
|
camera_params: dict,
|
|
flight_id: str | None = None,
|
|
) -> FlightRow:
|
|
row = FlightRow(
|
|
name=name,
|
|
description=description,
|
|
start_lat=start_lat,
|
|
start_lon=start_lon,
|
|
altitude=altitude,
|
|
camera_params=camera_params,
|
|
)
|
|
if flight_id:
|
|
row.id = flight_id
|
|
self._s.add(row)
|
|
await self._s.flush()
|
|
# Create initial flight state
|
|
state = FlightStateRow(flight_id=row.id)
|
|
self._s.add(state)
|
|
await self._s.flush()
|
|
return row
|
|
|
|
async def get_flight(self, flight_id: str) -> FlightRow | None:
|
|
return await self._s.get(FlightRow, flight_id)
|
|
|
|
async def list_flights(
|
|
self,
|
|
*,
|
|
status: str | None = None,
|
|
limit: int = 50,
|
|
offset: int = 0,
|
|
) -> list[FlightRow]:
|
|
stmt = select(FlightRow).offset(offset).limit(limit).order_by(FlightRow.created_at.desc())
|
|
if status:
|
|
stmt = stmt.join(FlightStateRow).where(FlightStateRow.status == status)
|
|
result = await self._s.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def update_flight(self, flight_id: str, **kwargs: Any) -> bool:
|
|
kwargs["updated_at"] = datetime.now(tz=timezone.utc)
|
|
stmt = update(FlightRow).where(FlightRow.id == flight_id).values(**kwargs)
|
|
result = await self._s.execute(stmt)
|
|
return result.rowcount > 0 # type: ignore[union-attr]
|
|
|
|
async def delete_flight(self, flight_id: str) -> bool:
|
|
stmt = delete(FlightRow).where(FlightRow.id == flight_id)
|
|
result = await self._s.execute(stmt)
|
|
return result.rowcount > 0 # type: ignore[union-attr]
|
|
|
|
# ── Waypoints ─────────────────────────────────────────────────────
|
|
|
|
async def insert_waypoint(
|
|
self,
|
|
flight_id: str,
|
|
*,
|
|
lat: float,
|
|
lon: float,
|
|
altitude: float | None = None,
|
|
confidence: float = 0.0,
|
|
refined: bool = False,
|
|
waypoint_id: str | None = None,
|
|
) -> WaypointRow:
|
|
row = WaypointRow(
|
|
flight_id=flight_id,
|
|
lat=lat,
|
|
lon=lon,
|
|
altitude=altitude,
|
|
confidence=confidence,
|
|
refined=refined,
|
|
)
|
|
if waypoint_id:
|
|
row.id = waypoint_id
|
|
self._s.add(row)
|
|
await self._s.flush()
|
|
return row
|
|
|
|
async def get_waypoints(self, flight_id: str, *, limit: int | None = None) -> list[WaypointRow]:
|
|
stmt = (
|
|
select(WaypointRow)
|
|
.where(WaypointRow.flight_id == flight_id)
|
|
.order_by(WaypointRow.timestamp)
|
|
)
|
|
if limit:
|
|
stmt = stmt.limit(limit)
|
|
result = await self._s.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def update_waypoint(self, flight_id: str, waypoint_id: str, **kwargs: Any) -> bool:
|
|
stmt = (
|
|
update(WaypointRow)
|
|
.where(WaypointRow.id == waypoint_id, WaypointRow.flight_id == flight_id)
|
|
.values(**kwargs)
|
|
)
|
|
result = await self._s.execute(stmt)
|
|
return result.rowcount > 0 # type: ignore[union-attr]
|
|
|
|
# ── Flight State ──────────────────────────────────────────────────
|
|
|
|
async def save_flight_state(self, flight_id: str, **kwargs: Any) -> bool:
|
|
kwargs["updated_at"] = datetime.now(tz=timezone.utc)
|
|
stmt = update(FlightStateRow).where(FlightStateRow.flight_id == flight_id).values(**kwargs)
|
|
result = await self._s.execute(stmt)
|
|
return result.rowcount > 0 # type: ignore[union-attr]
|
|
|
|
async def load_flight_state(self, flight_id: str) -> FlightStateRow | None:
|
|
return await self._s.get(FlightStateRow, flight_id)
|
|
|
|
# ── Frame Results ─────────────────────────────────────────────────
|
|
|
|
async def save_frame_result(
|
|
self,
|
|
flight_id: str,
|
|
*,
|
|
frame_id: int,
|
|
gps_lat: float,
|
|
gps_lon: float,
|
|
altitude: float = 0.0,
|
|
heading: float = 0.0,
|
|
confidence: float = 0.0,
|
|
refined: bool = False,
|
|
) -> FrameResultRow:
|
|
row = FrameResultRow(
|
|
flight_id=flight_id,
|
|
frame_id=frame_id,
|
|
gps_lat=gps_lat,
|
|
gps_lon=gps_lon,
|
|
altitude=altitude,
|
|
heading=heading,
|
|
confidence=confidence,
|
|
refined=refined,
|
|
)
|
|
self._s.add(row)
|
|
await self._s.flush()
|
|
return row
|
|
|
|
async def get_frame_results(self, flight_id: str) -> list[FrameResultRow]:
|
|
stmt = (
|
|
select(FrameResultRow)
|
|
.where(FrameResultRow.flight_id == flight_id)
|
|
.order_by(FrameResultRow.frame_id)
|
|
)
|
|
result = await self._s.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
# ── Heading History ───────────────────────────────────────────────
|
|
|
|
async def save_heading(
|
|
self, flight_id: str, *, frame_id: int, heading: float
|
|
) -> HeadingRow:
|
|
row = HeadingRow(flight_id=flight_id, frame_id=frame_id, heading=heading)
|
|
self._s.add(row)
|
|
await self._s.flush()
|
|
return row
|
|
|
|
async def get_heading_history(
|
|
self, flight_id: str, *, last_n: int | None = None
|
|
) -> list[HeadingRow]:
|
|
stmt = (
|
|
select(HeadingRow)
|
|
.where(HeadingRow.flight_id == flight_id)
|
|
.order_by(HeadingRow.timestamp.desc())
|
|
)
|
|
if last_n:
|
|
stmt = stmt.limit(last_n)
|
|
result = await self._s.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def get_latest_heading(self, flight_id: str) -> float | None:
|
|
rows = await self.get_heading_history(flight_id, last_n=1)
|
|
return rows[0].heading if rows else None
|
|
|
|
# ── Images ────────────────────────────────────────────────────────
|
|
|
|
async def save_image_metadata(
|
|
self, flight_id: str, *, frame_id: int, file_path: str, metadata: dict | None = None
|
|
) -> ImageRow:
|
|
row = ImageRow(
|
|
flight_id=flight_id, frame_id=frame_id, file_path=file_path, metadata_json=metadata
|
|
)
|
|
self._s.add(row)
|
|
await self._s.flush()
|
|
return row
|
|
|
|
async def get_image_path(self, flight_id: str, frame_id: int) -> str | None:
|
|
stmt = select(ImageRow.file_path).where(
|
|
ImageRow.flight_id == flight_id, ImageRow.frame_id == frame_id
|
|
)
|
|
result = await self._s.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
# ── Chunks ────────────────────────────────────────────────────────
|
|
|
|
async def save_chunk(
|
|
self,
|
|
flight_id: str,
|
|
*,
|
|
chunk_id: str | None = None,
|
|
start_frame_id: int,
|
|
end_frame_id: int | None = None,
|
|
frames: list[int] | None = None,
|
|
is_active: bool = True,
|
|
has_anchor: bool = False,
|
|
anchor_frame_id: int | None = None,
|
|
anchor_lat: float | None = None,
|
|
anchor_lon: float | None = None,
|
|
matching_status: str = "pending",
|
|
) -> ChunkRow:
|
|
row = ChunkRow(
|
|
flight_id=flight_id,
|
|
start_frame_id=start_frame_id,
|
|
end_frame_id=end_frame_id,
|
|
frames=frames or [],
|
|
is_active=is_active,
|
|
has_anchor=has_anchor,
|
|
anchor_frame_id=anchor_frame_id,
|
|
anchor_lat=anchor_lat,
|
|
anchor_lon=anchor_lon,
|
|
matching_status=matching_status,
|
|
)
|
|
if chunk_id:
|
|
row.chunk_id = chunk_id
|
|
self._s.add(row)
|
|
await self._s.flush()
|
|
return row
|
|
|
|
async def load_chunks(self, flight_id: str) -> list[ChunkRow]:
|
|
stmt = select(ChunkRow).where(ChunkRow.flight_id == flight_id)
|
|
result = await self._s.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def delete_chunk(self, flight_id: str, chunk_id: str) -> bool:
|
|
stmt = delete(ChunkRow).where(
|
|
ChunkRow.chunk_id == chunk_id, ChunkRow.flight_id == flight_id
|
|
)
|
|
result = await self._s.execute(stmt)
|
|
return result.rowcount > 0 # type: ignore[union-attr]
|
|
|
|
# ── Geofences ─────────────────────────────────────────────────────
|
|
|
|
async def insert_geofence(
|
|
self,
|
|
flight_id: str,
|
|
*,
|
|
nw_lat: float,
|
|
nw_lon: float,
|
|
se_lat: float,
|
|
se_lon: float,
|
|
) -> GeofenceRow:
|
|
row = GeofenceRow(
|
|
flight_id=flight_id,
|
|
nw_lat=nw_lat,
|
|
nw_lon=nw_lon,
|
|
se_lat=se_lat,
|
|
se_lon=se_lon,
|
|
)
|
|
self._s.add(row)
|
|
await self._s.flush()
|
|
return row
|