mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-06-22 07:11:12 +00:00
[AZ-414] [AZ-415] [AZ-418] Test batch 71: sharp turn + multi-segment + smoothing
- AZ-414 (FT-P-07 + FT-N-02): sharp_turn_detector helper covering AC-1 (gyro_z run detection + synthetic-overlay fallback), AC-2/AC-3 (FT-N-02 during-turn label + monotonic covariance), AC-4/AC-5/AC-6 (FT-P-07 recovery lag/drift/heading); twin scenario files under positive/ and negative/. - AZ-415 (FT-P-08): multi_segment_evaluator helper + scenario. - AZ-418 (FT-P-10): smoothing_evaluator helper covering AC-1 (raw + smoothed pose pairing), AC-2 (improvement rate >= 0.80), AC-3 (mean improvement >= 5 m); scenario file. - All scenarios skip-gated on upstream frame_source_replay / imu_replay / fdr_reader stubs (auto-activate when AZ-441 + AZ-407 leftovers land). - +68 unit tests; full e2e unit suite: 393 passed. See _docs/03_implementation/batch_71_report.md and _docs/03_implementation/reviews/batch_71_review.md. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -0,0 +1,345 @@
|
||||
"""Unit tests for ``runner.helpers.multi_segment_evaluator`` (FT-P-08 / AZ-415).
|
||||
|
||||
Covers AC-1 (blackout window detection from the injector manifest),
|
||||
AC-2 (dead_reckoned during blackout), AC-3 (recovery ≤3 frames),
|
||||
AC-4 (trajectory continuity ≤100 m), AC-5 (≥3 windows required).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from runner.helpers.geo import offset
|
||||
from runner.helpers.multi_segment_evaluator import (
|
||||
DEAD_RECKONED,
|
||||
MAX_RECOVERY_FRAMES_SAFETY_MS,
|
||||
MAX_TRAJECTORY_JUMP_M,
|
||||
MIN_SEGMENTS_REQUIRED,
|
||||
SATELLITE_ANCHORED,
|
||||
VISUAL_PROPAGATED,
|
||||
BlackoutWindow,
|
||||
EstimateSample,
|
||||
MultiSegmentReport,
|
||||
PerWindowReport,
|
||||
evaluate,
|
||||
evaluate_window,
|
||||
load_schedule,
|
||||
write_csv_evidence,
|
||||
)
|
||||
|
||||
|
||||
def _three_windows() -> list[BlackoutWindow]:
|
||||
"""Three disjoint windows, ≥30 s apart per AC-5 of the injector."""
|
||||
return [
|
||||
BlackoutWindow(start_ms=60_000, end_ms=70_000, first_frame_idx=180, last_frame_idx=210),
|
||||
BlackoutWindow(start_ms=120_000, end_ms=130_000, first_frame_idx=360, last_frame_idx=390),
|
||||
BlackoutWindow(start_ms=180_000, end_ms=190_000, first_frame_idx=540, last_frame_idx=570),
|
||||
]
|
||||
|
||||
|
||||
def _samples_clean_run() -> list[EstimateSample]:
|
||||
"""A clean run: satellite_anchored outside windows, dead_reckoned inside,
|
||||
recovery within 333 ms of each end_ms, trajectory continuous."""
|
||||
base_lat, base_lon = 48.275, 37.385
|
||||
samples: list[EstimateSample] = []
|
||||
for win in _three_windows():
|
||||
# Pre-window anchor at end_ms - 1000.
|
||||
samples.append(EstimateSample(
|
||||
monotonic_ms=win.start_ms - 1000,
|
||||
lat_deg=base_lat,
|
||||
lon_deg=base_lon,
|
||||
source_label=SATELLITE_ANCHORED,
|
||||
))
|
||||
# 3 dead_reckoned inside.
|
||||
for i, t in enumerate((win.start_ms + 1000, win.start_ms + 3000, win.start_ms + 5000)):
|
||||
samples.append(EstimateSample(
|
||||
monotonic_ms=t,
|
||||
lat_deg=base_lat,
|
||||
lon_deg=base_lon,
|
||||
source_label=DEAD_RECKONED,
|
||||
))
|
||||
# Recovery: 200 ms after end_ms (well within the 1100 ms budget).
|
||||
rec_lat, rec_lon = offset(base_lat, base_lon, bearing_deg=0.0, distance_m=20.0)
|
||||
samples.append(EstimateSample(
|
||||
monotonic_ms=win.end_ms + 200,
|
||||
lat_deg=rec_lat,
|
||||
lon_deg=rec_lon,
|
||||
source_label=SATELLITE_ANCHORED,
|
||||
))
|
||||
return samples
|
||||
|
||||
|
||||
def test_constants_match_spec() -> None:
|
||||
"""Three thresholds + AC-5 minimum must match the spec."""
|
||||
# Assert
|
||||
assert MAX_TRAJECTORY_JUMP_M == 100.0
|
||||
assert MIN_SEGMENTS_REQUIRED == 3
|
||||
# Recovery-budget approximation: 3 frames @ ~3 fps ≈ 1 s plus a 100 ms slack.
|
||||
assert 900 <= MAX_RECOVERY_FRAMES_SAFETY_MS <= 1500
|
||||
|
||||
|
||||
def test_load_schedule_round_trips_writer_shape(tmp_path: Path) -> None:
|
||||
"""The injector's ``schedule.json`` round-trips through ``load_schedule``."""
|
||||
# Arrange
|
||||
payload = {
|
||||
"segments": [
|
||||
{"start_ms": 100, "end_ms": 200, "first_frame_idx": 3, "last_frame_idx": 6},
|
||||
{"start_ms": 1000, "end_ms": 2000, "first_frame_idx": 30, "last_frame_idx": 60},
|
||||
]
|
||||
}
|
||||
schedule = tmp_path / "schedule.json"
|
||||
schedule.write_text(json.dumps(payload))
|
||||
|
||||
# Act
|
||||
windows = load_schedule(schedule)
|
||||
|
||||
# Assert
|
||||
assert len(windows) == 2
|
||||
assert windows[0].start_ms == 100
|
||||
assert windows[1].last_frame_idx == 60
|
||||
|
||||
|
||||
def test_load_schedule_rejects_missing_file(tmp_path: Path) -> None:
|
||||
# Act / Assert
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_schedule(tmp_path / "missing.json")
|
||||
|
||||
|
||||
def test_load_schedule_rejects_missing_segments_key(tmp_path: Path) -> None:
|
||||
# Arrange
|
||||
bad = tmp_path / "bad.json"
|
||||
bad.write_text(json.dumps({"windows": []}))
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="missing 'segments' key"):
|
||||
load_schedule(bad)
|
||||
|
||||
|
||||
def test_evaluate_window_clean_run_passes_all() -> None:
|
||||
"""A by-the-book run passes label, recovery, and jump checks."""
|
||||
# Arrange
|
||||
windows = _three_windows()
|
||||
samples = _samples_clean_run()
|
||||
|
||||
# Act
|
||||
report = evaluate_window(windows[0], 0, samples)
|
||||
|
||||
# Assert
|
||||
assert report.samples_inside == 3
|
||||
assert report.dead_reckoned_inside == 3
|
||||
assert report.label_violations == ()
|
||||
assert report.passes_label is True
|
||||
assert report.recovery_lag_ms == 200
|
||||
assert report.passes_recovery is True
|
||||
assert report.trajectory_jump_m == pytest.approx(20.0, abs=0.5)
|
||||
assert report.passes_jump is True
|
||||
assert report.passes is True
|
||||
|
||||
|
||||
def test_evaluate_window_satellite_anchored_during_blackout_violates_label() -> None:
|
||||
"""AC-2: any satellite_anchored inside the window is a violation."""
|
||||
# Arrange
|
||||
win = _three_windows()[0]
|
||||
samples = [
|
||||
EstimateSample(win.start_ms - 1000, 48.275, 37.385, SATELLITE_ANCHORED),
|
||||
EstimateSample(win.start_ms + 1000, 48.275, 37.385, DEAD_RECKONED),
|
||||
EstimateSample(win.start_ms + 3000, 48.275, 37.385, SATELLITE_ANCHORED), # violation
|
||||
EstimateSample(win.end_ms + 200, 48.275, 37.385, SATELLITE_ANCHORED),
|
||||
]
|
||||
|
||||
# Act
|
||||
report = evaluate_window(win, 0, samples)
|
||||
|
||||
# Assert
|
||||
assert "satellite_anchored" in report.label_violations
|
||||
assert report.passes_label is False
|
||||
assert report.passes is False
|
||||
|
||||
|
||||
def test_evaluate_window_visual_propagated_during_blackout_violates_label() -> None:
|
||||
"""AC-2: visual_propagated during blackout is also a violation."""
|
||||
# Arrange
|
||||
win = _three_windows()[0]
|
||||
samples = [
|
||||
EstimateSample(win.start_ms + 1000, 48.275, 37.385, VISUAL_PROPAGATED),
|
||||
EstimateSample(win.end_ms + 200, 48.275, 37.385, SATELLITE_ANCHORED),
|
||||
]
|
||||
|
||||
# Act
|
||||
report = evaluate_window(win, 0, samples)
|
||||
|
||||
# Assert
|
||||
assert report.label_violations == ("visual_propagated",)
|
||||
assert report.passes_label is False
|
||||
|
||||
|
||||
def test_evaluate_window_recovery_late_violates_ac3() -> None:
|
||||
"""AC-3: recovery after the 1100 ms budget fails."""
|
||||
# Arrange
|
||||
win = _three_windows()[0]
|
||||
samples = [
|
||||
EstimateSample(win.start_ms + 1000, 48.275, 37.385, DEAD_RECKONED),
|
||||
EstimateSample(win.end_ms + 1500, 48.275, 37.385, SATELLITE_ANCHORED), # 1500 > 1100
|
||||
]
|
||||
|
||||
# Act
|
||||
report = evaluate_window(win, 0, samples)
|
||||
|
||||
# Assert
|
||||
assert report.recovery_lag_ms == 1500
|
||||
assert report.passes_recovery is False
|
||||
assert report.passes is False
|
||||
|
||||
|
||||
def test_evaluate_window_no_recovery_at_all_fails_ac3() -> None:
|
||||
"""AC-3: no satellite_anchored after end_ms → no recovery."""
|
||||
# Arrange
|
||||
win = _three_windows()[0]
|
||||
samples = [
|
||||
EstimateSample(win.start_ms + 1000, 48.275, 37.385, DEAD_RECKONED),
|
||||
EstimateSample(win.end_ms + 500, 48.275, 37.385, DEAD_RECKONED),
|
||||
]
|
||||
|
||||
# Act
|
||||
report = evaluate_window(win, 0, samples)
|
||||
|
||||
# Assert
|
||||
assert report.recovery_lag_ms is None
|
||||
assert report.passes_recovery is False
|
||||
|
||||
|
||||
def test_evaluate_window_jump_above_100m_violates_ac4() -> None:
|
||||
"""AC-4: trajectory jump > 100 m fails."""
|
||||
# Arrange
|
||||
win = _three_windows()[0]
|
||||
base_lat, base_lon = 48.275, 37.385
|
||||
far_lat, far_lon = offset(base_lat, base_lon, bearing_deg=0.0, distance_m=150.0)
|
||||
samples = [
|
||||
EstimateSample(win.start_ms - 100, base_lat, base_lon, SATELLITE_ANCHORED),
|
||||
EstimateSample(win.start_ms + 1000, base_lat, base_lon, DEAD_RECKONED),
|
||||
EstimateSample(win.end_ms + 200, far_lat, far_lon, SATELLITE_ANCHORED),
|
||||
]
|
||||
|
||||
# Act
|
||||
report = evaluate_window(win, 0, samples)
|
||||
|
||||
# Assert
|
||||
assert report.trajectory_jump_m == pytest.approx(150.0, abs=0.5)
|
||||
assert report.passes_jump is False
|
||||
assert report.passes is False
|
||||
|
||||
|
||||
def test_evaluate_aggregate_clean_passes() -> None:
|
||||
"""All 3 windows pass → overall passes."""
|
||||
# Arrange
|
||||
windows = _three_windows()
|
||||
samples = _samples_clean_run()
|
||||
|
||||
# Act
|
||||
report = evaluate(windows, samples)
|
||||
|
||||
# Assert
|
||||
assert report.window_count == 3
|
||||
assert report.passes_segment_count is True
|
||||
assert report.failed_windows == ()
|
||||
assert report.passes is True
|
||||
|
||||
|
||||
def test_evaluate_aggregate_single_window_failure_fails_overall() -> None:
|
||||
"""One window fails → overall fails; failed_windows lists it."""
|
||||
# Arrange
|
||||
windows = _three_windows()
|
||||
samples = _samples_clean_run()
|
||||
# Inject a label violation in window 1.
|
||||
samples.insert(0, EstimateSample(
|
||||
windows[1].start_ms + 4000, 48.275, 37.385, SATELLITE_ANCHORED
|
||||
))
|
||||
|
||||
# Act
|
||||
report = evaluate(windows, samples)
|
||||
|
||||
# Assert
|
||||
assert 1 in report.failed_windows
|
||||
assert report.passes is False
|
||||
|
||||
|
||||
def test_evaluate_aggregate_below_min_segments_fails_overall() -> None:
|
||||
"""AC-5: <3 windows in the schedule → aggregate fails even if each passes."""
|
||||
# Arrange — only 2 windows.
|
||||
windows = _three_windows()[:2]
|
||||
samples = _samples_clean_run()
|
||||
|
||||
# Act
|
||||
report = evaluate(windows, samples)
|
||||
|
||||
# Assert
|
||||
assert report.window_count == 2
|
||||
assert report.passes_segment_count is False
|
||||
assert report.passes is False
|
||||
|
||||
|
||||
def test_evaluate_rejects_unknown_source_label() -> None:
|
||||
"""Programming-error guard: unknown source_label raises."""
|
||||
# Arrange
|
||||
win = _three_windows()[0]
|
||||
samples = [
|
||||
EstimateSample(win.start_ms + 1000, 48.275, 37.385, "stale_cache"),
|
||||
]
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="unknown source_label"):
|
||||
evaluate([win], samples)
|
||||
|
||||
|
||||
def test_write_csv_evidence_round_trip(tmp_path: Path) -> None:
|
||||
"""CSV header + per-window row shape."""
|
||||
# Arrange
|
||||
windows = _three_windows()
|
||||
samples = _samples_clean_run()
|
||||
report = evaluate(windows, samples)
|
||||
out_path = tmp_path / "ft-p-08.csv"
|
||||
|
||||
# Act
|
||||
write_csv_evidence(out_path, report)
|
||||
|
||||
# Assert
|
||||
rows = list(csv.reader(out_path.open()))
|
||||
assert rows[0] == [
|
||||
"window_index",
|
||||
"start_ms",
|
||||
"end_ms",
|
||||
"samples_inside",
|
||||
"dead_reckoned_inside",
|
||||
"label_violations",
|
||||
"recovery_lag_ms",
|
||||
"trajectory_jump_m",
|
||||
"passes_label",
|
||||
"passes_recovery",
|
||||
"passes_jump",
|
||||
"passes",
|
||||
]
|
||||
assert len(rows) == 1 + 3
|
||||
# Every window in the clean run passes.
|
||||
for r in rows[1:]:
|
||||
assert r[-1] == "true"
|
||||
|
||||
|
||||
def test_write_csv_evidence_serialises_no_recovery_as_blank(tmp_path: Path) -> None:
|
||||
"""When recovery is None, the recovery_lag_ms + trajectory_jump_m cells are blank."""
|
||||
# Arrange
|
||||
win = _three_windows()[0]
|
||||
samples = [EstimateSample(win.start_ms + 1000, 48.275, 37.385, DEAD_RECKONED)]
|
||||
report = evaluate([win], samples)
|
||||
out_path = tmp_path / "ft-p-08.csv"
|
||||
|
||||
# Act
|
||||
write_csv_evidence(out_path, report)
|
||||
|
||||
# Assert
|
||||
rows = list(csv.reader(out_path.open()))
|
||||
assert rows[1][6] == "" # recovery_lag_ms
|
||||
assert rows[1][7] == "" # trajectory_jump_m
|
||||
@@ -0,0 +1,517 @@
|
||||
"""Unit tests for ``runner.helpers.sharp_turn_detector`` (FT-P-07 + FT-N-02 / AZ-414).
|
||||
|
||||
Covers:
|
||||
|
||||
* threshold env-var override + defaults (AC-3.2)
|
||||
* contiguous-run detection + min-run-length pruning
|
||||
* synthetic-overlay fallback when no natural turn
|
||||
* FT-N-02 AC-2 (during-turn label) + AC-3 (covariance non-decreasing)
|
||||
* FT-P-07 AC-4 (recovery lag), AC-5 (drift ≤200 m), AC-6 (heading envelope)
|
||||
* CSV evidence schema
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from runner.helpers.geo import offset
|
||||
from runner.helpers.sharp_turn_detector import (
|
||||
ALLOWED_DURING_TURN_LABELS,
|
||||
DEFAULT_GYRO_Z_THRESHOLD_MDPS,
|
||||
MAX_HEADING_CHANGE_DEG,
|
||||
MAX_RECOVERY_DRIFT_M,
|
||||
MAX_RECOVERY_FRAMES_SAFETY_MS,
|
||||
MIN_RUN_LENGTH,
|
||||
SHARP_TURN_ENV_VAR,
|
||||
GyroSample,
|
||||
TurnDetection,
|
||||
TurnFrameSample,
|
||||
TurnSegment,
|
||||
detect_or_synthesize,
|
||||
detect_turn_segments,
|
||||
evaluate_ft_n_02,
|
||||
evaluate_ft_p_07,
|
||||
get_threshold_mdps,
|
||||
load_zgyro_samples,
|
||||
synthesize_overlay_segment,
|
||||
write_csv_evidence,
|
||||
)
|
||||
|
||||
|
||||
def _samples(zgyros_mdps: list[int], dt_ms: int = 100) -> list[GyroSample]:
|
||||
return [
|
||||
GyroSample(monotonic_ms=i * dt_ms, time_s=i * dt_ms / 1000.0, zgyro_mdps=z)
|
||||
for i, z in enumerate(zgyros_mdps)
|
||||
]
|
||||
|
||||
|
||||
def _frame(
|
||||
t_ms: int,
|
||||
lat: float = 48.275,
|
||||
lon: float = 37.385,
|
||||
label: str = "satellite_anchored",
|
||||
cov: float = 5.0,
|
||||
) -> TurnFrameSample:
|
||||
return TurnFrameSample(
|
||||
monotonic_ms=t_ms,
|
||||
lat_deg=lat,
|
||||
lon_deg=lon,
|
||||
source_label=label,
|
||||
cov_semi_major_m=cov,
|
||||
)
|
||||
|
||||
|
||||
def test_default_threshold_when_env_unset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
monkeypatch.delenv(SHARP_TURN_ENV_VAR, raising=False)
|
||||
|
||||
# Assert
|
||||
assert get_threshold_mdps() == DEFAULT_GYRO_Z_THRESHOLD_MDPS
|
||||
|
||||
|
||||
def test_threshold_env_override_applies(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
monkeypatch.setenv(SHARP_TURN_ENV_VAR, "12345")
|
||||
|
||||
# Assert
|
||||
assert get_threshold_mdps() == 12345
|
||||
|
||||
|
||||
def test_threshold_env_rejects_non_int(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
monkeypatch.setenv(SHARP_TURN_ENV_VAR, "fast")
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="not a valid int"):
|
||||
get_threshold_mdps()
|
||||
|
||||
|
||||
def test_threshold_env_rejects_non_positive(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
monkeypatch.setenv(SHARP_TURN_ENV_VAR, "0")
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="must be > 0"):
|
||||
get_threshold_mdps()
|
||||
|
||||
|
||||
def test_detect_simple_turn() -> None:
|
||||
"""A clean 5-sample run above threshold is detected as one segment."""
|
||||
# Arrange — under, 5x over, under.
|
||||
samples = _samples([0, 0, 35_000, 40_000, 50_000, 35_000, 31_000, 0, 0])
|
||||
|
||||
# Act
|
||||
detection = detect_turn_segments(samples, threshold_mdps=30_000)
|
||||
|
||||
# Assert
|
||||
assert detection.synthetic_overlay is False
|
||||
assert len(detection.segments) == 1
|
||||
seg = detection.segments[0]
|
||||
assert seg.start_index == 2
|
||||
assert seg.end_index == 6
|
||||
assert seg.sample_count == 5
|
||||
assert seg.peak_abs_zgyro_mdps == 50_000
|
||||
|
||||
|
||||
def test_detect_short_run_pruned() -> None:
|
||||
"""A 2-sample run below MIN_RUN_LENGTH is filtered out."""
|
||||
# Arrange — MIN_RUN_LENGTH is 3.
|
||||
samples = _samples([0, 35_000, 40_000, 0, 0])
|
||||
|
||||
# Act
|
||||
detection = detect_turn_segments(samples, threshold_mdps=30_000)
|
||||
|
||||
# Assert
|
||||
assert detection.segments == ()
|
||||
|
||||
|
||||
def test_detect_multiple_segments() -> None:
|
||||
"""Two separated turns produce two segments."""
|
||||
# Arrange
|
||||
samples = _samples(
|
||||
[
|
||||
0, 35_000, 40_000, 45_000, 0, 0,
|
||||
0, 50_000, 55_000, 60_000, 70_000, 0,
|
||||
]
|
||||
)
|
||||
|
||||
# Act
|
||||
detection = detect_turn_segments(samples, threshold_mdps=30_000, min_run_length=3)
|
||||
|
||||
# Assert
|
||||
assert len(detection.segments) == 2
|
||||
assert detection.segments[0].sample_count == 3
|
||||
assert detection.segments[1].sample_count == 4
|
||||
assert detection.segments[1].peak_abs_zgyro_mdps == 70_000
|
||||
|
||||
|
||||
def test_detect_negative_yaw_uses_abs_value() -> None:
|
||||
"""Sustained left-turn (negative zgyro) is detected via |zgyro|."""
|
||||
# Arrange
|
||||
samples = _samples([0, -40_000, -45_000, -55_000, 0])
|
||||
|
||||
# Act
|
||||
detection = detect_turn_segments(samples, threshold_mdps=30_000)
|
||||
|
||||
# Assert
|
||||
assert len(detection.segments) == 1
|
||||
assert detection.segments[0].peak_abs_zgyro_mdps == 55_000
|
||||
|
||||
|
||||
def test_detect_tail_run_included() -> None:
|
||||
"""A run that extends to the last sample is still detected."""
|
||||
# Arrange
|
||||
samples = _samples([0, 0, 35_000, 40_000, 45_000])
|
||||
|
||||
# Act
|
||||
detection = detect_turn_segments(samples, threshold_mdps=30_000)
|
||||
|
||||
# Assert
|
||||
assert len(detection.segments) == 1
|
||||
assert detection.segments[0].end_index == 4
|
||||
|
||||
|
||||
def test_detect_rejects_invalid_min_run_length() -> None:
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="min_run_length"):
|
||||
detect_turn_segments([], threshold_mdps=30_000, min_run_length=0)
|
||||
|
||||
|
||||
def test_synthesize_overlay_when_no_natural_turn() -> None:
|
||||
"""No-turn fixture falls back to synthetic overlay."""
|
||||
# Arrange — all zeros.
|
||||
samples = _samples([0] * 20)
|
||||
|
||||
# Act
|
||||
detection = synthesize_overlay_segment(samples, threshold_mdps=30_000, anchor_fraction=0.5)
|
||||
|
||||
# Assert
|
||||
assert detection.synthetic_overlay is True
|
||||
assert len(detection.segments) == 1
|
||||
seg = detection.segments[0]
|
||||
assert seg.sample_count >= MIN_RUN_LENGTH
|
||||
|
||||
|
||||
def test_synthesize_overlay_rejects_empty_samples() -> None:
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="samples must not be empty"):
|
||||
synthesize_overlay_segment([], threshold_mdps=30_000)
|
||||
|
||||
|
||||
def test_synthesize_overlay_rejects_invalid_anchor_fraction() -> None:
|
||||
# Arrange
|
||||
samples = _samples([0] * 5)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="anchor_fraction"):
|
||||
synthesize_overlay_segment(samples, threshold_mdps=30_000, anchor_fraction=1.5)
|
||||
|
||||
|
||||
def test_synthesize_overlay_rejects_short_duration() -> None:
|
||||
# Arrange
|
||||
samples = _samples([0] * 5)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="duration_samples"):
|
||||
synthesize_overlay_segment(samples, threshold_mdps=30_000, duration_samples=1)
|
||||
|
||||
|
||||
def test_ft_n_02_passes_with_only_propagated_labels() -> None:
|
||||
"""All inside-window labels are visual_propagated → AC-2 pass."""
|
||||
# Arrange
|
||||
seg = TurnSegment(
|
||||
start_index=0, end_index=2, start_ms=1000, end_ms=1200,
|
||||
peak_abs_zgyro_mdps=40_000, sample_count=3,
|
||||
)
|
||||
samples = [
|
||||
_frame(900, label="satellite_anchored", cov=5.0),
|
||||
_frame(1000, label="visual_propagated", cov=5.5),
|
||||
_frame(1100, label="dead_reckoned", cov=6.0),
|
||||
_frame(1200, label="visual_propagated", cov=6.5),
|
||||
_frame(1300, label="satellite_anchored", cov=2.0),
|
||||
]
|
||||
|
||||
# Act
|
||||
report = evaluate_ft_n_02(seg, segment_index=0, samples=samples)
|
||||
|
||||
# Assert
|
||||
assert report.samples_inside == 3
|
||||
assert report.label_violations == ()
|
||||
assert report.cov_non_decreasing is True
|
||||
assert report.passes is True
|
||||
|
||||
|
||||
def test_ft_n_02_fails_on_satellite_anchored_during_turn() -> None:
|
||||
"""satellite_anchored inside turn → AC-2 violation."""
|
||||
# Arrange
|
||||
seg = TurnSegment(0, 2, 1000, 1200, 40_000, 3)
|
||||
samples = [
|
||||
_frame(1100, label="satellite_anchored", cov=5.0),
|
||||
_frame(1200, label="visual_propagated", cov=6.0),
|
||||
]
|
||||
|
||||
# Act
|
||||
report = evaluate_ft_n_02(seg, 0, samples)
|
||||
|
||||
# Assert
|
||||
assert "satellite_anchored" in report.label_violations
|
||||
assert report.passes_label is False
|
||||
assert "satellite_anchored" not in ALLOWED_DURING_TURN_LABELS
|
||||
|
||||
|
||||
def test_ft_n_02_fails_on_decreasing_covariance() -> None:
|
||||
"""Covariance drop during turn → AC-3 violation."""
|
||||
# Arrange
|
||||
seg = TurnSegment(0, 2, 1000, 1200, 40_000, 3)
|
||||
samples = [
|
||||
_frame(1000, label="visual_propagated", cov=5.0),
|
||||
_frame(1100, label="visual_propagated", cov=8.0),
|
||||
_frame(1200, label="visual_propagated", cov=6.0), # drop!
|
||||
]
|
||||
|
||||
# Act
|
||||
report = evaluate_ft_n_02(seg, 0, samples)
|
||||
|
||||
# Assert
|
||||
assert report.cov_non_decreasing is False
|
||||
assert report.first_decreasing_at_ms == 1200
|
||||
assert report.passes_cov is False
|
||||
|
||||
|
||||
def test_ft_n_02_zero_samples_inside_does_not_pass() -> None:
|
||||
"""Empty turn window → passes_label is False (no data)."""
|
||||
# Arrange
|
||||
seg = TurnSegment(0, 2, 1000, 1200, 40_000, 3)
|
||||
samples = [_frame(900), _frame(1500)]
|
||||
|
||||
# Act
|
||||
report = evaluate_ft_n_02(seg, 0, samples)
|
||||
|
||||
# Assert
|
||||
assert report.samples_inside == 0
|
||||
assert report.passes_label is False
|
||||
|
||||
|
||||
def test_ft_p_07_passes_recovery_within_budget() -> None:
|
||||
"""Recovery anchor within ~1s, drift ≤200m, no pre-anchor → heading OK by default."""
|
||||
# Arrange
|
||||
seg = TurnSegment(0, 2, 1000, 1200, 40_000, 3)
|
||||
samples = [
|
||||
_frame(900, label="visual_propagated"),
|
||||
_frame(1100, label="visual_propagated"),
|
||||
_frame(1200, label="dead_reckoned"),
|
||||
_frame(1500, label="satellite_anchored"), # +300 ms
|
||||
]
|
||||
|
||||
# Act
|
||||
report = evaluate_ft_p_07(seg, 0, samples)
|
||||
|
||||
# Assert
|
||||
assert report.recovery_lag_ms == 300
|
||||
assert report.passes_recovery_lag is True
|
||||
assert report.drift_m == pytest.approx(0.0, abs=1e-6)
|
||||
assert report.passes_drift is True
|
||||
assert report.passes is True
|
||||
|
||||
|
||||
def test_ft_p_07_fails_when_recovery_takes_too_long() -> None:
|
||||
"""Recovery beyond safety budget → AC-4 fail."""
|
||||
# Arrange
|
||||
seg = TurnSegment(0, 2, 1000, 1200, 40_000, 3)
|
||||
samples = [
|
||||
_frame(1100, label="visual_propagated"),
|
||||
_frame(1200, label="dead_reckoned"),
|
||||
_frame(1200 + MAX_RECOVERY_FRAMES_SAFETY_MS + 100, label="satellite_anchored"),
|
||||
]
|
||||
|
||||
# Act
|
||||
report = evaluate_ft_p_07(seg, 0, samples)
|
||||
|
||||
# Assert
|
||||
assert report.passes_recovery_lag is False
|
||||
|
||||
|
||||
def test_ft_p_07_fails_when_drift_exceeds_budget() -> None:
|
||||
"""Drift between propagated-end and recovery anchor > 200 m → AC-5 fail."""
|
||||
# Arrange
|
||||
seg = TurnSegment(0, 2, 1000, 1200, 40_000, 3)
|
||||
base_lat, base_lon = 48.275, 37.385
|
||||
far_lat, far_lon = offset(base_lat, base_lon, bearing_deg=90.0, distance_m=MAX_RECOVERY_DRIFT_M + 50.0)
|
||||
samples = [
|
||||
_frame(1200, lat=base_lat, lon=base_lon, label="visual_propagated"),
|
||||
_frame(1500, lat=far_lat, lon=far_lon, label="satellite_anchored"),
|
||||
]
|
||||
|
||||
# Act
|
||||
report = evaluate_ft_p_07(seg, 0, samples)
|
||||
|
||||
# Assert
|
||||
assert report.drift_m is not None
|
||||
assert report.drift_m > MAX_RECOVERY_DRIFT_M
|
||||
assert report.passes_drift is False
|
||||
|
||||
|
||||
def test_ft_p_07_no_recovery_anchor_fails_all() -> None:
|
||||
"""No satellite_anchored after turn → AC-4/5/6 all fail."""
|
||||
# Arrange
|
||||
seg = TurnSegment(0, 2, 1000, 1200, 40_000, 3)
|
||||
samples = [_frame(1100, label="visual_propagated"), _frame(1500, label="visual_propagated")]
|
||||
|
||||
# Act
|
||||
report = evaluate_ft_p_07(seg, 0, samples)
|
||||
|
||||
# Assert
|
||||
assert report.recovery_anchor_ms is None
|
||||
assert report.passes is False
|
||||
|
||||
|
||||
def test_ft_p_07_heading_envelope_with_pre_anchor() -> None:
|
||||
"""Pre-anchor + propagated-end + recovery → heading delta computed and within 70°."""
|
||||
# Arrange — straight-line course; heading delta should be ~0°.
|
||||
seg = TurnSegment(0, 2, 1000, 1200, 40_000, 3)
|
||||
base_lat, base_lon = 48.275, 37.385
|
||||
mid_lat, mid_lon = offset(base_lat, base_lon, bearing_deg=90.0, distance_m=50.0)
|
||||
far_lat, far_lon = offset(mid_lat, mid_lon, bearing_deg=90.0, distance_m=50.0)
|
||||
samples = [
|
||||
_frame(900, lat=base_lat, lon=base_lon, label="satellite_anchored"),
|
||||
_frame(1200, lat=mid_lat, lon=mid_lon, label="visual_propagated"),
|
||||
_frame(1500, lat=far_lat, lon=far_lon, label="satellite_anchored"),
|
||||
]
|
||||
|
||||
# Act
|
||||
report = evaluate_ft_p_07(seg, 0, samples)
|
||||
|
||||
# Assert
|
||||
assert report.heading_change_deg is not None
|
||||
assert report.heading_change_deg < 1.0
|
||||
assert report.in_heading_envelope is True
|
||||
assert report.passes_heading is True
|
||||
|
||||
|
||||
def test_ft_p_07_heading_outside_envelope_fails() -> None:
|
||||
"""≥90° heading reversal → AC-6 fail."""
|
||||
# Arrange — pre→mid is east, mid→recovery is west (180° flip).
|
||||
seg = TurnSegment(0, 2, 1000, 1200, 40_000, 3)
|
||||
base_lat, base_lon = 48.275, 37.385
|
||||
mid_lat, mid_lon = offset(base_lat, base_lon, bearing_deg=90.0, distance_m=50.0)
|
||||
rev_lat, rev_lon = offset(mid_lat, mid_lon, bearing_deg=270.0, distance_m=50.0)
|
||||
samples = [
|
||||
_frame(900, lat=base_lat, lon=base_lon, label="satellite_anchored"),
|
||||
_frame(1200, lat=mid_lat, lon=mid_lon, label="visual_propagated"),
|
||||
_frame(1500, lat=rev_lat, lon=rev_lon, label="satellite_anchored"),
|
||||
]
|
||||
|
||||
# Act
|
||||
report = evaluate_ft_p_07(seg, 0, samples)
|
||||
|
||||
# Assert
|
||||
assert report.heading_change_deg is not None
|
||||
assert report.heading_change_deg > MAX_HEADING_CHANGE_DEG
|
||||
assert report.in_heading_envelope is False
|
||||
|
||||
|
||||
def test_load_zgyro_samples_missing_file_raises(tmp_path: Path) -> None:
|
||||
# Act / Assert
|
||||
with pytest.raises(FileNotFoundError, match="data_imu.csv not found"):
|
||||
load_zgyro_samples(tmp_path / "missing.csv")
|
||||
|
||||
|
||||
def test_load_zgyro_samples_missing_column_raises(tmp_path: Path) -> None:
|
||||
# Arrange
|
||||
csv_path = tmp_path / "data_imu.csv"
|
||||
csv_path.write_text("timestamp(ms),Time\n1000,1.0\n")
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="missing required column SCALED_IMU2.zgyro"):
|
||||
load_zgyro_samples(csv_path)
|
||||
|
||||
|
||||
def test_load_zgyro_samples_parses_rows(tmp_path: Path) -> None:
|
||||
"""Header + 2 data rows → 2 GyroSamples with correct types."""
|
||||
# Arrange
|
||||
csv_path = tmp_path / "data_imu.csv"
|
||||
csv_path.write_text(
|
||||
"timestamp(ms),Time,SCALED_IMU2.zgyro\n"
|
||||
"1000,1.0,15000\n"
|
||||
"1100,1.1,-32000\n"
|
||||
",,\n" # empty row — must be skipped, not crash
|
||||
)
|
||||
|
||||
# Act
|
||||
samples = load_zgyro_samples(csv_path)
|
||||
|
||||
# Assert
|
||||
assert len(samples) == 2
|
||||
assert samples[0].monotonic_ms == 1000
|
||||
assert samples[1].zgyro_mdps == -32000
|
||||
|
||||
|
||||
def test_detect_or_synthesize_uses_natural_when_available(tmp_path: Path) -> None:
|
||||
"""detect_or_synthesize prefers natural turns over synthetic overlay."""
|
||||
# Arrange
|
||||
csv_path = tmp_path / "data_imu.csv"
|
||||
csv_path.write_text(
|
||||
"timestamp(ms),Time,SCALED_IMU2.zgyro\n"
|
||||
+ "\n".join(
|
||||
f"{i * 100},{i * 0.1},{40_000 if 2 <= i <= 6 else 0}"
|
||||
for i in range(10)
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
# Act
|
||||
detection = detect_or_synthesize(csv_path)
|
||||
|
||||
# Assert
|
||||
assert detection.has_natural_turn is True
|
||||
|
||||
|
||||
def test_detect_or_synthesize_falls_back_to_overlay(tmp_path: Path) -> None:
|
||||
"""Quiet flight → synthetic overlay marked True."""
|
||||
# Arrange
|
||||
csv_path = tmp_path / "data_imu.csv"
|
||||
csv_path.write_text(
|
||||
"timestamp(ms),Time,SCALED_IMU2.zgyro\n"
|
||||
+ "\n".join(f"{i * 100},{i * 0.1},0" for i in range(10))
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
# Act
|
||||
detection = detect_or_synthesize(csv_path)
|
||||
|
||||
# Assert
|
||||
assert detection.synthetic_overlay is True
|
||||
assert len(detection.segments) == 1
|
||||
|
||||
|
||||
def test_write_csv_evidence_round_trip(tmp_path: Path) -> None:
|
||||
"""Combined evidence CSV has expected header + one row per segment."""
|
||||
# Arrange
|
||||
detection = TurnDetection(
|
||||
segments=(TurnSegment(0, 2, 1000, 1200, 40_000, 3),),
|
||||
threshold_mdps=30_000,
|
||||
synthetic_overlay=False,
|
||||
)
|
||||
samples = [
|
||||
_frame(900, label="satellite_anchored", cov=5.0),
|
||||
_frame(1100, label="visual_propagated", cov=5.5),
|
||||
_frame(1500, label="satellite_anchored", cov=2.0),
|
||||
]
|
||||
n02 = [evaluate_ft_n_02(detection.segments[0], 0, samples)]
|
||||
p07 = [evaluate_ft_p_07(detection.segments[0], 0, samples)]
|
||||
out = tmp_path / "ft-p-07.csv"
|
||||
|
||||
# Act
|
||||
write_csv_evidence(out, detection, n02, p07)
|
||||
|
||||
# Assert
|
||||
rows = list(csv.reader(out.open()))
|
||||
assert rows[0][:5] == [
|
||||
"segment_index", "start_ms", "end_ms", "peak_abs_zgyro_mdps", "synthetic_overlay"
|
||||
]
|
||||
assert rows[1][4] == "false"
|
||||
assert rows[1][12] == "true" # passes_ft_n_02
|
||||
assert rows[1][13] == "true" # passes_ft_p_07
|
||||
@@ -0,0 +1,284 @@
|
||||
"""Unit tests for ``runner.helpers.smoothing_evaluator`` (FT-P-10 / AZ-418).
|
||||
|
||||
Covers AC-2 (improvement rate ≥0.80), AC-3 (mean improvement ≥5 m), and
|
||||
the FDR pairing discipline (raw + smoothed per keyframe, no dupes).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from runner.helpers.geo import offset
|
||||
from runner.helpers.smoothing_evaluator import (
|
||||
IMPROVEMENT_RATE_REQUIRED,
|
||||
MEAN_IMPROVEMENT_M_REQUIRED,
|
||||
GtPose,
|
||||
KeyframePair,
|
||||
KeyframePoseRecord,
|
||||
SmoothingReport,
|
||||
evaluate,
|
||||
pair_records,
|
||||
resolve_gt_at,
|
||||
write_csv_evidence,
|
||||
)
|
||||
|
||||
|
||||
def _gt_track(n: int = 60, dt_ms: int = 100) -> list[GtPose]:
|
||||
"""A straight-line GT track 10 Hz for 6 s, base lat/lon = Derkachi-ish."""
|
||||
return [
|
||||
GtPose(monotonic_ms=i * dt_ms, lat_deg=48.275 + i * 1e-4, lon_deg=37.385)
|
||||
for i in range(n)
|
||||
]
|
||||
|
||||
|
||||
def _raw_smoothed_pair(
|
||||
keyframe_id: int,
|
||||
gt: GtPose,
|
||||
raw_offset_m: float,
|
||||
smoothed_offset_m: float,
|
||||
) -> tuple[KeyframePoseRecord, KeyframePoseRecord]:
|
||||
"""Build a (raw, smoothed) pair offset north of the GT pose by given amounts."""
|
||||
raw_lat, raw_lon = offset(gt.lat_deg, gt.lon_deg, bearing_deg=0.0, distance_m=raw_offset_m)
|
||||
sm_lat, sm_lon = offset(gt.lat_deg, gt.lon_deg, bearing_deg=0.0, distance_m=smoothed_offset_m)
|
||||
raw = KeyframePoseRecord(
|
||||
keyframe_id=keyframe_id,
|
||||
pose_kind="raw",
|
||||
monotonic_ms=gt.monotonic_ms,
|
||||
lat_deg=raw_lat,
|
||||
lon_deg=raw_lon,
|
||||
)
|
||||
smoothed = KeyframePoseRecord(
|
||||
keyframe_id=keyframe_id,
|
||||
pose_kind="smoothed",
|
||||
monotonic_ms=gt.monotonic_ms + 500, # window-exit later
|
||||
lat_deg=sm_lat,
|
||||
lon_deg=sm_lon,
|
||||
)
|
||||
return raw, smoothed
|
||||
|
||||
|
||||
def test_constants_match_spec() -> None:
|
||||
"""The AC-2 + AC-3 thresholds must match the spec text."""
|
||||
# Assert
|
||||
assert IMPROVEMENT_RATE_REQUIRED == 0.80
|
||||
assert MEAN_IMPROVEMENT_M_REQUIRED == 5.0
|
||||
|
||||
|
||||
def test_resolve_gt_at_picks_nearest() -> None:
|
||||
"""Linear scan picks the nearest GT pose."""
|
||||
# Arrange
|
||||
track = _gt_track()
|
||||
|
||||
# Act
|
||||
nearest = resolve_gt_at(monotonic_ms=523, gt_track=track)
|
||||
|
||||
# Assert — nearest 10 Hz sample to 523 ms is at 500 ms.
|
||||
assert nearest.monotonic_ms == 500
|
||||
|
||||
|
||||
def test_resolve_gt_at_rejects_empty_track() -> None:
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="gt_track is empty"):
|
||||
resolve_gt_at(monotonic_ms=0, gt_track=[])
|
||||
|
||||
|
||||
def test_pair_records_groups_by_keyframe() -> None:
|
||||
"""raw + smoothed get grouped per keyframe; partial entries remain partial."""
|
||||
# Arrange
|
||||
gt = _gt_track()[0]
|
||||
raw, sm = _raw_smoothed_pair(7, gt, raw_offset_m=10.0, smoothed_offset_m=3.0)
|
||||
records = [raw, sm]
|
||||
|
||||
# Act
|
||||
paired = pair_records(records)
|
||||
|
||||
# Assert
|
||||
assert paired == {7: (raw, sm)}
|
||||
|
||||
|
||||
def test_pair_records_keeps_orphans_partial() -> None:
|
||||
"""Smoothed without raw → (None, smoothed)."""
|
||||
# Arrange
|
||||
gt = _gt_track()[0]
|
||||
_, sm = _raw_smoothed_pair(7, gt, raw_offset_m=10.0, smoothed_offset_m=3.0)
|
||||
|
||||
# Act
|
||||
paired = pair_records([sm])
|
||||
|
||||
# Assert
|
||||
assert paired == {7: (None, sm)}
|
||||
|
||||
|
||||
def test_pair_records_rejects_duplicate_pose_kind() -> None:
|
||||
"""Two raws for the same keyframe → ValueError."""
|
||||
# Arrange
|
||||
gt = _gt_track()[0]
|
||||
raw1, _ = _raw_smoothed_pair(7, gt, raw_offset_m=10.0, smoothed_offset_m=3.0)
|
||||
raw2, _ = _raw_smoothed_pair(7, gt, raw_offset_m=8.0, smoothed_offset_m=3.0)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="duplicate raw pose"):
|
||||
pair_records([raw1, raw2])
|
||||
|
||||
|
||||
def test_pair_records_rejects_unknown_pose_kind() -> None:
|
||||
"""Programming-error guard for unknown pose_kind values."""
|
||||
# Arrange
|
||||
bogus = KeyframePoseRecord(
|
||||
keyframe_id=1, pose_kind="filtered", monotonic_ms=0, lat_deg=0.0, lon_deg=0.0
|
||||
)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="unknown pose_kind 'filtered'"):
|
||||
pair_records([bogus])
|
||||
|
||||
|
||||
def test_evaluate_all_smoothed_wins_passes() -> None:
|
||||
"""Every keyframe's smoothed is closer to GT → improvement rate 1.0."""
|
||||
# Arrange — 20 keyframes; raw 15m off, smoothed 2m off → 13m improvement each.
|
||||
track = _gt_track()
|
||||
records: list[KeyframePoseRecord] = []
|
||||
for i, gt in enumerate(track[:20]):
|
||||
raw, sm = _raw_smoothed_pair(i, gt, raw_offset_m=15.0, smoothed_offset_m=2.0)
|
||||
records += [raw, sm]
|
||||
|
||||
# Act
|
||||
report = evaluate(records, track)
|
||||
|
||||
# Assert
|
||||
assert report.pair_count == 20
|
||||
assert report.improvement_rate == 1.0
|
||||
assert report.mean_improvement_m == pytest.approx(13.0, abs=1.0)
|
||||
assert report.passes is True
|
||||
|
||||
|
||||
def test_evaluate_at_80_pct_improvement_rate_passes() -> None:
|
||||
"""80% smoothed wins AND mean improvement ≥5m → AC-2+AC-3 pass."""
|
||||
# Arrange — 10 keyframes: 8 smoothed_wins by 10m, 2 smoothed_loses by 1m.
|
||||
track = _gt_track()
|
||||
records: list[KeyframePoseRecord] = []
|
||||
for i, gt in enumerate(track[:10]):
|
||||
if i < 8:
|
||||
raw, sm = _raw_smoothed_pair(i, gt, raw_offset_m=12.0, smoothed_offset_m=2.0)
|
||||
else:
|
||||
raw, sm = _raw_smoothed_pair(i, gt, raw_offset_m=2.0, smoothed_offset_m=3.0)
|
||||
records += [raw, sm]
|
||||
|
||||
# Act
|
||||
report = evaluate(records, track)
|
||||
|
||||
# Assert
|
||||
assert report.improvement_rate == pytest.approx(0.80, abs=1e-6)
|
||||
assert report.passes_rate is True
|
||||
# mean = ((10 * 8) + (-1 * 2)) / 10 = 7.8 m
|
||||
assert report.mean_improvement_m == pytest.approx(7.8, abs=1.0)
|
||||
assert report.passes_mean is True
|
||||
assert report.passes is True
|
||||
|
||||
|
||||
def test_evaluate_below_80_pct_fails_overall() -> None:
|
||||
"""79% smoothed wins → AC-2 fails."""
|
||||
# Arrange — 100 keyframes: 79 wins, 21 losses.
|
||||
track = _gt_track(n=100)
|
||||
records: list[KeyframePoseRecord] = []
|
||||
for i, gt in enumerate(track):
|
||||
if i < 79:
|
||||
raw, sm = _raw_smoothed_pair(i, gt, raw_offset_m=15.0, smoothed_offset_m=2.0)
|
||||
else:
|
||||
raw, sm = _raw_smoothed_pair(i, gt, raw_offset_m=2.0, smoothed_offset_m=3.0)
|
||||
records += [raw, sm]
|
||||
|
||||
# Act
|
||||
report = evaluate(records, track)
|
||||
|
||||
# Assert
|
||||
assert report.improvement_rate == pytest.approx(0.79)
|
||||
assert report.passes_rate is False
|
||||
assert report.passes is False
|
||||
|
||||
|
||||
def test_evaluate_mean_improvement_below_5m_fails() -> None:
|
||||
"""100% rate but mean improvement = 3m → AC-3 fails."""
|
||||
# Arrange — every keyframe smoothed wins by 3 m.
|
||||
track = _gt_track()
|
||||
records: list[KeyframePoseRecord] = []
|
||||
for i, gt in enumerate(track[:20]):
|
||||
raw, sm = _raw_smoothed_pair(i, gt, raw_offset_m=8.0, smoothed_offset_m=5.0)
|
||||
records += [raw, sm]
|
||||
|
||||
# Act
|
||||
report = evaluate(records, track)
|
||||
|
||||
# Assert
|
||||
assert report.improvement_rate == 1.0
|
||||
assert report.mean_improvement_m == pytest.approx(3.0, abs=0.5)
|
||||
assert report.passes_mean is False
|
||||
assert report.passes is False
|
||||
|
||||
|
||||
def test_evaluate_excludes_unpaired_keyframes() -> None:
|
||||
"""Keyframe with only raw OR only smoothed is silently excluded."""
|
||||
# Arrange — keyframe 0 fully paired, keyframe 1 has only raw.
|
||||
track = _gt_track()
|
||||
raw0, sm0 = _raw_smoothed_pair(0, track[0], raw_offset_m=10.0, smoothed_offset_m=2.0)
|
||||
raw1, _ = _raw_smoothed_pair(1, track[1], raw_offset_m=10.0, smoothed_offset_m=2.0)
|
||||
|
||||
# Act
|
||||
report = evaluate([raw0, sm0, raw1], track)
|
||||
|
||||
# Assert
|
||||
assert report.pair_count == 1
|
||||
assert report.pairs[0].keyframe_id == 0
|
||||
|
||||
|
||||
def test_evaluate_empty_records_does_not_pass() -> None:
|
||||
"""Zero pairs → does NOT pass; rate + mean are 0."""
|
||||
# Arrange
|
||||
track = _gt_track()
|
||||
|
||||
# Act
|
||||
report = evaluate([], track)
|
||||
|
||||
# Assert
|
||||
assert report.pair_count == 0
|
||||
assert report.passes_rate is False
|
||||
assert report.passes_mean is False
|
||||
assert report.passes is False
|
||||
|
||||
|
||||
def test_evaluate_rejects_empty_gt_track() -> None:
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="gt_track must not be empty"):
|
||||
evaluate([], [])
|
||||
|
||||
|
||||
def test_write_csv_evidence_round_trip(tmp_path: Path) -> None:
|
||||
"""CSV header + one row per pair."""
|
||||
# Arrange
|
||||
track = _gt_track()
|
||||
raw, sm = _raw_smoothed_pair(0, track[0], raw_offset_m=15.0, smoothed_offset_m=2.0)
|
||||
report = evaluate([raw, sm], track)
|
||||
out = tmp_path / "ft-p-10.csv"
|
||||
|
||||
# Act
|
||||
write_csv_evidence(out, report)
|
||||
|
||||
# Assert
|
||||
rows = list(csv.reader(out.open()))
|
||||
assert rows[0] == [
|
||||
"keyframe_id",
|
||||
"raw_lat",
|
||||
"raw_lon",
|
||||
"smoothed_lat",
|
||||
"smoothed_lon",
|
||||
"gt_lat",
|
||||
"gt_lon",
|
||||
"raw_error_m",
|
||||
"smoothed_error_m",
|
||||
"improvement_m",
|
||||
"smoothed_wins",
|
||||
]
|
||||
assert rows[1][-1] == "true"
|
||||
@@ -46,6 +46,9 @@ E2E_ROOT = Path(__file__).resolve().parents[1]
|
||||
"runner/helpers/accuracy_evaluator.py",
|
||||
"runner/helpers/registration_classifier.py",
|
||||
"runner/helpers/mre_evaluator.py",
|
||||
"runner/helpers/multi_segment_evaluator.py",
|
||||
"runner/helpers/smoothing_evaluator.py",
|
||||
"runner/helpers/sharp_turn_detector.py",
|
||||
"fixtures/mock-suite-sat/Dockerfile",
|
||||
"fixtures/mock-suite-sat/app.py",
|
||||
"fixtures/mock-suite-sat/requirements.txt",
|
||||
@@ -84,6 +87,10 @@ E2E_ROOT = Path(__file__).resolve().parents[1]
|
||||
"tests/positive/test_ft_p_04_derkachi_f2f_registration.py",
|
||||
"tests/positive/test_ft_p_05_sat_anchor.py",
|
||||
"tests/positive/test_ft_p_06_mre_budgets.py",
|
||||
"tests/positive/test_ft_p_07_sharp_turn_recovery.py",
|
||||
"tests/positive/test_ft_p_08_multi_segment_reloc.py",
|
||||
"tests/positive/test_ft_p_10_smoothing_lookback.py",
|
||||
"tests/negative/test_ft_n_02_sharp_turn_failure.py",
|
||||
],
|
||||
)
|
||||
def test_required_path_exists(relative_path: str) -> None:
|
||||
|
||||
Reference in New Issue
Block a user