mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-04-22 08:56:37 +00:00
241 lines
9.8 KiB
Python
241 lines
9.8 KiB
Python
import pytest
|
|
from fastapi.testclient import TestClient
|
|
from unittest.mock import Mock, MagicMock
|
|
import json
|
|
from datetime import datetime
|
|
import asyncio
|
|
|
|
# Import the app and router from the correct modules
|
|
from f01_flight_api import router, get_lifecycle_manager, get_flight_database
|
|
from f02_1_flight_lifecycle_manager import FlightLifecycleManager, Flight, GPSPoint, CameraParameters, FlightState, Waypoint
|
|
from fastapi import FastAPI
|
|
|
|
# --- Test Setup ---
|
|
|
|
# Create a test app instance and include the API router
|
|
app = FastAPI()
|
|
app.include_router(router)
|
|
|
|
# Create mock objects for dependencies that will be injected into the API
|
|
mock_manager = Mock(spec=FlightLifecycleManager)
|
|
mock_db = Mock()
|
|
|
|
# Override the real dependencies with our mocks for all tests
|
|
app.dependency_overrides[get_lifecycle_manager] = lambda: mock_manager
|
|
app.dependency_overrides[get_flight_database] = lambda: mock_db
|
|
|
|
# Instantiate the TestClient with our configured app
|
|
client = TestClient(app)
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def reset_mocks():
|
|
"""
|
|
Pytest fixture that runs automatically before each test.
|
|
It resets the mock objects to ensure test isolation.
|
|
"""
|
|
mock_manager.reset_mock()
|
|
mock_db.reset_mock()
|
|
|
|
# --- Test Cases ---
|
|
|
|
class TestFlightAPI:
|
|
"""
|
|
Test suite for the F01 Flight API endpoints.
|
|
"""
|
|
|
|
def test_create_flight_success(self):
|
|
"""Test Case: Successful flight creation (POST /flights)."""
|
|
# Arrange: Mock the manager to return a specific flight ID
|
|
mock_manager.create_flight.return_value = "test_flight_abc"
|
|
|
|
flight_request_payload = {
|
|
"name": "Test Mission",
|
|
"start_gps": {"lat": 48.0, "lon": 37.0},
|
|
"altitude": 400.0,
|
|
"camera_params": {
|
|
"focal_length_mm": 25.0,
|
|
"sensor_width_mm": 23.5,
|
|
"resolution": {"width": 6000, "height": 4000}
|
|
}
|
|
}
|
|
|
|
# Act: Make the POST request
|
|
response = client.post("/api/v1/flights", json=flight_request_payload)
|
|
|
|
# Assert: Verify status code and response body
|
|
assert response.status_code == 201
|
|
json_response = response.json()
|
|
assert json_response["flight_id"] == "test_flight_abc"
|
|
assert json_response["status"] == "prefetching"
|
|
mock_manager.create_flight.assert_called_once()
|
|
|
|
def test_create_flight_validation_error(self):
|
|
"""Test Case: Flight creation with invalid payload."""
|
|
# Arrange: Invalid payload (missing 'name')
|
|
invalid_payload = {
|
|
"start_gps": {"lat": 48.0, "lon": 37.0},
|
|
"altitude": 400.0,
|
|
"camera_params": { "focal_length_mm": 25.0, "sensor_width_mm": 23.5, "resolution": {"width": 6000, "height": 4000}}
|
|
}
|
|
|
|
# Act: Make the POST request
|
|
response = client.post("/api/v1/flights", json=invalid_payload)
|
|
|
|
# Assert: FastAPI should return a 422 Unprocessable Entity
|
|
assert response.status_code == 422
|
|
|
|
def test_get_flight_success(self):
|
|
"""Test Case: Successfully retrieve flight details (GET /flights/{flight_id})."""
|
|
flight_id = "test_flight_xyz"
|
|
|
|
# Arrange: Mock the manager and DB to return specific data
|
|
mock_flight = Flight(
|
|
flight_id=flight_id, flight_name="Test Mission",
|
|
start_gps=GPSPoint(lat=48.0, lon=37.0), altitude_m=400.0,
|
|
camera_params=CameraParameters(focal_length_mm=25, sensor_width_mm=23.5, resolution={"width": 6000, "height": 4000}),
|
|
created_at=datetime.utcnow(), updated_at=datetime.utcnow()
|
|
)
|
|
mock_state = FlightState(flight_id=flight_id, state="active", processed_images=10, total_images=100, has_active_engine=True)
|
|
|
|
mock_manager.get_flight.return_value = mock_flight
|
|
mock_manager.get_flight_state.return_value = mock_state
|
|
mock_db.get_waypoints.return_value = []
|
|
|
|
# Act: Make the GET request
|
|
response = client.get(f"/api/v1/flights/{flight_id}")
|
|
|
|
# Assert
|
|
assert response.status_code == 200
|
|
json_response = response.json()
|
|
assert json_response["flight_id"] == flight_id
|
|
assert json_response["name"] == "Test Mission"
|
|
assert json_response["frames_processed"] == 10
|
|
|
|
def test_get_flight_not_found(self):
|
|
"""Test Case: Attempt to retrieve a non-existent flight."""
|
|
# Arrange: Mock manager to return None
|
|
mock_manager.get_flight.return_value = None
|
|
|
|
# Act & Assert
|
|
response = client.get("/api/v1/flights/non_existent_id")
|
|
assert response.status_code == 404
|
|
|
|
def test_upload_image_batch_success(self):
|
|
"""Test Case: Successfully upload a batch of images."""
|
|
flight_id = "test_upload_flight"
|
|
mock_manager.queue_images.return_value = True
|
|
|
|
# Arrange: Create multipart/form-data payload
|
|
files = [("images", (f"AD{i:06d}.jpg", f"dummy_bytes_{i}".encode(), "image/jpeg")) for i in range(1, 11)]
|
|
data = {"start_sequence": 1, "end_sequence": 10, "batch_number": 1}
|
|
|
|
# Act: Make the POST request
|
|
response = client.post(f"/api/v1/flights/{flight_id}/images/batch", data=data, files=files)
|
|
|
|
# Assert
|
|
assert response.status_code == 202
|
|
assert response.json() == {
|
|
"accepted": True,
|
|
"sequences": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
|
|
"next_expected": 11,
|
|
"message": "Batch queued for processing."
|
|
}
|
|
mock_manager.queue_images.assert_called_once()
|
|
|
|
def test_sse_stream(self):
|
|
"""Test Case: Connect to SSE stream and receive events."""
|
|
flight_id = "test_stream_flight"
|
|
|
|
# Arrange: Create a mock async generator to simulate SSE events
|
|
async def mock_stream_generator(fid, cid):
|
|
yield {"event": "frame_processed", "data": json.dumps({"frame_id": 1, "gps": [1, 1]})}
|
|
await asyncio.sleep(0.01)
|
|
yield {"event": "frame_refined", "data": json.dumps({"frame_id": 1, "gps": [1.1, 1.1]})}
|
|
|
|
mock_manager.get_flight.return_value = Mock()
|
|
mock_manager.create_client_stream.return_value = mock_stream_generator(flight_id, "testclient")
|
|
|
|
# Act: Use the streaming context manager of the TestClient
|
|
with client.stream("GET", f"/api/v1/flights/{flight_id}/stream", headers={"Accept": "text/event-stream"}) as response:
|
|
# Assert: Check status and headers
|
|
assert response.status_code == 200
|
|
assert "text/event-stream" in response.headers["content-type"]
|
|
|
|
# Manually parse the event stream
|
|
events = response.iter_lines()
|
|
|
|
# Event 1
|
|
event1_type = next(events)
|
|
event1_data = next(events)
|
|
next(events) # Consume the blank line separator
|
|
|
|
assert event1_type == 'event: frame_processed'
|
|
assert json.loads(event1_data.replace('data: ', '')) == {"frame_id": 1, "gps": [1, 1]}
|
|
|
|
# Event 2
|
|
event2_type = next(events)
|
|
event2_data = next(events)
|
|
|
|
assert event2_type == 'event: frame_refined'
|
|
assert json.loads(event2_data.replace('data: ', '')) == {"frame_id": 1, "gps": [1.1, 1.1]}
|
|
|
|
def test_list_flights_success(self):
|
|
"""Test Case: Retrieve a list of flights."""
|
|
mock_flight = Flight(
|
|
flight_id="test_flight_1", flight_name="Test Mission",
|
|
start_gps=GPSPoint(lat=48.0, lon=37.0), altitude_m=400.0,
|
|
camera_params=CameraParameters(focal_length_mm=25, sensor_width_mm=23.5, resolution={"width": 6000, "height": 4000}),
|
|
)
|
|
mock_db.query_flights.return_value = [mock_flight]
|
|
|
|
response = client.get("/api/v1/flights?status=created")
|
|
|
|
assert response.status_code == 200
|
|
assert len(response.json()) == 1
|
|
assert response.json()[0]["flight_id"] == "test_flight_1"
|
|
|
|
def test_get_results(self):
|
|
"""Test Case: Retrieve computed results."""
|
|
mock_result = Mock(image_id="AD000001.jpg", sequence_number=1, estimated_gps=GPSPoint(lat=48.0, lon=37.0), confidence=0.9, source="vo")
|
|
mock_manager.get_flight_results.return_value = [mock_result]
|
|
|
|
response = client.get("/api/v1/flights/test_flight_1/results")
|
|
|
|
assert response.status_code == 200
|
|
assert len(response.json()) == 1
|
|
assert response.json()[0]["image_id"] == "AD000001.jpg"
|
|
|
|
def test_object_to_gps(self):
|
|
"""Test Case: Convert pixel object to absolute GPS coordinate."""
|
|
mock_manager.convert_object_to_gps.return_value = GPSPoint(lat=48.0, lon=37.0)
|
|
|
|
response = client.post(
|
|
"/api/v1/flights/test_flight_1/frames/15/object-to-gps",
|
|
json={"pixel_x": 3126.0, "pixel_y": 2084.0}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
assert response.json()["frame_id"] == 15
|
|
|
|
def test_get_frame_context_success(self):
|
|
"""Test Case: Retrieve context for manual user fix."""
|
|
mock_manager.get_frame_context.return_value = {
|
|
"frame_id": 15,
|
|
"uav_image_url": "http://example.com/uav.jpg",
|
|
"satellite_candidates": [
|
|
{"tile_id": "t1", "image_url": "http://example.com/t1.jpg", "center_gps": {"lat": 48.0, "lon": 37.0}}
|
|
]
|
|
}
|
|
|
|
response = client.get("/api/v1/flights/test_flight_1/frames/15/context")
|
|
|
|
assert response.status_code == 200
|
|
assert response.json()["frame_id"] == 15
|
|
assert len(response.json()["satellite_candidates"]) == 1
|
|
|
|
def test_get_frame_context_not_found(self):
|
|
"""Test Case: Frame context not found."""
|
|
mock_manager.get_frame_context.return_value = None
|
|
|
|
response = client.get("/api/v1/flights/test_flight_1/frames/15/context")
|
|
assert response.status_code == 404 |