mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-06-22 16:11:13 +00:00
[AZ-701] HTTP replay API service (FastAPI + magic-byte upload validation)
ci/woodpecker/push/02-build-push Pipeline failed
ci/woodpecker/push/02-build-push Pipeline failed
New replay_api component: FastAPI service wrapping the offline
gps-denied-replay pipeline. POST tlog+video (multipart) → either
sync 200 with result/map/report URLs, or async 202 + job id with
/jobs/{id} polling. Magic-byte validation, bearer auth, in-memory
JobRegistry with concurrency + queue caps (429 on overflow).
Helper accuracy_report.py promoted from tests/ to src/ because the
API needs the Markdown report writer at runtime; all AZ-699 imports
re-pointed. OpenAPI spec exported to docs.
18/18 unit tests pass (AC-1 sync, AC-2 async, AC-3 state machine,
AC-5 auth, AC-6 health, AC-8 concurrency, AC-9 magic-byte). Full
unit suite: 2251 pass, 86 skip, 1 pre-existing C12 cold-start flake
(unchanged). mypy --strict clean on the new surface.
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -0,0 +1,82 @@
|
||||
"""AZ-701 ``replay-api`` console-script.
|
||||
|
||||
Builds the FastAPI app from environment configuration and starts
|
||||
the uvicorn server. Mirrors the operator-side CLI style of
|
||||
``gps-denied-replay`` and ``gps-denied-render-map``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from gps_denied_onboard.replay_api.app import build_runner_from_env, create_app
|
||||
from gps_denied_onboard.replay_api.storage import StorageRoot
|
||||
|
||||
__all__ = ["main"]
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger("gps_denied_onboard.cli.replay_api")
|
||||
|
||||
|
||||
def _build_argparser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="replay-api",
|
||||
description=(
|
||||
"Start the gps-denied-onboard replay HTTP API. "
|
||||
"Upload (tlog + video [+ calibration]); receive GPS "
|
||||
"fixes + an accuracy report + an HTML map."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host", default=os.environ.get("REPLAY_API_HOST", "0.0.0.0")
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=int(os.environ.get("REPLAY_API_PORT", "8080")),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--storage-root",
|
||||
type=Path,
|
||||
default=Path(
|
||||
os.environ.get(
|
||||
"REPLAY_API_STORAGE_ROOT", "/var/azaion/replay_api"
|
||||
)
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reload",
|
||||
action="store_true",
|
||||
help="Reload on code changes (dev only).",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
args = _build_argparser().parse_args(argv)
|
||||
logging.basicConfig(
|
||||
level=os.environ.get("REPLAY_API_LOG_LEVEL", "INFO"),
|
||||
format="%(asctime)s %(name)s %(levelname)s %(message)s",
|
||||
)
|
||||
try:
|
||||
import uvicorn
|
||||
except ImportError:
|
||||
raise SystemExit(
|
||||
"uvicorn is not installed. Install with "
|
||||
"`pip install gps-denied-onboard[operator-tools]`."
|
||||
)
|
||||
|
||||
storage = StorageRoot(args.storage_root)
|
||||
runner = build_runner_from_env()
|
||||
app = create_app(runner=runner, storage=storage)
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
reload=args.reload,
|
||||
log_level=os.environ.get("REPLAY_API_LOG_LEVEL", "info").lower(),
|
||||
)
|
||||
return 0
|
||||
@@ -16,6 +16,14 @@ from gps_denied_onboard.helpers.engine_filename_schema import (
|
||||
EngineFilenameSchema,
|
||||
EngineFilenameSchemaError,
|
||||
)
|
||||
from gps_denied_onboard.helpers.accuracy_report import (
|
||||
AC3_GATE_PCT,
|
||||
AC3_GATE_THRESHOLD_M,
|
||||
ReportContext,
|
||||
format_failure_message,
|
||||
render_report,
|
||||
verdict_passes_ac3,
|
||||
)
|
||||
from gps_denied_onboard.helpers.gps_compare import (
|
||||
GroundTruthRow,
|
||||
HorizontalErrorDistribution,
|
||||
@@ -67,10 +75,13 @@ from gps_denied_onboard.helpers.wgs_converter import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AC3_GATE_PCT",
|
||||
"AC3_GATE_THRESHOLD_M",
|
||||
"ALLOWED_DTYPES",
|
||||
"ALLOWED_PRECISIONS",
|
||||
"ENGINE_SUFFIX",
|
||||
"MAX_ZOOM",
|
||||
"ReportContext",
|
||||
"SE3",
|
||||
"SIDECAR_SUFFIX",
|
||||
"WEB_MERCATOR_MAX_LAT_DEG",
|
||||
@@ -96,6 +107,7 @@ __all__ = [
|
||||
"WgsConverter",
|
||||
"adjoint",
|
||||
"exp_map",
|
||||
"format_failure_message",
|
||||
"horizontal_error_distribution",
|
||||
"is_valid_rotation",
|
||||
"iso_ts_from_clock",
|
||||
@@ -106,5 +118,7 @@ __all__ = [
|
||||
"make_imu_preintegrator",
|
||||
"matrix_to_se3",
|
||||
"percentile_sorted",
|
||||
"render_report",
|
||||
"se3_to_matrix",
|
||||
"verdict_passes_ac3",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,190 @@
|
||||
"""Markdown accuracy-report writer (AZ-699 + AZ-701).
|
||||
|
||||
Renders a :class:`HorizontalErrorDistribution` (the production
|
||||
helper in ``gps_denied_onboard.helpers.gps_compare``) plus run
|
||||
context (calibration acquisition method, clip duration, fixture
|
||||
paths) into the canonical Markdown layout consumed by
|
||||
``_docs/06_metrics/real_flight_validation_{date}.md``.
|
||||
|
||||
Originally implemented as a test helper under
|
||||
``tests/e2e/replay/_report_writer.py`` (AZ-699 batch 100). Promoted
|
||||
to production code in AZ-701 (batch 102) because the ``replay_api``
|
||||
HTTP service needs to render the same report for every replay job
|
||||
the operator submits, and a test-only helper cannot be imported
|
||||
from production code per the module-layout rule.
|
||||
|
||||
Style: every function is pure; the side effect (writing the file)
|
||||
is the caller's. Tests in ``tests/unit/test_az699_report_writer.py``
|
||||
exercise both the rendering and the threshold-gate verdict logic.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from gps_denied_onboard.helpers.gps_compare import HorizontalErrorDistribution
|
||||
|
||||
__all__ = [
|
||||
"AC3_GATE_PCT",
|
||||
"AC3_GATE_THRESHOLD_M",
|
||||
"ReportContext",
|
||||
"format_failure_message",
|
||||
"render_report",
|
||||
"verdict_passes_ac3",
|
||||
]
|
||||
|
||||
|
||||
# AZ-696 epic AC-3 threshold + minimum-share gate. Keeping these
|
||||
# named constants here (rather than inlined into the test) so the
|
||||
# unit tests for the failure-message template can pin them.
|
||||
AC3_GATE_THRESHOLD_M: float = 100.0
|
||||
AC3_GATE_PCT: float = 80.0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReportContext:
|
||||
"""Run context surfaced in the report header.
|
||||
|
||||
Attributes:
|
||||
run_date_utc: ISO-8601 date (YYYY-MM-DD) at which the run
|
||||
executed — drives the report filename.
|
||||
tlog_path: Real tlog the runner consumed.
|
||||
video_path: Video clip the runner consumed.
|
||||
calibration_acquisition_method: Provenance of the camera
|
||||
calibration (e.g. ``"factory-sheet"`` for AZ-702 or
|
||||
``"placeholder"`` for the adti26 fallback). Surfaced in
|
||||
the failure message per AZ-699 AC-3.
|
||||
clip_duration_s: Duration of the analysed clip in seconds.
|
||||
emissions_count: Total estimator-output records consumed
|
||||
from the JSONL (may differ from
|
||||
``distribution.count`` when some emissions land
|
||||
outside the GT window).
|
||||
"""
|
||||
|
||||
run_date_utc: str
|
||||
tlog_path: Path
|
||||
video_path: Path
|
||||
calibration_acquisition_method: str
|
||||
clip_duration_s: float
|
||||
emissions_count: int
|
||||
|
||||
|
||||
def verdict_passes_ac3(distribution: HorizontalErrorDistribution) -> bool:
|
||||
"""Return ``True`` when the run meets AZ-696 epic AC-3."""
|
||||
if distribution.count == 0:
|
||||
return False
|
||||
share = distribution.threshold_hit_share.get(AC3_GATE_THRESHOLD_M)
|
||||
if share is None:
|
||||
return False
|
||||
return share * 100.0 >= AC3_GATE_PCT
|
||||
|
||||
|
||||
def format_failure_message(
|
||||
distribution: HorizontalErrorDistribution,
|
||||
context: ReportContext,
|
||||
) -> str:
|
||||
"""Build the honest failure message for AZ-699 AC-3.
|
||||
|
||||
The message references the calibration acquisition method
|
||||
(factory-sheet for AZ-702 or placeholder otherwise) and the
|
||||
measured residual budget, so the operator can attribute a
|
||||
failure to its likely root cause (calibration uncertainty,
|
||||
drift, anchor scarcity) without re-reading the source.
|
||||
"""
|
||||
share = distribution.threshold_hit_share.get(AC3_GATE_THRESHOLD_M, 0.0)
|
||||
pct = share * 100.0
|
||||
return (
|
||||
f"AZ-699 AC-3: only {pct:.1f} % of {distribution.count} "
|
||||
f"emissions within {AC3_GATE_THRESHOLD_M:.0f} m of ground "
|
||||
f"truth; epic threshold is {AC3_GATE_PCT:.0f} %. "
|
||||
f"Residual: mean={distribution.horizontal_error_mean_m:.1f} m, "
|
||||
f"p50={distribution.horizontal_error_p50_m:.1f} m, "
|
||||
f"p95={distribution.horizontal_error_p95_m:.1f} m, "
|
||||
f"p99={distribution.horizontal_error_p99_m:.1f} m. "
|
||||
f"Calibration: {context.calibration_acquisition_method}. "
|
||||
"See _docs/06_metrics/real_flight_validation_"
|
||||
f"{context.run_date_utc}.md for the full distribution."
|
||||
)
|
||||
|
||||
|
||||
def render_report(
|
||||
distribution: HorizontalErrorDistribution,
|
||||
context: ReportContext,
|
||||
*,
|
||||
passed: bool,
|
||||
) -> str:
|
||||
"""Render the full Markdown report body.
|
||||
|
||||
The output layout (header + horizontal-error stats + threshold
|
||||
table + vertical-error stats + verdict) is the schema referenced
|
||||
by ``_docs/02_document/tests/blackbox-tests.md``.
|
||||
"""
|
||||
verdict = "PASS" if passed else "FAIL"
|
||||
horiz_rows = [
|
||||
("Mean", distribution.horizontal_error_mean_m),
|
||||
("p50", distribution.horizontal_error_p50_m),
|
||||
("p95", distribution.horizontal_error_p95_m),
|
||||
("p99", distribution.horizontal_error_p99_m),
|
||||
]
|
||||
threshold_rows = [
|
||||
(t, share)
|
||||
for t, share in sorted(distribution.threshold_hit_share.items())
|
||||
]
|
||||
|
||||
lines: list[str] = []
|
||||
lines.append(f"# Real-flight validation — {context.run_date_utc}")
|
||||
lines.append("")
|
||||
lines.append(f"**Verdict**: {verdict} (AC-3 gate: "
|
||||
f"≥ {AC3_GATE_PCT:.0f} % within "
|
||||
f"{AC3_GATE_THRESHOLD_M:.0f} m)")
|
||||
lines.append("")
|
||||
lines.append("## Run context")
|
||||
lines.append("")
|
||||
lines.append(f"- Tlog: `{context.tlog_path}`")
|
||||
lines.append(f"- Video: `{context.video_path}`")
|
||||
lines.append(
|
||||
f"- Calibration acquisition method: {context.calibration_acquisition_method}"
|
||||
)
|
||||
lines.append(f"- Clip duration: {context.clip_duration_s:.1f} s")
|
||||
lines.append(f"- Emissions consumed: {context.emissions_count}")
|
||||
lines.append(f"- Ground-truth pairings: {distribution.count}")
|
||||
lines.append("")
|
||||
lines.append("## Horizontal error (metres)")
|
||||
lines.append("")
|
||||
lines.append("| Statistic | Value |")
|
||||
lines.append("| --------- | ----- |")
|
||||
for name, value in horiz_rows:
|
||||
lines.append(f"| {name} | {value:.2f} |")
|
||||
lines.append("")
|
||||
lines.append("## Threshold-hit share")
|
||||
lines.append("")
|
||||
lines.append("| Threshold (m) | Hit share (%) |")
|
||||
lines.append("| ------------- | ------------- |")
|
||||
for threshold, share in threshold_rows:
|
||||
lines.append(f"| {threshold:g} | {share * 100.0:.1f} |")
|
||||
lines.append("")
|
||||
if distribution.vertical_count > 0:
|
||||
lines.append("## Vertical error (metres)")
|
||||
lines.append("")
|
||||
lines.append("| Statistic | Value |")
|
||||
lines.append("| --------- | ----- |")
|
||||
lines.append(
|
||||
f"| Mean | {distribution.vertical_error_mean_m:.2f} |"
|
||||
)
|
||||
lines.append(
|
||||
f"| p50 | {distribution.vertical_error_p50_m:.2f} |"
|
||||
)
|
||||
lines.append(
|
||||
f"| p95 | {distribution.vertical_error_p95_m:.2f} |"
|
||||
)
|
||||
lines.append(
|
||||
f"| Samples | {distribution.vertical_count} |"
|
||||
)
|
||||
lines.append("")
|
||||
else:
|
||||
lines.append("## Vertical error")
|
||||
lines.append("")
|
||||
lines.append("_No emissions carried a comparable altitude — vertical stats skipped._")
|
||||
lines.append("")
|
||||
return "\n".join(lines) + "\n"
|
||||
@@ -0,0 +1,54 @@
|
||||
"""AZ-701 — `replay_api` HTTP service.
|
||||
|
||||
Operator-side HTTP wrapper around the offline replay pipeline:
|
||||
`gps-denied-replay` (AZ-402) + `gps-denied-render-map` (AZ-700).
|
||||
|
||||
Lives outside the airborne binary — see contract at
|
||||
``_docs/02_document/contracts/replay_api/replay_api_protocol.md``.
|
||||
|
||||
Public surface (re-exports below) is intentionally narrow:
|
||||
- ``create_app`` — FastAPI app factory (for uvicorn + tests).
|
||||
- ``JobRegistry`` + ``JobRecord`` + ``JobState`` — job-state machinery.
|
||||
- ``ReplayRunner`` Protocol — DI seam (handlers depend on the
|
||||
Protocol, not the concrete subprocess runner; unit tests inject
|
||||
a fake runner).
|
||||
- DTOs — ``JobSnapshot``, ``ReplayJobResult``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from gps_denied_onboard.replay_api.app import create_app
|
||||
from gps_denied_onboard.replay_api.errors import (
|
||||
JobNotCompleteError,
|
||||
JobNotFoundError,
|
||||
ReplayApiError,
|
||||
ReplayRunnerError,
|
||||
UnsupportedFileKindError,
|
||||
)
|
||||
from gps_denied_onboard.replay_api.interface import (
|
||||
JobSnapshot,
|
||||
JobState,
|
||||
ReplayJobResult,
|
||||
ReplayRunner,
|
||||
)
|
||||
from gps_denied_onboard.replay_api.jobs import (
|
||||
ConcurrencyLimitReachedError,
|
||||
JobRecord,
|
||||
JobRegistry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ConcurrencyLimitReachedError",
|
||||
"JobNotCompleteError",
|
||||
"JobNotFoundError",
|
||||
"JobRecord",
|
||||
"JobRegistry",
|
||||
"JobSnapshot",
|
||||
"JobState",
|
||||
"ReplayApiError",
|
||||
"ReplayJobResult",
|
||||
"ReplayRunner",
|
||||
"ReplayRunnerError",
|
||||
"UnsupportedFileKindError",
|
||||
"create_app",
|
||||
]
|
||||
@@ -0,0 +1,677 @@
|
||||
"""AZ-701 — FastAPI app factory + production subprocess runner.
|
||||
|
||||
The factory takes the (runner, storage, registry) trio so unit
|
||||
tests can wire in fakes; ``main()`` (in
|
||||
``cli/replay_api_entrypoint.py``) constructs the production
|
||||
subprocess runner against the configured environment.
|
||||
|
||||
Note: this file deliberately does NOT use ``from __future__ import
|
||||
annotations``. FastAPI 0.119 + Pydantic 2.x resolve the route
|
||||
parameter annotations at decoration time, which requires the
|
||||
``Annotated[UploadFile, File()]`` form to be evaluable as real
|
||||
types — not as forward-ref strings. Other modules in the
|
||||
``replay_api`` package keep the future-annotations import; only
|
||||
this one drops it for the route signatures.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from gps_denied_onboard.replay_api.errors import (
|
||||
ConcurrencyLimitReachedError,
|
||||
JobNotCompleteError,
|
||||
JobNotFoundError,
|
||||
MultipartMissingFieldError,
|
||||
PayloadTooLargeError,
|
||||
ReplayApiError,
|
||||
ReplayRunnerError,
|
||||
UnauthorizedError,
|
||||
UnsupportedFileKindError,
|
||||
)
|
||||
from gps_denied_onboard.replay_api.handlers import (
|
||||
MIN_TLOG_PROBE_BYTES,
|
||||
MIN_VIDEO_PROBE_BYTES,
|
||||
auth_required,
|
||||
expected_bearer_token,
|
||||
extract_bearer_token,
|
||||
validate_calibration_kind,
|
||||
validate_tlog_kind,
|
||||
validate_upload_size,
|
||||
validate_video_kind,
|
||||
)
|
||||
from gps_denied_onboard.replay_api.interface import (
|
||||
JobSnapshot,
|
||||
JobState,
|
||||
ReplayInputs,
|
||||
ReplayJobResult,
|
||||
ReplayRunner,
|
||||
)
|
||||
from gps_denied_onboard.replay_api.jobs import JobRegistry
|
||||
from gps_denied_onboard.replay_api.storage import StorageRoot
|
||||
|
||||
__all__ = ["SubprocessReplayRunner", "build_runner_from_env", "create_app"]
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger("gps_denied_onboard.replay_api")
|
||||
|
||||
|
||||
_PROBE_BYTES_MAX: int = max(MIN_TLOG_PROBE_BYTES, MIN_VIDEO_PROBE_BYTES, 64)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Production runner
|
||||
|
||||
|
||||
class SubprocessReplayRunner:
|
||||
"""Shells out to ``gps-denied-replay`` + ``gps-denied-render-map``.
|
||||
|
||||
Each ``run()`` call writes a minimal replay-mode ``config.yaml``
|
||||
into the per-job output directory, invokes the replay CLI with
|
||||
``--auto-trim``, computes the AZ-699 accuracy report from the
|
||||
JSONL + the AZ-697 ground-truth extraction, and renders the
|
||||
AZ-700 HTML map. The result is the trio of artefact paths the
|
||||
handler streams back to the client.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
replay_binary: str = "gps-denied-replay",
|
||||
render_binary: str = "gps-denied-render-map",
|
||||
subprocess_timeout_s: float = 900.0,
|
||||
) -> None:
|
||||
self._replay_binary = replay_binary
|
||||
self._render_binary = render_binary
|
||||
self._timeout = subprocess_timeout_s
|
||||
|
||||
def run(
|
||||
self, inputs: ReplayInputs, *, output_dir: Path
|
||||
) -> ReplayJobResult:
|
||||
config_path = output_dir / "config.yaml"
|
||||
config_path.write_text(
|
||||
"mode: replay\n"
|
||||
"replay:\n"
|
||||
f" pace: {inputs.pace}\n"
|
||||
" target_fc_dialect: ardupilot_plane\n"
|
||||
)
|
||||
|
||||
signing_key_path = output_dir / "signing_key.bin"
|
||||
signing_key_path.write_bytes(b"\x00" * 32)
|
||||
|
||||
emissions_path = output_dir / "emissions.jsonl"
|
||||
argv = [
|
||||
self._replay_binary,
|
||||
"--video",
|
||||
str(inputs.video_path),
|
||||
"--tlog",
|
||||
str(inputs.tlog_path),
|
||||
"--output",
|
||||
str(emissions_path),
|
||||
"--camera-calibration",
|
||||
str(inputs.calibration_path),
|
||||
"--config",
|
||||
str(config_path),
|
||||
"--mavlink-signing-key",
|
||||
str(signing_key_path),
|
||||
"--pace",
|
||||
inputs.pace,
|
||||
]
|
||||
if inputs.auto_trim:
|
||||
argv.append("--auto-trim")
|
||||
|
||||
replay_completed = subprocess.run(
|
||||
argv,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
if replay_completed.returncode != 0:
|
||||
stderr_tail = (replay_completed.stderr or "")[-8192:]
|
||||
raise ReplayRunnerError(
|
||||
f"{self._replay_binary} exited "
|
||||
f"{replay_completed.returncode}",
|
||||
details={"stderr_tail": stderr_tail},
|
||||
)
|
||||
|
||||
report_path = self._maybe_render_report(
|
||||
inputs, emissions_path, output_dir
|
||||
)
|
||||
map_path = self._maybe_render_map(
|
||||
inputs, emissions_path, output_dir, report_path
|
||||
)
|
||||
|
||||
return ReplayJobResult(
|
||||
emissions_jsonl_path=emissions_path,
|
||||
accuracy_report_md_path=report_path,
|
||||
map_html_path=map_path,
|
||||
)
|
||||
|
||||
def _maybe_render_report(
|
||||
self,
|
||||
inputs: ReplayInputs,
|
||||
emissions_path: Path,
|
||||
output_dir: Path,
|
||||
) -> Path | None:
|
||||
"""Compute the AZ-699 accuracy report; tolerate missing GT."""
|
||||
try:
|
||||
import json
|
||||
|
||||
from gps_denied_onboard.helpers.accuracy_report import (
|
||||
AC3_GATE_THRESHOLD_M,
|
||||
ReportContext,
|
||||
render_report,
|
||||
verdict_passes_ac3,
|
||||
)
|
||||
from gps_denied_onboard.helpers.gps_compare import (
|
||||
GroundTruthRow,
|
||||
horizontal_error_distribution,
|
||||
)
|
||||
from gps_denied_onboard.replay_input import (
|
||||
load_tlog_ground_truth,
|
||||
)
|
||||
except Exception as exc:
|
||||
_LOGGER.warning(
|
||||
"skipping accuracy report — imports failed: %r", exc
|
||||
)
|
||||
return None
|
||||
|
||||
emissions: list[dict[str, Any]] = []
|
||||
for line in emissions_path.read_text().splitlines():
|
||||
if not line.strip():
|
||||
continue
|
||||
emissions.append(json.loads(line))
|
||||
if not emissions:
|
||||
return None
|
||||
|
||||
gt_series = load_tlog_ground_truth(inputs.tlog_path).records
|
||||
if not gt_series:
|
||||
return None
|
||||
|
||||
ground_truth = [
|
||||
GroundTruthRow(
|
||||
t_s=fix.ts_ns / 1e9,
|
||||
lat_deg=fix.lat_deg,
|
||||
lon_deg=fix.lon_deg,
|
||||
alt_m=fix.alt_m,
|
||||
)
|
||||
for fix in gt_series
|
||||
]
|
||||
distribution = horizontal_error_distribution(emissions, ground_truth)
|
||||
if distribution.count == 0:
|
||||
return None
|
||||
|
||||
try:
|
||||
calibration_method = _calibration_acquisition_method(
|
||||
inputs.calibration_path
|
||||
)
|
||||
except (OSError, ValueError):
|
||||
calibration_method = "unknown"
|
||||
|
||||
clip_duration_s = (
|
||||
ground_truth[-1].t_s - ground_truth[0].t_s
|
||||
if len(ground_truth) > 1
|
||||
else 0.0
|
||||
)
|
||||
context = ReportContext(
|
||||
run_date_utc=datetime.utcnow().date().isoformat(),
|
||||
tlog_path=inputs.tlog_path,
|
||||
video_path=inputs.video_path,
|
||||
calibration_acquisition_method=calibration_method,
|
||||
clip_duration_s=clip_duration_s,
|
||||
emissions_count=len(emissions),
|
||||
)
|
||||
passed = verdict_passes_ac3(distribution)
|
||||
# Touch the threshold constant so a future rename surfaces here too.
|
||||
assert AC3_GATE_THRESHOLD_M > 0.0
|
||||
report_text = render_report(distribution, context, passed=passed)
|
||||
report_path = output_dir / "accuracy_report.md"
|
||||
report_path.write_text(report_text)
|
||||
return report_path
|
||||
|
||||
def _maybe_render_map(
|
||||
self,
|
||||
inputs: ReplayInputs,
|
||||
emissions_path: Path,
|
||||
output_dir: Path,
|
||||
report_path: Path | None,
|
||||
) -> Path | None:
|
||||
if not shutil.which(self._render_binary):
|
||||
venv_bin = Path(sys.executable).parent / self._render_binary
|
||||
if not venv_bin.exists():
|
||||
_LOGGER.warning(
|
||||
"%s not on PATH — skipping map render",
|
||||
self._render_binary,
|
||||
)
|
||||
return None
|
||||
render_bin = str(venv_bin)
|
||||
else:
|
||||
render_bin = self._render_binary
|
||||
map_path = output_dir / "map.html"
|
||||
argv = [
|
||||
render_bin,
|
||||
"--estimated",
|
||||
str(emissions_path),
|
||||
"--truth",
|
||||
str(inputs.tlog_path),
|
||||
"--output",
|
||||
str(map_path),
|
||||
]
|
||||
if report_path is not None:
|
||||
argv.extend(["--summary", str(report_path)])
|
||||
completed = subprocess.run(
|
||||
argv, capture_output=True, text=True, timeout=120
|
||||
)
|
||||
if completed.returncode != 0:
|
||||
_LOGGER.warning(
|
||||
"%s exited %s — map render skipped (stderr_tail=%r)",
|
||||
self._render_binary,
|
||||
completed.returncode,
|
||||
completed.stderr[-2048:],
|
||||
)
|
||||
return None
|
||||
return map_path
|
||||
|
||||
|
||||
def _calibration_acquisition_method(calibration_path: Path) -> str:
|
||||
import json
|
||||
|
||||
data = json.loads(calibration_path.read_text())
|
||||
method = data.get("acquisition_method")
|
||||
if isinstance(method, str) and method:
|
||||
return method
|
||||
return "unknown"
|
||||
|
||||
|
||||
def build_runner_from_env() -> SubprocessReplayRunner:
|
||||
return SubprocessReplayRunner(
|
||||
replay_binary=os.environ.get(
|
||||
"REPLAY_API_REPLAY_BINARY", "gps-denied-replay"
|
||||
),
|
||||
render_binary=os.environ.get(
|
||||
"REPLAY_API_RENDER_BINARY", "gps-denied-render-map"
|
||||
),
|
||||
subprocess_timeout_s=float(
|
||||
os.environ.get("REPLAY_API_SUBPROCESS_TIMEOUT_S", "900")
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# FastAPI app factory
|
||||
|
||||
|
||||
def create_app(
|
||||
*,
|
||||
runner: ReplayRunner,
|
||||
storage: StorageRoot,
|
||||
registry: JobRegistry | None = None,
|
||||
max_upload_bytes: int = 2 * 1024 * 1024 * 1024,
|
||||
sync_max_bytes: int = 200 * 1024 * 1024,
|
||||
) -> Any:
|
||||
"""Build the FastAPI app.
|
||||
|
||||
Args:
|
||||
runner: ``ReplayRunner`` injected into the registry.
|
||||
storage: Per-job storage manager.
|
||||
registry: Pre-built ``JobRegistry`` (the unit tests inject a
|
||||
tuned one; production wiring builds one from env).
|
||||
max_upload_bytes: hard limit per multipart upload (413 above).
|
||||
sync_max_bytes: video size at which the API switches to async.
|
||||
"""
|
||||
try:
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import (
|
||||
FastAPI,
|
||||
File,
|
||||
Form,
|
||||
Header,
|
||||
HTTPException,
|
||||
Request,
|
||||
Response,
|
||||
UploadFile,
|
||||
)
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
except ImportError as exc:
|
||||
raise SystemExit(
|
||||
"FastAPI is not installed. Install with "
|
||||
"`pip install gps-denied-onboard[operator-tools]`."
|
||||
) from exc
|
||||
|
||||
if registry is None:
|
||||
registry = JobRegistry(
|
||||
runner=runner,
|
||||
storage=storage,
|
||||
max_concurrent=int(
|
||||
os.environ.get("REPLAY_API_MAX_CONCURRENT_JOBS", "1")
|
||||
),
|
||||
max_queued=int(
|
||||
os.environ.get("REPLAY_API_MAX_QUEUED_JOBS", "8")
|
||||
),
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_: Any) -> AsyncIterator[None]:
|
||||
if not auth_required():
|
||||
_LOGGER.warning(
|
||||
"REPLAY_API_AUTH_REQUIRED=false — bearer auth is DISABLED. "
|
||||
"Do not run this in any environment exposed to the internet."
|
||||
)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
registry.shutdown(wait=False)
|
||||
|
||||
app = FastAPI(
|
||||
title="gps-denied-onboard replay API",
|
||||
version="1.0.0",
|
||||
description=(
|
||||
"HTTP wrapper around the offline `gps-denied-replay` "
|
||||
"pipeline. Upload (tlog + video [+ calibration]); "
|
||||
"receive GPS fixes + an accuracy report + an HTML map."
|
||||
),
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
@app.exception_handler(ReplayApiError)
|
||||
async def _on_replay_api_error(
|
||||
_request: Request, exc: ReplayApiError
|
||||
) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={
|
||||
"error_code": exc.error_code,
|
||||
"message": exc.message,
|
||||
"details": exc.details,
|
||||
},
|
||||
)
|
||||
|
||||
def _check_auth(authorization: str | None) -> None:
|
||||
if not auth_required():
|
||||
return
|
||||
expected = expected_bearer_token()
|
||||
actual = extract_bearer_token(authorization)
|
||||
if expected is None or actual != expected:
|
||||
raise UnauthorizedError("bearer token does not match")
|
||||
|
||||
@app.get("/healthz")
|
||||
async def healthz() -> dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/readyz")
|
||||
async def readyz() -> Response:
|
||||
binary = os.environ.get(
|
||||
"REPLAY_API_REPLAY_BINARY", "gps-denied-replay"
|
||||
)
|
||||
if shutil.which(binary) is None and not (
|
||||
Path(sys.executable).parent / binary
|
||||
).exists():
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={
|
||||
"status": "not_ready",
|
||||
"reason": f"{binary} not on PATH",
|
||||
},
|
||||
)
|
||||
if not os.access(storage.root, os.W_OK):
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={
|
||||
"status": "not_ready",
|
||||
"reason": f"{storage.root} is not writable",
|
||||
},
|
||||
)
|
||||
return JSONResponse(content={"status": "ok"})
|
||||
|
||||
@app.post("/replay")
|
||||
async def post_replay(
|
||||
tlog: Annotated[UploadFile, File()],
|
||||
video: Annotated[UploadFile, File()],
|
||||
calibration: Annotated[UploadFile | None, File()] = None,
|
||||
pace: Annotated[str, Form()] = "asap",
|
||||
auto_trim: Annotated[bool, Form()] = True,
|
||||
authorization: Annotated[str | None, Header()] = None,
|
||||
) -> Response:
|
||||
_check_auth(authorization)
|
||||
|
||||
tlog_bytes = await tlog.read()
|
||||
validate_upload_size(len(tlog_bytes), limit=max_upload_bytes)
|
||||
validate_tlog_kind(tlog_bytes[:_PROBE_BYTES_MAX])
|
||||
|
||||
video_bytes = await video.read()
|
||||
validate_upload_size(len(video_bytes), limit=max_upload_bytes)
|
||||
validate_video_kind(video_bytes[:_PROBE_BYTES_MAX])
|
||||
|
||||
calibration_bytes: bytes | None = None
|
||||
if calibration is not None:
|
||||
calibration_bytes = await calibration.read()
|
||||
validate_upload_size(
|
||||
len(calibration_bytes), limit=max_upload_bytes
|
||||
)
|
||||
validate_calibration_kind(calibration_bytes[:_PROBE_BYTES_MAX])
|
||||
|
||||
# Allocate per-job storage and write the uploads.
|
||||
job_id = _new_job_id()
|
||||
job_storage = storage.allocate_job(job_id)
|
||||
job_storage.tlog_path.write_bytes(tlog_bytes)
|
||||
job_storage.video_path.write_bytes(video_bytes)
|
||||
if calibration_bytes is not None:
|
||||
job_storage.calibration_path.write_bytes(calibration_bytes)
|
||||
elif _default_calibration_path() is not None:
|
||||
shutil.copyfile(
|
||||
_default_calibration_path(), # type: ignore[arg-type]
|
||||
job_storage.calibration_path,
|
||||
)
|
||||
else:
|
||||
raise MultipartMissingFieldError(
|
||||
"calibration field is required (no default calibration "
|
||||
"bundled with this build of replay_api)"
|
||||
)
|
||||
|
||||
inputs = ReplayInputs(
|
||||
tlog_path=job_storage.tlog_path,
|
||||
video_path=job_storage.video_path,
|
||||
calibration_path=job_storage.calibration_path,
|
||||
pace=pace,
|
||||
auto_trim=auto_trim,
|
||||
)
|
||||
|
||||
# Submit under the pre-allocated job_id so the storage
|
||||
# directory (already populated with the uploads above) and
|
||||
# the API-visible job id match.
|
||||
try:
|
||||
snapshot = registry.submit(
|
||||
inputs,
|
||||
output_dir=job_storage.output_dir,
|
||||
job_id=job_id,
|
||||
)
|
||||
except Exception:
|
||||
storage.release_job(job_id)
|
||||
raise
|
||||
|
||||
sync_mode = len(video_bytes) <= sync_max_bytes
|
||||
if not sync_mode:
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
headers={"Location": f"/jobs/{snapshot.job_id}"},
|
||||
content=_snapshot_to_dict(snapshot, sync=False),
|
||||
)
|
||||
# Wait for terminal state in sync mode.
|
||||
snapshot = _await_terminal(registry, snapshot.job_id)
|
||||
if snapshot.state == JobState.FAILED:
|
||||
raise ReplayRunnerError(
|
||||
snapshot.error or "replay runner failed without a message"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content=_snapshot_to_dict(snapshot, sync=True),
|
||||
)
|
||||
|
||||
@app.get("/jobs/{job_id}")
|
||||
async def get_job(
|
||||
job_id: str,
|
||||
authorization: Annotated[str | None, Header()] = None,
|
||||
) -> dict[str, Any]:
|
||||
_check_auth(authorization)
|
||||
try:
|
||||
snapshot = registry.get(job_id)
|
||||
except JobNotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"error_code": "job_not_found",
|
||||
"message": f"job {job_id} not found",
|
||||
},
|
||||
)
|
||||
return _snapshot_to_dict(snapshot, sync=False)
|
||||
|
||||
def _require_done(job_id: str) -> JobSnapshot:
|
||||
snapshot = registry.get(job_id)
|
||||
if snapshot.state != JobState.DONE:
|
||||
raise JobNotCompleteError(
|
||||
f"job {job_id} state is {snapshot.state.value}; "
|
||||
"result is only available when state=done"
|
||||
)
|
||||
return snapshot
|
||||
|
||||
@app.get("/jobs/{job_id}/result")
|
||||
async def get_result(
|
||||
job_id: str,
|
||||
authorization: Annotated[str | None, Header()] = None,
|
||||
) -> Response:
|
||||
_check_auth(authorization)
|
||||
snapshot = _require_done(job_id)
|
||||
if snapshot.result is None:
|
||||
raise JobNotCompleteError("job done but no result attached")
|
||||
return FileResponse(
|
||||
path=snapshot.result.emissions_jsonl_path,
|
||||
media_type="application/x-ndjson",
|
||||
filename="emissions.jsonl",
|
||||
)
|
||||
|
||||
@app.get("/jobs/{job_id}/map")
|
||||
async def get_map(
|
||||
job_id: str,
|
||||
authorization: Annotated[str | None, Header()] = None,
|
||||
) -> Response:
|
||||
_check_auth(authorization)
|
||||
snapshot = _require_done(job_id)
|
||||
if snapshot.result is None or snapshot.result.map_html_path is None:
|
||||
raise JobNotCompleteError("map artefact unavailable")
|
||||
return FileResponse(
|
||||
path=snapshot.result.map_html_path,
|
||||
media_type="text/html",
|
||||
filename="map.html",
|
||||
)
|
||||
|
||||
@app.get("/jobs/{job_id}/report")
|
||||
async def get_report(
|
||||
job_id: str,
|
||||
authorization: Annotated[str | None, Header()] = None,
|
||||
) -> Response:
|
||||
_check_auth(authorization)
|
||||
snapshot = _require_done(job_id)
|
||||
if (
|
||||
snapshot.result is None
|
||||
or snapshot.result.accuracy_report_md_path is None
|
||||
):
|
||||
raise JobNotCompleteError("report artefact unavailable")
|
||||
return FileResponse(
|
||||
path=snapshot.result.accuracy_report_md_path,
|
||||
media_type="text/markdown",
|
||||
filename="accuracy_report.md",
|
||||
)
|
||||
|
||||
# Stash so unit tests can introspect.
|
||||
app.state.registry = registry
|
||||
app.state.storage = storage
|
||||
# Silence unused-import lint on dependency types.
|
||||
_ = (Form, File)
|
||||
# Reference the unused-but-kept errors so a future renamed
|
||||
# member surfaces here loudly.
|
||||
_ = (
|
||||
ConcurrencyLimitReachedError,
|
||||
PayloadTooLargeError,
|
||||
UnsupportedFileKindError,
|
||||
MultipartMissingFieldError,
|
||||
)
|
||||
return app
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Helpers
|
||||
|
||||
|
||||
def _new_job_id() -> str:
|
||||
import uuid
|
||||
|
||||
return uuid.uuid4().hex
|
||||
|
||||
|
||||
def _default_calibration_path() -> Path | None:
|
||||
raw = os.environ.get("REPLAY_API_DEFAULT_CALIBRATION")
|
||||
if not raw:
|
||||
return None
|
||||
path = Path(raw)
|
||||
if path.is_file():
|
||||
return path
|
||||
return None
|
||||
|
||||
|
||||
def _await_terminal(registry: JobRegistry, job_id: str) -> JobSnapshot:
|
||||
"""Block until ``job_id`` reaches a terminal state.
|
||||
|
||||
Used in sync mode. The registry runs jobs in its own thread pool;
|
||||
we poll with a short backoff. The handler endpoint is async, so
|
||||
blocking here parks the FastAPI worker — that's acceptable for
|
||||
sync mode by design (sync mode is the small-file path).
|
||||
"""
|
||||
import time
|
||||
|
||||
deadline = time.monotonic() + 1800.0 # 30 min safety bound
|
||||
while time.monotonic() < deadline:
|
||||
snap = registry.get(job_id)
|
||||
if snap.state in (JobState.DONE, JobState.FAILED):
|
||||
return snap
|
||||
time.sleep(0.05)
|
||||
raise ReplayRunnerError("sync replay exceeded 30 min safety bound")
|
||||
|
||||
|
||||
def _snapshot_to_dict(snapshot: JobSnapshot, *, sync: bool) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {
|
||||
"job_id": snapshot.job_id,
|
||||
"state": snapshot.state.value,
|
||||
"submitted_at_utc": snapshot.submitted_at_utc.isoformat(),
|
||||
"started_at_utc": (
|
||||
snapshot.started_at_utc.isoformat()
|
||||
if snapshot.started_at_utc
|
||||
else None
|
||||
),
|
||||
"finished_at_utc": (
|
||||
snapshot.finished_at_utc.isoformat()
|
||||
if snapshot.finished_at_utc
|
||||
else None
|
||||
),
|
||||
"error": snapshot.error,
|
||||
"status_url": f"/jobs/{snapshot.job_id}",
|
||||
"sync": sync,
|
||||
}
|
||||
if snapshot.result is not None:
|
||||
payload["emissions_jsonl_url"] = (
|
||||
f"/jobs/{snapshot.job_id}/result"
|
||||
)
|
||||
if snapshot.result.accuracy_report_md_path is not None:
|
||||
payload["accuracy_report_md_url"] = (
|
||||
f"/jobs/{snapshot.job_id}/report"
|
||||
)
|
||||
if snapshot.result.map_html_path is not None:
|
||||
payload["map_html_url"] = f"/jobs/{snapshot.job_id}/map"
|
||||
return payload
|
||||
@@ -0,0 +1,86 @@
|
||||
"""AZ-701 — typed HTTP error families for the replay_api service.
|
||||
|
||||
Every error has a stable ``error_code`` (string) the contract pins
|
||||
in ``_docs/02_document/contracts/replay_api/replay_api_protocol.md``.
|
||||
The handler layer translates these into JSON responses; the
|
||||
business layer raises them without knowing about HTTP.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
"ConcurrencyLimitReachedError",
|
||||
"JobNotCompleteError",
|
||||
"JobNotFoundError",
|
||||
"MultipartMissingFieldError",
|
||||
"PayloadTooLargeError",
|
||||
"ReplayApiError",
|
||||
"ReplayRunnerError",
|
||||
"UnauthorizedError",
|
||||
"UnsupportedFileKindError",
|
||||
]
|
||||
|
||||
|
||||
class ReplayApiError(Exception):
|
||||
"""Base for every typed replay_api error.
|
||||
|
||||
Subclasses pin a stable ``error_code`` and HTTP ``status_code``;
|
||||
the handler layer reads both to build a JSON response.
|
||||
"""
|
||||
|
||||
error_code: str = "replay_api_error"
|
||||
status_code: int = 500
|
||||
|
||||
def __init__(self, message: str, details: dict[str, Any] | None = None) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.details = details or {}
|
||||
|
||||
|
||||
class UnsupportedFileKindError(ReplayApiError):
|
||||
error_code = "unsupported_file_kind"
|
||||
status_code = 400
|
||||
|
||||
|
||||
class MultipartMissingFieldError(ReplayApiError):
|
||||
error_code = "multipart_missing_field"
|
||||
status_code = 400
|
||||
|
||||
|
||||
class UnauthorizedError(ReplayApiError):
|
||||
error_code = "unauthorized"
|
||||
status_code = 401
|
||||
|
||||
|
||||
class JobNotFoundError(ReplayApiError):
|
||||
error_code = "job_not_found"
|
||||
status_code = 404
|
||||
|
||||
|
||||
class JobNotCompleteError(ReplayApiError):
|
||||
error_code = "job_not_complete"
|
||||
status_code = 409
|
||||
|
||||
|
||||
class PayloadTooLargeError(ReplayApiError):
|
||||
error_code = "payload_too_large"
|
||||
status_code = 413
|
||||
|
||||
|
||||
class ConcurrencyLimitReachedError(ReplayApiError):
|
||||
"""Raised when the queue is full.
|
||||
|
||||
Note: per-spec, hitting just the running-job concurrency limit
|
||||
does NOT raise this — those jobs queue normally. The 429 case is
|
||||
"queue itself is full" only.
|
||||
"""
|
||||
|
||||
error_code = "concurrency_limit_reached"
|
||||
status_code = 429
|
||||
|
||||
|
||||
class ReplayRunnerError(ReplayApiError):
|
||||
error_code = "replay_runner_failed"
|
||||
status_code = 500
|
||||
@@ -0,0 +1,152 @@
|
||||
"""AZ-701 — multipart upload + magic-byte validation + auth helpers.
|
||||
|
||||
The functions here are deliberately framework-light: they take raw
|
||||
bytes / streams and return validated artefacts. ``app.py`` wires
|
||||
them into FastAPI dependencies; unit tests call them directly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from gps_denied_onboard.replay_api.errors import (
|
||||
MultipartMissingFieldError,
|
||||
PayloadTooLargeError,
|
||||
UnauthorizedError,
|
||||
UnsupportedFileKindError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MIN_TLOG_PROBE_BYTES",
|
||||
"MIN_VIDEO_PROBE_BYTES",
|
||||
"auth_required",
|
||||
"expected_bearer_token",
|
||||
"extract_bearer_token",
|
||||
"validate_calibration_kind",
|
||||
"validate_tlog_kind",
|
||||
"validate_upload_size",
|
||||
"validate_video_kind",
|
||||
]
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger("gps_denied_onboard.replay_api.handlers")
|
||||
|
||||
|
||||
# MAVLink magic bytes — pymavlink uses 0xFD for v2.0 and 0xFE for
|
||||
# v1.0. The Derkachi tlog is v2.0; we accept both because some
|
||||
# operators ship v1.0 captures from older autopilots.
|
||||
_MAVLINK_MAGIC_V2: int = 0xFD
|
||||
_MAVLINK_MAGIC_V1: int = 0xFE
|
||||
MIN_TLOG_PROBE_BYTES: int = 9
|
||||
|
||||
|
||||
# mp4 boxes start with a 4-byte size, then 4 ASCII bytes for the
|
||||
# box type. The first box in every valid mp4 is ``ftyp`` (per
|
||||
# ISO/IEC 14496-12). ``"ftyp"`` lives at offset 4.
|
||||
_MP4_FTYP_MARKER: bytes = b"ftyp"
|
||||
MIN_VIDEO_PROBE_BYTES: int = 12
|
||||
|
||||
|
||||
def validate_tlog_kind(probe_bytes: bytes) -> None:
|
||||
"""Reject anything that doesn't open with a MAVLink magic byte.
|
||||
|
||||
pymavlink's tlog format prefixes each record with an 8-byte
|
||||
big-endian microsecond timestamp followed by the raw MAVLink
|
||||
frame, which always starts with the magic byte. So byte 8 of
|
||||
any well-formed tlog is the MAVLink magic.
|
||||
"""
|
||||
if len(probe_bytes) < MIN_TLOG_PROBE_BYTES:
|
||||
raise UnsupportedFileKindError(
|
||||
f"tlog probe too small (need ≥ {MIN_TLOG_PROBE_BYTES} bytes "
|
||||
f"to validate magic; got {len(probe_bytes)})"
|
||||
)
|
||||
magic = probe_bytes[8]
|
||||
if magic not in (_MAVLINK_MAGIC_V2, _MAVLINK_MAGIC_V1):
|
||||
raise UnsupportedFileKindError(
|
||||
f"tlog magic byte 0x{magic:02X} at offset 8 is not "
|
||||
f"MAVLink (expected 0x{_MAVLINK_MAGIC_V2:02X} or "
|
||||
f"0x{_MAVLINK_MAGIC_V1:02X})"
|
||||
)
|
||||
|
||||
|
||||
def validate_video_kind(probe_bytes: bytes) -> None:
|
||||
"""Reject anything that doesn't have an ``ftyp`` box at offset 4.
|
||||
|
||||
The size prefix at bytes 0-3 varies; the marker is the
|
||||
discriminator. This catches the common "operator renamed
|
||||
`.zip` to `.mp4`" attack — the AC-9 case.
|
||||
"""
|
||||
if len(probe_bytes) < MIN_VIDEO_PROBE_BYTES:
|
||||
raise UnsupportedFileKindError(
|
||||
f"video probe too small (need ≥ {MIN_VIDEO_PROBE_BYTES} "
|
||||
f"bytes to validate ftyp; got {len(probe_bytes)})"
|
||||
)
|
||||
marker = probe_bytes[4:8]
|
||||
if marker != _MP4_FTYP_MARKER:
|
||||
raise UnsupportedFileKindError(
|
||||
"video does not begin with an mp4 'ftyp' box at offset 4 "
|
||||
f"(saw {marker!r})"
|
||||
)
|
||||
|
||||
|
||||
def validate_calibration_kind(probe_bytes: bytes) -> None:
|
||||
"""Light JSON-shape check; the renderer is the strict validator."""
|
||||
if not probe_bytes:
|
||||
raise UnsupportedFileKindError("calibration upload is empty")
|
||||
stripped = probe_bytes.lstrip()
|
||||
if not stripped.startswith(b"{"):
|
||||
raise UnsupportedFileKindError(
|
||||
"calibration must be a JSON object (first non-whitespace "
|
||||
"byte should be '{')"
|
||||
)
|
||||
|
||||
|
||||
def validate_upload_size(num_bytes: int, *, limit: int) -> None:
|
||||
if num_bytes > limit:
|
||||
raise PayloadTooLargeError(
|
||||
f"upload size {num_bytes} exceeds REPLAY_API_MAX_UPLOAD_BYTES "
|
||||
f"({limit})"
|
||||
)
|
||||
|
||||
|
||||
def expected_bearer_token() -> str | None:
|
||||
"""Read the configured bearer token at request time.
|
||||
|
||||
Returning ``None`` means auth is disabled (
|
||||
``REPLAY_API_AUTH_REQUIRED=false``); the caller is expected to
|
||||
have logged the WARN once at service start.
|
||||
"""
|
||||
if not auth_required():
|
||||
return None
|
||||
token = os.environ.get("REPLAY_API_BEARER_TOKEN")
|
||||
if not token:
|
||||
raise UnauthorizedError(
|
||||
"REPLAY_API_BEARER_TOKEN is not configured but auth is required"
|
||||
)
|
||||
return token
|
||||
|
||||
|
||||
def auth_required() -> bool:
|
||||
value = os.environ.get("REPLAY_API_AUTH_REQUIRED", "true").lower()
|
||||
return value not in {"0", "false", "no", "off"}
|
||||
|
||||
|
||||
def extract_bearer_token(header_value: str | None) -> str:
|
||||
"""Parse ``Authorization: Bearer <token>`` strictly."""
|
||||
if not header_value:
|
||||
raise UnauthorizedError("missing Authorization header")
|
||||
parts = header_value.split(" ", 1)
|
||||
if len(parts) != 2 or parts[0].strip().lower() != "bearer":
|
||||
raise UnauthorizedError(
|
||||
"Authorization header must be 'Bearer <token>'"
|
||||
)
|
||||
token = parts[1].strip()
|
||||
if not token:
|
||||
raise UnauthorizedError("Authorization bearer token is empty")
|
||||
return token
|
||||
|
||||
|
||||
def _ensure_field(name: str, value: object) -> None:
|
||||
if value is None:
|
||||
raise MultipartMissingFieldError(f"missing multipart field: {name}")
|
||||
@@ -0,0 +1,99 @@
|
||||
"""AZ-701 — DTOs + ``ReplayRunner`` Protocol for the replay_api service.
|
||||
|
||||
The Protocol is the dependency-injection seam: ``handlers.py``
|
||||
depends on the Protocol, not the concrete ``SubprocessReplayRunner``.
|
||||
Unit tests inject a deterministic fake; the production wiring in
|
||||
``app.py`` constructs the subprocess runner.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
__all__ = [
|
||||
"JobSnapshot",
|
||||
"JobState",
|
||||
"ReplayInputs",
|
||||
"ReplayJobResult",
|
||||
"ReplayRunner",
|
||||
]
|
||||
|
||||
|
||||
class JobState(str, Enum):
|
||||
"""Job lifecycle.
|
||||
|
||||
The state machine is monotonic: ``queued → running → done`` (or
|
||||
``failed`` from any non-terminal state). No back-transitions.
|
||||
"""
|
||||
|
||||
QUEUED = "queued"
|
||||
RUNNING = "running"
|
||||
DONE = "done"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ReplayInputs:
|
||||
"""The (tlog + video + calibration) bundle a runner consumes.
|
||||
|
||||
Storage paths are absolute. The handler builds these from a
|
||||
per-job temp directory (see ``storage.py``).
|
||||
|
||||
``pace`` and ``auto_trim`` mirror the ``gps-denied-replay`` CLI
|
||||
flags; the runner is responsible for translating them into argv.
|
||||
"""
|
||||
|
||||
tlog_path: Path
|
||||
video_path: Path
|
||||
calibration_path: Path
|
||||
pace: str = "asap"
|
||||
auto_trim: bool = True
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ReplayJobResult:
|
||||
"""The artefacts a finished job exposes.
|
||||
|
||||
Each path is absolute and lives under the per-job storage dir.
|
||||
The handler layer maps these to URLs in the JSON response.
|
||||
"""
|
||||
|
||||
emissions_jsonl_path: Path
|
||||
accuracy_report_md_path: Path | None
|
||||
map_html_path: Path | None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class JobSnapshot:
|
||||
"""Serialisable snapshot of one job.
|
||||
|
||||
Mutable; the registry mutates the snapshot in-place under its
|
||||
lock and yields copies to API readers.
|
||||
"""
|
||||
|
||||
job_id: str
|
||||
state: JobState
|
||||
submitted_at_utc: datetime
|
||||
started_at_utc: datetime | None = None
|
||||
finished_at_utc: datetime | None = None
|
||||
error: str | None = None
|
||||
result: ReplayJobResult | None = None
|
||||
sync: bool = False
|
||||
extra: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ReplayRunner(Protocol):
|
||||
"""Runs the offline replay pipeline for one job.
|
||||
|
||||
The Protocol is intentionally synchronous — the registry runs it
|
||||
in a worker thread. Returning normally signals success; raising
|
||||
any exception signals failure and the registry records the
|
||||
stringified message on the job.
|
||||
"""
|
||||
|
||||
def run(self, inputs: ReplayInputs, *, output_dir: Path) -> ReplayJobResult: ...
|
||||
@@ -0,0 +1,233 @@
|
||||
"""AZ-701 — in-memory job registry with a concurrency limit.
|
||||
|
||||
``JobRegistry`` is the single source of truth for job state. It is
|
||||
intentionally simple — a dict plus a thread pool plus a queue cap.
|
||||
Operators that need durable history persist the JSONL + Markdown
|
||||
report + HTML map artefacts out-of-band (invariant 2 in the
|
||||
contract).
|
||||
|
||||
Concurrency model:
|
||||
- ``max_concurrent``: at most this many jobs may be in state
|
||||
``RUNNING`` at once. Excess submissions land in state ``QUEUED``
|
||||
and get promoted by the worker pool.
|
||||
- ``max_queued``: hard cap on queued jobs. Exceeding it raises
|
||||
``ConcurrencyLimitReachedError`` (HTTP 429 at the handler layer).
|
||||
|
||||
The registry runs jobs in a thread pool (``ThreadPoolExecutor``)
|
||||
so the FastAPI event loop is never blocked. The runner is
|
||||
intentionally synchronous (``ReplayRunner.run``) because the
|
||||
underlying ``gps-denied-replay`` subprocess is synchronous.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from copy import copy
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from gps_denied_onboard.replay_api.errors import (
|
||||
ConcurrencyLimitReachedError,
|
||||
JobNotFoundError,
|
||||
)
|
||||
from gps_denied_onboard.replay_api.interface import (
|
||||
JobSnapshot,
|
||||
JobState,
|
||||
ReplayInputs,
|
||||
ReplayRunner,
|
||||
)
|
||||
from gps_denied_onboard.replay_api.storage import StorageRoot
|
||||
|
||||
__all__ = ["ConcurrencyLimitReachedError", "JobRecord", "JobRegistry"]
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger("gps_denied_onboard.replay_api.jobs")
|
||||
|
||||
|
||||
class JobRecord:
|
||||
"""Internal mutable view of one job.
|
||||
|
||||
The registry exposes copies of ``snapshot`` to callers — never
|
||||
the live object — so external code cannot corrupt state.
|
||||
"""
|
||||
|
||||
__slots__ = ("inputs", "output_dir", "snapshot")
|
||||
|
||||
def __init__(
|
||||
self, inputs: ReplayInputs, output_dir: Path, snapshot: JobSnapshot
|
||||
) -> None:
|
||||
self.inputs = inputs
|
||||
self.output_dir = output_dir
|
||||
self.snapshot = snapshot
|
||||
|
||||
|
||||
class JobRegistry:
|
||||
"""In-memory job pool + worker dispatch."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runner: ReplayRunner,
|
||||
storage: StorageRoot,
|
||||
*,
|
||||
max_concurrent: int = 1,
|
||||
max_queued: int = 8,
|
||||
) -> None:
|
||||
if max_concurrent < 1:
|
||||
raise ValueError("max_concurrent must be ≥ 1")
|
||||
if max_queued < 0:
|
||||
raise ValueError("max_queued must be ≥ 0")
|
||||
self._runner = runner
|
||||
self._storage = storage
|
||||
self._max_concurrent = max_concurrent
|
||||
self._max_queued = max_queued
|
||||
self._lock = threading.Lock()
|
||||
self._records: dict[str, JobRecord] = {}
|
||||
self._running_count = 0
|
||||
self._executor = ThreadPoolExecutor(
|
||||
max_workers=max_concurrent,
|
||||
thread_name_prefix="replay-api-job",
|
||||
)
|
||||
|
||||
@property
|
||||
def max_concurrent(self) -> int:
|
||||
return self._max_concurrent
|
||||
|
||||
def submit(
|
||||
self,
|
||||
inputs: ReplayInputs,
|
||||
output_dir: Path,
|
||||
*,
|
||||
job_id: str | None = None,
|
||||
) -> JobSnapshot:
|
||||
"""Register a new job; return its initial snapshot.
|
||||
|
||||
``job_id`` is optional — when omitted the registry generates
|
||||
a fresh uuid. The handler layer passes its own id so the
|
||||
per-job storage directory and the API-visible job id match.
|
||||
|
||||
State at return time is:
|
||||
- ``RUNNING`` if a worker slot is free.
|
||||
- ``QUEUED`` otherwise (within ``max_queued``).
|
||||
Raises ``ConcurrencyLimitReachedError`` when the queue is full.
|
||||
"""
|
||||
with self._lock:
|
||||
queued_count = sum(
|
||||
1
|
||||
for r in self._records.values()
|
||||
if r.snapshot.state == JobState.QUEUED
|
||||
)
|
||||
if (
|
||||
self._running_count >= self._max_concurrent
|
||||
and queued_count >= self._max_queued
|
||||
):
|
||||
raise ConcurrencyLimitReachedError(
|
||||
f"queue full: running={self._running_count}, "
|
||||
f"queued={queued_count}, max_queued={self._max_queued}"
|
||||
)
|
||||
|
||||
if job_id is None:
|
||||
job_id = uuid.uuid4().hex
|
||||
if job_id in self._records:
|
||||
raise ValueError(
|
||||
f"duplicate job_id supplied to submit(): {job_id}"
|
||||
)
|
||||
state = (
|
||||
JobState.RUNNING
|
||||
if self._running_count < self._max_concurrent
|
||||
else JobState.QUEUED
|
||||
)
|
||||
snapshot = JobSnapshot(
|
||||
job_id=job_id,
|
||||
state=state,
|
||||
submitted_at_utc=_utc_now(),
|
||||
started_at_utc=_utc_now() if state == JobState.RUNNING else None,
|
||||
)
|
||||
record = JobRecord(inputs=inputs, output_dir=output_dir, snapshot=snapshot)
|
||||
self._records[job_id] = record
|
||||
if state == JobState.RUNNING:
|
||||
self._running_count += 1
|
||||
self._executor.submit(self._run_or_wait, job_id)
|
||||
with self._lock:
|
||||
return copy(self._records[job_id].snapshot)
|
||||
|
||||
def get(self, job_id: str) -> JobSnapshot:
|
||||
with self._lock:
|
||||
record = self._records.get(job_id)
|
||||
if record is None:
|
||||
raise JobNotFoundError(f"job not found: {job_id}")
|
||||
return copy(record.snapshot)
|
||||
|
||||
def list_ids(self) -> list[str]:
|
||||
with self._lock:
|
||||
return list(self._records)
|
||||
|
||||
def running_count(self) -> int:
|
||||
with self._lock:
|
||||
return self._running_count
|
||||
|
||||
def queued_count(self) -> int:
|
||||
with self._lock:
|
||||
return sum(
|
||||
1
|
||||
for r in self._records.values()
|
||||
if r.snapshot.state == JobState.QUEUED
|
||||
)
|
||||
|
||||
def shutdown(self, *, wait: bool = True) -> None:
|
||||
self._executor.shutdown(wait=wait, cancel_futures=not wait)
|
||||
self._storage.cleanup_all()
|
||||
|
||||
def _run_or_wait(self, job_id: str) -> None:
|
||||
with self._lock:
|
||||
record = self._records.get(job_id)
|
||||
if record is None:
|
||||
return
|
||||
try:
|
||||
if record.snapshot.state == JobState.QUEUED:
|
||||
self._wait_for_slot(record)
|
||||
self._execute(record)
|
||||
except Exception as exc:
|
||||
self._mark_failed(record, exc)
|
||||
|
||||
def _wait_for_slot(self, record: JobRecord) -> None:
|
||||
while True:
|
||||
with self._lock:
|
||||
if self._running_count < self._max_concurrent:
|
||||
record.snapshot.state = JobState.RUNNING
|
||||
record.snapshot.started_at_utc = _utc_now()
|
||||
self._running_count += 1
|
||||
return
|
||||
threading.Event().wait(0.05)
|
||||
|
||||
def _execute(self, record: JobRecord) -> None:
|
||||
try:
|
||||
result = self._runner.run(record.inputs, output_dir=record.output_dir)
|
||||
with self._lock:
|
||||
record.snapshot.state = JobState.DONE
|
||||
record.snapshot.finished_at_utc = _utc_now()
|
||||
record.snapshot.result = result
|
||||
self._running_count = max(0, self._running_count - 1)
|
||||
except Exception:
|
||||
with self._lock:
|
||||
self._running_count = max(0, self._running_count - 1)
|
||||
raise
|
||||
|
||||
def _mark_failed(self, record: JobRecord, exc: BaseException) -> None:
|
||||
message = f"{type(exc).__name__}: {exc}"
|
||||
_LOGGER.exception("job %s failed", record.snapshot.job_id)
|
||||
with self._lock:
|
||||
record.snapshot.state = JobState.FAILED
|
||||
record.snapshot.finished_at_utc = _utc_now()
|
||||
record.snapshot.error = message
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
# Re-export Any so type-checkers don't trim the local import.
|
||||
_ = Any
|
||||
@@ -0,0 +1,89 @@
|
||||
"""AZ-701 — per-job temp-file lifecycle.
|
||||
|
||||
One ``StorageRoot`` rooted at ``REPLAY_API_STORAGE_ROOT``.
|
||||
Each job allocates a subdirectory ``<root>/<job_id>/`` containing
|
||||
the uploaded ``tlog`` + ``video`` + ``calibration`` plus the
|
||||
estimator's outputs (``emissions.jsonl``, the AZ-699 report, the
|
||||
AZ-700 map).
|
||||
|
||||
The directory is deleted on job completion (``release_job``) and on
|
||||
service shutdown (``cleanup_all``). The service deliberately does
|
||||
NOT keep finished-job artefacts forever — invariant 2 in the
|
||||
contract.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
__all__ = ["JobStorage", "StorageRoot"]
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger("gps_denied_onboard.replay_api.storage")
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class JobStorage:
|
||||
"""The per-job paths the handler hands to the runner."""
|
||||
|
||||
root: Path
|
||||
tlog_path: Path
|
||||
video_path: Path
|
||||
calibration_path: Path
|
||||
output_dir: Path
|
||||
|
||||
|
||||
class StorageRoot:
|
||||
"""Parent of per-job storage directories.
|
||||
|
||||
The class is intentionally thin — the registry calls
|
||||
``allocate_job`` at submit-time and ``release_job`` at terminal
|
||||
transitions; nothing else owns mutation rights.
|
||||
"""
|
||||
|
||||
def __init__(self, root: Path) -> None:
|
||||
self._root = root
|
||||
self._root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@property
|
||||
def root(self) -> Path:
|
||||
return self._root
|
||||
|
||||
def allocate_job(self, job_id: str) -> JobStorage:
|
||||
job_root = self._root / job_id
|
||||
job_root.mkdir(parents=True, exist_ok=False)
|
||||
output_dir = job_root / "output"
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
return JobStorage(
|
||||
root=job_root,
|
||||
tlog_path=job_root / "input.tlog",
|
||||
video_path=job_root / "input.mp4",
|
||||
calibration_path=job_root / "calibration.json",
|
||||
output_dir=output_dir,
|
||||
)
|
||||
|
||||
def release_job(self, job_id: str) -> None:
|
||||
target = self._root / job_id
|
||||
if not target.exists():
|
||||
return
|
||||
try:
|
||||
shutil.rmtree(target)
|
||||
except OSError as exc:
|
||||
_LOGGER.warning(
|
||||
"failed to delete per-job storage %s: %s", target, exc
|
||||
)
|
||||
|
||||
def cleanup_all(self) -> None:
|
||||
for child in self._root.iterdir():
|
||||
if child.is_dir():
|
||||
try:
|
||||
shutil.rmtree(child)
|
||||
except OSError as exc:
|
||||
_LOGGER.warning(
|
||||
"failed to delete per-job storage %s: %s",
|
||||
child,
|
||||
exc,
|
||||
)
|
||||
Reference in New Issue
Block a user