[AZ-300] Implement PytorchFp16Runtime — C7 simple-baseline strategy

AZ-300 mandatory simple-baseline InferenceRuntime (eager FP16 PyTorch).
Implements the AZ-297 Protocol; current_runtime_label returns
"pytorch_fp16". Numerical reference every fancier C7 strategy (AZ-298
TRT, AZ-299 ORT) is measured against, and the only viable runtime for
Tier-1 workstation Docker where TRT is non-trivial to install.

Production code (new):
 - components/c7_inference/pytorch_fp16_runtime.py — runtime +
   PytorchEngineHandle + output-shape adapter
 - components/c7_inference/architecture_registry.py — torch-free
   register_architecture / default_registry / ArchitectureFactory
   (Risk-1 mitigation: no L2->L3 back-edge from C7 into per-backbone
   code)
 - components/c7_inference/__init__.py — re-exports the registry
   mechanism. Still does NOT import the concrete strategy module
   (Invariant I-5)
 - components/c7_inference/config.py — adds per_frame_debug_log bool
   field (gates the DEBUG per-frame latency log)

Tests (new): tests/unit/c7_inference/test_pytorch_fp16_runtime.py
covers AC-1..AC-8 + NFRs. AC-1/2/6/7 + thermal/release/registry
guards run unconditionally (17 tests); AC-3/4/5/8 +
NFR-perf-deserialize + NFR-reliability-eval-mode require CUDA and
skip on Tier-1 CI / macOS dev.

Tests (modified):
 - test_protocol_conformance.py — narrowed
   test_ac5_build_inference_runtime_flag_on_but_module_missing
   parametrisation to exclude pytorch_fp16 (now-built); TRT / ORT
   still covered until AZ-298 / AZ-299 ship.

CI: .github/workflows/ci.yml lint + unit jobs now install
'-e .[dev,inference]' because mypy + pytest need torch + torchvision +
onnxruntime on the runner.

Three task-spec -> as-built deltas documented in
_docs/02_tasks/done/AZ-300_c7_pytorch_baseline.md Implementation Notes:
 1. Constructor conforms to AZ-297 factory shape (config positional;
    thermal_publisher + registry + clock keyword-only optionals).
    AZ-302 will update the factory to thread thermal_publisher.
 2. Architecture registry uses extras["model_name"] as lookup key
    (avoids touching the frozen BuildConfig / EngineCacheEntry DTOs).
 3. Warm-up forward deferred to AZ-300 tier-2 follow-up — the zero-arg
    registry has no per-backbone input-shape metadata.

