mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-04-23 02:06:36 +00:00
feat: stage2 — SQLite DB layer (ORM, async engine, repository, cascade delete, 9 DB tests)
This commit is contained in:
@@ -0,0 +1,295 @@
|
||||
"""Repository (DAO) implementing IFlightDatabase operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from gps_denied.db.models import (
|
||||
ChunkRow,
|
||||
FlightRow,
|
||||
FlightStateRow,
|
||||
FrameResultRow,
|
||||
GeofenceRow,
|
||||
HeadingRow,
|
||||
ImageRow,
|
||||
WaypointRow,
|
||||
)
|
||||
|
||||
|
||||
class FlightRepository:
|
||||
"""Async repository wrapping all DB operations for flights."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._s = session
|
||||
|
||||
# ── Flight CRUD ───────────────────────────────────────────────────
|
||||
|
||||
async def insert_flight(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
description: str,
|
||||
start_lat: float,
|
||||
start_lon: float,
|
||||
altitude: float,
|
||||
camera_params: dict,
|
||||
flight_id: str | None = None,
|
||||
) -> FlightRow:
|
||||
row = FlightRow(
|
||||
name=name,
|
||||
description=description,
|
||||
start_lat=start_lat,
|
||||
start_lon=start_lon,
|
||||
altitude=altitude,
|
||||
camera_params=camera_params,
|
||||
)
|
||||
if flight_id:
|
||||
row.id = flight_id
|
||||
self._s.add(row)
|
||||
await self._s.flush()
|
||||
# Create initial flight state
|
||||
state = FlightStateRow(flight_id=row.id)
|
||||
self._s.add(state)
|
||||
await self._s.flush()
|
||||
return row
|
||||
|
||||
async def get_flight(self, flight_id: str) -> FlightRow | None:
|
||||
return await self._s.get(FlightRow, flight_id)
|
||||
|
||||
async def list_flights(
|
||||
self,
|
||||
*,
|
||||
status: str | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[FlightRow]:
|
||||
stmt = select(FlightRow).offset(offset).limit(limit).order_by(FlightRow.created_at.desc())
|
||||
if status:
|
||||
stmt = stmt.join(FlightStateRow).where(FlightStateRow.status == status)
|
||||
result = await self._s.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def update_flight(self, flight_id: str, **kwargs: Any) -> bool:
|
||||
kwargs["updated_at"] = datetime.now(tz=timezone.utc)
|
||||
stmt = update(FlightRow).where(FlightRow.id == flight_id).values(**kwargs)
|
||||
result = await self._s.execute(stmt)
|
||||
return result.rowcount > 0 # type: ignore[union-attr]
|
||||
|
||||
async def delete_flight(self, flight_id: str) -> bool:
|
||||
stmt = delete(FlightRow).where(FlightRow.id == flight_id)
|
||||
result = await self._s.execute(stmt)
|
||||
return result.rowcount > 0 # type: ignore[union-attr]
|
||||
|
||||
# ── Waypoints ─────────────────────────────────────────────────────
|
||||
|
||||
async def insert_waypoint(
|
||||
self,
|
||||
flight_id: str,
|
||||
*,
|
||||
lat: float,
|
||||
lon: float,
|
||||
altitude: float | None = None,
|
||||
confidence: float = 0.0,
|
||||
refined: bool = False,
|
||||
waypoint_id: str | None = None,
|
||||
) -> WaypointRow:
|
||||
row = WaypointRow(
|
||||
flight_id=flight_id,
|
||||
lat=lat,
|
||||
lon=lon,
|
||||
altitude=altitude,
|
||||
confidence=confidence,
|
||||
refined=refined,
|
||||
)
|
||||
if waypoint_id:
|
||||
row.id = waypoint_id
|
||||
self._s.add(row)
|
||||
await self._s.flush()
|
||||
return row
|
||||
|
||||
async def get_waypoints(self, flight_id: str, *, limit: int | None = None) -> list[WaypointRow]:
|
||||
stmt = (
|
||||
select(WaypointRow)
|
||||
.where(WaypointRow.flight_id == flight_id)
|
||||
.order_by(WaypointRow.timestamp)
|
||||
)
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
result = await self._s.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def update_waypoint(self, flight_id: str, waypoint_id: str, **kwargs: Any) -> bool:
|
||||
stmt = (
|
||||
update(WaypointRow)
|
||||
.where(WaypointRow.id == waypoint_id, WaypointRow.flight_id == flight_id)
|
||||
.values(**kwargs)
|
||||
)
|
||||
result = await self._s.execute(stmt)
|
||||
return result.rowcount > 0 # type: ignore[union-attr]
|
||||
|
||||
# ── Flight State ──────────────────────────────────────────────────
|
||||
|
||||
async def save_flight_state(self, flight_id: str, **kwargs: Any) -> bool:
|
||||
kwargs["updated_at"] = datetime.now(tz=timezone.utc)
|
||||
stmt = update(FlightStateRow).where(FlightStateRow.flight_id == flight_id).values(**kwargs)
|
||||
result = await self._s.execute(stmt)
|
||||
return result.rowcount > 0 # type: ignore[union-attr]
|
||||
|
||||
async def load_flight_state(self, flight_id: str) -> FlightStateRow | None:
|
||||
return await self._s.get(FlightStateRow, flight_id)
|
||||
|
||||
# ── Frame Results ─────────────────────────────────────────────────
|
||||
|
||||
async def save_frame_result(
|
||||
self,
|
||||
flight_id: str,
|
||||
*,
|
||||
frame_id: int,
|
||||
gps_lat: float,
|
||||
gps_lon: float,
|
||||
altitude: float = 0.0,
|
||||
heading: float = 0.0,
|
||||
confidence: float = 0.0,
|
||||
refined: bool = False,
|
||||
) -> FrameResultRow:
|
||||
row = FrameResultRow(
|
||||
flight_id=flight_id,
|
||||
frame_id=frame_id,
|
||||
gps_lat=gps_lat,
|
||||
gps_lon=gps_lon,
|
||||
altitude=altitude,
|
||||
heading=heading,
|
||||
confidence=confidence,
|
||||
refined=refined,
|
||||
)
|
||||
self._s.add(row)
|
||||
await self._s.flush()
|
||||
return row
|
||||
|
||||
async def get_frame_results(self, flight_id: str) -> list[FrameResultRow]:
|
||||
stmt = (
|
||||
select(FrameResultRow)
|
||||
.where(FrameResultRow.flight_id == flight_id)
|
||||
.order_by(FrameResultRow.frame_id)
|
||||
)
|
||||
result = await self._s.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
# ── Heading History ───────────────────────────────────────────────
|
||||
|
||||
async def save_heading(
|
||||
self, flight_id: str, *, frame_id: int, heading: float
|
||||
) -> HeadingRow:
|
||||
row = HeadingRow(flight_id=flight_id, frame_id=frame_id, heading=heading)
|
||||
self._s.add(row)
|
||||
await self._s.flush()
|
||||
return row
|
||||
|
||||
async def get_heading_history(
|
||||
self, flight_id: str, *, last_n: int | None = None
|
||||
) -> list[HeadingRow]:
|
||||
stmt = (
|
||||
select(HeadingRow)
|
||||
.where(HeadingRow.flight_id == flight_id)
|
||||
.order_by(HeadingRow.timestamp.desc())
|
||||
)
|
||||
if last_n:
|
||||
stmt = stmt.limit(last_n)
|
||||
result = await self._s.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_latest_heading(self, flight_id: str) -> float | None:
|
||||
rows = await self.get_heading_history(flight_id, last_n=1)
|
||||
return rows[0].heading if rows else None
|
||||
|
||||
# ── Images ────────────────────────────────────────────────────────
|
||||
|
||||
async def save_image_metadata(
|
||||
self, flight_id: str, *, frame_id: int, file_path: str, metadata: dict | None = None
|
||||
) -> ImageRow:
|
||||
row = ImageRow(
|
||||
flight_id=flight_id, frame_id=frame_id, file_path=file_path, metadata_json=metadata
|
||||
)
|
||||
self._s.add(row)
|
||||
await self._s.flush()
|
||||
return row
|
||||
|
||||
async def get_image_path(self, flight_id: str, frame_id: int) -> str | None:
|
||||
stmt = select(ImageRow.file_path).where(
|
||||
ImageRow.flight_id == flight_id, ImageRow.frame_id == frame_id
|
||||
)
|
||||
result = await self._s.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
# ── Chunks ────────────────────────────────────────────────────────
|
||||
|
||||
async def save_chunk(
|
||||
self,
|
||||
flight_id: str,
|
||||
*,
|
||||
chunk_id: str | None = None,
|
||||
start_frame_id: int,
|
||||
end_frame_id: int | None = None,
|
||||
frames: list[int] | None = None,
|
||||
is_active: bool = True,
|
||||
has_anchor: bool = False,
|
||||
anchor_frame_id: int | None = None,
|
||||
anchor_lat: float | None = None,
|
||||
anchor_lon: float | None = None,
|
||||
matching_status: str = "pending",
|
||||
) -> ChunkRow:
|
||||
row = ChunkRow(
|
||||
flight_id=flight_id,
|
||||
start_frame_id=start_frame_id,
|
||||
end_frame_id=end_frame_id,
|
||||
frames=frames or [],
|
||||
is_active=is_active,
|
||||
has_anchor=has_anchor,
|
||||
anchor_frame_id=anchor_frame_id,
|
||||
anchor_lat=anchor_lat,
|
||||
anchor_lon=anchor_lon,
|
||||
matching_status=matching_status,
|
||||
)
|
||||
if chunk_id:
|
||||
row.chunk_id = chunk_id
|
||||
self._s.add(row)
|
||||
await self._s.flush()
|
||||
return row
|
||||
|
||||
async def load_chunks(self, flight_id: str) -> list[ChunkRow]:
|
||||
stmt = select(ChunkRow).where(ChunkRow.flight_id == flight_id)
|
||||
result = await self._s.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def delete_chunk(self, flight_id: str, chunk_id: str) -> bool:
|
||||
stmt = delete(ChunkRow).where(
|
||||
ChunkRow.chunk_id == chunk_id, ChunkRow.flight_id == flight_id
|
||||
)
|
||||
result = await self._s.execute(stmt)
|
||||
return result.rowcount > 0 # type: ignore[union-attr]
|
||||
|
||||
# ── Geofences ─────────────────────────────────────────────────────
|
||||
|
||||
async def insert_geofence(
|
||||
self,
|
||||
flight_id: str,
|
||||
*,
|
||||
nw_lat: float,
|
||||
nw_lon: float,
|
||||
se_lat: float,
|
||||
se_lon: float,
|
||||
) -> GeofenceRow:
|
||||
row = GeofenceRow(
|
||||
flight_id=flight_id,
|
||||
nw_lat=nw_lat,
|
||||
nw_lon=nw_lon,
|
||||
se_lat=se_lat,
|
||||
se_lon=se_lon,
|
||||
)
|
||||
self._s.add(row)
|
||||
await self._s.flush()
|
||||
return row
|
||||
Reference in New Issue
Block a user