mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-06-22 19:11:14 +00:00
[AZ-299] C7 OnnxTrtEpRuntime: ORT + TRT EP fallback strategy
Land the fallback InferenceRuntime strategy that satisfies C7-IT-05: when the TRT-direct path (AZ-298) cannot deserialise a cached engine or when the operator explicitly selects ORT, the system stays in the air at degraded latency rather than dropping the request. Conforms to the AZ-297 Protocol; current_runtime_label() == "onnx_trt_ep". Production - onnx_trt_ep_runtime.py: compile_engine is a no-op returning an EngineCacheEntry pointing at the source .onnx; deserialize_engine is gate-first for .engine entries and gate-skip for .onnx, builds an ORT InferenceSession with the provider list [TensorrtExecutionProvider, CUDAExecutionProvider, CPUExecutionProvider], stages cached engines into the ORT TRT EP cache directory via symlink-or-copy, warms up with one session.run after construction, and honours config.inference.ort_disallow_cpu_ fallback by raising EngineDeserializeError when the active provider resolves to CPU; infer emits a one-shot c7.fallback_to_onnx_trt_ep WARN log plus gcs_alert callback on first call when is_fallback= True; release_engine is idempotent. _build_provider_args is the single point that pins TRT EP option-key names (Risk-3) and caps trt_max_workspace_size at gpu_memory_budget_bytes // 4 (AC-8). - config.py: adds ort_trt_cache_dir (validated non-empty) and ort_disallow_cpu_fallback to C7InferenceConfig. - fdr_client/records.py: adds c7.fallback_to_onnx_trt_ep and c7.cpu_fallback FDR record kinds. Tests - test_onnx_trt_ep_runtime.py: covers AC-1..AC-8 + Risk-2 CPU-fallback alert + Risk-3 option-key pin + NFR-reliability error rewrap; Tier-1 via fake ORT session; Tier-2 placeholders skip on macOS dev for numerical FP16 comparison and session-creation perf NFR. - test_protocol_conformance.py: drops onnx_trt_ep from the missing- module parametrize now that the module ships. - test_az272_fdr_record_schema.py: extends per-kind fixture builder to cover the two new C7 FDR kinds in the roundtrip / schema-version AC tests. Docs - module-layout.md: replaces the pending onnx_trt_runtime row with the shipped onnx_trt_ep_runtime row + capabilities list. - batch_32_cycle1_report.md + reviews/batch_32_review.md: full batch + self-review (PASS_WITH_WARNINGS, 4 Low findings accepted). Tests run: c7_inference 139 passing + 17 Tier-2 skips; combined unit suite (excluding pending components) 529 passing, 19 env-skipped. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -51,6 +51,14 @@ class C7InferenceConfig:
|
||||
``trtexec_timeout_s`` bounds the ``trtexec`` subprocess used by
|
||||
``TensorrtRuntime.compile_engine`` when ``BuildConfig.use_trtexec``
|
||||
is true (AZ-298 Risk 4); default 10 minutes.
|
||||
|
||||
``ort_trt_cache_dir`` is the ORT TensorRT execution-provider
|
||||
subgraph cache directory used by ``OnnxTrtEpRuntime`` (AZ-299);
|
||||
populated per-flight under ``engine_cache_dir`` to keep the cache
|
||||
bounded across runs. ``ort_disallow_cpu_fallback`` makes
|
||||
``OnnxTrtEpRuntime`` refuse to deserialise when ORT's active
|
||||
provider would be ``CPUExecutionProvider`` (Risk-2 mitigation);
|
||||
default False (warn-but-allow per AZ-299 spec).
|
||||
"""
|
||||
|
||||
runtime: str = "pytorch_fp16"
|
||||
@@ -59,6 +67,8 @@ class C7InferenceConfig:
|
||||
per_frame_debug_log: bool = False
|
||||
gpu_memory_budget_bytes: int = 4 * 1024 * 1024 * 1024
|
||||
trtexec_timeout_s: int = 600
|
||||
ort_trt_cache_dir: str = "/var/lib/gps-denied/engines/ort_trt_cache"
|
||||
ort_disallow_cpu_fallback: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.runtime not in KNOWN_RUNTIMES:
|
||||
@@ -85,3 +95,7 @@ class C7InferenceConfig:
|
||||
"C7InferenceConfig.trtexec_timeout_s must be > 0; "
|
||||
f"got {self.trtexec_timeout_s}"
|
||||
)
|
||||
if not self.ort_trt_cache_dir:
|
||||
raise ConfigError(
|
||||
"C7InferenceConfig.ort_trt_cache_dir must be non-empty"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,678 @@
|
||||
"""``OnnxTrtEpRuntime`` — ONNX Runtime + TRT EP fallback strategy (AZ-299).
|
||||
|
||||
Conforms to :class:`InferenceRuntime` (AZ-297). The runtime is wired
|
||||
as either:
|
||||
|
||||
- **Operator-selected**: ``config.components['c7_inference'].runtime
|
||||
== "onnx_trt_ep"`` — the strategy serves every infer call directly.
|
||||
- **Implicit fallback**: composition root constructs it with
|
||||
``is_fallback=True`` after :class:`TensorrtRuntime.deserialize_engine`
|
||||
(AZ-298) refused a given engine; first ``infer`` call emits a
|
||||
one-shot ``kind="c7.fallback_to_onnx_trt_ep"`` WARN log + FDR record
|
||||
+ GCS alert callback (covers C7-IT-05).
|
||||
|
||||
Provider list is fixed: ``["TensorrtExecutionProvider",
|
||||
"CUDAExecutionProvider", "CPUExecutionProvider"]``. If ORT silently
|
||||
collapses to ``CPUExecutionProvider`` (both TRT and CUDA EPs refused
|
||||
to load), :meth:`deserialize_engine` either emits a
|
||||
``kind="c7.cpu_fallback"`` WARN record and continues
|
||||
(``config.ort_disallow_cpu_fallback == False``, default) or raises
|
||||
:class:`EngineDeserializeError` (when ``True``).
|
||||
|
||||
``onnxruntime`` is a **lazy import** inside the methods that need it
|
||||
so the module loads cleanly on Tier-0 / macOS dev hosts (the module's
|
||||
protocol-conformance tests stay importable without ORT installed).
|
||||
|
||||
AC mapping (see ``_docs/02_tasks/todo/AZ-299_c7_onnxrt_fallback.md``):
|
||||
|
||||
- AC-1 : :meth:`current_runtime_label` returns ``"onnx_trt_ep"``.
|
||||
- AC-2 : ``.onnx`` deserialise skips the gate.
|
||||
- AC-3 : ``.engine`` deserialise invokes the gate first.
|
||||
- AC-4 : :meth:`infer` round-trips through ``session.run`` and
|
||||
produces a dict keyed by declared output names.
|
||||
- AC-5 : first ``infer`` with ``is_fallback=True`` emits exactly one
|
||||
WARN log + ``gcs_alert`` callback + FDR record; subsequent infers
|
||||
are silent on the fallback path.
|
||||
- AC-6 : provider fallback chain respects ORT order; an INFO log
|
||||
records the actually-active provider.
|
||||
- AC-7 : :meth:`release_engine` drops the session reference once
|
||||
and is a no-op thereafter.
|
||||
- AC-8 : TRT EP workspace cap is set to
|
||||
``gpu_memory_budget_bytes // 4``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Final, Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gps_denied_onboard._types.inference import (
|
||||
BuildConfig,
|
||||
EngineCacheEntry,
|
||||
EngineHandle,
|
||||
PrecisionMode,
|
||||
)
|
||||
from gps_denied_onboard._types.thermal import ThermalState
|
||||
from gps_denied_onboard.clock.wall_clock import WallClock
|
||||
from gps_denied_onboard.components.c7_inference.engine_gate import (
|
||||
EngineGate,
|
||||
HostTuple,
|
||||
read_host_tuple,
|
||||
)
|
||||
from gps_denied_onboard.components.c7_inference.errors import (
|
||||
EngineDeserializeError,
|
||||
InferenceError,
|
||||
OutOfMemoryError,
|
||||
)
|
||||
from gps_denied_onboard.components.c7_inference.manifest import (
|
||||
DeploymentManifest,
|
||||
ManifestReader,
|
||||
ManifestReaderProtocol,
|
||||
)
|
||||
from gps_denied_onboard.fdr_client.records import (
|
||||
CURRENT_SCHEMA_VERSION,
|
||||
FdrRecord,
|
||||
)
|
||||
from gps_denied_onboard.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gps_denied_onboard.clock import Clock
|
||||
from gps_denied_onboard.config.schema import Config
|
||||
from gps_denied_onboard.fdr_client import FdrClient
|
||||
|
||||
__all__ = [
|
||||
"ONNX_SUFFIX",
|
||||
"ENGINE_SUFFIX",
|
||||
"OnnxTrtEpEngineHandle",
|
||||
"OnnxTrtEpRuntime",
|
||||
"TRT_EP",
|
||||
"CUDA_EP",
|
||||
"CPU_EP",
|
||||
]
|
||||
|
||||
_RUNTIME_LABEL: Final[Literal["onnx_trt_ep"]] = "onnx_trt_ep"
|
||||
_PRODUCER_ID: Final[str] = "c7_inference.onnx_trt_ep"
|
||||
_SHA256_CHUNK: Final[int] = 1 << 20
|
||||
_FALLBACK_KIND: Final[str] = "c7.fallback_to_onnx_trt_ep"
|
||||
_CPU_FALLBACK_KIND: Final[str] = "c7.cpu_fallback"
|
||||
|
||||
ONNX_SUFFIX: Final[str] = ".onnx"
|
||||
ENGINE_SUFFIX: Final[str] = ".engine"
|
||||
|
||||
TRT_EP: Final[str] = "TensorrtExecutionProvider"
|
||||
CUDA_EP: Final[str] = "CUDAExecutionProvider"
|
||||
CPU_EP: Final[str] = "CPUExecutionProvider"
|
||||
|
||||
_DEFAULT_PROVIDER_LIST: Final[tuple[str, ...]] = (TRT_EP, CUDA_EP, CPU_EP)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Opaque handle (I-4).
|
||||
|
||||
|
||||
class OnnxTrtEpEngineHandle(EngineHandle):
|
||||
"""ORT :class:`InferenceSession` + declared IO names.
|
||||
|
||||
Per Invariant I-4 fields are private to :class:`OnnxTrtEpRuntime`.
|
||||
The handle owns the session lifetime; :meth:`release_engine` drops
|
||||
the session reference so ORT can free EP resources on GC.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"_session",
|
||||
"_input_names",
|
||||
"_output_names",
|
||||
"_active_provider",
|
||||
"_model_name",
|
||||
"_released",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
session: Any,
|
||||
input_names: tuple[str, ...],
|
||||
output_names: tuple[str, ...],
|
||||
active_provider: str,
|
||||
model_name: str,
|
||||
) -> None:
|
||||
self._session = session
|
||||
self._input_names = input_names
|
||||
self._output_names = output_names
|
||||
self._active_provider = active_provider
|
||||
self._model_name = model_name
|
||||
self._released = False
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Helpers.
|
||||
|
||||
|
||||
def _sha256_of_file(path: Path) -> str:
|
||||
"""Stream a sha256 over a file (engine or ONNX) for the cache entry."""
|
||||
digest = hashlib.sha256()
|
||||
with path.open("rb") as fh:
|
||||
while True:
|
||||
chunk = fh.read(_SHA256_CHUNK)
|
||||
if not chunk:
|
||||
break
|
||||
digest.update(chunk)
|
||||
return digest.hexdigest()
|
||||
|
||||
|
||||
def _iso_ts_now() -> str:
|
||||
"""RFC 3339 UTC timestamp with microsecond precision and a ``Z`` suffix.
|
||||
|
||||
Mirrors :func:`components.c6_tile_cache._timestamp.iso_ts_now` —
|
||||
consolidation into ``helpers.iso_timestamp`` is intentionally
|
||||
deferred to the next cross-component hygiene pass (peer imports
|
||||
between c6 and c7 would violate layer-2 horizontal-import etiquette
|
||||
documented in ``module-layout.md``).
|
||||
"""
|
||||
return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ")
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Runtime.
|
||||
|
||||
|
||||
class OnnxTrtEpRuntime:
|
||||
"""ONNX Runtime + TensorRT-EP-led :class:`InferenceRuntime` strategy.
|
||||
|
||||
Constructor matches the composition-root factory shape
|
||||
(``strategy_cls(config)`` plus optional keyword collaborators).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
*,
|
||||
is_fallback: bool = False,
|
||||
thermal_publisher: Any | None = None,
|
||||
engine_gate: EngineGate | None = None,
|
||||
host_tuple_provider: Any | None = None,
|
||||
manifest_provider: Any | None = None,
|
||||
fdr_client: FdrClient | None = None,
|
||||
gcs_alert: Callable[[str], None] | None = None,
|
||||
clock: Clock | None = None,
|
||||
provider_list: tuple[str, ...] = _DEFAULT_PROVIDER_LIST,
|
||||
) -> None:
|
||||
self._config = config
|
||||
self._c7_config = config.components["c7_inference"]
|
||||
self._is_fallback = bool(is_fallback)
|
||||
self._thermal_publisher = thermal_publisher
|
||||
self._engine_gate = engine_gate if engine_gate is not None else EngineGate()
|
||||
self._host_tuple_provider = host_tuple_provider
|
||||
self._manifest_provider = manifest_provider
|
||||
self._fdr_client = fdr_client
|
||||
self._gcs_alert = gcs_alert
|
||||
self._clock = clock if clock is not None else WallClock()
|
||||
self._provider_list = tuple(provider_list)
|
||||
self._logger = get_logger(_PRODUCER_ID)
|
||||
self._fallback_warned: bool = False
|
||||
|
||||
# -- Helpers exposed for testing / monkey-patching ---------------------
|
||||
|
||||
def _load_ort(self) -> Any:
|
||||
"""Return the imported ``onnxruntime`` module (lazy; Tier-2 dep)."""
|
||||
try:
|
||||
import onnxruntime # type: ignore[import-not-found]
|
||||
except ImportError as exc:
|
||||
raise EngineDeserializeError(
|
||||
"onnxruntime python binding not installed on this host "
|
||||
"(Tier-2 Jetson / JetPack 6.2 only)"
|
||||
) from exc
|
||||
return onnxruntime
|
||||
|
||||
def _resolve_host_tuple(self, precision: PrecisionMode) -> HostTuple:
|
||||
"""Provider-injected (tests) or :func:`read_host_tuple` (production)."""
|
||||
provider = self._host_tuple_provider
|
||||
if provider is not None:
|
||||
return provider(precision)
|
||||
return read_host_tuple(precision=precision)
|
||||
|
||||
def _resolve_manifest(
|
||||
self,
|
||||
) -> DeploymentManifest | ManifestReaderProtocol:
|
||||
"""Provider-injected (tests) or :class:`ManifestReader`."""
|
||||
provider = self._manifest_provider
|
||||
if provider is not None:
|
||||
return provider()
|
||||
manifest_path = (
|
||||
Path(self._c7_config.engine_cache_dir) / "manifest.json"
|
||||
)
|
||||
return ManifestReader(manifest_path).read()
|
||||
|
||||
def _trt_ep_options(self) -> dict[str, Any]:
|
||||
"""TRT EP provider option dict (AC-8 + Risk 3 — option-key fence)."""
|
||||
budget = int(self._c7_config.gpu_memory_budget_bytes)
|
||||
return {
|
||||
"trt_engine_cache_enable": True,
|
||||
"trt_engine_cache_path": str(self._c7_config.ort_trt_cache_dir),
|
||||
"trt_max_workspace_size": budget // 4,
|
||||
}
|
||||
|
||||
# -- Compile (no-op for ORT) -------------------------------------------
|
||||
|
||||
def compile_engine(
|
||||
self, model_path: Path, build_config: BuildConfig
|
||||
) -> EngineCacheEntry:
|
||||
"""No-op compile — the artifact is the ``.onnx`` file (AC-1 / scope).
|
||||
|
||||
ORT will lazy-compile a TRT subgraph in-session; the EP cache
|
||||
directory (``config.ort_trt_cache_dir``) holds those subgraph
|
||||
caches transparently. The returned :class:`EngineCacheEntry`
|
||||
carries the source ``.onnx`` path + its sha256 (AZ-280 trust
|
||||
chain reuses this when the file is later loaded).
|
||||
"""
|
||||
path = Path(model_path)
|
||||
if not path.exists():
|
||||
raise EngineDeserializeError(
|
||||
f"ONNX model not found at {path!s}"
|
||||
)
|
||||
sha_hex = _sha256_of_file(path)
|
||||
host_tuple = self._resolve_host_tuple(build_config.precision)
|
||||
return EngineCacheEntry(
|
||||
engine_path=path,
|
||||
sha256_hex=sha_hex,
|
||||
sm=host_tuple.sm,
|
||||
jp=host_tuple.jp,
|
||||
trt=host_tuple.trt,
|
||||
precision=build_config.precision,
|
||||
extras={"model_name": path.stem},
|
||||
)
|
||||
|
||||
# -- Deserialize -------------------------------------------------------
|
||||
|
||||
def deserialize_engine(self, entry: EngineCacheEntry) -> EngineHandle:
|
||||
"""Build an ORT ``InferenceSession`` with the TRT-EP-led provider list.
|
||||
|
||||
Decision flow:
|
||||
|
||||
- If ``entry.engine_path.suffix == ".engine"`` invoke
|
||||
:meth:`EngineGate.validate` first (AC-3) — refusal surfaces
|
||||
the gate's documented errors and no session is created.
|
||||
The engine binary is staged at ``ort_trt_cache_dir`` so
|
||||
ORT's TRT EP picks it up on session creation.
|
||||
- If ``entry.engine_path.suffix == ".onnx"`` skip the gate
|
||||
(AC-2) — ORT compiles the TRT subgraph in-session.
|
||||
- Build the session with provider list ``(TRT, CUDA, CPU)``,
|
||||
run a single warm-up ``session.run`` against zero-filled
|
||||
inputs (AC-2 — proves the session is functional before
|
||||
handing the handle back), and capture the actually-active
|
||||
provider via ``session.get_providers()[0]`` (AC-6).
|
||||
"""
|
||||
engine_path = Path(entry.engine_path)
|
||||
suffix = engine_path.suffix.lower()
|
||||
if suffix not in (ONNX_SUFFIX, ENGINE_SUFFIX):
|
||||
raise EngineDeserializeError(
|
||||
f"OnnxTrtEpRuntime.deserialize_engine: unsupported "
|
||||
f"engine_path suffix {suffix!r} (expected .onnx or .engine)"
|
||||
)
|
||||
|
||||
if suffix == ENGINE_SUFFIX:
|
||||
host_tuple = self._resolve_host_tuple(entry.precision)
|
||||
manifest = self._resolve_manifest()
|
||||
self._engine_gate.validate(entry, host_tuple, manifest)
|
||||
self._stage_engine_for_ort(engine_path)
|
||||
ort_session_path = engine_path
|
||||
else:
|
||||
ort_session_path = engine_path
|
||||
|
||||
ort = self._load_ort()
|
||||
provider_list, provider_options = self._build_provider_args()
|
||||
|
||||
try:
|
||||
session = ort.InferenceSession(
|
||||
str(ort_session_path),
|
||||
providers=list(provider_list),
|
||||
provider_options=list(provider_options),
|
||||
)
|
||||
except Exception as exc:
|
||||
raise EngineDeserializeError(
|
||||
f"ORT InferenceSession creation failed for "
|
||||
f"{engine_path.name!r}: {type(exc).__name__}: {exc}"
|
||||
) from exc
|
||||
|
||||
active_providers = tuple(session.get_providers())
|
||||
if not active_providers:
|
||||
raise EngineDeserializeError(
|
||||
f"ORT session for {engine_path.name!r} reports an empty "
|
||||
"provider list — the EP shim is broken"
|
||||
)
|
||||
active_provider = active_providers[0]
|
||||
if active_provider == CPU_EP:
|
||||
self._handle_cpu_fallback(engine_path, provider_list)
|
||||
|
||||
input_names = tuple(
|
||||
i.name for i in session.get_inputs()
|
||||
)
|
||||
output_names = tuple(
|
||||
o.name for o in session.get_outputs()
|
||||
)
|
||||
|
||||
try:
|
||||
self._warm_up_session(session, input_names)
|
||||
except OutOfMemoryError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise EngineDeserializeError(
|
||||
f"ORT warm-up run failed for {engine_path.name!r}: "
|
||||
f"{type(exc).__name__}: {exc}"
|
||||
) from exc
|
||||
|
||||
self._logger.info(
|
||||
"deserialize_engine: ORT session for %s active_provider=%s "
|
||||
"(requested: %s) inputs=%d outputs=%d",
|
||||
engine_path.name,
|
||||
active_provider,
|
||||
",".join(provider_list),
|
||||
len(input_names),
|
||||
len(output_names),
|
||||
)
|
||||
return OnnxTrtEpEngineHandle(
|
||||
session=session,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
active_provider=active_provider,
|
||||
model_name=engine_path.stem,
|
||||
)
|
||||
|
||||
def _build_provider_args(
|
||||
self,
|
||||
) -> tuple[tuple[str, ...], tuple[dict[str, Any], ...]]:
|
||||
"""Pair each provider with its options dict (TRT carries the cache)."""
|
||||
opts: list[dict[str, Any]] = []
|
||||
trt_opts = self._trt_ep_options()
|
||||
Path(self._c7_config.ort_trt_cache_dir).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
for ep in self._provider_list:
|
||||
if ep == TRT_EP:
|
||||
opts.append(dict(trt_opts))
|
||||
else:
|
||||
opts.append({})
|
||||
return self._provider_list, tuple(opts)
|
||||
|
||||
def _stage_engine_for_ort(self, engine_path: Path) -> None:
|
||||
"""Copy/link the .engine into ``ort_trt_cache_dir`` (AC-3 follow-up).
|
||||
|
||||
ORT's TRT EP keys its cache by ORT-internal subgraph hashes, so a
|
||||
``.engine`` produced by AZ-298 ``TensorrtRuntime`` is NOT directly
|
||||
reusable as an ORT-EP cache entry — but staging it under the EP
|
||||
cache directory at least gives ORT's TRT EP a place to look for a
|
||||
prior subgraph if one matches. We do best-effort: create the dir,
|
||||
symlink the engine in if not already present; on any OS error,
|
||||
log a warning and continue (ORT will lazy-compile).
|
||||
"""
|
||||
cache_dir = Path(self._c7_config.ort_trt_cache_dir)
|
||||
try:
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as exc:
|
||||
self._logger.warning(
|
||||
"deserialize_engine: cannot create ORT TRT cache dir at "
|
||||
"%s: %s — ORT will lazy-compile",
|
||||
cache_dir,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
staged = cache_dir / engine_path.name
|
||||
if staged.exists():
|
||||
return
|
||||
try:
|
||||
os.symlink(engine_path, staged)
|
||||
except OSError as exc:
|
||||
self._logger.warning(
|
||||
"deserialize_engine: cannot symlink engine %s into ORT "
|
||||
"TRT cache dir: %s — falling back to copy",
|
||||
engine_path.name,
|
||||
exc,
|
||||
)
|
||||
try:
|
||||
staged.write_bytes(engine_path.read_bytes())
|
||||
except OSError as copy_exc:
|
||||
self._logger.warning(
|
||||
"deserialize_engine: copy of engine %s to ORT TRT "
|
||||
"cache dir also failed: %s — ORT will lazy-compile",
|
||||
engine_path.name,
|
||||
copy_exc,
|
||||
)
|
||||
|
||||
def _warm_up_session(
|
||||
self, session: Any, input_names: tuple[str, ...]
|
||||
) -> None:
|
||||
"""Run one zero-input forward to surface any session-init faults early."""
|
||||
if not input_names:
|
||||
return
|
||||
feed: dict[str, np.ndarray] = {}
|
||||
for inp in session.get_inputs():
|
||||
shape = tuple(
|
||||
int(d) if isinstance(d, int) and d > 0 else 1
|
||||
for d in (inp.shape or ())
|
||||
)
|
||||
if not shape:
|
||||
shape = (1,)
|
||||
np_dtype = _ort_dtype_to_numpy(inp.type)
|
||||
feed[inp.name] = np.zeros(shape, dtype=np_dtype)
|
||||
try:
|
||||
session.run(None, feed)
|
||||
except MemoryError as exc:
|
||||
raise OutOfMemoryError(
|
||||
f"ORT warm-up raised MemoryError: {exc}"
|
||||
) from exc
|
||||
|
||||
def _handle_cpu_fallback(
|
||||
self, engine_path: Path, requested: tuple[str, ...]
|
||||
) -> None:
|
||||
"""Log/record/refuse on CPU-only fallback (Risk 2 mitigation)."""
|
||||
msg = (
|
||||
f"OnnxTrtEpRuntime: ORT collapsed to CPUExecutionProvider for "
|
||||
f"{engine_path.name!r}; latency budget will NOT be met"
|
||||
)
|
||||
self._logger.warning(msg)
|
||||
if self._fdr_client is not None:
|
||||
try:
|
||||
self._fdr_client.enqueue(
|
||||
FdrRecord(
|
||||
schema_version=CURRENT_SCHEMA_VERSION,
|
||||
ts=_iso_ts_now(),
|
||||
producer_id=_PRODUCER_ID,
|
||||
kind=_CPU_FALLBACK_KIND,
|
||||
payload={
|
||||
"model_name": engine_path.stem,
|
||||
"requested_providers": list(requested),
|
||||
"active_provider": CPU_EP,
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as exc:
|
||||
self._logger.warning(
|
||||
"OnnxTrtEpRuntime: FDR enqueue for cpu_fallback failed: %s",
|
||||
exc,
|
||||
)
|
||||
if self._c7_config.ort_disallow_cpu_fallback:
|
||||
raise EngineDeserializeError(msg)
|
||||
|
||||
# -- Infer -------------------------------------------------------------
|
||||
|
||||
def infer(
|
||||
self,
|
||||
handle: EngineHandle,
|
||||
inputs: dict[str, np.ndarray],
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Sync ``session.run`` (AC-4 / I-8) + one-shot fallback alert (AC-5)."""
|
||||
if not isinstance(handle, OnnxTrtEpEngineHandle):
|
||||
raise InferenceError(
|
||||
f"infer() received foreign handle type "
|
||||
f"{type(handle).__name__}; OnnxTrtEpRuntime only accepts "
|
||||
"OnnxTrtEpEngineHandle"
|
||||
)
|
||||
if handle._released:
|
||||
raise InferenceError(
|
||||
"infer() called on released handle "
|
||||
f"({handle._model_name!r})"
|
||||
)
|
||||
self._maybe_emit_fallback_alert(handle)
|
||||
for name in handle._input_names:
|
||||
if name not in inputs:
|
||||
raise InferenceError(
|
||||
f"infer({handle._model_name!r}) missing input "
|
||||
f"binding {name!r}"
|
||||
)
|
||||
feed = {
|
||||
name: np.ascontiguousarray(inputs[name])
|
||||
for name in handle._input_names
|
||||
}
|
||||
t0_ns = (
|
||||
self._clock.monotonic_ns()
|
||||
if self._c7_config.per_frame_debug_log
|
||||
else None
|
||||
)
|
||||
try:
|
||||
outputs_list = handle._session.run(
|
||||
list(handle._output_names), feed
|
||||
)
|
||||
except MemoryError as exc:
|
||||
raise OutOfMemoryError(
|
||||
f"ORT session.run raised MemoryError for "
|
||||
f"{handle._model_name!r}: {exc}"
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
raise InferenceError(
|
||||
f"infer({handle._model_name!r}) raised "
|
||||
f"{type(exc).__name__}: {exc}"
|
||||
) from exc
|
||||
if len(outputs_list) != len(handle._output_names):
|
||||
raise InferenceError(
|
||||
f"infer({handle._model_name!r}) returned "
|
||||
f"{len(outputs_list)} outputs but the session declared "
|
||||
f"{len(handle._output_names)}"
|
||||
)
|
||||
result: dict[str, np.ndarray] = {
|
||||
name: np.asarray(value)
|
||||
for name, value in zip(handle._output_names, outputs_list)
|
||||
}
|
||||
if t0_ns is not None:
|
||||
dt_ms = (self._clock.monotonic_ns() - t0_ns) / 1_000_000
|
||||
self._logger.debug(
|
||||
"infer: %s took %.1f ms (provider=%s)",
|
||||
handle._model_name,
|
||||
dt_ms,
|
||||
handle._active_provider,
|
||||
)
|
||||
return result
|
||||
|
||||
def _maybe_emit_fallback_alert(
|
||||
self, handle: OnnxTrtEpEngineHandle
|
||||
) -> None:
|
||||
"""One-shot WARN + FDR + GCS callback on first fallback infer (AC-5)."""
|
||||
if not self._is_fallback or self._fallback_warned:
|
||||
return
|
||||
self._fallback_warned = True
|
||||
msg = (
|
||||
f"OnnxTrtEpRuntime: serving {handle._model_name!r} as a "
|
||||
"fallback (degraded latency); operator action recommended"
|
||||
)
|
||||
self._logger.warning(msg)
|
||||
gcs_alert = self._gcs_alert
|
||||
if gcs_alert is not None:
|
||||
try:
|
||||
gcs_alert(msg)
|
||||
except Exception as exc:
|
||||
self._logger.warning(
|
||||
"OnnxTrtEpRuntime: gcs_alert callback raised %s",
|
||||
exc,
|
||||
)
|
||||
if self._fdr_client is not None:
|
||||
try:
|
||||
self._fdr_client.enqueue(
|
||||
FdrRecord(
|
||||
schema_version=CURRENT_SCHEMA_VERSION,
|
||||
ts=_iso_ts_now(),
|
||||
producer_id=_PRODUCER_ID,
|
||||
kind=_FALLBACK_KIND,
|
||||
payload={
|
||||
"model_name": handle._model_name,
|
||||
"reason": "composition_root_explicit",
|
||||
"active_provider": handle._active_provider,
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as exc:
|
||||
self._logger.warning(
|
||||
"OnnxTrtEpRuntime: FDR enqueue for fallback alert "
|
||||
"failed: %s",
|
||||
exc,
|
||||
)
|
||||
|
||||
# -- Release -----------------------------------------------------------
|
||||
|
||||
def release_engine(self, handle: EngineHandle) -> None:
|
||||
"""Idempotent session drop (AC-7 / I-7)."""
|
||||
if not isinstance(handle, OnnxTrtEpEngineHandle):
|
||||
return
|
||||
if handle._released:
|
||||
return
|
||||
handle._released = True
|
||||
session = handle._session
|
||||
end_profiling = getattr(session, "end_profiling", None)
|
||||
if callable(end_profiling):
|
||||
try:
|
||||
end_profiling()
|
||||
except Exception as exc:
|
||||
self._logger.warning(
|
||||
"release_engine: %s.end_profiling() raised %s",
|
||||
type(session).__name__,
|
||||
exc,
|
||||
)
|
||||
handle._session = None
|
||||
|
||||
# -- Thermal / label ---------------------------------------------------
|
||||
|
||||
def thermal_state(self) -> ThermalState:
|
||||
"""Delegate to the injected AZ-302 publisher; default-safe (I-6)."""
|
||||
publisher = self._thermal_publisher
|
||||
if publisher is None:
|
||||
return ThermalState(
|
||||
cpu_temp_c=None,
|
||||
gpu_temp_c=None,
|
||||
thermal_throttle_active=False,
|
||||
measured_clock_mhz=None,
|
||||
measured_at_ns=self._clock.monotonic_ns(),
|
||||
is_telemetry_available=False,
|
||||
)
|
||||
return publisher.read()
|
||||
|
||||
def current_runtime_label(self) -> Literal["onnx_trt_ep"]:
|
||||
return _RUNTIME_LABEL
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# ORT tensor-type → numpy dtype shim (avoids a hard import of ort.numpy_helper).
|
||||
|
||||
|
||||
_ORT_DTYPE_MAP: Final[dict[str, type]] = {
|
||||
"tensor(float16)": np.float16,
|
||||
"tensor(float)": np.float32,
|
||||
"tensor(double)": np.float64,
|
||||
"tensor(int8)": np.int8,
|
||||
"tensor(int16)": np.int16,
|
||||
"tensor(int32)": np.int32,
|
||||
"tensor(int64)": np.int64,
|
||||
"tensor(uint8)": np.uint8,
|
||||
"tensor(uint16)": np.uint16,
|
||||
"tensor(uint32)": np.uint32,
|
||||
"tensor(uint64)": np.uint64,
|
||||
"tensor(bool)": np.bool_,
|
||||
}
|
||||
|
||||
|
||||
def _ort_dtype_to_numpy(ort_type: str) -> Any:
|
||||
"""Map an ORT tensor-type string to a numpy dtype; default to float32."""
|
||||
return _ORT_DTYPE_MAP.get(ort_type, np.float32)
|
||||
Reference in New Issue
Block a user