mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-06-22 20:41:12 +00:00
[AZ-300] Implement PytorchFp16Runtime — C7 simple-baseline strategy
AZ-300 mandatory simple-baseline InferenceRuntime (eager FP16 PyTorch).
Implements the AZ-297 Protocol; current_runtime_label returns
"pytorch_fp16". Numerical reference every fancier C7 strategy (AZ-298
TRT, AZ-299 ORT) is measured against, and the only viable runtime for
Tier-1 workstation Docker where TRT is non-trivial to install.
Production code (new):
- components/c7_inference/pytorch_fp16_runtime.py — runtime +
PytorchEngineHandle + output-shape adapter
- components/c7_inference/architecture_registry.py — torch-free
register_architecture / default_registry / ArchitectureFactory
(Risk-1 mitigation: no L2->L3 back-edge from C7 into per-backbone
code)
- components/c7_inference/__init__.py — re-exports the registry
mechanism. Still does NOT import the concrete strategy module
(Invariant I-5)
- components/c7_inference/config.py — adds per_frame_debug_log bool
field (gates the DEBUG per-frame latency log)
Tests (new): tests/unit/c7_inference/test_pytorch_fp16_runtime.py
covers AC-1..AC-8 + NFRs. AC-1/2/6/7 + thermal/release/registry
guards run unconditionally (17 tests); AC-3/4/5/8 +
NFR-perf-deserialize + NFR-reliability-eval-mode require CUDA and
skip on Tier-1 CI / macOS dev.
Tests (modified):
- test_protocol_conformance.py — narrowed
test_ac5_build_inference_runtime_flag_on_but_module_missing
parametrisation to exclude pytorch_fp16 (now-built); TRT / ORT
still covered until AZ-298 / AZ-299 ship.
CI: .github/workflows/ci.yml lint + unit jobs now install
'-e .[dev,inference]' because mypy + pytest need torch + torchvision +
onnxruntime on the runner.
Three task-spec -> as-built deltas documented in
_docs/02_tasks/done/AZ-300_c7_pytorch_baseline.md Implementation Notes:
1. Constructor conforms to AZ-297 factory shape (config positional;
thermal_publisher + registry + clock keyword-only optionals).
AZ-302 will update the factory to thread thermal_publisher.
2. Architecture registry uses extras["model_name"] as lookup key
(avoids touching the frozen BuildConfig / EngineCacheEntry DTOs).
3. Warm-up forward deferred to AZ-300 tier-2 follow-up — the zero-arg
registry has no per-backbone input-shape metadata.
Suite: 1120 passed / 10 skipped (CUDA + Tier-2 + cmake / actionlint
environment gates). No regressions in non-c7_inference areas.
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -28,6 +28,11 @@ from gps_denied_onboard._types.inference import (
|
||||
PrecisionMode,
|
||||
)
|
||||
from gps_denied_onboard._types.thermal import ThermalState
|
||||
from gps_denied_onboard.components.c7_inference.architecture_registry import (
|
||||
ArchitectureFactory,
|
||||
default_registry,
|
||||
register_architecture,
|
||||
)
|
||||
from gps_denied_onboard.components.c7_inference.config import C7InferenceConfig
|
||||
from gps_denied_onboard.components.c7_inference.errors import (
|
||||
CalibrationCacheError,
|
||||
@@ -47,6 +52,7 @@ from gps_denied_onboard.config.schema import register_component_block
|
||||
register_component_block("c7_inference", C7InferenceConfig)
|
||||
|
||||
__all__ = [
|
||||
"ArchitectureFactory",
|
||||
"BuildConfig",
|
||||
"C7InferenceConfig",
|
||||
"CalibrationCacheError",
|
||||
@@ -65,4 +71,6 @@ __all__ = [
|
||||
"RuntimeError",
|
||||
"TelemetryUnavailableError",
|
||||
"ThermalState",
|
||||
"default_registry",
|
||||
"register_architecture",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
"""Per-backbone architecture registry for ``PytorchFp16Runtime`` (AZ-300).
|
||||
|
||||
The PyTorch baseline loads checkpoints saved as ``state_dict`` (the
|
||||
``weights_only=True`` security path); ``model.load_state_dict`` requires
|
||||
a pre-instantiated ``nn.Module``. This module owns the small
|
||||
``model_name -> factory()`` indirection that lets the composition root
|
||||
register each backbone's architecture class at startup without C7
|
||||
importing component code directly (Risk-1 in the AZ-300 spec — would
|
||||
violate the L2→L3 module-layout layering otherwise).
|
||||
|
||||
The registry is a plain ``dict``; ``register_architecture`` is
|
||||
idempotent for the same factory and rejects re-registration with a
|
||||
different one. The module does NOT import :mod:`torch` — keeping it
|
||||
torch-free lets the composition root populate the registry even on
|
||||
non-PyTorch tiers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch import nn
|
||||
|
||||
__all__ = [
|
||||
"ArchitectureFactory",
|
||||
"default_registry",
|
||||
"register_architecture",
|
||||
]
|
||||
|
||||
ArchitectureFactory = Callable[[], "nn.Module"]
|
||||
|
||||
_DEFAULT_REGISTRY: dict[str, ArchitectureFactory] = {}
|
||||
|
||||
|
||||
def register_architecture(
|
||||
model_name: str, factory: ArchitectureFactory
|
||||
) -> None:
|
||||
"""Register a model-architecture factory under ``model_name``.
|
||||
|
||||
Idempotent for the same ``factory`` object; raises ``ValueError`` if
|
||||
a *different* factory is registered under the same name (catches the
|
||||
common "two backbones colliding on the same stem" bug at composition
|
||||
time rather than at first ``deserialize_engine`` call).
|
||||
"""
|
||||
if not model_name:
|
||||
raise ValueError("model_name must be non-empty")
|
||||
existing = _DEFAULT_REGISTRY.get(model_name)
|
||||
if existing is not None and existing is not factory:
|
||||
raise ValueError(
|
||||
f"Architecture {model_name!r} already registered with a "
|
||||
f"different factory; refusing to override"
|
||||
)
|
||||
_DEFAULT_REGISTRY[model_name] = factory
|
||||
|
||||
|
||||
def default_registry() -> dict[str, ArchitectureFactory]:
|
||||
"""Return the module-level singleton registry.
|
||||
|
||||
Returned by reference — callers see live registrations. Tests that
|
||||
need an isolated registry SHOULD construct their own ``dict`` and
|
||||
pass it explicitly to :class:`PytorchFp16Runtime`.
|
||||
"""
|
||||
return _DEFAULT_REGISTRY
|
||||
@@ -14,8 +14,8 @@ from typing import Final
|
||||
from gps_denied_onboard.config.schema import ConfigError
|
||||
|
||||
__all__ = [
|
||||
"C7InferenceConfig",
|
||||
"KNOWN_RUNTIMES",
|
||||
"C7InferenceConfig",
|
||||
]
|
||||
|
||||
KNOWN_RUNTIMES: Final[frozenset[str]] = frozenset(
|
||||
@@ -45,6 +45,7 @@ class C7InferenceConfig:
|
||||
runtime: str = "pytorch_fp16"
|
||||
thermal_poll_hz: float = 1.0
|
||||
engine_cache_dir: str = "/var/lib/gps-denied/engines"
|
||||
per_frame_debug_log: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.runtime not in KNOWN_RUNTIMES:
|
||||
|
||||
@@ -0,0 +1,339 @@
|
||||
"""``PytorchFp16Runtime`` — mandatory simple-baseline C7 strategy (AZ-300).
|
||||
|
||||
Conforms to :class:`InferenceRuntime` (AZ-297). Loads each backbone's
|
||||
canonical PyTorch checkpoint (``.pt`` / ``.pth``), runs eager
|
||||
``.half().cuda()`` forward, no ``torch.compile`` / JIT / autocast — the
|
||||
"simple baseline" every fancier strategy (AZ-298 TRT, AZ-299 ORT) is
|
||||
numerically measured against per the ENG-RULE.
|
||||
|
||||
Composition contract (factory shape, AZ-297 ``build_inference_runtime``)::
|
||||
|
||||
strategy_cls(config: Config)
|
||||
|
||||
This module conforms by accepting ``config`` positionally; every other
|
||||
collaborator — the per-backbone architecture registry, the AZ-302
|
||||
``ThermalStatePublisher`` reference, the injectable :class:`Clock` —
|
||||
is a keyword-only optional with safe defaults. AZ-302 will update the
|
||||
factory to thread ``thermal_publisher`` through; until then,
|
||||
:meth:`thermal_state` returns a default-safe ``ThermalState`` per
|
||||
Invariant I-6.
|
||||
|
||||
AC mapping (see ``_docs/02_tasks/todo/AZ-300_c7_pytorch_baseline.md``):
|
||||
|
||||
- AC-1 : :meth:`current_runtime_label` returns ``"pytorch_fp16"``;
|
||||
protocol conformance via ``runtime_checkable``.
|
||||
- AC-2 : :meth:`compile_engine` is a no-op — no ``.engine`` produced;
|
||||
returned :class:`EngineCacheEntry` carries the checkpoint path.
|
||||
- AC-3 : :meth:`deserialize_engine` runs ``torch.load`` →
|
||||
``load_state_dict`` → ``.half().cuda().eval()``.
|
||||
- AC-4 : :meth:`infer` runs synchronous FP16 forward; outputs are
|
||||
host-resident :class:`numpy.ndarray`.
|
||||
- AC-5 : :meth:`release_engine` drops references + calls
|
||||
``torch.cuda.empty_cache``.
|
||||
- AC-6 / AC-7 : missing checkpoint or mismatched state_dict rewrap to
|
||||
:class:`EngineDeserializeError` with ``__cause__`` preserved.
|
||||
- AC-8 : ``torch.cuda.OutOfMemoryError`` rewraps to the C7-local
|
||||
:class:`OutOfMemoryError`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Final, Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from gps_denied_onboard._types.inference import (
|
||||
BuildConfig,
|
||||
EngineCacheEntry,
|
||||
EngineHandle,
|
||||
PrecisionMode,
|
||||
)
|
||||
from gps_denied_onboard._types.thermal import ThermalState
|
||||
from gps_denied_onboard.clock.wall_clock import WallClock
|
||||
from gps_denied_onboard.components.c7_inference.architecture_registry import (
|
||||
ArchitectureFactory,
|
||||
default_registry,
|
||||
)
|
||||
from gps_denied_onboard.components.c7_inference.errors import (
|
||||
EngineDeserializeError,
|
||||
InferenceError,
|
||||
OutOfMemoryError,
|
||||
)
|
||||
from gps_denied_onboard.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gps_denied_onboard.clock import Clock
|
||||
from gps_denied_onboard.config.schema import Config
|
||||
|
||||
__all__ = ["PytorchEngineHandle", "PytorchFp16Runtime"]
|
||||
|
||||
_RUNTIME_LABEL: Final[Literal["pytorch_fp16"]] = "pytorch_fp16"
|
||||
_SHA256_CHUNK: Final[int] = 1 << 20 # 1 MiB
|
||||
_MODEL_NAME_KEY: Final[str] = "model_name"
|
||||
|
||||
|
||||
class PytorchEngineHandle(EngineHandle):
|
||||
"""Opaque handle wrapping a deserialised + GPU-resident PyTorch model.
|
||||
|
||||
Fields are private to :class:`PytorchFp16Runtime`; per Invariant I-4
|
||||
consumers MUST NOT introspect. The handle is reusable across many
|
||||
:meth:`PytorchFp16Runtime.infer` calls until
|
||||
:meth:`PytorchFp16Runtime.release_engine`.
|
||||
"""
|
||||
|
||||
__slots__ = ("_model", "_model_name", "_released")
|
||||
|
||||
def __init__(self, model: nn.Module, model_name: str) -> None:
|
||||
self._model = model
|
||||
self._model_name = model_name
|
||||
self._released = False
|
||||
|
||||
|
||||
class PytorchFp16Runtime:
|
||||
"""Eager FP16 PyTorch :class:`InferenceRuntime` — simple baseline."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
*,
|
||||
thermal_publisher: Any | None = None,
|
||||
architecture_registry: Mapping[str, ArchitectureFactory] | None = None,
|
||||
clock: Clock | None = None,
|
||||
) -> None:
|
||||
self._config = config
|
||||
self._c7_config = config.components["c7_inference"]
|
||||
self._thermal_publisher = thermal_publisher
|
||||
self._registry: Mapping[str, ArchitectureFactory] = (
|
||||
architecture_registry
|
||||
if architecture_registry is not None
|
||||
else default_registry()
|
||||
)
|
||||
self._clock = clock if clock is not None else WallClock()
|
||||
self._logger = get_logger("c7_inference.pytorch_fp16")
|
||||
|
||||
def compile_engine(
|
||||
self, model_path: Path, build_config: BuildConfig
|
||||
) -> EngineCacheEntry:
|
||||
"""No-op compile — PyTorch checkpoints are the artifact (AC-2).
|
||||
|
||||
Returns an :class:`EngineCacheEntry` whose ``engine_path`` is
|
||||
the source ``.pt`` path. The five-tuple ``(sm, jp, trt, precision)``
|
||||
is ``(None, None, None, FP16)`` — PyTorch is hardware-portable
|
||||
across SM levels. ``extras["model_name"]`` is the checkpoint's
|
||||
file stem; :meth:`deserialize_engine` reads it to look up the
|
||||
architecture factory in the registry.
|
||||
"""
|
||||
path = Path(model_path)
|
||||
if not path.exists():
|
||||
raise EngineDeserializeError(
|
||||
f"PyTorch checkpoint not found at {path!s}"
|
||||
)
|
||||
sha256_hex = _sha256_of_file(path)
|
||||
return EngineCacheEntry(
|
||||
engine_path=path,
|
||||
sha256_hex=sha256_hex,
|
||||
sm=None,
|
||||
jp=None,
|
||||
trt=None,
|
||||
precision=PrecisionMode.FP16,
|
||||
extras={_MODEL_NAME_KEY: path.stem},
|
||||
)
|
||||
|
||||
def deserialize_engine(self, entry: EngineCacheEntry) -> EngineHandle:
|
||||
"""Load the checkpoint and produce a GPU-resident handle (AC-3)."""
|
||||
path = Path(entry.engine_path)
|
||||
if not path.exists():
|
||||
raise EngineDeserializeError(
|
||||
f"PyTorch checkpoint not found at {path!s}"
|
||||
)
|
||||
model_name = entry.extras.get(_MODEL_NAME_KEY) or path.stem
|
||||
factory = self._registry.get(model_name)
|
||||
if factory is None:
|
||||
raise EngineDeserializeError(
|
||||
f"No architecture registered for model_name={model_name!r}; "
|
||||
"the composition root must call "
|
||||
"`c7_inference.register_architecture(name, factory)` before "
|
||||
"the runtime is constructed."
|
||||
)
|
||||
try:
|
||||
model = factory()
|
||||
except Exception as exc:
|
||||
raise EngineDeserializeError(
|
||||
f"Architecture factory for {model_name!r} raised "
|
||||
f"{type(exc).__name__}"
|
||||
) from exc
|
||||
try:
|
||||
state_dict = torch.load(
|
||||
path, map_location="cpu", weights_only=True
|
||||
)
|
||||
except FileNotFoundError as exc:
|
||||
raise EngineDeserializeError(
|
||||
f"PyTorch checkpoint not found at {path!s}"
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
raise EngineDeserializeError(
|
||||
f"torch.load failed for checkpoint {path!s}: "
|
||||
f"{type(exc).__name__}"
|
||||
) from exc
|
||||
try:
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
except RuntimeError as exc:
|
||||
raise EngineDeserializeError(
|
||||
f"load_state_dict(strict=True) rejected checkpoint "
|
||||
f"{path!s} for architecture {model_name!r}"
|
||||
) from exc
|
||||
try:
|
||||
model = model.half().cuda().eval()
|
||||
except torch.cuda.OutOfMemoryError as exc:
|
||||
raise OutOfMemoryError(
|
||||
f"CUDA OOM while moving {model_name!r} to GPU"
|
||||
) from exc
|
||||
except RuntimeError as exc:
|
||||
raise EngineDeserializeError(
|
||||
f"GPU move / half-cast failed for {model_name!r}"
|
||||
) from exc
|
||||
n_params = sum(p.numel() for p in model.parameters())
|
||||
gpu_bytes = sum(
|
||||
p.numel() * p.element_size() for p in model.parameters()
|
||||
)
|
||||
self._logger.info(
|
||||
"pytorch_fp16.deserialize_engine: loaded %s (%d params, "
|
||||
"~%.1f MB GPU, sha256=%s)",
|
||||
model_name,
|
||||
n_params,
|
||||
gpu_bytes / (1024 * 1024),
|
||||
entry.sha256_hex[:12],
|
||||
)
|
||||
return PytorchEngineHandle(model, model_name)
|
||||
|
||||
def infer(
|
||||
self,
|
||||
handle: EngineHandle,
|
||||
inputs: dict[str, np.ndarray],
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Run a sync FP16 forward; return numpy outputs (AC-4)."""
|
||||
if not isinstance(handle, PytorchEngineHandle):
|
||||
raise InferenceError(
|
||||
f"infer() received foreign handle type {type(handle).__name__}; "
|
||||
"PytorchFp16Runtime only accepts PytorchEngineHandle"
|
||||
)
|
||||
if handle._released:
|
||||
raise InferenceError(
|
||||
"infer() called on released handle "
|
||||
f"({handle._model_name!r})"
|
||||
)
|
||||
model = handle._model
|
||||
if self._c7_config.per_frame_debug_log:
|
||||
t0_ns = self._clock.monotonic_ns()
|
||||
else:
|
||||
t0_ns = None
|
||||
try:
|
||||
torch_inputs = {
|
||||
name: torch.from_numpy(np.ascontiguousarray(arr))
|
||||
.half()
|
||||
.cuda(non_blocking=False)
|
||||
for name, arr in inputs.items()
|
||||
}
|
||||
with torch.no_grad(), torch.inference_mode():
|
||||
raw = model(**torch_inputs)
|
||||
outputs = _to_numpy_dict(raw)
|
||||
except torch.cuda.OutOfMemoryError as exc:
|
||||
raise OutOfMemoryError(
|
||||
f"CUDA OOM during infer({handle._model_name!r})"
|
||||
) from exc
|
||||
except InferenceError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise InferenceError(
|
||||
f"forward pass for {handle._model_name!r} raised "
|
||||
f"{type(exc).__name__}"
|
||||
) from exc
|
||||
if t0_ns is not None:
|
||||
dt_us = (self._clock.monotonic_ns() - t0_ns) / 1_000
|
||||
self._logger.debug(
|
||||
"pytorch_fp16.infer: %s took %.1f us",
|
||||
handle._model_name,
|
||||
dt_us,
|
||||
)
|
||||
return outputs
|
||||
|
||||
def release_engine(self, handle: EngineHandle) -> None:
|
||||
"""Drop references + ``torch.cuda.empty_cache()`` (AC-5, I-7)."""
|
||||
if not isinstance(handle, PytorchEngineHandle):
|
||||
return
|
||||
if handle._released:
|
||||
return
|
||||
handle._released = True
|
||||
handle._model = None # type: ignore[assignment]
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def thermal_state(self) -> ThermalState:
|
||||
"""Delegate to the injected publisher, else default-safe (I-6)."""
|
||||
publisher = self._thermal_publisher
|
||||
if publisher is None:
|
||||
return ThermalState(
|
||||
cpu_temp_c=None,
|
||||
gpu_temp_c=None,
|
||||
thermal_throttle_active=False,
|
||||
measured_clock_mhz=None,
|
||||
measured_at_ns=self._clock.monotonic_ns(),
|
||||
is_telemetry_available=False,
|
||||
)
|
||||
return publisher.read()
|
||||
|
||||
def current_runtime_label(self) -> Literal["pytorch_fp16"]:
|
||||
return _RUNTIME_LABEL
|
||||
|
||||
|
||||
def _sha256_of_file(path: Path) -> str:
|
||||
"""Stream-hash ``path`` so we never load multi-GB checkpoints into RAM."""
|
||||
hasher = hashlib.sha256()
|
||||
with path.open("rb") as fh:
|
||||
while True:
|
||||
chunk = fh.read(_SHA256_CHUNK)
|
||||
if not chunk:
|
||||
break
|
||||
hasher.update(chunk)
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
def _to_numpy_dict(raw: object) -> dict[str, np.ndarray]:
|
||||
"""Convert a forward-pass return value into the Protocol's output shape.
|
||||
|
||||
Supports four common backbone-return shapes:
|
||||
|
||||
- ``dict[str, Tensor]`` — used straight.
|
||||
- ``Tensor`` — wrapped as ``{"output": tensor}``.
|
||||
- ``tuple`` / ``list`` of tensors — keyed ``output_0``, ``output_1``, …
|
||||
- anything else → :class:`InferenceError` (caller rewraps).
|
||||
|
||||
Each tensor is moved to CPU (sync barrier per I-8) and converted to
|
||||
a contiguous :class:`numpy.ndarray`. FP16 tensors stay FP16 in numpy
|
||||
(``np.float16``); downstream consumers do the cast if they need FP32.
|
||||
"""
|
||||
if isinstance(raw, dict):
|
||||
return {name: _tensor_to_numpy(t) for name, t in raw.items()}
|
||||
if isinstance(raw, torch.Tensor):
|
||||
return {"output": _tensor_to_numpy(raw)}
|
||||
if isinstance(raw, (tuple, list)):
|
||||
return {
|
||||
f"output_{idx}": _tensor_to_numpy(t) for idx, t in enumerate(raw)
|
||||
}
|
||||
raise InferenceError(
|
||||
f"forward pass returned unsupported type {type(raw).__name__}; "
|
||||
"expected dict / Tensor / tuple / list"
|
||||
)
|
||||
|
||||
|
||||
def _tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
|
||||
if not isinstance(tensor, torch.Tensor):
|
||||
raise InferenceError(
|
||||
f"output container held non-Tensor element {type(tensor).__name__}"
|
||||
)
|
||||
return tensor.detach().cpu().numpy()
|
||||
Reference in New Issue
Block a user