mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-06-21 19:31:12 +00:00
0ad3278b12
Land the fallback InferenceRuntime strategy that satisfies C7-IT-05: when the TRT-direct path (AZ-298) cannot deserialise a cached engine or when the operator explicitly selects ORT, the system stays in the air at degraded latency rather than dropping the request. Conforms to the AZ-297 Protocol; current_runtime_label() == "onnx_trt_ep". Production - onnx_trt_ep_runtime.py: compile_engine is a no-op returning an EngineCacheEntry pointing at the source .onnx; deserialize_engine is gate-first for .engine entries and gate-skip for .onnx, builds an ORT InferenceSession with the provider list [TensorrtExecutionProvider, CUDAExecutionProvider, CPUExecutionProvider], stages cached engines into the ORT TRT EP cache directory via symlink-or-copy, warms up with one session.run after construction, and honours config.inference.ort_disallow_cpu_ fallback by raising EngineDeserializeError when the active provider resolves to CPU; infer emits a one-shot c7.fallback_to_onnx_trt_ep WARN log plus gcs_alert callback on first call when is_fallback= True; release_engine is idempotent. _build_provider_args is the single point that pins TRT EP option-key names (Risk-3) and caps trt_max_workspace_size at gpu_memory_budget_bytes // 4 (AC-8). - config.py: adds ort_trt_cache_dir (validated non-empty) and ort_disallow_cpu_fallback to C7InferenceConfig. - fdr_client/records.py: adds c7.fallback_to_onnx_trt_ep and c7.cpu_fallback FDR record kinds. Tests - test_onnx_trt_ep_runtime.py: covers AC-1..AC-8 + Risk-2 CPU-fallback alert + Risk-3 option-key pin + NFR-reliability error rewrap; Tier-1 via fake ORT session; Tier-2 placeholders skip on macOS dev for numerical FP16 comparison and session-creation perf NFR. - test_protocol_conformance.py: drops onnx_trt_ep from the missing- module parametrize now that the module ships. - test_az272_fdr_record_schema.py: extends per-kind fixture builder to cover the two new C7 FDR kinds in the roundtrip / schema-version AC tests. Docs - module-layout.md: replaces the pending onnx_trt_runtime row with the shipped onnx_trt_ep_runtime row + capabilities list. - batch_32_cycle1_report.md + reviews/batch_32_review.md: full batch + self-review (PASS_WITH_WARNINGS, 4 Low findings accepted). Tests run: c7_inference 139 passing + 17 Tier-2 skips; combined unit suite (excluding pending components) 529 passing, 19 env-skipped. Co-authored-by: Cursor <cursoragent@cursor.com>
817 lines
27 KiB
Python
817 lines
27 KiB
Python
"""AZ-299 — :class:`OnnxTrtEpRuntime` acceptance tests.
|
|
|
|
The real ORT TRT EP path (provider negotiation, session creation,
|
|
warm-up, TRT subgraph compile) requires onnxruntime + a Tier-2 Jetson
|
|
host; those tests are guarded by :data:`_REQUIRE_ORT` and skip cleanly
|
|
on Tier-1 / macOS dev. CPU-runnable coverage uses fake ORT
|
|
:class:`InferenceSession` shims to verify:
|
|
|
|
- Protocol conformance + label (AC-1).
|
|
- ``.onnx`` deserialise skips the gate (AC-2).
|
|
- ``.engine`` deserialise invokes the gate first (AC-3).
|
|
- ``infer`` round-trips through ``session.run`` and returns named
|
|
outputs (AC-4) + the input-binding-missing rewrap.
|
|
- First-infer fallback WARN log + GCS callback + FDR record fire
|
|
exactly once (AC-5).
|
|
- TRT EP options carry the workspace cap from
|
|
``config.gpu_memory_budget_bytes // 4`` (AC-8).
|
|
- ``release_engine`` is idempotent (AC-7).
|
|
- NFR-reliability — ORT internal exceptions rewrap to
|
|
:class:`InferenceError`.
|
|
- CPU-fallback handling (Risk 2) emits the FDR + log; the hard-refusal
|
|
toggle raises :class:`EngineDeserializeError`.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from unittest.mock import MagicMock
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from gps_denied_onboard._types.inference import (
|
|
BuildConfig,
|
|
EngineCacheEntry,
|
|
OptimizationProfile,
|
|
PrecisionMode,
|
|
)
|
|
from gps_denied_onboard._types.thermal import ThermalState
|
|
from gps_denied_onboard.components.c7_inference import (
|
|
C7InferenceConfig,
|
|
DeploymentManifest,
|
|
EngineDeserializeError,
|
|
EngineSchemaMismatchError,
|
|
HostTuple,
|
|
InferenceError,
|
|
InferenceRuntime,
|
|
OutOfMemoryError,
|
|
)
|
|
from gps_denied_onboard.components.c7_inference.onnx_trt_ep_runtime import (
|
|
CPU_EP,
|
|
CUDA_EP,
|
|
ENGINE_SUFFIX,
|
|
ONNX_SUFFIX,
|
|
OnnxTrtEpEngineHandle,
|
|
OnnxTrtEpRuntime,
|
|
TRT_EP,
|
|
_sha256_of_file,
|
|
)
|
|
from gps_denied_onboard.config.schema import Config
|
|
from gps_denied_onboard.fdr_client.client import FdrClient
|
|
from gps_denied_onboard.fdr_client.records import (
|
|
FdrRecord,
|
|
)
|
|
from gps_denied_onboard.helpers.sha256_sidecar import (
|
|
SIDECAR_SUFFIX,
|
|
)
|
|
|
|
try:
|
|
import onnxruntime # type: ignore[import-not-found] # noqa: F401
|
|
|
|
_HAS_ORT = True
|
|
except ImportError:
|
|
_HAS_ORT = False
|
|
|
|
_REQUIRE_ORT = pytest.mark.skipif(
|
|
not _HAS_ORT,
|
|
reason="onnxruntime not installed (Tier-2 Jetson / JetPack 6.2 only)",
|
|
)
|
|
|
|
_TIER2_REASON = (
|
|
"AZ-299 Tier-2 microbench harness owns the real-ORT perf / numerical "
|
|
"asserts (C7-IT-05); skipped on Tier-1 CI / macOS dev."
|
|
)
|
|
|
|
|
|
_TIER2_HOST = HostTuple(sm=87, jp="6.2", trt="10.3", precision=PrecisionMode.FP16)
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# Fakes.
|
|
|
|
|
|
class _FakeIoTensor:
|
|
"""Minimal stand-in for ``onnxruntime.NodeArg``."""
|
|
|
|
def __init__(self, name: str, shape: tuple[int, ...], dtype: str) -> None:
|
|
self.name = name
|
|
self.shape = list(shape)
|
|
self.type = dtype
|
|
|
|
|
|
class _FakeOrtSession:
|
|
"""Hand-rolled :class:`onnxruntime.InferenceSession` substitute."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
active_providers: tuple[str, ...] = (TRT_EP, CUDA_EP, CPU_EP),
|
|
inputs: tuple[_FakeIoTensor, ...] = (
|
|
_FakeIoTensor("x", (1, 3, 224, 224), "tensor(float)"),
|
|
),
|
|
outputs: tuple[_FakeIoTensor, ...] = (
|
|
_FakeIoTensor("y", (1, 1000), "tensor(float)"),
|
|
),
|
|
run_side_effect: Any | None = None,
|
|
) -> None:
|
|
self._active_providers = active_providers
|
|
self._inputs = inputs
|
|
self._outputs = outputs
|
|
self.run_calls: list[tuple[list[str] | None, dict[str, np.ndarray]]] = []
|
|
self.profiling_ended = 0
|
|
self._run_side_effect = run_side_effect
|
|
|
|
def get_providers(self) -> list[str]:
|
|
return list(self._active_providers)
|
|
|
|
def get_inputs(self) -> list[_FakeIoTensor]:
|
|
return list(self._inputs)
|
|
|
|
def get_outputs(self) -> list[_FakeIoTensor]:
|
|
return list(self._outputs)
|
|
|
|
def run(
|
|
self,
|
|
output_names: list[str] | None,
|
|
feed: dict[str, np.ndarray],
|
|
) -> list[np.ndarray]:
|
|
self.run_calls.append((output_names, feed))
|
|
if self._run_side_effect is not None:
|
|
raise self._run_side_effect
|
|
result: list[np.ndarray] = []
|
|
for out in self._outputs:
|
|
shape = tuple(d if d > 0 else 1 for d in out.shape)
|
|
result.append(np.zeros(shape, dtype=np.float32))
|
|
return result
|
|
|
|
def end_profiling(self) -> None:
|
|
self.profiling_ended += 1
|
|
|
|
|
|
class _FakeOrt:
|
|
"""Replaces the lazy-imported ``onnxruntime`` module."""
|
|
|
|
def __init__(self, session: _FakeOrtSession) -> None:
|
|
self._session = session
|
|
self.session_kwargs: dict[str, Any] = {}
|
|
self.session_path: str | None = None
|
|
self.create_side_effect: Any | None = None
|
|
|
|
def InferenceSession(
|
|
self,
|
|
path: str,
|
|
*,
|
|
providers: list[str],
|
|
provider_options: list[dict[str, Any]],
|
|
) -> _FakeOrtSession:
|
|
self.session_path = path
|
|
self.session_kwargs = {
|
|
"providers": providers,
|
|
"provider_options": provider_options,
|
|
}
|
|
if self.create_side_effect is not None:
|
|
raise self.create_side_effect
|
|
return self._session
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# Fixtures.
|
|
|
|
|
|
@pytest.fixture
|
|
def config(tmp_path: Path) -> Config:
|
|
return Config.with_blocks(
|
|
c7_inference=C7InferenceConfig(
|
|
runtime="onnx_trt_ep",
|
|
engine_cache_dir=str(tmp_path / "engines"),
|
|
ort_trt_cache_dir=str(tmp_path / "ort_cache"),
|
|
)
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def runtime_basic(config: Config) -> OnnxTrtEpRuntime:
|
|
return OnnxTrtEpRuntime(config)
|
|
|
|
|
|
def _make_engine_entry(
|
|
tmp_path: Path,
|
|
*,
|
|
sm: int = 87,
|
|
jp: str = "6.2",
|
|
trt: str = "10.3",
|
|
precision: PrecisionMode = PrecisionMode.FP16,
|
|
payload: bytes = b"fake-engine-bytes",
|
|
) -> tuple[EngineCacheEntry, Path]:
|
|
name = f"ultravpr__sm{sm}_jp{jp}_trt{trt}_{precision.value}.engine"
|
|
engine_path = tmp_path / name
|
|
engine_path.write_bytes(payload)
|
|
sha_hex = hashlib.sha256(payload).hexdigest()
|
|
Path(str(engine_path) + SIDECAR_SUFFIX).write_text(sha_hex, encoding="utf-8")
|
|
entry = EngineCacheEntry(
|
|
engine_path=engine_path,
|
|
sha256_hex=sha_hex,
|
|
sm=sm,
|
|
jp=jp,
|
|
trt=trt,
|
|
precision=precision,
|
|
extras={},
|
|
)
|
|
return entry, engine_path
|
|
|
|
|
|
def _make_onnx_entry(tmp_path: Path) -> tuple[EngineCacheEntry, Path]:
|
|
onnx_path = tmp_path / "ultravpr.onnx"
|
|
onnx_path.write_bytes(b"\x08\x07") # not a real ONNX, but file exists.
|
|
sha_hex = hashlib.sha256(b"\x08\x07").hexdigest()
|
|
entry = EngineCacheEntry(
|
|
engine_path=onnx_path,
|
|
sha256_hex=sha_hex,
|
|
sm=87,
|
|
jp="6.2",
|
|
trt="10.3",
|
|
precision=PrecisionMode.FP16,
|
|
extras={"model_name": onnx_path.stem},
|
|
)
|
|
return entry, onnx_path
|
|
|
|
|
|
def _manifest_for(engine_path: Path) -> DeploymentManifest:
|
|
sha_hex = hashlib.sha256(engine_path.read_bytes()).hexdigest()
|
|
return DeploymentManifest(
|
|
root=engine_path.parent,
|
|
entries={engine_path.name: sha_hex},
|
|
)
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# AC-1: Protocol + label.
|
|
|
|
|
|
def test_ac1_protocol_conformance(runtime_basic: OnnxTrtEpRuntime) -> None:
|
|
assert isinstance(runtime_basic, InferenceRuntime)
|
|
assert runtime_basic.current_runtime_label() == "onnx_trt_ep"
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# AC-2: deserialize from .onnx skips the gate.
|
|
|
|
|
|
def test_ac2_onnx_path_skips_gate(
|
|
tmp_path: Path, config: Config
|
|
) -> None:
|
|
# Arrange
|
|
entry, onnx_path = _make_onnx_entry(tmp_path)
|
|
fake_session = _FakeOrtSession()
|
|
fake_ort = _FakeOrt(fake_session)
|
|
|
|
gate = MagicMock() # would raise AssertionError if called
|
|
runtime = OnnxTrtEpRuntime(
|
|
config,
|
|
engine_gate=gate,
|
|
host_tuple_provider=lambda _p: _TIER2_HOST,
|
|
manifest_provider=lambda: _manifest_for(onnx_path),
|
|
)
|
|
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
|
|
# Act
|
|
handle = runtime.deserialize_engine(entry)
|
|
# Assert
|
|
gate.validate.assert_not_called()
|
|
assert isinstance(handle, OnnxTrtEpEngineHandle)
|
|
assert handle._active_provider == TRT_EP
|
|
assert handle._model_name == onnx_path.stem
|
|
assert fake_ort.session_path == str(onnx_path)
|
|
assert fake_ort.session_kwargs["providers"] == [TRT_EP, CUDA_EP, CPU_EP]
|
|
# warm-up call recorded.
|
|
assert len(fake_session.run_calls) == 1
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# AC-3: deserialize from .engine invokes the gate first.
|
|
|
|
|
|
def test_ac3_engine_path_invokes_gate_first(
|
|
tmp_path: Path, config: Config
|
|
) -> None:
|
|
# Arrange — engine filename says sm=86 but host is sm=87 → gate refuses.
|
|
entry, engine_path = _make_engine_entry(tmp_path, sm=86)
|
|
|
|
class _NoSessionOrt:
|
|
def InferenceSession(self, *args: Any, **kwargs: Any) -> Any:
|
|
raise AssertionError(
|
|
"AC-3: ORT session must NOT be created when gate refuses"
|
|
)
|
|
|
|
runtime = OnnxTrtEpRuntime(
|
|
config,
|
|
host_tuple_provider=lambda _p: _TIER2_HOST,
|
|
manifest_provider=lambda: _manifest_for(engine_path),
|
|
)
|
|
runtime._load_ort = lambda: _NoSessionOrt() # type: ignore[method-assign]
|
|
# Act / Assert
|
|
with pytest.raises(EngineSchemaMismatchError, match="sm=86"):
|
|
runtime.deserialize_engine(entry)
|
|
|
|
|
|
def test_ac3_engine_path_passes_gate_then_creates_session(
|
|
tmp_path: Path, config: Config
|
|
) -> None:
|
|
# Arrange
|
|
entry, engine_path = _make_engine_entry(tmp_path, sm=87)
|
|
fake_session = _FakeOrtSession()
|
|
fake_ort = _FakeOrt(fake_session)
|
|
runtime = OnnxTrtEpRuntime(
|
|
config,
|
|
host_tuple_provider=lambda _p: _TIER2_HOST,
|
|
manifest_provider=lambda: _manifest_for(engine_path),
|
|
)
|
|
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
|
|
# Act
|
|
handle = runtime.deserialize_engine(entry)
|
|
# Assert
|
|
assert isinstance(handle, OnnxTrtEpEngineHandle)
|
|
assert fake_ort.session_path == str(engine_path)
|
|
# Engine file should have been staged under the ORT cache dir.
|
|
staged = Path(config.components["c7_inference"].ort_trt_cache_dir) / engine_path.name
|
|
assert staged.exists()
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# AC-4: infer round-trips and returns named outputs.
|
|
|
|
|
|
def test_ac4_infer_round_trip_named_outputs(
|
|
tmp_path: Path, config: Config
|
|
) -> None:
|
|
# Arrange
|
|
entry, onnx_path = _make_onnx_entry(tmp_path)
|
|
fake_session = _FakeOrtSession()
|
|
fake_ort = _FakeOrt(fake_session)
|
|
runtime = OnnxTrtEpRuntime(
|
|
config,
|
|
host_tuple_provider=lambda _p: _TIER2_HOST,
|
|
manifest_provider=lambda: _manifest_for(onnx_path),
|
|
)
|
|
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
|
|
handle = runtime.deserialize_engine(entry)
|
|
fake_session.run_calls.clear()
|
|
# Act
|
|
outputs = runtime.infer(
|
|
handle, {"x": np.ones((1, 3, 224, 224), dtype=np.float32)}
|
|
)
|
|
# Assert
|
|
assert set(outputs.keys()) == {"y"}
|
|
assert outputs["y"].shape == (1, 1000)
|
|
assert len(fake_session.run_calls) == 1
|
|
requested_outputs, feed = fake_session.run_calls[0]
|
|
assert requested_outputs == ["y"]
|
|
assert "x" in feed and feed["x"].shape == (1, 3, 224, 224)
|
|
|
|
|
|
def test_ac4_infer_missing_input_binding_rewraps(
|
|
tmp_path: Path, config: Config
|
|
) -> None:
|
|
entry, onnx_path = _make_onnx_entry(tmp_path)
|
|
fake_session = _FakeOrtSession()
|
|
fake_ort = _FakeOrt(fake_session)
|
|
runtime = OnnxTrtEpRuntime(
|
|
config,
|
|
host_tuple_provider=lambda _p: _TIER2_HOST,
|
|
manifest_provider=lambda: _manifest_for(onnx_path),
|
|
)
|
|
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
|
|
handle = runtime.deserialize_engine(entry)
|
|
with pytest.raises(InferenceError, match="missing input binding"):
|
|
runtime.infer(handle, {})
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# AC-5: fallback WARN + gcs_alert + FDR fire exactly once.
|
|
|
|
|
|
def test_ac5_first_infer_fallback_alert_fires_once(
|
|
tmp_path: Path, config: Config, caplog: pytest.LogCaptureFixture
|
|
) -> None:
|
|
# Arrange
|
|
entry, onnx_path = _make_onnx_entry(tmp_path)
|
|
fake_session = _FakeOrtSession()
|
|
fake_ort = _FakeOrt(fake_session)
|
|
gcs_calls: list[str] = []
|
|
fdr = FdrClient(producer_id="c7_inference.onnx_trt_ep", capacity=16)
|
|
runtime = OnnxTrtEpRuntime(
|
|
config,
|
|
is_fallback=True,
|
|
host_tuple_provider=lambda _p: _TIER2_HOST,
|
|
manifest_provider=lambda: _manifest_for(onnx_path),
|
|
fdr_client=fdr,
|
|
gcs_alert=lambda msg: gcs_calls.append(msg),
|
|
)
|
|
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
|
|
handle = runtime.deserialize_engine(entry)
|
|
# Act — first infer.
|
|
caplog.set_level("WARNING", logger="c7_inference.onnx_trt_ep")
|
|
runtime.infer(handle, {"x": np.ones((1, 3, 224, 224), dtype=np.float32)})
|
|
# Assert — exactly one WARN log on the fallback kind.
|
|
warn_records = [
|
|
r for r in caplog.records
|
|
if r.levelname == "WARNING" and "fallback" in r.getMessage()
|
|
]
|
|
assert len(warn_records) == 1
|
|
assert len(gcs_calls) == 1
|
|
# FDR drained: one c7.fallback_to_onnx_trt_ep record.
|
|
fdr_records = fdr.drain(max_records=4)
|
|
fallback_records = [
|
|
r for r in fdr_records if r.kind == "c7.fallback_to_onnx_trt_ep"
|
|
]
|
|
assert len(fallback_records) == 1
|
|
rec = fallback_records[0]
|
|
assert rec.payload["model_name"] == onnx_path.stem
|
|
assert rec.payload["active_provider"] == TRT_EP
|
|
# Act — second infer should be silent.
|
|
caplog.clear()
|
|
runtime.infer(handle, {"x": np.ones((1, 3, 224, 224), dtype=np.float32)})
|
|
further_warns = [r for r in caplog.records if r.levelname == "WARNING"]
|
|
assert further_warns == []
|
|
assert len(gcs_calls) == 1
|
|
further_records = fdr.drain(max_records=4)
|
|
assert all(
|
|
r.kind != "c7.fallback_to_onnx_trt_ep" for r in further_records
|
|
)
|
|
|
|
|
|
def test_ac5_non_fallback_runtime_is_silent(
|
|
tmp_path: Path,
|
|
config: Config,
|
|
caplog: pytest.LogCaptureFixture,
|
|
) -> None:
|
|
entry, onnx_path = _make_onnx_entry(tmp_path)
|
|
fake_session = _FakeOrtSession()
|
|
fake_ort = _FakeOrt(fake_session)
|
|
fdr = FdrClient(producer_id="c7_inference.onnx_trt_ep", capacity=16)
|
|
gcs_calls: list[str] = []
|
|
runtime = OnnxTrtEpRuntime(
|
|
config,
|
|
is_fallback=False,
|
|
host_tuple_provider=lambda _p: _TIER2_HOST,
|
|
manifest_provider=lambda: _manifest_for(onnx_path),
|
|
fdr_client=fdr,
|
|
gcs_alert=lambda msg: gcs_calls.append(msg),
|
|
)
|
|
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
|
|
handle = runtime.deserialize_engine(entry)
|
|
caplog.set_level("WARNING", logger="c7_inference.onnx_trt_ep")
|
|
runtime.infer(handle, {"x": np.ones((1, 3, 224, 224), dtype=np.float32)})
|
|
fallback_warns = [
|
|
r for r in caplog.records if "fallback" in r.getMessage()
|
|
]
|
|
assert fallback_warns == []
|
|
assert gcs_calls == []
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# AC-6: provider fallback chain (TRT refuses → CUDA active; label unchanged).
|
|
|
|
|
|
def test_ac6_active_provider_is_first_get_providers_entry(
|
|
tmp_path: Path, config: Config
|
|
) -> None:
|
|
entry, onnx_path = _make_onnx_entry(tmp_path)
|
|
fake_session = _FakeOrtSession(
|
|
active_providers=(CUDA_EP, CPU_EP)
|
|
)
|
|
fake_ort = _FakeOrt(fake_session)
|
|
runtime = OnnxTrtEpRuntime(
|
|
config,
|
|
host_tuple_provider=lambda _p: _TIER2_HOST,
|
|
manifest_provider=lambda: _manifest_for(onnx_path),
|
|
)
|
|
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
|
|
handle = runtime.deserialize_engine(entry)
|
|
assert handle._active_provider == CUDA_EP
|
|
assert runtime.current_runtime_label() == "onnx_trt_ep"
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# AC-7: release_engine idempotent.
|
|
|
|
|
|
def test_ac7_release_engine_idempotent(
|
|
tmp_path: Path, config: Config
|
|
) -> None:
|
|
entry, onnx_path = _make_onnx_entry(tmp_path)
|
|
fake_session = _FakeOrtSession()
|
|
fake_ort = _FakeOrt(fake_session)
|
|
runtime = OnnxTrtEpRuntime(
|
|
config,
|
|
host_tuple_provider=lambda _p: _TIER2_HOST,
|
|
manifest_provider=lambda: _manifest_for(onnx_path),
|
|
)
|
|
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
|
|
handle = runtime.deserialize_engine(entry)
|
|
# Act — first release.
|
|
runtime.release_engine(handle)
|
|
assert handle._released is True
|
|
assert handle._session is None
|
|
assert fake_session.profiling_ended == 1
|
|
# Act — second release is a no-op.
|
|
runtime.release_engine(handle)
|
|
assert fake_session.profiling_ended == 1
|
|
|
|
|
|
def test_release_engine_ignores_foreign_handle(
|
|
runtime_basic: OnnxTrtEpRuntime,
|
|
) -> None:
|
|
class _Foreign:
|
|
pass
|
|
|
|
runtime_basic.release_engine(_Foreign()) # type: ignore[arg-type]
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# AC-8: TRT EP options carry the budget // 4 workspace cap.
|
|
|
|
|
|
def test_ac8_trt_ep_options_carry_budget_workspace_cap(
|
|
tmp_path: Path, config: Config
|
|
) -> None:
|
|
entry, onnx_path = _make_onnx_entry(tmp_path)
|
|
fake_session = _FakeOrtSession()
|
|
fake_ort = _FakeOrt(fake_session)
|
|
runtime = OnnxTrtEpRuntime(
|
|
config,
|
|
host_tuple_provider=lambda _p: _TIER2_HOST,
|
|
manifest_provider=lambda: _manifest_for(onnx_path),
|
|
)
|
|
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
|
|
runtime.deserialize_engine(entry)
|
|
opts = fake_ort.session_kwargs["provider_options"]
|
|
# First provider option dict corresponds to TRT EP.
|
|
trt_opts = opts[0]
|
|
budget = int(config.components["c7_inference"].gpu_memory_budget_bytes)
|
|
assert trt_opts["trt_max_workspace_size"] == budget // 4
|
|
assert trt_opts["trt_engine_cache_enable"] is True
|
|
assert Path(trt_opts["trt_engine_cache_path"]).exists()
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# Risk 2: CPU-fallback handling.
|
|
|
|
|
|
def test_cpu_fallback_emits_record_when_warn_allowed(
|
|
tmp_path: Path, config: Config
|
|
) -> None:
|
|
entry, onnx_path = _make_onnx_entry(tmp_path)
|
|
fake_session = _FakeOrtSession(active_providers=(CPU_EP,))
|
|
fake_ort = _FakeOrt(fake_session)
|
|
fdr = FdrClient(producer_id="c7_inference.onnx_trt_ep", capacity=16)
|
|
runtime = OnnxTrtEpRuntime(
|
|
config,
|
|
host_tuple_provider=lambda _p: _TIER2_HOST,
|
|
manifest_provider=lambda: _manifest_for(onnx_path),
|
|
fdr_client=fdr,
|
|
)
|
|
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
|
|
runtime.deserialize_engine(entry)
|
|
drained = fdr.drain(max_records=4)
|
|
cpu_records = [r for r in drained if r.kind == "c7.cpu_fallback"]
|
|
assert len(cpu_records) == 1
|
|
payload = cpu_records[0].payload
|
|
assert payload["active_provider"] == CPU_EP
|
|
assert payload["requested_providers"] == [TRT_EP, CUDA_EP, CPU_EP]
|
|
|
|
|
|
def test_cpu_fallback_raises_when_disallowed(
|
|
tmp_path: Path,
|
|
) -> None:
|
|
entry, onnx_path = _make_onnx_entry(tmp_path)
|
|
fake_session = _FakeOrtSession(active_providers=(CPU_EP,))
|
|
fake_ort = _FakeOrt(fake_session)
|
|
config = Config.with_blocks(
|
|
c7_inference=C7InferenceConfig(
|
|
runtime="onnx_trt_ep",
|
|
engine_cache_dir=str(onnx_path.parent / "engines"),
|
|
ort_trt_cache_dir=str(onnx_path.parent / "ort_cache"),
|
|
ort_disallow_cpu_fallback=True,
|
|
)
|
|
)
|
|
runtime = OnnxTrtEpRuntime(
|
|
config,
|
|
host_tuple_provider=lambda _p: _TIER2_HOST,
|
|
manifest_provider=lambda: _manifest_for(onnx_path),
|
|
)
|
|
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
|
|
with pytest.raises(
|
|
EngineDeserializeError, match="CPUExecutionProvider"
|
|
):
|
|
runtime.deserialize_engine(entry)
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# NFR-reliability: ORT internal exceptions rewrap into the AZ-297 family.
|
|
|
|
|
|
def test_nfr_reliability_session_creation_failure_rewraps(
|
|
tmp_path: Path, config: Config
|
|
) -> None:
|
|
entry, onnx_path = _make_onnx_entry(tmp_path)
|
|
fake_ort = _FakeOrt(_FakeOrtSession())
|
|
fake_ort.create_side_effect = RuntimeError("ORT C++ exception")
|
|
runtime = OnnxTrtEpRuntime(
|
|
config,
|
|
host_tuple_provider=lambda _p: _TIER2_HOST,
|
|
manifest_provider=lambda: _manifest_for(onnx_path),
|
|
)
|
|
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
|
|
with pytest.raises(EngineDeserializeError, match="ORT C\\+\\+ exception"):
|
|
runtime.deserialize_engine(entry)
|
|
|
|
|
|
def test_nfr_reliability_infer_rewraps_runtime_error(
|
|
tmp_path: Path, config: Config
|
|
) -> None:
|
|
entry, onnx_path = _make_onnx_entry(tmp_path)
|
|
fake_session = _FakeOrtSession()
|
|
fake_ort = _FakeOrt(fake_session)
|
|
runtime = OnnxTrtEpRuntime(
|
|
config,
|
|
host_tuple_provider=lambda _p: _TIER2_HOST,
|
|
manifest_provider=lambda: _manifest_for(onnx_path),
|
|
)
|
|
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
|
|
handle = runtime.deserialize_engine(entry)
|
|
fake_session._run_side_effect = RuntimeError(
|
|
"ORT runtime: OrtInvalidArgument"
|
|
)
|
|
with pytest.raises(InferenceError, match="OrtInvalidArgument"):
|
|
runtime.infer(
|
|
handle, {"x": np.ones((1, 3, 224, 224), dtype=np.float32)}
|
|
)
|
|
|
|
|
|
def test_nfr_reliability_infer_rewraps_memory_error(
|
|
tmp_path: Path, config: Config
|
|
) -> None:
|
|
entry, onnx_path = _make_onnx_entry(tmp_path)
|
|
fake_session = _FakeOrtSession()
|
|
fake_ort = _FakeOrt(fake_session)
|
|
runtime = OnnxTrtEpRuntime(
|
|
config,
|
|
host_tuple_provider=lambda _p: _TIER2_HOST,
|
|
manifest_provider=lambda: _manifest_for(onnx_path),
|
|
)
|
|
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
|
|
handle = runtime.deserialize_engine(entry)
|
|
fake_session._run_side_effect = MemoryError("ORT mid-run OOM")
|
|
with pytest.raises(OutOfMemoryError, match="OOM"):
|
|
runtime.infer(
|
|
handle, {"x": np.ones((1, 3, 224, 224), dtype=np.float32)}
|
|
)
|
|
|
|
|
|
def test_infer_rejects_foreign_handle(runtime_basic: OnnxTrtEpRuntime) -> None:
|
|
with pytest.raises(InferenceError, match="foreign handle"):
|
|
runtime_basic.infer(object(), {}) # type: ignore[arg-type]
|
|
|
|
|
|
def test_infer_rejects_released_handle(
|
|
tmp_path: Path, config: Config
|
|
) -> None:
|
|
entry, onnx_path = _make_onnx_entry(tmp_path)
|
|
fake_session = _FakeOrtSession()
|
|
fake_ort = _FakeOrt(fake_session)
|
|
runtime = OnnxTrtEpRuntime(
|
|
config,
|
|
host_tuple_provider=lambda _p: _TIER2_HOST,
|
|
manifest_provider=lambda: _manifest_for(onnx_path),
|
|
)
|
|
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
|
|
handle = runtime.deserialize_engine(entry)
|
|
runtime.release_engine(handle)
|
|
with pytest.raises(InferenceError, match="released handle"):
|
|
runtime.infer(
|
|
handle, {"x": np.ones((1, 3, 224, 224), dtype=np.float32)}
|
|
)
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# compile_engine + helpers.
|
|
|
|
|
|
def test_compile_engine_returns_onnx_entry(
|
|
tmp_path: Path, runtime_basic: OnnxTrtEpRuntime
|
|
) -> None:
|
|
onnx_path = tmp_path / "model.onnx"
|
|
onnx_path.write_bytes(b"\x08\x07")
|
|
bc = BuildConfig(
|
|
precision=PrecisionMode.FP16,
|
|
workspace_mb=512,
|
|
calibration_dataset=None,
|
|
optimization_profiles=(
|
|
OptimizationProfile(
|
|
input_name="x",
|
|
min_shape=(1, 3, 224, 224),
|
|
opt_shape=(1, 3, 224, 224),
|
|
max_shape=(1, 3, 224, 224),
|
|
),
|
|
),
|
|
)
|
|
runtime_basic._host_tuple_provider = lambda _p: _TIER2_HOST
|
|
entry = runtime_basic.compile_engine(onnx_path, bc)
|
|
assert entry.engine_path == onnx_path
|
|
assert entry.sha256_hex == _sha256_of_file(onnx_path)
|
|
assert entry.precision is PrecisionMode.FP16
|
|
assert entry.extras["model_name"] == "model"
|
|
|
|
|
|
def test_compile_engine_missing_onnx_raises(
|
|
tmp_path: Path, runtime_basic: OnnxTrtEpRuntime
|
|
) -> None:
|
|
runtime_basic._host_tuple_provider = lambda _p: _TIER2_HOST
|
|
bc = BuildConfig(
|
|
precision=PrecisionMode.FP16,
|
|
workspace_mb=512,
|
|
calibration_dataset=None,
|
|
optimization_profiles=(),
|
|
)
|
|
with pytest.raises(EngineDeserializeError, match="not found"):
|
|
runtime_basic.compile_engine(tmp_path / "nope.onnx", bc)
|
|
|
|
|
|
def test_deserialize_engine_unknown_suffix_raises(
|
|
tmp_path: Path, config: Config
|
|
) -> None:
|
|
bogus = tmp_path / "model.tflite"
|
|
bogus.write_bytes(b"x")
|
|
entry = EngineCacheEntry(
|
|
engine_path=bogus,
|
|
sha256_hex="x" * 64,
|
|
sm=87,
|
|
jp="6.2",
|
|
trt="10.3",
|
|
precision=PrecisionMode.FP16,
|
|
extras={},
|
|
)
|
|
runtime = OnnxTrtEpRuntime(config)
|
|
with pytest.raises(EngineDeserializeError, match="unsupported"):
|
|
runtime.deserialize_engine(entry)
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# thermal_state delegation.
|
|
|
|
|
|
def test_thermal_state_default_safe(runtime_basic: OnnxTrtEpRuntime) -> None:
|
|
snapshot = runtime_basic.thermal_state()
|
|
assert isinstance(snapshot, ThermalState)
|
|
assert snapshot.is_telemetry_available is False
|
|
assert snapshot.thermal_throttle_active is False
|
|
|
|
|
|
def test_thermal_state_delegates_to_publisher(config: Config) -> None:
|
|
canned = ThermalState(
|
|
cpu_temp_c=44.0,
|
|
gpu_temp_c=58.0,
|
|
thermal_throttle_active=True,
|
|
measured_clock_mhz=624,
|
|
measured_at_ns=1_000_000_000,
|
|
is_telemetry_available=True,
|
|
)
|
|
|
|
class _Publisher:
|
|
def read(self) -> ThermalState:
|
|
return canned
|
|
|
|
runtime = OnnxTrtEpRuntime(config, thermal_publisher=_Publisher())
|
|
assert runtime.thermal_state() is canned
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# Tier-2 placeholders — real ORT path lives in the AZ-299 microbench harness.
|
|
|
|
|
|
@_REQUIRE_ORT
|
|
@pytest.mark.tier2
|
|
def test_ac4_numerical_match_against_tensorrt_direct(
|
|
tmp_path: Path, config: Config
|
|
) -> None: # pragma: no cover - Tier-2 only
|
|
pytest.skip(_TIER2_REASON)
|
|
|
|
|
|
@_REQUIRE_ORT
|
|
@pytest.mark.tier2
|
|
def test_nfr_perf_session_create_first_p95_under_30s(
|
|
tmp_path: Path, config: Config
|
|
) -> None: # pragma: no cover - Tier-2 only
|
|
pytest.skip(_TIER2_REASON)
|
|
|
|
|
|
@_REQUIRE_ORT
|
|
@pytest.mark.tier2
|
|
def test_nfr_perf_session_create_subsequent_p95_under_5s(
|
|
tmp_path: Path, config: Config
|
|
) -> None: # pragma: no cover - Tier-2 only
|
|
pytest.skip(_TIER2_REASON)
|