Files
Oleksandr Bezdieniezhnykh 0ad3278b12 [AZ-299] C7 OnnxTrtEpRuntime: ORT + TRT EP fallback strategy
Land the fallback InferenceRuntime strategy that satisfies C7-IT-05:
when the TRT-direct path (AZ-298) cannot deserialise a cached engine
or when the operator explicitly selects ORT, the system stays in the
air at degraded latency rather than dropping the request. Conforms to
the AZ-297 Protocol; current_runtime_label() == "onnx_trt_ep".

Production
- onnx_trt_ep_runtime.py: compile_engine is a no-op returning an
  EngineCacheEntry pointing at the source .onnx; deserialize_engine
  is gate-first for .engine entries and gate-skip for .onnx, builds
  an ORT InferenceSession with the provider list
  [TensorrtExecutionProvider, CUDAExecutionProvider,
  CPUExecutionProvider], stages cached engines into the ORT TRT EP
  cache directory via symlink-or-copy, warms up with one session.run
  after construction, and honours config.inference.ort_disallow_cpu_
  fallback by raising EngineDeserializeError when the active provider
  resolves to CPU; infer emits a one-shot c7.fallback_to_onnx_trt_ep
  WARN log plus gcs_alert callback on first call when is_fallback=
  True; release_engine is idempotent. _build_provider_args is the
  single point that pins TRT EP option-key names (Risk-3) and caps
  trt_max_workspace_size at gpu_memory_budget_bytes // 4 (AC-8).
- config.py: adds ort_trt_cache_dir (validated non-empty) and
  ort_disallow_cpu_fallback to C7InferenceConfig.
- fdr_client/records.py: adds c7.fallback_to_onnx_trt_ep and
  c7.cpu_fallback FDR record kinds.

Tests
- test_onnx_trt_ep_runtime.py: covers AC-1..AC-8 + Risk-2 CPU-fallback
  alert + Risk-3 option-key pin + NFR-reliability error rewrap; Tier-1
  via fake ORT session; Tier-2 placeholders skip on macOS dev for
  numerical FP16 comparison and session-creation perf NFR.
- test_protocol_conformance.py: drops onnx_trt_ep from the missing-
  module parametrize now that the module ships.
- test_az272_fdr_record_schema.py: extends per-kind fixture builder
  to cover the two new C7 FDR kinds in the roundtrip / schema-version
  AC tests.

Docs
- module-layout.md: replaces the pending onnx_trt_runtime row with
  the shipped onnx_trt_ep_runtime row + capabilities list.
- batch_32_cycle1_report.md + reviews/batch_32_review.md: full batch
  + self-review (PASS_WITH_WARNINGS, 4 Low findings accepted).

Tests run: c7_inference 139 passing + 17 Tier-2 skips; combined unit
suite (excluding pending components) 529 passing, 19 env-skipped.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-12 23:55:50 +03:00

817 lines
27 KiB
Python

