mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-06-22 18:01:13 +00:00
[AZ-527] Consolidate _assert_engine_output_dim into c2-internal helper
Closes cumulative review batches 49-51 Finding F1 (Medium / Maintainability) -- the 7-way duplication of _assert_engine_output_dim across c2_vpr secondary VPR strategy modules. Add c2-internal helper assert_engine_output_dim(inference_runtime, handle, preprocessor, descriptor_dim, *, output_key='embedding', input_key='input') in src/gps_denied_onboard/components/c2_vpr/ _engine_dim_assertion.py. The helper runs a zero-init dry-run inference at preprocessor.input_shape() and asserts the engine output dict carries (1, descriptor_dim) under output_key. Raises gps_denied_onboard.config.schema.ConfigError on mismatch (preserving the prior error envelope and message wording byte-identically). Migrate 7 strategy modules (ultra_vpr, net_vlad, mega_loc, mix_vpr, sela_vpr, eigen_places, salad) to import the helper and delete the local _assert_engine_output_dim definitions + their inline 'AZ-527 (planned)' comments. NetVLAD is the only call site that overrides output_key='vlad_descriptor'; the other 6 explicitly pass output_key=_OUTPUT_KEY + input_key=_ENGINE_INPUT_KEY (matching helper defaults but documenting strategy contract at the call site). Add tests/unit/c2_vpr/test_az527_engine_dim_assertion.py (14 tests, AAA pattern, Protocol-conforming fakes) covering AC-1..AC-4: helper signature; wrong shape raises ConfigError naming both dims; missing output key raises ConfigError naming the missing key; AST-walk regression guard for stray definitions outside the helper module (modeled on AZ-526's test_ac4_az526_no_module_level_iso_ts_from_clock_outside_helper); import-grep regression guard verifying all 7 strategy modules import the helper. AC-5 (existing AZ-337/338/339/340 AC-6 sub-tests pass unmodified) is exercised transitively: c2_vpr/ full directory 230/230 PASS, no test file modified outside the new test_az527_*. AC-6 (AZ-270 + AZ-507 layer lints) verified by tests/unit/test_az270_compose_root.py 8/8 PASS. Code-review verdict: PASS (zero findings). Ruff clean. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -0,0 +1,102 @@
|
||||
"""C2-internal engine output-dim assertion helper (AZ-527).
|
||||
|
||||
Single home for the dry-run probe that every c2_vpr secondary VPR
|
||||
strategy module runs at :func:`create` time to verify that the loaded
|
||||
inference engine emits a descriptor of the shape the strategy
|
||||
contract promises.
|
||||
|
||||
This helper consolidates the formerly seven-way duplicated
|
||||
``_assert_engine_output_dim`` helper across ``ultra_vpr.py``,
|
||||
``net_vlad.py``, ``mega_loc.py``, ``mix_vpr.py``, ``sela_vpr.py``,
|
||||
``eigen_places.py`` and ``salad.py``. Behaviour, error envelope, and
|
||||
error-message wording are byte-identical to the prior copies so that
|
||||
the AZ-337 / 338 / 339 / 340 AC-6 sub-tests continue to pass
|
||||
unmodified.
|
||||
|
||||
The helper is C2-internal: the underscore-prefixed module name and
|
||||
the unimported-from-``c2_vpr.__init__`` placement keep it inside the
|
||||
component boundary per
|
||||
``components/02_c2_vpr/description.md`` § 6 (engine output-shape
|
||||
contracts are a C2 internal concern; C7 owns its own engine-shape
|
||||
assertions inside the runtime).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gps_denied_onboard.config.schema import ConfigError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gps_denied_onboard._types.inference import EngineHandle
|
||||
from gps_denied_onboard.components.c2_vpr._preprocessor import (
|
||||
BackbonePreprocessor,
|
||||
)
|
||||
from gps_denied_onboard.components.c2_vpr.inference_runtime_cut import (
|
||||
InferenceRuntimeCut,
|
||||
)
|
||||
|
||||
__all__ = ["assert_engine_output_dim"]
|
||||
|
||||
|
||||
def assert_engine_output_dim(
|
||||
inference_runtime: InferenceRuntimeCut,
|
||||
handle: EngineHandle,
|
||||
preprocessor: BackbonePreprocessor,
|
||||
descriptor_dim: int,
|
||||
*,
|
||||
output_key: str = "embedding",
|
||||
input_key: str = "input",
|
||||
) -> None:
|
||||
"""Verify the engine emits ``(1, descriptor_dim)`` under ``output_key``.
|
||||
|
||||
Runs a single zero-init dry-run inference at the preprocessor's
|
||||
declared ``input_shape()`` (FP16 NCHW, batch=1, RGB) and asserts
|
||||
two contracts on the returned output dict:
|
||||
|
||||
1. ``output_key`` is present in the dict (default ``"embedding"``,
|
||||
NetVLAD overrides to ``"vlad_descriptor"``).
|
||||
2. The output ndarray has shape ``(1, descriptor_dim)``.
|
||||
|
||||
Both contracts are verified at ``create()`` time (startup), not
|
||||
per frame.
|
||||
|
||||
:param inference_runtime: the C2-internal :class:`InferenceRuntimeCut`
|
||||
that wraps the loaded engine.
|
||||
:param handle: the deserialised :class:`EngineHandle` to probe.
|
||||
:param preprocessor: the strategy's :class:`BackbonePreprocessor`;
|
||||
only ``input_shape()`` is consulted to size the probe tensor.
|
||||
:param descriptor_dim: the strategy's expected descriptor width.
|
||||
For most backbones this is a module-level ``DESCRIPTOR_DIM``
|
||||
constant; NetVLAD passes its runtime-resolved descriptor_dim.
|
||||
:param output_key: the key under which the engine surfaces the
|
||||
descriptor. Defaults to ``"embedding"`` (used by 6 of the 7
|
||||
c2_vpr secondary backbones); NetVLAD passes ``"vlad_descriptor"``.
|
||||
:param input_key: the key under which the engine accepts the input
|
||||
tensor. Defaults to ``"input"`` (used by all 7 backbones).
|
||||
|
||||
:raises ConfigError: if ``output_key`` is absent from the engine
|
||||
output, or the output ndarray shape does not match
|
||||
``(1, descriptor_dim)``. The error message names both the
|
||||
expected and actual values.
|
||||
"""
|
||||
h, w = preprocessor.input_shape()
|
||||
probe = np.zeros((1, 3, h, w), dtype=np.float16)
|
||||
outputs = inference_runtime.infer(handle, {input_key: probe})
|
||||
if output_key not in outputs:
|
||||
raise ConfigError(
|
||||
f"engine output shape mismatch: {output_key!r} key absent; "
|
||||
f"got keys {sorted(outputs.keys())!r}"
|
||||
)
|
||||
actual = np.asarray(outputs[output_key])
|
||||
if (
|
||||
actual.ndim != 2
|
||||
or actual.shape[0] != 1
|
||||
or actual.shape[1] != descriptor_dim
|
||||
):
|
||||
raise ConfigError(
|
||||
f"engine output shape mismatch: expected (1, {descriptor_dim}), "
|
||||
f"got {tuple(actual.shape)}"
|
||||
)
|
||||
@@ -60,6 +60,9 @@ from gps_denied_onboard._types.inference import (
|
||||
)
|
||||
from gps_denied_onboard._types.vpr import VprQuery, VprResult
|
||||
from gps_denied_onboard.clock import Clock
|
||||
from gps_denied_onboard.components.c2_vpr._engine_dim_assertion import (
|
||||
assert_engine_output_dim,
|
||||
)
|
||||
from gps_denied_onboard.components.c2_vpr._faiss_bridge import FaissBridge
|
||||
from gps_denied_onboard.components.c2_vpr._preprocessor_eigen_places import (
|
||||
EigenPlacesBackbonePreprocessor,
|
||||
@@ -394,7 +397,14 @@ def create(
|
||||
clock=clock,
|
||||
)
|
||||
|
||||
_assert_engine_output_dim(inference_runtime, handle, preprocessor)
|
||||
assert_engine_output_dim(
|
||||
inference_runtime,
|
||||
handle,
|
||||
preprocessor,
|
||||
DESCRIPTOR_DIM,
|
||||
output_key=_OUTPUT_KEY,
|
||||
input_key=_ENGINE_INPUT_KEY,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"C2 VPR strategy ready",
|
||||
@@ -422,31 +432,3 @@ def create(
|
||||
)
|
||||
|
||||
|
||||
def _assert_engine_output_dim(
|
||||
inference_runtime: InferenceRuntimeCut,
|
||||
handle: EngineHandle,
|
||||
preprocessor: EigenPlacesBackbonePreprocessor,
|
||||
) -> None:
|
||||
# The 7-way duplication of this helper (ultra_vpr / net_vlad /
|
||||
# mega_loc / mix_vpr / sela_vpr / eigen_places / salad) is tracked
|
||||
# by AZ-527 (hygiene PBI sized in parallel with AZ-339 land). The
|
||||
# duplication is intentional for now: extracting earlier would
|
||||
# expand AZ-340's scope past the three new strategies.
|
||||
h, w = preprocessor.input_shape()
|
||||
probe = np.zeros((1, 3, h, w), dtype=np.float16)
|
||||
outputs = inference_runtime.infer(handle, {_ENGINE_INPUT_KEY: probe})
|
||||
if _OUTPUT_KEY not in outputs:
|
||||
raise ConfigError(
|
||||
f"engine output shape mismatch: {_OUTPUT_KEY!r} key absent; "
|
||||
f"got keys {sorted(outputs.keys())!r}"
|
||||
)
|
||||
actual = np.asarray(outputs[_OUTPUT_KEY])
|
||||
if (
|
||||
actual.ndim != 2
|
||||
or actual.shape[0] != 1
|
||||
or actual.shape[1] != DESCRIPTOR_DIM
|
||||
):
|
||||
raise ConfigError(
|
||||
f"engine output shape mismatch: expected (1, {DESCRIPTOR_DIM}), "
|
||||
f"got {tuple(actual.shape)}"
|
||||
)
|
||||
|
||||
@@ -59,6 +59,9 @@ from gps_denied_onboard._types.inference import (
|
||||
)
|
||||
from gps_denied_onboard._types.vpr import VprQuery, VprResult
|
||||
from gps_denied_onboard.clock import Clock
|
||||
from gps_denied_onboard.components.c2_vpr._engine_dim_assertion import (
|
||||
assert_engine_output_dim,
|
||||
)
|
||||
from gps_denied_onboard.components.c2_vpr._faiss_bridge import FaissBridge
|
||||
from gps_denied_onboard.components.c2_vpr._preprocessor_mega_loc import (
|
||||
MegaLocBackbonePreprocessor,
|
||||
@@ -393,7 +396,14 @@ def create(
|
||||
clock=clock,
|
||||
)
|
||||
|
||||
_assert_engine_output_dim(inference_runtime, handle, preprocessor)
|
||||
assert_engine_output_dim(
|
||||
inference_runtime,
|
||||
handle,
|
||||
preprocessor,
|
||||
DESCRIPTOR_DIM,
|
||||
output_key=_OUTPUT_KEY,
|
||||
input_key=_ENGINE_INPUT_KEY,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"C2 VPR strategy ready",
|
||||
@@ -421,31 +431,3 @@ def create(
|
||||
)
|
||||
|
||||
|
||||
def _assert_engine_output_dim(
|
||||
inference_runtime: InferenceRuntimeCut,
|
||||
handle: EngineHandle,
|
||||
preprocessor: MegaLocBackbonePreprocessor,
|
||||
) -> None:
|
||||
# The 4-way duplication of this helper (ultra_vpr / net_vlad /
|
||||
# mega_loc / mix_vpr) will be consolidated by AZ-527 (hygiene
|
||||
# PBI sized in parallel with AZ-339 land). The duplication is
|
||||
# intentional for now: extracting earlier would expand AZ-339's
|
||||
# scope past the two new strategies.
|
||||
h, w = preprocessor.input_shape()
|
||||
probe = np.zeros((1, 3, h, w), dtype=np.float16)
|
||||
outputs = inference_runtime.infer(handle, {_ENGINE_INPUT_KEY: probe})
|
||||
if _OUTPUT_KEY not in outputs:
|
||||
raise ConfigError(
|
||||
f"engine output shape mismatch: {_OUTPUT_KEY!r} key absent; "
|
||||
f"got keys {sorted(outputs.keys())!r}"
|
||||
)
|
||||
actual = np.asarray(outputs[_OUTPUT_KEY])
|
||||
if (
|
||||
actual.ndim != 2
|
||||
or actual.shape[0] != 1
|
||||
or actual.shape[1] != DESCRIPTOR_DIM
|
||||
):
|
||||
raise ConfigError(
|
||||
f"engine output shape mismatch: expected (1, {DESCRIPTOR_DIM}), "
|
||||
f"got {tuple(actual.shape)}"
|
||||
)
|
||||
|
||||
@@ -60,6 +60,9 @@ from gps_denied_onboard._types.inference import (
|
||||
)
|
||||
from gps_denied_onboard._types.vpr import VprQuery, VprResult
|
||||
from gps_denied_onboard.clock import Clock
|
||||
from gps_denied_onboard.components.c2_vpr._engine_dim_assertion import (
|
||||
assert_engine_output_dim,
|
||||
)
|
||||
from gps_denied_onboard.components.c2_vpr._faiss_bridge import FaissBridge
|
||||
from gps_denied_onboard.components.c2_vpr._preprocessor_mix_vpr import (
|
||||
MixVprBackbonePreprocessor,
|
||||
@@ -396,7 +399,14 @@ def create(
|
||||
clock=clock,
|
||||
)
|
||||
|
||||
_assert_engine_output_dim(inference_runtime, handle, preprocessor)
|
||||
assert_engine_output_dim(
|
||||
inference_runtime,
|
||||
handle,
|
||||
preprocessor,
|
||||
DESCRIPTOR_DIM,
|
||||
output_key=_OUTPUT_KEY,
|
||||
input_key=_ENGINE_INPUT_KEY,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"C2 VPR strategy ready",
|
||||
@@ -424,31 +434,3 @@ def create(
|
||||
)
|
||||
|
||||
|
||||
def _assert_engine_output_dim(
|
||||
inference_runtime: InferenceRuntimeCut,
|
||||
handle: EngineHandle,
|
||||
preprocessor: MixVprBackbonePreprocessor,
|
||||
) -> None:
|
||||
# The 4-way duplication of this helper (ultra_vpr / net_vlad /
|
||||
# mega_loc / mix_vpr) will be consolidated by AZ-527 (hygiene
|
||||
# PBI sized in parallel with AZ-339 land). The duplication is
|
||||
# intentional for now: extracting earlier would expand AZ-339's
|
||||
# scope past the two new strategies.
|
||||
h, w = preprocessor.input_shape()
|
||||
probe = np.zeros((1, 3, h, w), dtype=np.float16)
|
||||
outputs = inference_runtime.infer(handle, {_ENGINE_INPUT_KEY: probe})
|
||||
if _OUTPUT_KEY not in outputs:
|
||||
raise ConfigError(
|
||||
f"engine output shape mismatch: {_OUTPUT_KEY!r} key absent; "
|
||||
f"got keys {sorted(outputs.keys())!r}"
|
||||
)
|
||||
actual = np.asarray(outputs[_OUTPUT_KEY])
|
||||
if (
|
||||
actual.ndim != 2
|
||||
or actual.shape[0] != 1
|
||||
or actual.shape[1] != DESCRIPTOR_DIM
|
||||
):
|
||||
raise ConfigError(
|
||||
f"engine output shape mismatch: expected (1, {DESCRIPTOR_DIM}), "
|
||||
f"got {tuple(actual.shape)}"
|
||||
)
|
||||
|
||||
@@ -80,6 +80,9 @@ from gps_denied_onboard._types.inference import (
|
||||
)
|
||||
from gps_denied_onboard._types.vpr import VprQuery, VprResult
|
||||
from gps_denied_onboard.clock import Clock
|
||||
from gps_denied_onboard.components.c2_vpr._engine_dim_assertion import (
|
||||
assert_engine_output_dim,
|
||||
)
|
||||
from gps_denied_onboard.components.c2_vpr._faiss_bridge import FaissBridge
|
||||
from gps_denied_onboard.components.c2_vpr._net_vlad_architecture import (
|
||||
DEFAULT_NUM_CLUSTERS,
|
||||
@@ -461,8 +464,12 @@ def create(
|
||||
clock=clock,
|
||||
)
|
||||
|
||||
_assert_engine_output_dim(
|
||||
inference_runtime, handle, descriptor_dim, preprocessor
|
||||
assert_engine_output_dim(
|
||||
inference_runtime,
|
||||
handle,
|
||||
preprocessor,
|
||||
descriptor_dim,
|
||||
output_key="vlad_descriptor",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -491,23 +498,3 @@ def create(
|
||||
)
|
||||
|
||||
|
||||
def _assert_engine_output_dim(
|
||||
inference_runtime: InferenceRuntimeCut,
|
||||
handle: EngineHandle,
|
||||
expected_dim: int,
|
||||
preprocessor: NetVladBackbonePreprocessor,
|
||||
) -> None:
|
||||
h, w = preprocessor.input_shape()
|
||||
probe = np.zeros((1, 3, h, w), dtype=np.float16)
|
||||
outputs = inference_runtime.infer(handle, {"input": probe})
|
||||
if "vlad_descriptor" not in outputs:
|
||||
raise ConfigError(
|
||||
f"engine output shape mismatch: 'vlad_descriptor' key absent; "
|
||||
f"got keys {sorted(outputs.keys())!r}"
|
||||
)
|
||||
actual = np.asarray(outputs["vlad_descriptor"])
|
||||
if actual.ndim != 2 or actual.shape[0] != 1 or actual.shape[1] != expected_dim:
|
||||
raise ConfigError(
|
||||
f"engine output shape mismatch: expected (1, {expected_dim}), "
|
||||
f"got {tuple(actual.shape)}"
|
||||
)
|
||||
|
||||
@@ -69,6 +69,9 @@ from gps_denied_onboard._types.inference import (
|
||||
)
|
||||
from gps_denied_onboard._types.vpr import VprQuery, VprResult
|
||||
from gps_denied_onboard.clock import Clock
|
||||
from gps_denied_onboard.components.c2_vpr._engine_dim_assertion import (
|
||||
assert_engine_output_dim,
|
||||
)
|
||||
from gps_denied_onboard.components.c2_vpr._faiss_bridge import FaissBridge
|
||||
from gps_denied_onboard.components.c2_vpr._preprocessor_salad import (
|
||||
SaladBackbonePreprocessor,
|
||||
@@ -406,7 +409,14 @@ def create(
|
||||
clock=clock,
|
||||
)
|
||||
|
||||
_assert_engine_output_dim(inference_runtime, handle, preprocessor)
|
||||
assert_engine_output_dim(
|
||||
inference_runtime,
|
||||
handle,
|
||||
preprocessor,
|
||||
DESCRIPTOR_DIM,
|
||||
output_key=_OUTPUT_KEY,
|
||||
input_key=_ENGINE_INPUT_KEY,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"C2 VPR strategy ready",
|
||||
@@ -434,31 +444,3 @@ def create(
|
||||
)
|
||||
|
||||
|
||||
def _assert_engine_output_dim(
|
||||
inference_runtime: InferenceRuntimeCut,
|
||||
handle: EngineHandle,
|
||||
preprocessor: SaladBackbonePreprocessor,
|
||||
) -> None:
|
||||
# The 7-way duplication of this helper (ultra_vpr / net_vlad /
|
||||
# mega_loc / mix_vpr / sela_vpr / eigen_places / salad) is tracked
|
||||
# by AZ-527 (hygiene PBI sized in parallel with AZ-339 land). The
|
||||
# duplication is intentional for now: extracting earlier would
|
||||
# expand AZ-340's scope past the three new strategies.
|
||||
h, w = preprocessor.input_shape()
|
||||
probe = np.zeros((1, 3, h, w), dtype=np.float16)
|
||||
outputs = inference_runtime.infer(handle, {_ENGINE_INPUT_KEY: probe})
|
||||
if _OUTPUT_KEY not in outputs:
|
||||
raise ConfigError(
|
||||
f"engine output shape mismatch: {_OUTPUT_KEY!r} key absent; "
|
||||
f"got keys {sorted(outputs.keys())!r}"
|
||||
)
|
||||
actual = np.asarray(outputs[_OUTPUT_KEY])
|
||||
if (
|
||||
actual.ndim != 2
|
||||
or actual.shape[0] != 1
|
||||
or actual.shape[1] != DESCRIPTOR_DIM
|
||||
):
|
||||
raise ConfigError(
|
||||
f"engine output shape mismatch: expected (1, {DESCRIPTOR_DIM}), "
|
||||
f"got {tuple(actual.shape)}"
|
||||
)
|
||||
|
||||
@@ -59,6 +59,9 @@ from gps_denied_onboard._types.inference import (
|
||||
)
|
||||
from gps_denied_onboard._types.vpr import VprQuery, VprResult
|
||||
from gps_denied_onboard.clock import Clock
|
||||
from gps_denied_onboard.components.c2_vpr._engine_dim_assertion import (
|
||||
assert_engine_output_dim,
|
||||
)
|
||||
from gps_denied_onboard.components.c2_vpr._faiss_bridge import FaissBridge
|
||||
from gps_denied_onboard.components.c2_vpr._preprocessor_sela_vpr import (
|
||||
SelaVprBackbonePreprocessor,
|
||||
@@ -393,7 +396,14 @@ def create(
|
||||
clock=clock,
|
||||
)
|
||||
|
||||
_assert_engine_output_dim(inference_runtime, handle, preprocessor)
|
||||
assert_engine_output_dim(
|
||||
inference_runtime,
|
||||
handle,
|
||||
preprocessor,
|
||||
DESCRIPTOR_DIM,
|
||||
output_key=_OUTPUT_KEY,
|
||||
input_key=_ENGINE_INPUT_KEY,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"C2 VPR strategy ready",
|
||||
@@ -421,31 +431,3 @@ def create(
|
||||
)
|
||||
|
||||
|
||||
def _assert_engine_output_dim(
|
||||
inference_runtime: InferenceRuntimeCut,
|
||||
handle: EngineHandle,
|
||||
preprocessor: SelaVprBackbonePreprocessor,
|
||||
) -> None:
|
||||
# The 7-way duplication of this helper (ultra_vpr / net_vlad /
|
||||
# mega_loc / mix_vpr / sela_vpr / eigen_places / salad) is tracked
|
||||
# by AZ-527 (hygiene PBI sized in parallel with AZ-339 land). The
|
||||
# duplication is intentional for now: extracting earlier would
|
||||
# expand AZ-340's scope past the three new strategies.
|
||||
h, w = preprocessor.input_shape()
|
||||
probe = np.zeros((1, 3, h, w), dtype=np.float16)
|
||||
outputs = inference_runtime.infer(handle, {_ENGINE_INPUT_KEY: probe})
|
||||
if _OUTPUT_KEY not in outputs:
|
||||
raise ConfigError(
|
||||
f"engine output shape mismatch: {_OUTPUT_KEY!r} key absent; "
|
||||
f"got keys {sorted(outputs.keys())!r}"
|
||||
)
|
||||
actual = np.asarray(outputs[_OUTPUT_KEY])
|
||||
if (
|
||||
actual.ndim != 2
|
||||
or actual.shape[0] != 1
|
||||
or actual.shape[1] != DESCRIPTOR_DIM
|
||||
):
|
||||
raise ConfigError(
|
||||
f"engine output shape mismatch: expected (1, {DESCRIPTOR_DIM}), "
|
||||
f"got {tuple(actual.shape)}"
|
||||
)
|
||||
|
||||
@@ -63,6 +63,9 @@ from gps_denied_onboard._types.inference import (
|
||||
)
|
||||
from gps_denied_onboard._types.vpr import VprQuery, VprResult
|
||||
from gps_denied_onboard.clock import Clock
|
||||
from gps_denied_onboard.components.c2_vpr._engine_dim_assertion import (
|
||||
assert_engine_output_dim,
|
||||
)
|
||||
from gps_denied_onboard.components.c2_vpr._faiss_bridge import FaissBridge
|
||||
from gps_denied_onboard.components.c2_vpr._preprocessor_ultra_vpr import (
|
||||
UltraVprBackbonePreprocessor,
|
||||
@@ -401,7 +404,14 @@ def create(
|
||||
clock=clock,
|
||||
)
|
||||
|
||||
_assert_engine_output_dim(inference_runtime, handle, preprocessor)
|
||||
assert_engine_output_dim(
|
||||
inference_runtime,
|
||||
handle,
|
||||
preprocessor,
|
||||
DESCRIPTOR_DIM,
|
||||
output_key=_OUTPUT_KEY,
|
||||
input_key=_ENGINE_INPUT_KEY,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"C2 VPR strategy ready",
|
||||
@@ -429,26 +439,3 @@ def create(
|
||||
)
|
||||
|
||||
|
||||
def _assert_engine_output_dim(
|
||||
inference_runtime: InferenceRuntimeCut,
|
||||
handle: EngineHandle,
|
||||
preprocessor: UltraVprBackbonePreprocessor,
|
||||
) -> None:
|
||||
h, w = preprocessor.input_shape()
|
||||
probe = np.zeros((1, 3, h, w), dtype=np.float16)
|
||||
outputs = inference_runtime.infer(handle, {_ENGINE_INPUT_KEY: probe})
|
||||
if _OUTPUT_KEY not in outputs:
|
||||
raise ConfigError(
|
||||
f"engine output shape mismatch: {_OUTPUT_KEY!r} key absent; "
|
||||
f"got keys {sorted(outputs.keys())!r}"
|
||||
)
|
||||
actual = np.asarray(outputs[_OUTPUT_KEY])
|
||||
if (
|
||||
actual.ndim != 2
|
||||
or actual.shape[0] != 1
|
||||
or actual.shape[1] != DESCRIPTOR_DIM
|
||||
):
|
||||
raise ConfigError(
|
||||
f"engine output shape mismatch: expected (1, {DESCRIPTOR_DIM}), "
|
||||
f"got {tuple(actual.shape)}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user