Suite: 1120 passed / 10 skipped (CUDA + Tier-2 + cmake / actionlint
environment gates). No regressions in non-c7_inference areas.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Oleksandr Bezdieniezhnykh
2026-05-12 10:13:21 +03:00
parent fce80290bc
commit 65ad2168ed
10 changed files with 1079 additions and 9 deletions
@@ -28,6 +28,11 @@ from gps_denied_onboard._types.inference import (
PrecisionMode,
)
from gps_denied_onboard._types.thermal import ThermalState
from gps_denied_onboard.components.c7_inference.architecture_registry import (
ArchitectureFactory,
default_registry,
register_architecture,
)
from gps_denied_onboard.components.c7_inference.config import C7InferenceConfig
from gps_denied_onboard.components.c7_inference.errors import (
CalibrationCacheError,
@@ -47,6 +52,7 @@ from gps_denied_onboard.config.schema import register_component_block
register_component_block("c7_inference", C7InferenceConfig)
__all__ = [
"ArchitectureFactory",
"BuildConfig",
"C7InferenceConfig",
"CalibrationCacheError",
@@ -65,4 +71,6 @@ __all__ = [
"RuntimeError",
"TelemetryUnavailableError",
"ThermalState",
"default_registry",
"register_architecture",
]
@@ -0,0 +1,65 @@
"""Per-backbone architecture registry for ``PytorchFp16Runtime`` (AZ-300).
The PyTorch baseline loads checkpoints saved as ``state_dict`` (the
``weights_only=True`` security path); ``model.load_state_dict`` requires
a pre-instantiated ``nn.Module``. This module owns the small
``model_name -> factory()`` indirection that lets the composition root
register each backbone's architecture class at startup without C7
importing component code directly (Risk-1 in the AZ-300 spec — would
violate the L2→L3 module-layout layering otherwise).
The registry is a plain ``dict``; ``register_architecture`` is
idempotent for the same factory and rejects re-registration with a
different one. The module does NOT import :mod:`torch` — keeping it
torch-free lets the composition root populate the registry even on
non-PyTorch tiers.
"""
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from torch import nn
__all__ = [
"ArchitectureFactory",
"default_registry",
"register_architecture",
]
ArchitectureFactory = Callable[[], "nn.Module"]
_DEFAULT_REGISTRY: dict[str, ArchitectureFactory] = {}
def register_architecture(
model_name: str, factory: ArchitectureFactory
) -> None:
"""Register a model-architecture factory under ``model_name``.
Idempotent for the same ``factory`` object; raises ``ValueError`` if
a *different* factory is registered under the same name (catches the
common "two backbones colliding on the same stem" bug at composition
time rather than at first ``deserialize_engine`` call).
"""
if not model_name:
raise ValueError("model_name must be non-empty")
existing = _DEFAULT_REGISTRY.get(model_name)
if existing is not None and existing is not factory:
raise ValueError(
f"Architecture {model_name!r} already registered with a "
f"different factory; refusing to override"
)
_DEFAULT_REGISTRY[model_name] = factory
def default_registry() -> dict[str, ArchitectureFactory]:
"""Return the module-level singleton registry.
Returned by reference — callers see live registrations. Tests that
need an isolated registry SHOULD construct their own ``dict`` and
pass it explicitly to :class:`PytorchFp16Runtime`.
"""
return _DEFAULT_REGISTRY
@@ -14,8 +14,8 @@ from typing import Final
from gps_denied_onboard.config.schema import ConfigError
__all__ = [
"C7InferenceConfig",
"KNOWN_RUNTIMES",
"C7InferenceConfig",
]
KNOWN_RUNTIMES: Final[frozenset[str]] = frozenset(
@@ -45,6 +45,7 @@ class C7InferenceConfig:
runtime: str = "pytorch_fp16"
thermal_poll_hz: float = 1.0
engine_cache_dir: str = "/var/lib/gps-denied/engines"
per_frame_debug_log: bool = False
def __post_init__(self) -> None:
if self.runtime not in KNOWN_RUNTIMES:
@@ -0,0 +1,339 @@
"""``PytorchFp16Runtime`` — mandatory simple-baseline C7 strategy (AZ-300).
Conforms to :class:`InferenceRuntime` (AZ-297). Loads each backbone's
canonical PyTorch checkpoint (``.pt`` / ``.pth``), runs eager
``.half().cuda()`` forward, no ``torch.compile`` / JIT / autocast — the
"simple baseline" every fancier strategy (AZ-298 TRT, AZ-299 ORT) is
numerically measured against per the ENG-RULE.
Composition contract (factory shape, AZ-297 ``build_inference_runtime``)::
strategy_cls(config: Config)
This module conforms by accepting ``config`` positionally; every other
collaborator — the per-backbone architecture registry, the AZ-302
``ThermalStatePublisher`` reference, the injectable :class:`Clock` —
is a keyword-only optional with safe defaults. AZ-302 will update the
factory to thread ``thermal_publisher`` through; until then,
:meth:`thermal_state` returns a default-safe ``ThermalState`` per
Invariant I-6.
AC mapping (see ``_docs/02_tasks/todo/AZ-300_c7_pytorch_baseline.md``):
- AC-1 : :meth:`current_runtime_label` returns ``"pytorch_fp16"``;
protocol conformance via ``runtime_checkable``.
- AC-2 : :meth:`compile_engine` is a no-op — no ``.engine`` produced;
returned :class:`EngineCacheEntry` carries the checkpoint path.
- AC-3 : :meth:`deserialize_engine` runs ``torch.load`` →
``load_state_dict`` → ``.half().cuda().eval()``.
- AC-4 : :meth:`infer` runs synchronous FP16 forward; outputs are
host-resident :class:`numpy.ndarray`.
- AC-5 : :meth:`release_engine` drops references + calls
``torch.cuda.empty_cache``.
- AC-6 / AC-7 : missing checkpoint or mismatched state_dict rewrap to
:class:`EngineDeserializeError` with ``__cause__`` preserved.
- AC-8 : ``torch.cuda.OutOfMemoryError`` rewraps to the C7-local
:class:`OutOfMemoryError`.
"""
from __future__ import annotations
import hashlib
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Final, Literal
import numpy as np
import torch
from torch import nn
from gps_denied_onboard._types.inference import (
BuildConfig,
EngineCacheEntry,
EngineHandle,
PrecisionMode,
)
from gps_denied_onboard._types.thermal import ThermalState
from gps_denied_onboard.clock.wall_clock import WallClock
from gps_denied_onboard.components.c7_inference.architecture_registry import (
ArchitectureFactory,
default_registry,
)
from gps_denied_onboard.components.c7_inference.errors import (
EngineDeserializeError,
InferenceError,
OutOfMemoryError,
)
from gps_denied_onboard.logging import get_logger
if TYPE_CHECKING:
from gps_denied_onboard.clock import Clock
from gps_denied_onboard.config.schema import Config
__all__ = ["PytorchEngineHandle", "PytorchFp16Runtime"]
_RUNTIME_LABEL: Final[Literal["pytorch_fp16"]] = "pytorch_fp16"
_SHA256_CHUNK: Final[int] = 1 << 20 # 1 MiB
_MODEL_NAME_KEY: Final[str] = "model_name"
class PytorchEngineHandle(EngineHandle):
"""Opaque handle wrapping a deserialised + GPU-resident PyTorch model.
Fields are private to :class:`PytorchFp16Runtime`; per Invariant I-4
consumers MUST NOT introspect. The handle is reusable across many
:meth:`PytorchFp16Runtime.infer` calls until
:meth:`PytorchFp16Runtime.release_engine`.
"""
__slots__ = ("_model", "_model_name", "_released")
def __init__(self, model: nn.Module, model_name: str) -> None:
self._model = model
self._model_name = model_name
self._released = False
class PytorchFp16Runtime:
"""Eager FP16 PyTorch :class:`InferenceRuntime` — simple baseline."""
def __init__(
self,
config: Config,
*,
thermal_publisher: Any | None = None,
architecture_registry: Mapping[str, ArchitectureFactory] | None = None,
clock: Clock | None = None,
) -> None:
self._config = config
self._c7_config = config.components["c7_inference"]
self._thermal_publisher = thermal_publisher
self._registry: Mapping[str, ArchitectureFactory] = (
architecture_registry
if architecture_registry is not None
else default_registry()
)
self._clock = clock if clock is not None else WallClock()
self._logger = get_logger("c7_inference.pytorch_fp16")
def compile_engine(
self, model_path: Path, build_config: BuildConfig
) -> EngineCacheEntry:
"""No-op compile — PyTorch checkpoints are the artifact (AC-2).
Returns an :class:`EngineCacheEntry` whose ``engine_path`` is
the source ``.pt`` path. The five-tuple ``(sm, jp, trt, precision)``
is ``(None, None, None, FP16)`` — PyTorch is hardware-portable
across SM levels. ``extras["model_name"]`` is the checkpoint's
file stem; :meth:`deserialize_engine` reads it to look up the
architecture factory in the registry.
"""
path = Path(model_path)
if not path.exists():
raise EngineDeserializeError(
f"PyTorch checkpoint not found at {path!s}"
)
sha256_hex = _sha256_of_file(path)
return EngineCacheEntry(
engine_path=path,
sha256_hex=sha256_hex,
sm=None,
jp=None,
trt=None,
precision=PrecisionMode.FP16,
extras={_MODEL_NAME_KEY: path.stem},
)
def deserialize_engine(self, entry: EngineCacheEntry) -> EngineHandle:
"""Load the checkpoint and produce a GPU-resident handle (AC-3)."""
path = Path(entry.engine_path)
if not path.exists():
raise EngineDeserializeError(
f"PyTorch checkpoint not found at {path!s}"
)
model_name = entry.extras.get(_MODEL_NAME_KEY) or path.stem
factory = self._registry.get(model_name)
if factory is None:
raise EngineDeserializeError(
f"No architecture registered for model_name={model_name!r}; "
"the composition root must call "
"`c7_inference.register_architecture(name, factory)` before "
"the runtime is constructed."
)
try:
model = factory()
except Exception as exc:
raise EngineDeserializeError(
f"Architecture factory for {model_name!r} raised "
f"{type(exc).__name__}"
) from exc
try:
state_dict = torch.load(
path, map_location="cpu", weights_only=True
)
except FileNotFoundError as exc:
raise EngineDeserializeError(
f"PyTorch checkpoint not found at {path!s}"
) from exc
except Exception as exc:
raise EngineDeserializeError(
f"torch.load failed for checkpoint {path!s}: "
f"{type(exc).__name__}"
) from exc
try:
model.load_state_dict(state_dict, strict=True)
except RuntimeError as exc:
raise EngineDeserializeError(
f"load_state_dict(strict=True) rejected checkpoint "
f"{path!s} for architecture {model_name!r}"
) from exc
try:
model = model.half().cuda().eval()
except torch.cuda.OutOfMemoryError as exc:
raise OutOfMemoryError(
f"CUDA OOM while moving {model_name!r} to GPU"
) from exc
except RuntimeError as exc:
raise EngineDeserializeError(
f"GPU move / half-cast failed for {model_name!r}"
) from exc
n_params = sum(p.numel() for p in model.parameters())
gpu_bytes = sum(
p.numel() * p.element_size() for p in model.parameters()
)
self._logger.info(
"pytorch_fp16.deserialize_engine: loaded %s (%d params, "
"~%.1f MB GPU, sha256=%s)",
model_name,
n_params,
gpu_bytes / (1024 * 1024),
entry.sha256_hex[:12],
)
return PytorchEngineHandle(model, model_name)
def infer(
self,
handle: EngineHandle,
inputs: dict[str, np.ndarray],
) -> dict[str, np.ndarray]:
"""Run a sync FP16 forward; return numpy outputs (AC-4)."""
if not isinstance(handle, PytorchEngineHandle):
raise InferenceError(
f"infer() received foreign handle type {type(handle).__name__}; "
"PytorchFp16Runtime only accepts PytorchEngineHandle"
)
if handle._released:
raise InferenceError(
"infer() called on released handle "
f"({handle._model_name!r})"
)
model = handle._model
if self._c7_config.per_frame_debug_log:
t0_ns = self._clock.monotonic_ns()
else:
t0_ns = None
try:
torch_inputs = {
name: torch.from_numpy(np.ascontiguousarray(arr))
.half()
.cuda(non_blocking=False)
for name, arr in inputs.items()
}
with torch.no_grad(), torch.inference_mode():
raw = model(**torch_inputs)
outputs = _to_numpy_dict(raw)
except torch.cuda.OutOfMemoryError as exc:
raise OutOfMemoryError(
f"CUDA OOM during infer({handle._model_name!r})"
) from exc
except InferenceError:
raise
except Exception as exc:
raise InferenceError(
f"forward pass for {handle._model_name!r} raised "
f"{type(exc).__name__}"
) from exc
if t0_ns is not None:
dt_us = (self._clock.monotonic_ns() - t0_ns) / 1_000
self._logger.debug(
"pytorch_fp16.infer: %s took %.1f us",
handle._model_name,
dt_us,
)
return outputs
def release_engine(self, handle: EngineHandle) -> None:
"""Drop references + ``torch.cuda.empty_cache()`` (AC-5, I-7)."""
if not isinstance(handle, PytorchEngineHandle):
return
if handle._released:
return
handle._released = True
handle._model = None # type: ignore[assignment]
if torch.cuda.is_available():
torch.cuda.empty_cache()
def thermal_state(self) -> ThermalState:
"""Delegate to the injected publisher, else default-safe (I-6)."""
publisher = self._thermal_publisher
if publisher is None:
return ThermalState(
cpu_temp_c=None,
gpu_temp_c=None,
thermal_throttle_active=False,
measured_clock_mhz=None,
measured_at_ns=self._clock.monotonic_ns(),
is_telemetry_available=False,
)
return publisher.read()
def current_runtime_label(self) -> Literal["pytorch_fp16"]:
return _RUNTIME_LABEL
def _sha256_of_file(path: Path) -> str:
"""Stream-hash ``path`` so we never load multi-GB checkpoints into RAM."""
hasher = hashlib.sha256()
with path.open("rb") as fh:
while True:
chunk = fh.read(_SHA256_CHUNK)
if not chunk:
break
hasher.update(chunk)
return hasher.hexdigest()
def _to_numpy_dict(raw: object) -> dict[str, np.ndarray]:
"""Convert a forward-pass return value into the Protocol's output shape.
Supports four common backbone-return shapes:
- ``dict[str, Tensor]`` — used straight.
- ``Tensor`` — wrapped as ``{"output": tensor}``.
- ``tuple`` / ``list`` of tensors — keyed ``output_0``, ``output_1``, …
- anything else → :class:`InferenceError` (caller rewraps).
Each tensor is moved to CPU (sync barrier per I-8) and converted to
a contiguous :class:`numpy.ndarray`. FP16 tensors stay FP16 in numpy
(``np.float16``); downstream consumers do the cast if they need FP32.
"""
if isinstance(raw, dict):
return {name: _tensor_to_numpy(t) for name, t in raw.items()}
if isinstance(raw, torch.Tensor):
return {"output": _tensor_to_numpy(raw)}
if isinstance(raw, (tuple, list)):
return {
f"output_{idx}": _tensor_to_numpy(t) for idx, t in enumerate(raw)
}
raise InferenceError(
f"forward pass returned unsupported type {type(raw).__name__}; "
"expected dict / Tensor / tuple / list"
)
def _tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
if not isinstance(tensor, torch.Tensor):
raise InferenceError(
f"output container held non-Tensor element {type(tensor).__name__}"
)
return tensor.detach().cpu().numpy()