"""AZ-299 — :class:`OnnxTrtEpRuntime` acceptance tests.
The real ORT TRT EP path (provider negotiation, session creation,
warm-up, TRT subgraph compile) requires onnxruntime + a Tier-2 Jetson
host; those tests are guarded by :data:`_REQUIRE_ORT` and skip cleanly
on Tier-1 / macOS dev. CPU-runnable coverage uses fake ORT
:class:`InferenceSession` shims to verify:
- Protocol conformance + label (AC-1).
- ``.onnx`` deserialise skips the gate (AC-2).
- ``.engine`` deserialise invokes the gate first (AC-3).
- ``infer`` round-trips through ``session.run`` and returns named
outputs (AC-4) + the input-binding-missing rewrap.
- First-infer fallback WARN log + GCS callback + FDR record fire
exactly once (AC-5).
- TRT EP options carry the workspace cap from
``config.gpu_memory_budget_bytes // 4`` (AC-8).
- ``release_engine`` is idempotent (AC-7).
- NFR-reliability — ORT internal exceptions rewrap to
:class:`InferenceError`.
- CPU-fallback handling (Risk 2) emits the FDR + log; the hard-refusal
toggle raises :class:`EngineDeserializeError`.
"""
from __future__ import annotations
import hashlib
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock
import numpy as np
import pytest
from gps_denied_onboard._types.inference import (
BuildConfig,
EngineCacheEntry,
OptimizationProfile,
PrecisionMode,
)
from gps_denied_onboard._types.thermal import ThermalState
from gps_denied_onboard.components.c7_inference import (
C7InferenceConfig,
DeploymentManifest,
EngineDeserializeError,
EngineSchemaMismatchError,
HostTuple,
InferenceError,
InferenceRuntime,
OutOfMemoryError,
)
from gps_denied_onboard.components.c7_inference.onnx_trt_ep_runtime import (
CPU_EP,
CUDA_EP,
ENGINE_SUFFIX,
ONNX_SUFFIX,
OnnxTrtEpEngineHandle,
OnnxTrtEpRuntime,
TRT_EP,
_sha256_of_file,
)
from gps_denied_onboard.config.schema import Config
from gps_denied_onboard.fdr_client.client import FdrClient
from gps_denied_onboard.fdr_client.records import (
FdrRecord,
)
from gps_denied_onboard.helpers.sha256_sidecar import (
SIDECAR_SUFFIX,
)
try:
import onnxruntime # type: ignore[import-not-found] # noqa: F401
_HAS_ORT = True
except ImportError:
_HAS_ORT = False
_REQUIRE_ORT = pytest.mark.skipif(
not _HAS_ORT,
reason="onnxruntime not installed (Tier-2 Jetson / JetPack 6.2 only)",
)
_TIER2_REASON = (
"AZ-299 Tier-2 microbench harness owns the real-ORT perf / numerical "
"asserts (C7-IT-05); skipped on Tier-1 CI / macOS dev."
)
_TIER2_HOST = HostTuple(sm=87, jp="6.2", trt="10.3", precision=PrecisionMode.FP16)
# ----------------------------------------------------------------------
# Fakes.
class _FakeIoTensor:
"""Minimal stand-in for ``onnxruntime.NodeArg``."""
def __init__(self, name: str, shape: tuple[int, ...], dtype: str) -> None:
self.name = name
self.shape = list(shape)
self.type = dtype
class _FakeOrtSession:
"""Hand-rolled :class:`onnxruntime.InferenceSession` substitute."""
def __init__(
self,
*,
active_providers: tuple[str, ...] = (TRT_EP, CUDA_EP, CPU_EP),
inputs: tuple[_FakeIoTensor, ...] = (
_FakeIoTensor("x", (1, 3, 224, 224), "tensor(float)"),
),
outputs: tuple[_FakeIoTensor, ...] = (
_FakeIoTensor("y", (1, 1000), "tensor(float)"),
),
run_side_effect: Any | None = None,
) -> None:
self._active_providers = active_providers
self._inputs = inputs
self._outputs = outputs
self.run_calls: list[tuple[list[str] | None, dict[str, np.ndarray]]] = []
self.profiling_ended = 0
self._run_side_effect = run_side_effect
def get_providers(self) -> list[str]:
return list(self._active_providers)
def get_inputs(self) -> list[_FakeIoTensor]:
return list(self._inputs)
def get_outputs(self) -> list[_FakeIoTensor]:
return list(self._outputs)
def run(
self,
output_names: list[str] | None,
feed: dict[str, np.ndarray],
) -> list[np.ndarray]:
self.run_calls.append((output_names, feed))
if self._run_side_effect is not None:
raise self._run_side_effect
result: list[np.ndarray] = []
for out in self._outputs:
shape = tuple(d if d > 0 else 1 for d in out.shape)
result.append(np.zeros(shape, dtype=np.float32))
return result
def end_profiling(self) -> None:
self.profiling_ended += 1
class _FakeOrt:
"""Replaces the lazy-imported ``onnxruntime`` module."""
def __init__(self, session: _FakeOrtSession) -> None:
self._session = session
self.session_kwargs: dict[str, Any] = {}
self.session_path: str | None = None
self.create_side_effect: Any | None = None
def InferenceSession(
self,
path: str,
*,
providers: list[str],
provider_options: list[dict[str, Any]],
) -> _FakeOrtSession:
self.session_path = path
self.session_kwargs = {
"providers": providers,
"provider_options": provider_options,
}
if self.create_side_effect is not None:
raise self.create_side_effect
return self._session
# ----------------------------------------------------------------------
# Fixtures.
@pytest.fixture
def config(tmp_path: Path) -> Config:
return Config.with_blocks(
c7_inference=C7InferenceConfig(
runtime="onnx_trt_ep",
engine_cache_dir=str(tmp_path / "engines"),
ort_trt_cache_dir=str(tmp_path / "ort_cache"),
)
)
@pytest.fixture
def runtime_basic(config: Config) -> OnnxTrtEpRuntime:
return OnnxTrtEpRuntime(config)
def _make_engine_entry(
tmp_path: Path,
*,
sm: int = 87,
jp: str = "6.2",
trt: str = "10.3",
precision: PrecisionMode = PrecisionMode.FP16,
payload: bytes = b"fake-engine-bytes",
) -> tuple[EngineCacheEntry, Path]:
name = f"ultravpr__sm{sm}_jp{jp}_trt{trt}_{precision.value}.engine"
engine_path = tmp_path / name
engine_path.write_bytes(payload)
sha_hex = hashlib.sha256(payload).hexdigest()
Path(str(engine_path) + SIDECAR_SUFFIX).write_text(sha_hex, encoding="utf-8")
entry = EngineCacheEntry(
engine_path=engine_path,
sha256_hex=sha_hex,
sm=sm,
jp=jp,
trt=trt,
precision=precision,
extras={},
)
return entry, engine_path
def _make_onnx_entry(tmp_path: Path) -> tuple[EngineCacheEntry, Path]:
onnx_path = tmp_path / "ultravpr.onnx"
onnx_path.write_bytes(b"\x08\x07") # not a real ONNX, but file exists.
sha_hex = hashlib.sha256(b"\x08\x07").hexdigest()
entry = EngineCacheEntry(
engine_path=onnx_path,
sha256_hex=sha_hex,
sm=87,
jp="6.2",
trt="10.3",
precision=PrecisionMode.FP16,
extras={"model_name": onnx_path.stem},
)
return entry, onnx_path
def _manifest_for(engine_path: Path) -> DeploymentManifest:
sha_hex = hashlib.sha256(engine_path.read_bytes()).hexdigest()
return DeploymentManifest(
root=engine_path.parent,
entries={engine_path.name: sha_hex},
)
# ----------------------------------------------------------------------
# AC-1: Protocol + label.
def test_ac1_protocol_conformance(runtime_basic: OnnxTrtEpRuntime) -> None:
assert isinstance(runtime_basic, InferenceRuntime)
assert runtime_basic.current_runtime_label() == "onnx_trt_ep"
# ----------------------------------------------------------------------
# AC-2: deserialize from .onnx skips the gate.
def test_ac2_onnx_path_skips_gate(
tmp_path: Path, config: Config
) -> None:
# Arrange
entry, onnx_path = _make_onnx_entry(tmp_path)
fake_session = _FakeOrtSession()
fake_ort = _FakeOrt(fake_session)
gate = MagicMock() # would raise AssertionError if called
runtime = OnnxTrtEpRuntime(
config,
engine_gate=gate,
host_tuple_provider=lambda _p: _TIER2_HOST,
manifest_provider=lambda: _manifest_for(onnx_path),
)
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
# Act
handle = runtime.deserialize_engine(entry)
# Assert
gate.validate.assert_not_called()
assert isinstance(handle, OnnxTrtEpEngineHandle)
assert handle._active_provider == TRT_EP
assert handle._model_name == onnx_path.stem
assert fake_ort.session_path == str(onnx_path)
assert fake_ort.session_kwargs["providers"] == [TRT_EP, CUDA_EP, CPU_EP]
# warm-up call recorded.
assert len(fake_session.run_calls) == 1
# ----------------------------------------------------------------------
# AC-3: deserialize from .engine invokes the gate first.
def test_ac3_engine_path_invokes_gate_first(
tmp_path: Path, config: Config
) -> None:
# Arrange — engine filename says sm=86 but host is sm=87 → gate refuses.
entry, engine_path = _make_engine_entry(tmp_path, sm=86)
class _NoSessionOrt:
def InferenceSession(self, *args: Any, **kwargs: Any) -> Any:
raise AssertionError(
"AC-3: ORT session must NOT be created when gate refuses"
)
runtime = OnnxTrtEpRuntime(
config,
host_tuple_provider=lambda _p: _TIER2_HOST,
manifest_provider=lambda: _manifest_for(engine_path),
)
runtime._load_ort = lambda: _NoSessionOrt() # type: ignore[method-assign]
# Act / Assert
with pytest.raises(EngineSchemaMismatchError, match="sm=86"):
runtime.deserialize_engine(entry)
def test_ac3_engine_path_passes_gate_then_creates_session(
tmp_path: Path, config: Config
) -> None:
# Arrange
entry, engine_path = _make_engine_entry(tmp_path, sm=87)
fake_session = _FakeOrtSession()
fake_ort = _FakeOrt(fake_session)
runtime = OnnxTrtEpRuntime(
config,
host_tuple_provider=lambda _p: _TIER2_HOST,
manifest_provider=lambda: _manifest_for(engine_path),
)
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
# Act
handle = runtime.deserialize_engine(entry)
# Assert
assert isinstance(handle, OnnxTrtEpEngineHandle)
assert fake_ort.session_path == str(engine_path)
# Engine file should have been staged under the ORT cache dir.
staged = Path(config.components["c7_inference"].ort_trt_cache_dir) / engine_path.name
assert staged.exists()
# ----------------------------------------------------------------------
# AC-4: infer round-trips and returns named outputs.
def test_ac4_infer_round_trip_named_outputs(
tmp_path: Path, config: Config
) -> None:
# Arrange
entry, onnx_path = _make_onnx_entry(tmp_path)
fake_session = _FakeOrtSession()
fake_ort = _FakeOrt(fake_session)
runtime = OnnxTrtEpRuntime(
config,
host_tuple_provider=lambda _p: _TIER2_HOST,
manifest_provider=lambda: _manifest_for(onnx_path),
)
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
handle = runtime.deserialize_engine(entry)
fake_session.run_calls.clear()
# Act
outputs = runtime.infer(
handle, {"x": np.ones((1, 3, 224, 224), dtype=np.float32)}
)
# Assert
assert set(outputs.keys()) == {"y"}
assert outputs["y"].shape == (1, 1000)
assert len(fake_session.run_calls) == 1
requested_outputs, feed = fake_session.run_calls[0]
assert requested_outputs == ["y"]
assert "x" in feed and feed["x"].shape == (1, 3, 224, 224)
def test_ac4_infer_missing_input_binding_rewraps(
tmp_path: Path, config: Config
) -> None:
entry, onnx_path = _make_onnx_entry(tmp_path)
fake_session = _FakeOrtSession()
fake_ort = _FakeOrt(fake_session)
runtime = OnnxTrtEpRuntime(
config,
host_tuple_provider=lambda _p: _TIER2_HOST,
manifest_provider=lambda: _manifest_for(onnx_path),
)
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
handle = runtime.deserialize_engine(entry)
with pytest.raises(InferenceError, match="missing input binding"):
runtime.infer(handle, {})
# ----------------------------------------------------------------------
# AC-5: fallback WARN + gcs_alert + FDR fire exactly once.
def test_ac5_first_infer_fallback_alert_fires_once(
tmp_path: Path, config: Config, caplog: pytest.LogCaptureFixture
) -> None:
# Arrange
entry, onnx_path = _make_onnx_entry(tmp_path)
fake_session = _FakeOrtSession()
fake_ort = _FakeOrt(fake_session)
gcs_calls: list[str] = []
fdr = FdrClient(producer_id="c7_inference.onnx_trt_ep", capacity=16)
runtime = OnnxTrtEpRuntime(
config,
is_fallback=True,
host_tuple_provider=lambda _p: _TIER2_HOST,
manifest_provider=lambda: _manifest_for(onnx_path),
fdr_client=fdr,
gcs_alert=lambda msg: gcs_calls.append(msg),
)
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
handle = runtime.deserialize_engine(entry)
# Act — first infer.
caplog.set_level("WARNING", logger="c7_inference.onnx_trt_ep")
runtime.infer(handle, {"x": np.ones((1, 3, 224, 224), dtype=np.float32)})
# Assert — exactly one WARN log on the fallback kind.
warn_records = [
r for r in caplog.records
if r.levelname == "WARNING" and "fallback" in r.getMessage()
]
assert len(warn_records) == 1
assert len(gcs_calls) == 1
# FDR drained: one c7.fallback_to_onnx_trt_ep record.
fdr_records = fdr.drain(max_records=4)
fallback_records = [
r for r in fdr_records if r.kind == "c7.fallback_to_onnx_trt_ep"
]
assert len(fallback_records) == 1
rec = fallback_records[0]
assert rec.payload["model_name"] == onnx_path.stem
assert rec.payload["active_provider"] == TRT_EP
# Act — second infer should be silent.
caplog.clear()
runtime.infer(handle, {"x": np.ones((1, 3, 224, 224), dtype=np.float32)})
further_warns = [r for r in caplog.records if r.levelname == "WARNING"]
assert further_warns == []
assert len(gcs_calls) == 1
further_records = fdr.drain(max_records=4)
assert all(
r.kind != "c7.fallback_to_onnx_trt_ep" for r in further_records
)
def test_ac5_non_fallback_runtime_is_silent(
tmp_path: Path,
config: Config,
caplog: pytest.LogCaptureFixture,
) -> None:
entry, onnx_path = _make_onnx_entry(tmp_path)
fake_session = _FakeOrtSession()
fake_ort = _FakeOrt(fake_session)
fdr = FdrClient(producer_id="c7_inference.onnx_trt_ep", capacity=16)
gcs_calls: list[str] = []
runtime = OnnxTrtEpRuntime(
config,
is_fallback=False,
host_tuple_provider=lambda _p: _TIER2_HOST,
manifest_provider=lambda: _manifest_for(onnx_path),
fdr_client=fdr,
gcs_alert=lambda msg: gcs_calls.append(msg),
)
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
handle = runtime.deserialize_engine(entry)
caplog.set_level("WARNING", logger="c7_inference.onnx_trt_ep")
runtime.infer(handle, {"x": np.ones((1, 3, 224, 224), dtype=np.float32)})
fallback_warns = [
r for r in caplog.records if "fallback" in r.getMessage()
]
assert fallback_warns == []
assert gcs_calls == []
# ----------------------------------------------------------------------
# AC-6: provider fallback chain (TRT refuses → CUDA active; label unchanged).
def test_ac6_active_provider_is_first_get_providers_entry(
tmp_path: Path, config: Config
) -> None:
entry, onnx_path = _make_onnx_entry(tmp_path)
fake_session = _FakeOrtSession(
active_providers=(CUDA_EP, CPU_EP)
)
fake_ort = _FakeOrt(fake_session)
runtime = OnnxTrtEpRuntime(
config,
host_tuple_provider=lambda _p: _TIER2_HOST,
manifest_provider=lambda: _manifest_for(onnx_path),
)
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
handle = runtime.deserialize_engine(entry)
assert handle._active_provider == CUDA_EP
assert runtime.current_runtime_label() == "onnx_trt_ep"
# ----------------------------------------------------------------------
# AC-7: release_engine idempotent.
def test_ac7_release_engine_idempotent(
tmp_path: Path, config: Config
) -> None:
entry, onnx_path = _make_onnx_entry(tmp_path)
fake_session = _FakeOrtSession()
fake_ort = _FakeOrt(fake_session)
runtime = OnnxTrtEpRuntime(
config,
host_tuple_provider=lambda _p: _TIER2_HOST,
manifest_provider=lambda: _manifest_for(onnx_path),
)
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
handle = runtime.deserialize_engine(entry)
# Act — first release.
runtime.release_engine(handle)
assert handle._released is True
assert handle._session is None
assert fake_session.profiling_ended == 1
# Act — second release is a no-op.
runtime.release_engine(handle)
assert fake_session.profiling_ended == 1
def test_release_engine_ignores_foreign_handle(
runtime_basic: OnnxTrtEpRuntime,
) -> None:
class _Foreign:
pass
runtime_basic.release_engine(_Foreign()) # type: ignore[arg-type]
# ----------------------------------------------------------------------
# AC-8: TRT EP options carry the budget // 4 workspace cap.
def test_ac8_trt_ep_options_carry_budget_workspace_cap(
tmp_path: Path, config: Config
) -> None:
entry, onnx_path = _make_onnx_entry(tmp_path)
fake_session = _FakeOrtSession()
fake_ort = _FakeOrt(fake_session)
runtime = OnnxTrtEpRuntime(
config,
host_tuple_provider=lambda _p: _TIER2_HOST,
manifest_provider=lambda: _manifest_for(onnx_path),
)
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
runtime.deserialize_engine(entry)
opts = fake_ort.session_kwargs["provider_options"]
# First provider option dict corresponds to TRT EP.
trt_opts = opts[0]
budget = int(config.components["c7_inference"].gpu_memory_budget_bytes)
assert trt_opts["trt_max_workspace_size"] == budget // 4
assert trt_opts["trt_engine_cache_enable"] is True
assert Path(trt_opts["trt_engine_cache_path"]).exists()
# ----------------------------------------------------------------------
# Risk 2: CPU-fallback handling.
def test_cpu_fallback_emits_record_when_warn_allowed(
tmp_path: Path, config: Config
) -> None:
entry, onnx_path = _make_onnx_entry(tmp_path)
fake_session = _FakeOrtSession(active_providers=(CPU_EP,))
fake_ort = _FakeOrt(fake_session)
fdr = FdrClient(producer_id="c7_inference.onnx_trt_ep", capacity=16)
runtime = OnnxTrtEpRuntime(
config,
host_tuple_provider=lambda _p: _TIER2_HOST,
manifest_provider=lambda: _manifest_for(onnx_path),
fdr_client=fdr,
)
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
runtime.deserialize_engine(entry)
drained = fdr.drain(max_records=4)
cpu_records = [r for r in drained if r.kind == "c7.cpu_fallback"]
assert len(cpu_records) == 1
payload = cpu_records[0].payload
assert payload["active_provider"] == CPU_EP
assert payload["requested_providers"] == [TRT_EP, CUDA_EP, CPU_EP]
def test_cpu_fallback_raises_when_disallowed(
tmp_path: Path,
) -> None:
entry, onnx_path = _make_onnx_entry(tmp_path)
fake_session = _FakeOrtSession(active_providers=(CPU_EP,))
fake_ort = _FakeOrt(fake_session)
config = Config.with_blocks(
c7_inference=C7InferenceConfig(
runtime="onnx_trt_ep",
engine_cache_dir=str(onnx_path.parent / "engines"),
ort_trt_cache_dir=str(onnx_path.parent / "ort_cache"),
ort_disallow_cpu_fallback=True,
)
)
runtime = OnnxTrtEpRuntime(
config,
host_tuple_provider=lambda _p: _TIER2_HOST,
manifest_provider=lambda: _manifest_for(onnx_path),
)
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
with pytest.raises(
EngineDeserializeError, match="CPUExecutionProvider"
):
runtime.deserialize_engine(entry)
# ----------------------------------------------------------------------
# NFR-reliability: ORT internal exceptions rewrap into the AZ-297 family.
def test_nfr_reliability_session_creation_failure_rewraps(
tmp_path: Path, config: Config
) -> None:
entry, onnx_path = _make_onnx_entry(tmp_path)
fake_ort = _FakeOrt(_FakeOrtSession())
fake_ort.create_side_effect = RuntimeError("ORT C++ exception")
runtime = OnnxTrtEpRuntime(
config,
host_tuple_provider=lambda _p: _TIER2_HOST,
manifest_provider=lambda: _manifest_for(onnx_path),
)
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
with pytest.raises(EngineDeserializeError, match="ORT C\\+\\+ exception"):
runtime.deserialize_engine(entry)
def test_nfr_reliability_infer_rewraps_runtime_error(
tmp_path: Path, config: Config
) -> None:
entry, onnx_path = _make_onnx_entry(tmp_path)
fake_session = _FakeOrtSession()
fake_ort = _FakeOrt(fake_session)
runtime = OnnxTrtEpRuntime(
config,
host_tuple_provider=lambda _p: _TIER2_HOST,
manifest_provider=lambda: _manifest_for(onnx_path),
)
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
handle = runtime.deserialize_engine(entry)
fake_session._run_side_effect = RuntimeError(
"ORT runtime: OrtInvalidArgument"
)
with pytest.raises(InferenceError, match="OrtInvalidArgument"):
runtime.infer(
handle, {"x": np.ones((1, 3, 224, 224), dtype=np.float32)}
)
def test_nfr_reliability_infer_rewraps_memory_error(
tmp_path: Path, config: Config
) -> None:
entry, onnx_path = _make_onnx_entry(tmp_path)
fake_session = _FakeOrtSession()
fake_ort = _FakeOrt(fake_session)
runtime = OnnxTrtEpRuntime(
config,
host_tuple_provider=lambda _p: _TIER2_HOST,
manifest_provider=lambda: _manifest_for(onnx_path),
)
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
handle = runtime.deserialize_engine(entry)
fake_session._run_side_effect = MemoryError("ORT mid-run OOM")
with pytest.raises(OutOfMemoryError, match="OOM"):
runtime.infer(
handle, {"x": np.ones((1, 3, 224, 224), dtype=np.float32)}
)
def test_infer_rejects_foreign_handle(runtime_basic: OnnxTrtEpRuntime) -> None:
with pytest.raises(InferenceError, match="foreign handle"):
runtime_basic.infer(object(), {}) # type: ignore[arg-type]
def test_infer_rejects_released_handle(
tmp_path: Path, config: Config
) -> None:
entry, onnx_path = _make_onnx_entry(tmp_path)
fake_session = _FakeOrtSession()
fake_ort = _FakeOrt(fake_session)
runtime = OnnxTrtEpRuntime(
config,
host_tuple_provider=lambda _p: _TIER2_HOST,
manifest_provider=lambda: _manifest_for(onnx_path),
)
runtime._load_ort = lambda: fake_ort # type: ignore[method-assign]
handle = runtime.deserialize_engine(entry)
runtime.release_engine(handle)
with pytest.raises(InferenceError, match="released handle"):
runtime.infer(
handle, {"x": np.ones((1, 3, 224, 224), dtype=np.float32)}
)
# ----------------------------------------------------------------------
# compile_engine + helpers.
def test_compile_engine_returns_onnx_entry(
tmp_path: Path, runtime_basic: OnnxTrtEpRuntime
) -> None:
onnx_path = tmp_path / "model.onnx"
onnx_path.write_bytes(b"\x08\x07")
bc = BuildConfig(
precision=PrecisionMode.FP16,
workspace_mb=512,
calibration_dataset=None,
optimization_profiles=(
OptimizationProfile(
input_name="x",
min_shape=(1, 3, 224, 224),
opt_shape=(1, 3, 224, 224),
max_shape=(1, 3, 224, 224),
),
),
)
runtime_basic._host_tuple_provider = lambda _p: _TIER2_HOST
entry = runtime_basic.compile_engine(onnx_path, bc)
assert entry.engine_path == onnx_path
assert entry.sha256_hex == _sha256_of_file(onnx_path)
assert entry.precision is PrecisionMode.FP16
assert entry.extras["model_name"] == "model"
def test_compile_engine_missing_onnx_raises(
tmp_path: Path, runtime_basic: OnnxTrtEpRuntime
) -> None:
runtime_basic._host_tuple_provider = lambda _p: _TIER2_HOST
bc = BuildConfig(
precision=PrecisionMode.FP16,
workspace_mb=512,
calibration_dataset=None,
optimization_profiles=(),
)
with pytest.raises(EngineDeserializeError, match="not found"):
runtime_basic.compile_engine(tmp_path / "nope.onnx", bc)
def test_deserialize_engine_unknown_suffix_raises(
tmp_path: Path, config: Config
) -> None:
bogus = tmp_path / "model.tflite"
bogus.write_bytes(b"x")
entry = EngineCacheEntry(
engine_path=bogus,
sha256_hex="x" * 64,
sm=87,
jp="6.2",
trt="10.3",
precision=PrecisionMode.FP16,
extras={},
)
runtime = OnnxTrtEpRuntime(config)
with pytest.raises(EngineDeserializeError, match="unsupported"):
runtime.deserialize_engine(entry)
# ----------------------------------------------------------------------
# thermal_state delegation.
def test_thermal_state_default_safe(runtime_basic: OnnxTrtEpRuntime) -> None:
snapshot = runtime_basic.thermal_state()
assert isinstance(snapshot, ThermalState)
assert snapshot.is_telemetry_available is False
assert snapshot.thermal_throttle_active is False
def test_thermal_state_delegates_to_publisher(config: Config) -> None:
canned = ThermalState(
cpu_temp_c=44.0,
gpu_temp_c=58.0,
thermal_throttle_active=True,
measured_clock_mhz=624,
measured_at_ns=1_000_000_000,
is_telemetry_available=True,
)
class _Publisher:
def read(self) -> ThermalState:
return canned
runtime = OnnxTrtEpRuntime(config, thermal_publisher=_Publisher())
assert runtime.thermal_state() is canned
# ----------------------------------------------------------------------
# Tier-2 placeholders — real ORT path lives in the AZ-299 microbench harness.
@_REQUIRE_ORT
@pytest.mark.tier2
def test_ac4_numerical_match_against_tensorrt_direct(
tmp_path: Path, config: Config
) -> None: # pragma: no cover - Tier-2 only
pytest.skip(_TIER2_REASON)
@_REQUIRE_ORT
@pytest.mark.tier2
def test_nfr_perf_session_create_first_p95_under_30s(
tmp_path: Path, config: Config
) -> None: # pragma: no cover - Tier-2 only
pytest.skip(_TIER2_REASON)
@_REQUIRE_ORT
@pytest.mark.tier2
def test_nfr_perf_session_create_subsequent_p95_under_5s(
tmp_path: Path, config: Config
) -> None: # pragma: no cover - Tier-2 only
pytest.skip(_TIER2_REASON)