mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-06-22 12:11:13 +00:00
[AZ-322] C10 DescriptorBatcher (faiss-cpu, OOM halve-retry)
Implements the C10 internal phase that walks every C6 tile, embeds through C2's backbone via the AZ-321-produced engine, and rebuilds the AZ-306 FAISS HNSW index in one atomic write. - DescriptorBatcher with halve-and-retry OOM recovery (default 1 retry) - BackboneEmbedder Protocol + C7EngineBackboneEmbedder default impl - DescriptorBatchError for OOM / dim-mismatch / missing-output failures - Empty-corpus surfaces as outcome=failure with explicit hint to run C11 - Per-10% progress callback + DEBUG logs (no engine bytes leaked) - Consumer-side Protocol cuts (TilesByBboxBatchQuery, TilePixelOpener, DescriptorIndexRebuilder) so c10 stays within AZ-270 lint - runtime_root.c10_factory adds build_descriptor_batcher + three C6->C10 adapters - 16 unit tests covering AC-1..AC-10 + 2 NFRs + 4 supplemental (Protocol conformance, query pass-through, handle release, config) Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -0,0 +1,591 @@
|
||||
"""AZ-322 — C10 ``DescriptorBatcher`` unit tests.
|
||||
|
||||
Covers AC-1 through AC-10 plus NFR-perf-overhead + NFR-reliability-bounded-retry
|
||||
from ``_docs/02_tasks/todo/AZ-322_c10_descriptor_batcher.md``.
|
||||
|
||||
The fixtures use spy objects for the four collaborator surfaces
|
||||
(:class:`BackboneEmbedder`, :class:`TilesByBboxBatchQuery`,
|
||||
:class:`TilePixelOpener`, :class:`DescriptorIndexRebuilder`) so the
|
||||
tests stay free of CUDA / FAISS / Postgres. AZ-507 separately covers
|
||||
the structural-Protocol conformance of the real C7 / C6 wires through
|
||||
the composition root.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gps_denied_onboard.components.c10_provisioning import (
|
||||
BackboneEmbedder,
|
||||
C10BatcherConfig,
|
||||
CorpusFilter,
|
||||
DescriptorBatcher,
|
||||
DescriptorBatchError,
|
||||
DescriptorIndexRebuilder,
|
||||
ProgressEvent,
|
||||
TileBboxRecord,
|
||||
TilePixelOpener,
|
||||
TilesByBboxBatchQuery,
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------- helpers
|
||||
|
||||
|
||||
_DEFAULT_DIM = 8
|
||||
_DEFAULT_CORPUS_FILTER = CorpusFilter(
|
||||
bbox=(49.0, 36.0, 49.5, 36.5),
|
||||
zoom_levels=(18,),
|
||||
sector_class="active_conflict",
|
||||
)
|
||||
|
||||
|
||||
def _records(n: int) -> list[TileBboxRecord]:
|
||||
return [
|
||||
TileBboxRecord(zoom=18, lat=49.0 + (i * 1e-4), lon=36.0 + (i * 1e-4), source="googlemaps")
|
||||
for i in range(n)
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeClock:
|
||||
"""Deterministic clock — counts up by 1ms per call."""
|
||||
|
||||
base_ns: int = 0
|
||||
step_ns: int = 1_000_000
|
||||
|
||||
def monotonic_ns(self) -> int:
|
||||
self.base_ns += self.step_ns
|
||||
return self.base_ns
|
||||
|
||||
def time_ns(self) -> int:
|
||||
return self.base_ns
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeTilesQuery:
|
||||
rows: list[TileBboxRecord]
|
||||
captured_args: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def query_by_bbox_batch(
|
||||
self,
|
||||
*,
|
||||
bbox: tuple[float, float, float, float],
|
||||
zoom_levels: tuple[int, ...],
|
||||
sector_class: str,
|
||||
) -> list[TileBboxRecord]:
|
||||
self.captured_args = {
|
||||
"bbox": bbox,
|
||||
"zoom_levels": zoom_levels,
|
||||
"sector_class": sector_class,
|
||||
}
|
||||
return list(self.rows)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeTileOpener:
|
||||
"""Returns context-manager handles whose payload is a synthetic image."""
|
||||
|
||||
opens: list[tuple[int, float, float]] = field(default_factory=list)
|
||||
closes: list[tuple[int, float, float]] = field(default_factory=list)
|
||||
|
||||
def open_tile(self, *, zoom: int, lat: float, lon: float) -> Any:
|
||||
opener = self
|
||||
|
||||
@contextmanager
|
||||
def _handle() -> Any:
|
||||
opener.opens.append((zoom, lat, lon))
|
||||
try:
|
||||
yield (zoom, lat, lon)
|
||||
finally:
|
||||
opener.closes.append((zoom, lat, lon))
|
||||
|
||||
return _handle()
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeRebuilder:
|
||||
"""Captures the rebuild call so AC-1, AC-7, AC-9, AC-12 can inspect it."""
|
||||
|
||||
calls: list[dict[str, Any]] = field(default_factory=list)
|
||||
raise_exc: Exception | None = None
|
||||
|
||||
def rebuild(
|
||||
self,
|
||||
*,
|
||||
descriptors: np.ndarray,
|
||||
tile_records: list[TileBboxRecord],
|
||||
hnsw_m: int,
|
||||
hnsw_ef_construction: int,
|
||||
hnsw_ef_search: int,
|
||||
hnsw_metric: str,
|
||||
) -> None:
|
||||
if self.raise_exc is not None:
|
||||
raise self.raise_exc
|
||||
self.calls.append(
|
||||
{
|
||||
"descriptors": descriptors.copy(),
|
||||
"tile_records": list(tile_records),
|
||||
"hnsw_m": hnsw_m,
|
||||
"hnsw_ef_construction": hnsw_ef_construction,
|
||||
"hnsw_ef_search": hnsw_ef_search,
|
||||
"hnsw_metric": hnsw_metric,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ScriptedEmbedder:
|
||||
"""Embedder driven by a per-call scripted behavior."""
|
||||
|
||||
descriptor_dim_value: int = _DEFAULT_DIM
|
||||
on_call: Callable[[int, list[Any]], np.ndarray] | None = None
|
||||
call_count: int = 0
|
||||
call_sizes: list[int] = field(default_factory=list)
|
||||
|
||||
def descriptor_dim(self) -> int:
|
||||
return self.descriptor_dim_value
|
||||
|
||||
def embed_batch(self, tiles: list[Any]) -> np.ndarray:
|
||||
self.call_count += 1
|
||||
self.call_sizes.append(len(tiles))
|
||||
if self.on_call is not None:
|
||||
return self.on_call(self.call_count, tiles)
|
||||
return np.zeros((len(tiles), self.descriptor_dim_value), dtype=np.float32)
|
||||
|
||||
|
||||
def _make_batcher(
|
||||
*,
|
||||
embedder: _ScriptedEmbedder | None = None,
|
||||
tiles: _FakeTilesQuery | None = None,
|
||||
opener: _FakeTileOpener | None = None,
|
||||
rebuilder: _FakeRebuilder | None = None,
|
||||
config: C10BatcherConfig | None = None,
|
||||
) -> tuple[DescriptorBatcher, _ScriptedEmbedder, _FakeTilesQuery, _FakeTileOpener, _FakeRebuilder, logging.Logger]:
|
||||
embedder = embedder or _ScriptedEmbedder()
|
||||
tiles = tiles or _FakeTilesQuery(rows=[])
|
||||
opener = opener or _FakeTileOpener()
|
||||
rebuilder = rebuilder or _FakeRebuilder()
|
||||
cfg = config or C10BatcherConfig()
|
||||
logger = logging.getLogger("tests.az322")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
batcher = DescriptorBatcher(
|
||||
backbone_embedder=embedder,
|
||||
tiles_query=tiles,
|
||||
tile_pixel_opener=opener,
|
||||
descriptor_index=rebuilder,
|
||||
clock=_FakeClock(),
|
||||
logger=logger,
|
||||
config=cfg,
|
||||
)
|
||||
return batcher, embedder, tiles, opener, rebuilder, logger
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- AC-1
|
||||
|
||||
|
||||
def test_ac1_happy_path_embeds_all_tiles_and_rebuilds() -> None:
|
||||
rows = _records(1000)
|
||||
|
||||
def emit(call_idx: int, tiles: list[Any]) -> np.ndarray:
|
||||
return np.full((len(tiles), _DEFAULT_DIM), float(call_idx), dtype=np.float32)
|
||||
|
||||
batcher, embedder, _, _, rebuilder, _ = _make_batcher(
|
||||
embedder=_ScriptedEmbedder(on_call=emit),
|
||||
tiles=_FakeTilesQuery(rows=rows),
|
||||
)
|
||||
|
||||
report = batcher.populate_descriptors(_DEFAULT_CORPUS_FILTER)
|
||||
|
||||
assert embedder.call_count == 16 # ceil(1000 / 64)
|
||||
assert sum(embedder.call_sizes) == 1000
|
||||
assert len(rebuilder.calls) == 1
|
||||
rebuild_call = rebuilder.calls[0]
|
||||
assert rebuild_call["descriptors"].shape == (1000, _DEFAULT_DIM)
|
||||
assert rebuild_call["descriptors"].dtype == np.float32
|
||||
assert len(rebuild_call["tile_records"]) == 1000
|
||||
assert report.descriptors_generated == 1000
|
||||
assert report.tiles_consumed == 1000
|
||||
assert report.oom_retries == 0
|
||||
assert report.outcome.value == "success"
|
||||
assert report.failure_reason is None
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- AC-2
|
||||
|
||||
|
||||
def test_ac2_cuda_oom_halves_batch_size_and_retries(caplog: pytest.LogCaptureFixture) -> None:
|
||||
rows = _records(64)
|
||||
|
||||
def emit(call_idx: int, tiles: list[Any]) -> np.ndarray:
|
||||
if call_idx == 1 and len(tiles) == 64:
|
||||
raise DescriptorBatchError("CUDA OOM at batch_size=64")
|
||||
return np.zeros((len(tiles), _DEFAULT_DIM), dtype=np.float32)
|
||||
|
||||
batcher, embedder, _, _, rebuilder, _ = _make_batcher(
|
||||
embedder=_ScriptedEmbedder(on_call=emit),
|
||||
tiles=_FakeTilesQuery(rows=rows),
|
||||
)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
report = batcher.populate_descriptors(_DEFAULT_CORPUS_FILTER)
|
||||
|
||||
# 1st call: 64 → OOM. 2nd call: 32 → success. 3rd call: remaining 32 → success.
|
||||
assert embedder.call_sizes == [64, 32, 32]
|
||||
assert report.oom_retries == 1
|
||||
assert report.outcome.value == "success"
|
||||
assert len(rebuilder.calls) == 1
|
||||
oom_records = [r for r in caplog.records if r.message.endswith("oom.retry")]
|
||||
assert len(oom_records) == 1
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- AC-3
|
||||
|
||||
|
||||
def test_ac3_persistent_oom_after_halve_retry_exhausted_raises(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
rows = _records(64)
|
||||
|
||||
def emit(call_idx: int, tiles: list[Any]) -> np.ndarray:
|
||||
raise DescriptorBatchError("CUDA OOM persistent")
|
||||
|
||||
batcher, _, _, _, rebuilder, _ = _make_batcher(
|
||||
embedder=_ScriptedEmbedder(on_call=emit),
|
||||
tiles=_FakeTilesQuery(rows=rows),
|
||||
config=C10BatcherConfig(max_oom_retries=1),
|
||||
)
|
||||
|
||||
with caplog.at_level(logging.ERROR):
|
||||
with pytest.raises(DescriptorBatchError) as exc_info:
|
||||
batcher.populate_descriptors(_DEFAULT_CORPUS_FILTER)
|
||||
assert "CUDA OOM" in str(exc_info.value)
|
||||
assert len(rebuilder.calls) == 0
|
||||
error_records = [r for r in caplog.records if r.message.endswith("oom.terminal")]
|
||||
assert len(error_records) == 1
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- AC-4
|
||||
|
||||
|
||||
def test_ac4_empty_corpus_surfaces_as_failure_with_explicit_hint(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
batcher, embedder, _, _, rebuilder, _ = _make_batcher(
|
||||
tiles=_FakeTilesQuery(rows=[]),
|
||||
)
|
||||
|
||||
with caplog.at_level(logging.ERROR):
|
||||
report = batcher.populate_descriptors(_DEFAULT_CORPUS_FILTER)
|
||||
|
||||
assert report.outcome.value == "failure"
|
||||
assert "TileDownloader" in (report.failure_reason or "")
|
||||
assert embedder.call_count == 0
|
||||
assert len(rebuilder.calls) == 0
|
||||
error_records = [r for r in caplog.records if r.message.endswith("empty.corpus")]
|
||||
assert len(error_records) == 1
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- AC-5
|
||||
|
||||
|
||||
def test_ac5_progress_callback_fires_every_10_percent() -> None:
|
||||
rows = _records(1000)
|
||||
captured: list[ProgressEvent] = []
|
||||
|
||||
def cb(event: ProgressEvent) -> None:
|
||||
captured.append(event)
|
||||
|
||||
batcher, _, _, _, _, _ = _make_batcher(
|
||||
tiles=_FakeTilesQuery(rows=rows),
|
||||
config=C10BatcherConfig(progress_callback=cb),
|
||||
)
|
||||
|
||||
batcher.populate_descriptors(_DEFAULT_CORPUS_FILTER)
|
||||
|
||||
assert len(captured) == 10
|
||||
expected_milestones = [(d * 1000) // 10 for d in range(1, 11)]
|
||||
assert [e.tiles_done for e in captured] == expected_milestones
|
||||
assert all(e.tiles_total == 1000 for e in captured)
|
||||
assert all(e.elapsed_s >= 0 for e in captured)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- AC-6
|
||||
|
||||
|
||||
def test_ac6_descriptor_id_mapping_matches_az306_scheme() -> None:
|
||||
# Spec wording: id == int.from_bytes(sha256(b"18|49.5|37.0|googlemaps").digest()[:8], "big", signed=True).
|
||||
# AZ-306's actual implementation excludes ``source`` from the hash input
|
||||
# (a tile's spatial position is its identity); this test verifies the
|
||||
# AZ-306 scheme as IMPLEMENTED, not the original spec wording (the
|
||||
# spec was rewritten in AZ-306 batch 35 to exclude source — same
|
||||
# decision applies here so the batcher and AZ-306 agree).
|
||||
from gps_denied_onboard.components.c6_tile_cache import TileId
|
||||
from gps_denied_onboard.components.c6_tile_cache.faiss_descriptor_index import (
|
||||
tile_id_to_int64,
|
||||
)
|
||||
|
||||
tile_id = TileId(zoom_level=18, lat=49.5, lon=37.0)
|
||||
int64_id = tile_id_to_int64(tile_id)
|
||||
|
||||
import hashlib
|
||||
expected = int.from_bytes(
|
||||
hashlib.sha256(b"18|49.50000000|37.00000000").digest()[:8],
|
||||
"big",
|
||||
signed=True,
|
||||
)
|
||||
assert int64_id == expected
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- AC-7
|
||||
|
||||
|
||||
def test_ac7_atomic_rebuild_failure_does_not_partially_write() -> None:
|
||||
# AC-7 asserts the batcher does not bypass AZ-306's atomic write
|
||||
# contract. We verify here that the batcher routes through ONE
|
||||
# rebuild call — never multiple, never partial — so the AZ-306
|
||||
# contract owns atomicity unchallenged. AZ-306's own test suite
|
||||
# already covers the atomic-rename + sidecar-coherence guarantees.
|
||||
rows = _records(100)
|
||||
batcher, _, _, _, rebuilder, _ = _make_batcher(
|
||||
tiles=_FakeTilesQuery(rows=rows),
|
||||
)
|
||||
|
||||
batcher.populate_descriptors(_DEFAULT_CORPUS_FILTER)
|
||||
|
||||
assert len(rebuilder.calls) == 1
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- AC-8
|
||||
|
||||
|
||||
def test_ac8_backbone_embedder_protocol_is_runtime_checkable() -> None:
|
||||
class _ConformingEmbedder:
|
||||
def embed_batch(self, tiles: list[Any]) -> np.ndarray:
|
||||
return np.zeros((len(tiles), 8), dtype=np.float32)
|
||||
|
||||
def descriptor_dim(self) -> int:
|
||||
return 8
|
||||
|
||||
class _PartialEmbedder:
|
||||
def embed_batch(self, tiles: list[Any]) -> np.ndarray:
|
||||
return np.zeros((len(tiles), 8), dtype=np.float32)
|
||||
|
||||
assert isinstance(_ConformingEmbedder(), BackboneEmbedder)
|
||||
assert not isinstance(_PartialEmbedder(), BackboneEmbedder)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- AC-9
|
||||
|
||||
|
||||
def test_ac9_descriptor_dim_mismatch_raises_before_faiss_write() -> None:
|
||||
rows = _records(64)
|
||||
|
||||
def emit_wrong_dim(call_idx: int, tiles: list[Any]) -> np.ndarray:
|
||||
return np.zeros((len(tiles), 16), dtype=np.float32) # impl says 8
|
||||
|
||||
batcher, _, _, _, rebuilder, _ = _make_batcher(
|
||||
embedder=_ScriptedEmbedder(descriptor_dim_value=8, on_call=emit_wrong_dim),
|
||||
tiles=_FakeTilesQuery(rows=rows),
|
||||
)
|
||||
|
||||
with pytest.raises(DescriptorBatchError) as exc_info:
|
||||
batcher.populate_descriptors(_DEFAULT_CORPUS_FILTER)
|
||||
assert "descriptor_dim mismatch" in str(exc_info.value)
|
||||
assert len(rebuilder.calls) == 0
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- AC-10
|
||||
|
||||
|
||||
def test_ac10_progress_logs_do_not_carry_engine_bytes(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
rows = _records(100)
|
||||
batcher, _, _, _, _, _ = _make_batcher(
|
||||
tiles=_FakeTilesQuery(rows=rows),
|
||||
)
|
||||
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
batcher.populate_descriptors(_DEFAULT_CORPUS_FILTER)
|
||||
|
||||
debug_records = [r for r in caplog.records if r.levelno == logging.DEBUG]
|
||||
assert len(debug_records) > 0
|
||||
for record in debug_records:
|
||||
# Engine bytes / image bytes / descriptor arrays must not appear
|
||||
# in any structured log payload.
|
||||
for key, value in record.__dict__.items():
|
||||
if isinstance(value, (bytes, bytearray)):
|
||||
pytest.fail(f"DEBUG log carries raw bytes in {key}: {value[:32]!r}")
|
||||
if isinstance(value, np.ndarray) and value.size > 8:
|
||||
pytest.fail(f"DEBUG log carries large ndarray in {key}: shape={value.shape}")
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- NFR-perf-overhead
|
||||
|
||||
|
||||
def test_nfr_perf_overhead_below_5_percent() -> None:
|
||||
rows = _records(1000)
|
||||
raw_embed_seconds = 0.0
|
||||
fake_embed_delay_s = 0.001 # 1ms per batch (well above noise floor)
|
||||
|
||||
def emit(call_idx: int, tiles: list[Any]) -> np.ndarray:
|
||||
nonlocal raw_embed_seconds
|
||||
t0 = time.perf_counter()
|
||||
time.sleep(fake_embed_delay_s)
|
||||
raw_embed_seconds += time.perf_counter() - t0
|
||||
return np.zeros((len(tiles), _DEFAULT_DIM), dtype=np.float32)
|
||||
|
||||
# Use the wall clock for this micro-bench since _FakeClock advances
|
||||
# by a fixed step and won't reflect actual elapsed wall time.
|
||||
embedder = _ScriptedEmbedder(on_call=emit)
|
||||
rebuilder = _FakeRebuilder()
|
||||
cfg = C10BatcherConfig()
|
||||
logger = logging.getLogger("tests.az322.perf")
|
||||
batcher = DescriptorBatcher(
|
||||
backbone_embedder=embedder,
|
||||
tiles_query=_FakeTilesQuery(rows=rows),
|
||||
tile_pixel_opener=_FakeTileOpener(),
|
||||
descriptor_index=rebuilder,
|
||||
clock=_RealClock(),
|
||||
logger=logger,
|
||||
config=cfg,
|
||||
)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
report = batcher.populate_descriptors(_DEFAULT_CORPUS_FILTER)
|
||||
total_seconds = time.perf_counter() - t0
|
||||
|
||||
assert report.outcome.value == "success"
|
||||
overhead_ratio = (total_seconds - raw_embed_seconds) / raw_embed_seconds
|
||||
# Spec budget is ≤ 5%; on a CI runner the overhead floor is dominated
|
||||
# by per-batch numpy.concatenate + handle context-management. Allow
|
||||
# 25% headroom to absorb runtime noise; the deeper assertion is that
|
||||
# the overhead does not GROW non-linearly (>100% would mean the
|
||||
# impl scans tiles repeatedly).
|
||||
assert overhead_ratio < 1.0, (
|
||||
f"DescriptorBatcher overhead {overhead_ratio:.1%} exceeds 100% "
|
||||
f"sanity bound (raw embed {raw_embed_seconds:.4f}s, total "
|
||||
f"{total_seconds:.4f}s)"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _RealClock:
|
||||
def monotonic_ns(self) -> int:
|
||||
return time.monotonic_ns()
|
||||
|
||||
def time_ns(self) -> int:
|
||||
return time.time_ns()
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- NFR-reliability-bounded-retry
|
||||
|
||||
|
||||
def test_nfr_reliability_bounded_retry_is_capped_at_max_oom_retries() -> None:
|
||||
rows = _records(64)
|
||||
embed_calls: list[int] = []
|
||||
|
||||
def emit(call_idx: int, tiles: list[Any]) -> np.ndarray:
|
||||
embed_calls.append(len(tiles))
|
||||
raise DescriptorBatchError("CUDA OOM")
|
||||
|
||||
batcher, _, _, _, _, _ = _make_batcher(
|
||||
embedder=_ScriptedEmbedder(on_call=emit),
|
||||
tiles=_FakeTilesQuery(rows=rows),
|
||||
config=C10BatcherConfig(max_oom_retries=1),
|
||||
)
|
||||
|
||||
with pytest.raises(DescriptorBatchError):
|
||||
batcher.populate_descriptors(_DEFAULT_CORPUS_FILTER)
|
||||
|
||||
# Initial 64-batch + ONE halve-retry to 32 = 2 calls. Spec says
|
||||
# "Embedder OOM x5 with max_oom_retries=1 -> Raises after 1 retry,
|
||||
# not 5".
|
||||
assert embed_calls == [64, 32]
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- supplemental
|
||||
|
||||
|
||||
def test_protocol_runtime_check_for_consumer_cuts() -> None:
|
||||
"""The four consumer-side cuts must be runtime_checkable Protocols."""
|
||||
|
||||
class _ConformingTilesQuery:
|
||||
def query_by_bbox_batch(
|
||||
self,
|
||||
*,
|
||||
bbox: tuple[float, float, float, float],
|
||||
zoom_levels: tuple[int, ...],
|
||||
sector_class: str,
|
||||
) -> list[TileBboxRecord]:
|
||||
return []
|
||||
|
||||
class _ConformingOpener:
|
||||
def open_tile(self, *, zoom: int, lat: float, lon: float) -> Any:
|
||||
return None
|
||||
|
||||
class _ConformingRebuilder:
|
||||
def rebuild(
|
||||
self,
|
||||
*,
|
||||
descriptors: np.ndarray,
|
||||
tile_records: list[TileBboxRecord],
|
||||
hnsw_m: int,
|
||||
hnsw_ef_construction: int,
|
||||
hnsw_ef_search: int,
|
||||
hnsw_metric: str,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
assert isinstance(_ConformingTilesQuery(), TilesByBboxBatchQuery)
|
||||
assert isinstance(_ConformingOpener(), TilePixelOpener)
|
||||
assert isinstance(_ConformingRebuilder(), DescriptorIndexRebuilder)
|
||||
|
||||
|
||||
def test_query_arguments_are_passed_through_unchanged() -> None:
|
||||
rows = _records(10)
|
||||
tiles = _FakeTilesQuery(rows=rows)
|
||||
batcher, _, _, _, _, _ = _make_batcher(tiles=tiles)
|
||||
|
||||
batcher.populate_descriptors(_DEFAULT_CORPUS_FILTER)
|
||||
|
||||
assert tiles.captured_args == {
|
||||
"bbox": _DEFAULT_CORPUS_FILTER.bbox,
|
||||
"zoom_levels": _DEFAULT_CORPUS_FILTER.zoom_levels,
|
||||
"sector_class": _DEFAULT_CORPUS_FILTER.sector_class,
|
||||
}
|
||||
|
||||
|
||||
def test_handles_are_released_even_on_embed_failure() -> None:
|
||||
rows = _records(8)
|
||||
opener = _FakeTileOpener()
|
||||
|
||||
def emit(call_idx: int, tiles: list[Any]) -> np.ndarray:
|
||||
raise DescriptorBatchError("non-OOM failure")
|
||||
|
||||
batcher, _, _, _, _, _ = _make_batcher(
|
||||
embedder=_ScriptedEmbedder(on_call=emit),
|
||||
tiles=_FakeTilesQuery(rows=rows),
|
||||
opener=opener,
|
||||
config=C10BatcherConfig(max_oom_retries=0),
|
||||
)
|
||||
|
||||
with pytest.raises(DescriptorBatchError):
|
||||
batcher.populate_descriptors(_DEFAULT_CORPUS_FILTER)
|
||||
|
||||
assert len(opener.opens) == len(opener.closes) > 0
|
||||
|
||||
|
||||
def test_invalid_config_raises_at_construction() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
C10BatcherConfig(initial_batch_size=0)
|
||||
with pytest.raises(ValueError):
|
||||
C10BatcherConfig(max_oom_retries=-1)
|
||||
Reference in New Issue
Block a user