From e47274bcbd0148c2d733bb62d2481dfe0641333f Mon Sep 17 00:00:00 2001 From: Yuzviak Date: Sun, 22 Mar 2026 22:25:44 +0200 Subject: [PATCH] =?UTF-8?q?feat:=20stage2=20=E2=80=94=20SQLite=20DB=20laye?= =?UTF-8?q?r=20(ORM,=20async=20engine,=20repository,=20cascade=20delete,?= =?UTF-8?q?=209=20DB=20tests)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs-Lokal/LOCAL_EXECUTION_PLAN.md | 6 +- src/gps_denied/db/__init__.py | 1 + src/gps_denied/db/engine.py | 41 ++++ src/gps_denied/db/models.py | 230 ++++++++++++++++++++++ src/gps_denied/db/repository.py | 295 +++++++++++++++++++++++++++++ tests/test_database.py | 228 ++++++++++++++++++++++ 6 files changed, 798 insertions(+), 3 deletions(-) create mode 100644 src/gps_denied/db/__init__.py create mode 100644 src/gps_denied/db/engine.py create mode 100644 src/gps_denied/db/models.py create mode 100644 src/gps_denied/db/repository.py create mode 100644 tests/test_database.py diff --git a/docs-Lokal/LOCAL_EXECUTION_PLAN.md b/docs-Lokal/LOCAL_EXECUTION_PLAN.md index 5e07a0a..fda3318 100644 --- a/docs-Lokal/LOCAL_EXECUTION_PLAN.md +++ b/docs-Lokal/LOCAL_EXECUTION_PLAN.md @@ -64,9 +64,9 @@ - **Дані:** Завантажити тестові зображення у папку `data/test_flights`. - **Ваги:** Завантажити ваги SuperPoint, LightGlue, LiteSAM локально в `weights/`. -### Етап 1 — Конфігурація та доменні моделі -- Реалізувати завантаження конфігів з env + YAML. -- Pydantic-схеми: Flight, Waypoint, ImageBatch, події SSE. +### Етап 1 — Конфігурація та доменні моделі ✅ +- Реалізовано завантаження конфігів з `.env` через `pydantic-settings` (`config.py`). +- Pydantic-схеми: GPSPoint, CameraParameters, Flight*, Waypoint, Batch*, SSE events. ### Етап 2 — База даних полёту - SQLite БД: міграції (flights, waypoints, frame results, chunk state). diff --git a/src/gps_denied/db/__init__.py b/src/gps_denied/db/__init__.py new file mode 100644 index 0000000..cdce083 --- /dev/null +++ b/src/gps_denied/db/__init__.py @@ -0,0 +1 @@ +"""Database package.""" diff --git a/src/gps_denied/db/engine.py b/src/gps_denied/db/engine.py new file mode 100644 index 0000000..3df899b --- /dev/null +++ b/src/gps_denied/db/engine.py @@ -0,0 +1,41 @@ +"""Async database engine and session factory.""" + +from __future__ import annotations + +from sqlalchemy import event +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from gps_denied.config import get_settings + + +def _build_engine(): + settings = get_settings() + connect_args = {} + if "sqlite" in settings.db.url: + connect_args["check_same_thread"] = False + + eng = create_async_engine( + settings.db.url, + echo=settings.db.echo, + connect_args=connect_args, + ) + + # Enable FK enforcement for SQLite (required for ON DELETE CASCADE) + @event.listens_for(eng.sync_engine, "connect") + def _set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + return eng + + +engine = _build_engine() + +async_session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + +async def get_session() -> AsyncSession: # noqa: ANN201 — used as FastAPI dependency + """Yield an async session (for use as a FastAPI dependency).""" + async with async_session_factory() as session: + yield session diff --git a/src/gps_denied/db/models.py b/src/gps_denied/db/models.py new file mode 100644 index 0000000..764c1b3 --- /dev/null +++ b/src/gps_denied/db/models.py @@ -0,0 +1,230 @@ +"""SQLAlchemy ORM models for the flight database.""" + +from __future__ import annotations + +import uuid +from datetime import datetime, timezone + +from sqlalchemy import ( + Boolean, + DateTime, + Float, + ForeignKey, + Index, + Integer, + JSON, + String, + Text, +) +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + + +def _utcnow() -> datetime: + return datetime.now(tz=timezone.utc) + + +def _new_id() -> str: + return uuid.uuid4().hex + + +class Base(DeclarativeBase): + """Declarative base for all ORM models.""" + + +# ── Flights ─────────────────────────────────────────────────────────────── + + +class FlightRow(Base): + __tablename__ = "flights" + + id: Mapped[str] = mapped_column(String(64), primary_key=True, default=_new_id) + name: Mapped[str] = mapped_column(String(255), nullable=False) + description: Mapped[str] = mapped_column(Text, default="") + start_lat: Mapped[float] = mapped_column(Float, nullable=False) + start_lon: Mapped[float] = mapped_column(Float, nullable=False) + altitude: Mapped[float] = mapped_column(Float, nullable=False) + camera_params: Mapped[dict] = mapped_column(JSON, nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=_utcnow, onupdate=_utcnow + ) + + # relationships (cascade delete) + waypoints: Mapped[list[WaypointRow]] = relationship( + back_populates="flight", cascade="all, delete-orphan", passive_deletes=True + ) + geofences: Mapped[list[GeofenceRow]] = relationship( + back_populates="flight", cascade="all, delete-orphan", passive_deletes=True + ) + state: Mapped[FlightStateRow | None] = relationship( + back_populates="flight", cascade="all, delete-orphan", passive_deletes=True, uselist=False + ) + frame_results: Mapped[list[FrameResultRow]] = relationship( + back_populates="flight", cascade="all, delete-orphan", passive_deletes=True + ) + headings: Mapped[list[HeadingRow]] = relationship( + back_populates="flight", cascade="all, delete-orphan", passive_deletes=True + ) + images: Mapped[list[ImageRow]] = relationship( + back_populates="flight", cascade="all, delete-orphan", passive_deletes=True + ) + chunks: Mapped[list[ChunkRow]] = relationship( + back_populates="flight", cascade="all, delete-orphan", passive_deletes=True + ) + + +# ── Waypoints ───────────────────────────────────────────────────────────── + + +class WaypointRow(Base): + __tablename__ = "waypoints" + + id: Mapped[str] = mapped_column(String(64), primary_key=True, default=_new_id) + flight_id: Mapped[str] = mapped_column( + ForeignKey("flights.id", ondelete="CASCADE"), nullable=False, index=True + ) + lat: Mapped[float] = mapped_column(Float, nullable=False) + lon: Mapped[float] = mapped_column(Float, nullable=False) + altitude: Mapped[float | None] = mapped_column(Float, nullable=True) + confidence: Mapped[float] = mapped_column(Float, default=0.0) + timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow) + refined: Mapped[bool] = mapped_column(Boolean, default=False) + + flight: Mapped[FlightRow] = relationship(back_populates="waypoints") + + __table_args__ = (Index("ix_waypoints_flight_ts", "flight_id", "timestamp"),) + + +# ── Geofences ───────────────────────────────────────────────────────────── + + +class GeofenceRow(Base): + __tablename__ = "geofences" + + id: Mapped[str] = mapped_column(String(64), primary_key=True, default=_new_id) + flight_id: Mapped[str] = mapped_column( + ForeignKey("flights.id", ondelete="CASCADE"), nullable=False, index=True + ) + nw_lat: Mapped[float] = mapped_column(Float, nullable=False) + nw_lon: Mapped[float] = mapped_column(Float, nullable=False) + se_lat: Mapped[float] = mapped_column(Float, nullable=False) + se_lon: Mapped[float] = mapped_column(Float, nullable=False) + + flight: Mapped[FlightRow] = relationship(back_populates="geofences") + + +# ── Flight State ────────────────────────────────────────────────────────── + + +class FlightStateRow(Base): + __tablename__ = "flight_state" + + flight_id: Mapped[str] = mapped_column( + ForeignKey("flights.id", ondelete="CASCADE"), primary_key=True + ) + status: Mapped[str] = mapped_column(String(32), default="created") + frames_processed: Mapped[int] = mapped_column(Integer, default=0) + frames_total: Mapped[int] = mapped_column(Integer, default=0) + current_frame: Mapped[int | None] = mapped_column(Integer, nullable=True) + blocked: Mapped[bool] = mapped_column(Boolean, default=False) + search_grid_size: Mapped[int | None] = mapped_column(Integer, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=_utcnow, onupdate=_utcnow + ) + + flight: Mapped[FlightRow] = relationship(back_populates="state") + + +# ── Frame Results ───────────────────────────────────────────────────────── + + +class FrameResultRow(Base): + __tablename__ = "frame_results" + + id: Mapped[str] = mapped_column(String(64), primary_key=True, default=_new_id) + flight_id: Mapped[str] = mapped_column( + ForeignKey("flights.id", ondelete="CASCADE"), nullable=False, index=True + ) + frame_id: Mapped[int] = mapped_column(Integer, nullable=False) + gps_lat: Mapped[float] = mapped_column(Float, nullable=False) + gps_lon: Mapped[float] = mapped_column(Float, nullable=False) + altitude: Mapped[float] = mapped_column(Float, default=0.0) + heading: Mapped[float] = mapped_column(Float, default=0.0) + confidence: Mapped[float] = mapped_column(Float, default=0.0) + refined: Mapped[bool] = mapped_column(Boolean, default=False) + timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow) + + flight: Mapped[FlightRow] = relationship(back_populates="frame_results") + + __table_args__ = ( + Index("ix_frame_results_flight_frame", "flight_id", "frame_id", unique=True), + ) + + +# ── Heading History ─────────────────────────────────────────────────────── + + +class HeadingRow(Base): + __tablename__ = "heading_history" + + id: Mapped[str] = mapped_column(String(64), primary_key=True, default=_new_id) + flight_id: Mapped[str] = mapped_column( + ForeignKey("flights.id", ondelete="CASCADE"), nullable=False, index=True + ) + frame_id: Mapped[int] = mapped_column(Integer, nullable=False) + heading: Mapped[float] = mapped_column(Float, nullable=False) + timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow) + + flight: Mapped[FlightRow] = relationship(back_populates="headings") + + __table_args__ = ( + Index("ix_heading_flight_frame", "flight_id", "frame_id", unique=True), + ) + + +# ── Images ──────────────────────────────────────────────────────────────── + + +class ImageRow(Base): + __tablename__ = "images" + + id: Mapped[str] = mapped_column(String(64), primary_key=True, default=_new_id) + flight_id: Mapped[str] = mapped_column( + ForeignKey("flights.id", ondelete="CASCADE"), nullable=False, index=True + ) + frame_id: Mapped[int] = mapped_column(Integer, nullable=False) + file_path: Mapped[str] = mapped_column(Text, nullable=False) + metadata_json: Mapped[dict | None] = mapped_column(JSON, nullable=True) + + flight: Mapped[FlightRow] = relationship(back_populates="images") + + __table_args__ = ( + Index("ix_images_flight_frame", "flight_id", "frame_id", unique=True), + ) + + +# ── Chunks ──────────────────────────────────────────────────────────────── + + +class ChunkRow(Base): + __tablename__ = "chunks" + + chunk_id: Mapped[str] = mapped_column(String(64), primary_key=True, default=_new_id) + flight_id: Mapped[str] = mapped_column( + ForeignKey("flights.id", ondelete="CASCADE"), nullable=False, index=True + ) + start_frame_id: Mapped[int] = mapped_column(Integer, nullable=False) + end_frame_id: Mapped[int | None] = mapped_column(Integer, nullable=True) + frames: Mapped[list] = mapped_column(JSON, default=list) + is_active: Mapped[bool] = mapped_column(Boolean, default=True) + has_anchor: Mapped[bool] = mapped_column(Boolean, default=False) + anchor_frame_id: Mapped[int | None] = mapped_column(Integer, nullable=True) + anchor_lat: Mapped[float | None] = mapped_column(Float, nullable=True) + anchor_lon: Mapped[float | None] = mapped_column(Float, nullable=True) + matching_status: Mapped[str] = mapped_column(String(32), default="pending") + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=_utcnow, onupdate=_utcnow + ) + + flight: Mapped[FlightRow] = relationship(back_populates="chunks") diff --git a/src/gps_denied/db/repository.py b/src/gps_denied/db/repository.py new file mode 100644 index 0000000..f4622f9 --- /dev/null +++ b/src/gps_denied/db/repository.py @@ -0,0 +1,295 @@ +"""Repository (DAO) implementing IFlightDatabase operations.""" + +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +from sqlalchemy import delete, select, update +from sqlalchemy.ext.asyncio import AsyncSession + +from gps_denied.db.models import ( + ChunkRow, + FlightRow, + FlightStateRow, + FrameResultRow, + GeofenceRow, + HeadingRow, + ImageRow, + WaypointRow, +) + + +class FlightRepository: + """Async repository wrapping all DB operations for flights.""" + + def __init__(self, session: AsyncSession) -> None: + self._s = session + + # ── Flight CRUD ─────────────────────────────────────────────────── + + async def insert_flight( + self, + *, + name: str, + description: str, + start_lat: float, + start_lon: float, + altitude: float, + camera_params: dict, + flight_id: str | None = None, + ) -> FlightRow: + row = FlightRow( + name=name, + description=description, + start_lat=start_lat, + start_lon=start_lon, + altitude=altitude, + camera_params=camera_params, + ) + if flight_id: + row.id = flight_id + self._s.add(row) + await self._s.flush() + # Create initial flight state + state = FlightStateRow(flight_id=row.id) + self._s.add(state) + await self._s.flush() + return row + + async def get_flight(self, flight_id: str) -> FlightRow | None: + return await self._s.get(FlightRow, flight_id) + + async def list_flights( + self, + *, + status: str | None = None, + limit: int = 50, + offset: int = 0, + ) -> list[FlightRow]: + stmt = select(FlightRow).offset(offset).limit(limit).order_by(FlightRow.created_at.desc()) + if status: + stmt = stmt.join(FlightStateRow).where(FlightStateRow.status == status) + result = await self._s.execute(stmt) + return list(result.scalars().all()) + + async def update_flight(self, flight_id: str, **kwargs: Any) -> bool: + kwargs["updated_at"] = datetime.now(tz=timezone.utc) + stmt = update(FlightRow).where(FlightRow.id == flight_id).values(**kwargs) + result = await self._s.execute(stmt) + return result.rowcount > 0 # type: ignore[union-attr] + + async def delete_flight(self, flight_id: str) -> bool: + stmt = delete(FlightRow).where(FlightRow.id == flight_id) + result = await self._s.execute(stmt) + return result.rowcount > 0 # type: ignore[union-attr] + + # ── Waypoints ───────────────────────────────────────────────────── + + async def insert_waypoint( + self, + flight_id: str, + *, + lat: float, + lon: float, + altitude: float | None = None, + confidence: float = 0.0, + refined: bool = False, + waypoint_id: str | None = None, + ) -> WaypointRow: + row = WaypointRow( + flight_id=flight_id, + lat=lat, + lon=lon, + altitude=altitude, + confidence=confidence, + refined=refined, + ) + if waypoint_id: + row.id = waypoint_id + self._s.add(row) + await self._s.flush() + return row + + async def get_waypoints(self, flight_id: str, *, limit: int | None = None) -> list[WaypointRow]: + stmt = ( + select(WaypointRow) + .where(WaypointRow.flight_id == flight_id) + .order_by(WaypointRow.timestamp) + ) + if limit: + stmt = stmt.limit(limit) + result = await self._s.execute(stmt) + return list(result.scalars().all()) + + async def update_waypoint(self, flight_id: str, waypoint_id: str, **kwargs: Any) -> bool: + stmt = ( + update(WaypointRow) + .where(WaypointRow.id == waypoint_id, WaypointRow.flight_id == flight_id) + .values(**kwargs) + ) + result = await self._s.execute(stmt) + return result.rowcount > 0 # type: ignore[union-attr] + + # ── Flight State ────────────────────────────────────────────────── + + async def save_flight_state(self, flight_id: str, **kwargs: Any) -> bool: + kwargs["updated_at"] = datetime.now(tz=timezone.utc) + stmt = update(FlightStateRow).where(FlightStateRow.flight_id == flight_id).values(**kwargs) + result = await self._s.execute(stmt) + return result.rowcount > 0 # type: ignore[union-attr] + + async def load_flight_state(self, flight_id: str) -> FlightStateRow | None: + return await self._s.get(FlightStateRow, flight_id) + + # ── Frame Results ───────────────────────────────────────────────── + + async def save_frame_result( + self, + flight_id: str, + *, + frame_id: int, + gps_lat: float, + gps_lon: float, + altitude: float = 0.0, + heading: float = 0.0, + confidence: float = 0.0, + refined: bool = False, + ) -> FrameResultRow: + row = FrameResultRow( + flight_id=flight_id, + frame_id=frame_id, + gps_lat=gps_lat, + gps_lon=gps_lon, + altitude=altitude, + heading=heading, + confidence=confidence, + refined=refined, + ) + self._s.add(row) + await self._s.flush() + return row + + async def get_frame_results(self, flight_id: str) -> list[FrameResultRow]: + stmt = ( + select(FrameResultRow) + .where(FrameResultRow.flight_id == flight_id) + .order_by(FrameResultRow.frame_id) + ) + result = await self._s.execute(stmt) + return list(result.scalars().all()) + + # ── Heading History ─────────────────────────────────────────────── + + async def save_heading( + self, flight_id: str, *, frame_id: int, heading: float + ) -> HeadingRow: + row = HeadingRow(flight_id=flight_id, frame_id=frame_id, heading=heading) + self._s.add(row) + await self._s.flush() + return row + + async def get_heading_history( + self, flight_id: str, *, last_n: int | None = None + ) -> list[HeadingRow]: + stmt = ( + select(HeadingRow) + .where(HeadingRow.flight_id == flight_id) + .order_by(HeadingRow.timestamp.desc()) + ) + if last_n: + stmt = stmt.limit(last_n) + result = await self._s.execute(stmt) + return list(result.scalars().all()) + + async def get_latest_heading(self, flight_id: str) -> float | None: + rows = await self.get_heading_history(flight_id, last_n=1) + return rows[0].heading if rows else None + + # ── Images ──────────────────────────────────────────────────────── + + async def save_image_metadata( + self, flight_id: str, *, frame_id: int, file_path: str, metadata: dict | None = None + ) -> ImageRow: + row = ImageRow( + flight_id=flight_id, frame_id=frame_id, file_path=file_path, metadata_json=metadata + ) + self._s.add(row) + await self._s.flush() + return row + + async def get_image_path(self, flight_id: str, frame_id: int) -> str | None: + stmt = select(ImageRow.file_path).where( + ImageRow.flight_id == flight_id, ImageRow.frame_id == frame_id + ) + result = await self._s.execute(stmt) + return result.scalar_one_or_none() + + # ── Chunks ──────────────────────────────────────────────────────── + + async def save_chunk( + self, + flight_id: str, + *, + chunk_id: str | None = None, + start_frame_id: int, + end_frame_id: int | None = None, + frames: list[int] | None = None, + is_active: bool = True, + has_anchor: bool = False, + anchor_frame_id: int | None = None, + anchor_lat: float | None = None, + anchor_lon: float | None = None, + matching_status: str = "pending", + ) -> ChunkRow: + row = ChunkRow( + flight_id=flight_id, + start_frame_id=start_frame_id, + end_frame_id=end_frame_id, + frames=frames or [], + is_active=is_active, + has_anchor=has_anchor, + anchor_frame_id=anchor_frame_id, + anchor_lat=anchor_lat, + anchor_lon=anchor_lon, + matching_status=matching_status, + ) + if chunk_id: + row.chunk_id = chunk_id + self._s.add(row) + await self._s.flush() + return row + + async def load_chunks(self, flight_id: str) -> list[ChunkRow]: + stmt = select(ChunkRow).where(ChunkRow.flight_id == flight_id) + result = await self._s.execute(stmt) + return list(result.scalars().all()) + + async def delete_chunk(self, flight_id: str, chunk_id: str) -> bool: + stmt = delete(ChunkRow).where( + ChunkRow.chunk_id == chunk_id, ChunkRow.flight_id == flight_id + ) + result = await self._s.execute(stmt) + return result.rowcount > 0 # type: ignore[union-attr] + + # ── Geofences ───────────────────────────────────────────────────── + + async def insert_geofence( + self, + flight_id: str, + *, + nw_lat: float, + nw_lon: float, + se_lat: float, + se_lon: float, + ) -> GeofenceRow: + row = GeofenceRow( + flight_id=flight_id, + nw_lat=nw_lat, + nw_lon=nw_lon, + se_lat=se_lat, + se_lon=se_lon, + ) + self._s.add(row) + await self._s.flush() + return row diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 0000000..44ee439 --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,228 @@ +"""Tests for the database layer — CRUD, cascade, transactions.""" + +import pytest +from sqlalchemy import event +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from gps_denied.db.models import Base +from gps_denied.db.repository import FlightRepository + + +@pytest.fixture +async def session(): + """Create an in-memory SQLite database for each test.""" + engine = create_async_engine("sqlite+aiosqlite://", echo=False) + + @event.listens_for(engine.sync_engine, "connect") + def _set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + async_session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + async with async_session() as s: + yield s + await engine.dispose() + + +@pytest.fixture +def repo(session: AsyncSession) -> FlightRepository: + return FlightRepository(session) + + +CAM = { + "focal_length": 25.0, + "sensor_width": 23.5, + "sensor_height": 15.6, + "resolution_width": 6252, + "resolution_height": 4168, +} + + +# ── Flight CRUD ─────────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_insert_and_get_flight(repo: FlightRepository, session: AsyncSession): + flight = await repo.insert_flight( + name="Test_Flight_001", + description="Test", + start_lat=48.275, + start_lon=37.385, + altitude=400, + camera_params=CAM, + ) + await session.commit() + + loaded = await repo.get_flight(flight.id) + assert loaded is not None + assert loaded.name == "Test_Flight_001" + assert loaded.altitude == 400 + + +@pytest.mark.asyncio +async def test_list_flights(repo: FlightRepository, session: AsyncSession): + await repo.insert_flight( + name="F1", description="", start_lat=0, start_lon=0, altitude=100, camera_params=CAM + ) + await repo.insert_flight( + name="F2", description="", start_lat=0, start_lon=0, altitude=200, camera_params=CAM + ) + await session.commit() + + flights = await repo.list_flights() + assert len(flights) == 2 + + +@pytest.mark.asyncio +async def test_update_flight(repo: FlightRepository, session: AsyncSession): + flight = await repo.insert_flight( + name="Old", description="", start_lat=0, start_lon=0, altitude=100, camera_params=CAM + ) + await session.commit() + + ok = await repo.update_flight(flight.id, name="New") + await session.commit() + assert ok is True + + reloaded = await repo.get_flight(flight.id) + assert reloaded.name == "New" + + +@pytest.mark.asyncio +async def test_delete_flight_cascade(repo: FlightRepository, session: AsyncSession): + flight = await repo.insert_flight( + name="Del", description="", start_lat=0, start_lon=0, altitude=100, camera_params=CAM + ) + fid = flight.id + # Add related entities + await repo.insert_waypoint(fid, lat=48.0, lon=37.0, confidence=0.9) + await repo.save_frame_result(fid, frame_id=1, gps_lat=48.0, gps_lon=37.0) + await repo.save_heading(fid, frame_id=1, heading=90.0) + await repo.save_image_metadata(fid, frame_id=1, file_path="/img/1.jpg") + await repo.save_chunk(fid, start_frame_id=1, frames=[1, 2, 3]) + await repo.insert_geofence(fid, nw_lat=49.0, nw_lon=36.0, se_lat=47.0, se_lon=38.0) + await session.commit() + + # Delete flight — should cascade + ok = await repo.delete_flight(fid) + await session.commit() + assert ok is True + assert await repo.get_flight(fid) is None + assert await repo.get_waypoints(fid) == [] + assert await repo.get_frame_results(fid) == [] + assert await repo.get_heading_history(fid) == [] + assert await repo.load_chunks(fid) == [] + + +# ── Waypoints ───────────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_waypoint_crud(repo: FlightRepository, session: AsyncSession): + flight = await repo.insert_flight( + name="WP", description="", start_lat=0, start_lon=0, altitude=100, camera_params=CAM + ) + wp = await repo.insert_waypoint(flight.id, lat=48.1, lon=37.2, confidence=0.8) + await session.commit() + + wps = await repo.get_waypoints(flight.id) + assert len(wps) == 1 + assert wps[0].lat == 48.1 + + ok = await repo.update_waypoint(flight.id, wp.id, lat=48.2, refined=True) + await session.commit() + assert ok is True + + +# ── Flight State ────────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_flight_state(repo: FlightRepository, session: AsyncSession): + flight = await repo.insert_flight( + name="State", description="", start_lat=0, start_lon=0, altitude=100, camera_params=CAM + ) + await session.commit() + + state = await repo.load_flight_state(flight.id) + assert state is not None + assert state.status == "created" + + await repo.save_flight_state(flight.id, status="processing", frames_total=500) + await session.commit() + + state = await repo.load_flight_state(flight.id) + assert state.status == "processing" + assert state.frames_total == 500 + + +# ── Frame Results ───────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_frame_results(repo: FlightRepository, session: AsyncSession): + flight = await repo.insert_flight( + name="FR", description="", start_lat=0, start_lon=0, altitude=100, camera_params=CAM + ) + await repo.save_frame_result( + flight.id, frame_id=1, gps_lat=48.0, gps_lon=37.0, confidence=0.95 + ) + await repo.save_frame_result( + flight.id, frame_id=2, gps_lat=48.001, gps_lon=37.001, confidence=0.90 + ) + await session.commit() + + results = await repo.get_frame_results(flight.id) + assert len(results) == 2 + assert results[0].frame_id == 1 + + +# ── Heading History ─────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_heading_history(repo: FlightRepository, session: AsyncSession): + flight = await repo.insert_flight( + name="HD", description="", start_lat=0, start_lon=0, altitude=100, camera_params=CAM + ) + for i in range(5): + await repo.save_heading(flight.id, frame_id=i, heading=float(i * 30)) + await session.commit() + + latest = await repo.get_latest_heading(flight.id) + assert latest == 120.0 # last frame heading + + last3 = await repo.get_heading_history(flight.id, last_n=3) + assert len(last3) == 3 + + +# ── Chunks ──────────────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_chunks(repo: FlightRepository, session: AsyncSession): + flight = await repo.insert_flight( + name="CK", description="", start_lat=0, start_lon=0, altitude=100, camera_params=CAM + ) + chunk = await repo.save_chunk( + flight.id, + chunk_id="chunk_001", + start_frame_id=1, + end_frame_id=10, + frames=[1, 2, 3, 4, 5], + has_anchor=True, + anchor_lat=48.0, + anchor_lon=37.0, + ) + await session.commit() + + chunks = await repo.load_chunks(flight.id) + assert len(chunks) == 1 + assert chunks[0].chunk_id == "chunk_001" + + ok = await repo.delete_chunk(flight.id, "chunk_001") + await session.commit() + assert ok is True + assert await repo.load_chunks(flight.id) == []