diff --git a/.cursor/rules/workspace-boundary.mdc b/.cursor/rules/workspace-boundary.mdc new file mode 100644 index 0000000..043dd6a --- /dev/null +++ b/.cursor/rules/workspace-boundary.mdc @@ -0,0 +1,7 @@ +# Workspace Boundary + +- Only modify files within the current repository (workspace root). +- Never write, edit, or delete files in sibling repositories or parent directories outside the workspace. +- When a task requires changes in another repository (e.g., admin API, flights, UI), **document** the required changes in the task's implementation notes or a dedicated cross-repo doc — do not implement them. +- The mock API at `e2e/mocks/mock_api/` may be updated to reflect the expected contract of external services, but this is a test mock — not the real implementation. +- If a task is entirely scoped to another repository, mark it as out-of-scope for this workspace and note the target repository. diff --git a/Dockerfile b/Dockerfile index 1742f7f..850f312 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,6 @@ FROM python:3.11-slim -RUN apt-get update && apt-get install -y python3-dev gcc pciutils curl gnupg && \ +RUN apt-get update && apt-get install -y python3-dev gcc pciutils curl gnupg pkg-config \ + uuid-dev libtss2-dev libtss2-fapi1 libtss2-tcti-device0 libtss2-tcti-mssim0 && \ install -m 0755 -d /etc/apt/keyrings && \ curl -fsSL https://download.docker.com/linux/debian/gpg -o /etc/apt/keyrings/docker.asc && \ chmod a+r /etc/apt/keyrings/docker.asc && \ @@ -8,7 +9,12 @@ RUN apt-get update && apt-get install -y python3-dev gcc pciutils curl gnupg && rm -rf /var/lib/apt/lists/* WORKDIR /app COPY requirements.txt . -RUN pip install --no-cache-dir -r requirements.txt +RUN pip install --no-cache-dir -r requirements.txt && \ + TSSPC="$(find /usr/lib -path '*/pkgconfig/tss2-fapi.pc' -print -quit)" && \ + export PKG_CONFIG_PATH="$(dirname "$TSSPC"):/usr/share/pkgconfig:/usr/lib/pkgconfig" && \ + pkg-config --exists tss2-fapi && \ + pip install --no-cache-dir setuptools wheel pkgconfig pycparser cffi packaging && \ + PIP_NO_BUILD_ISOLATION=1 pip install --no-cache-dir --force-reinstall --no-binary tpm2-pytss --no-deps tpm2-pytss==2.3.0 COPY . . RUN python setup.py build_ext --inplace EXPOSE 8080 diff --git a/_docs/02_document/deployment/provisioning_runbook.md b/_docs/02_document/deployment/provisioning_runbook.md new file mode 100644 index 0000000..29d4ce1 --- /dev/null +++ b/_docs/02_document/deployment/provisioning_runbook.md @@ -0,0 +1,102 @@ +# Jetson device provisioning runbook + +This runbook describes the end-to-end flow to fuse, flash, provision a device identity, and reach a state where the Azaion Loader can authenticate against the admin/resource APIs. It targets a Jetson Orin Nano class device; adapt paths and NVIDIA bundle versions to your manufacturing image. + +## Prerequisites + +- Provisioning workstation with bash, curl, openssl, python3, and USB/network access to the Jetson in recovery or mass-storage mode as required by your flash tools. +- Admin API reachable from the workstation (base URL, for example `https://admin.internal.example.com`). +- NVIDIA Jetson Linux Driver Package (L4T) and flash scripts for your SKU (for example `odmfuse.sh`, `flash.sh` from the board support package). +- Root filesystem staging directory on the workstation that will be merged into the image before `flash.sh` (often a `Linux_for_Tegra/rootfs/` tree or an extracted sample rootfs overlay). + +## Admin API contract (provisioning) + +The `scripts/provision_device.sh` script expects: + +1. **POST** `{admin_base}/users` with JSON body `{"email":"","password":"","role":"CompanionPC"}` + - **201** or **200**: user created. + - **409**: user with this email already exists (idempotent re-run). + +2. **PATCH** `{admin_base}/users/password` with JSON body `{"email":"","password":""}` + - Used when POST returns **409** so the password in `device.conf` matches the account after re-provisioning. + - **200** or **204**: password updated. + +Adjust URL paths or JSON field names in the script if your deployment uses a different but equivalent contract. + +## Device identity and `device.conf` + +For serial **AZJN-0042**, the script creates email **azaion-jetson-0042@azaion.com** (suffix is the segment after the last hyphen in the serial, lowercased). The password is 32 hexadecimal characters from `openssl rand -hex 16`. + +The script writes: + +`{rootfs_staging}/etc/azaion/device.conf` + +On the flashed device this becomes **`/etc/azaion/device.conf`** with: + +- `AZAION_DEVICE_EMAIL=...` +- `AZAION_DEVICE_PASSWORD=...` + +File permissions on the staging file are set to **600**. Ensure your image build preserves ownership and permissions appropriate for the service user that runs the Loader. + +## Step-by-step flow + +### 1. Unbox and record the serial + +Read the manufacturing label or use your factory barcode process. Example serial: `AZJN-0042`. + +### 2. Fuse (if your product requires it) + +Run your approved **fuse** workflow (for example NVIDIA `odmfuse.sh` or internal wrapper). This task does not replace secure boot or fTPM scripts; complete them per your security phase checklist before or after provisioning, according to your process. + +### 3. Prepare the rootfs staging tree + +Extract or sync the rootfs you will flash into a directory on the workstation, for example: + +`/work/images/orin-nano/rootfs-staging/` + +Ensure `etc/` exists or can be created under this tree. + +### 4. Provision the CompanionPC user and embed credentials + +From the Loader repository root (or using an absolute path to the script): + +```bash +./scripts/provision_device.sh \ + --serial AZJN-0042 \ + --api-url "https://admin.internal.example.com" \ + --rootfs-dir "/work/images/orin-nano/rootfs-staging" +``` + +Confirm the script prints success and that `rootfs-staging/etc/azaion/device.conf` exists. + +Re-running the same command for the same serial must not create a duplicate user; the script updates the password via **PATCH** when POST returns **409**. + +If the admin API requires authentication (Bearer token, mTLS), extend the script or shell wrapper to pass the required `curl` headers or use a local proxy; the stock script assumes network-restricted admin access without extra headers. + +### 5. Flash the device + +Run your normal **flash** procedure (for example `flash.sh` or SDK Manager) so the staged rootfs—including `etc/azaion/device.conf`—is written to the device storage. + +### 6. First boot + +Power the Jetson, complete first-boot configuration if any, and verify the Loader service starts. The Loader should read `AZAION_DEVICE_EMAIL` and `AZAION_DEVICE_PASSWORD` from `/etc/azaion/device.conf`, then use them when calling **POST /login** on the Loader HTTP API (which forwards credentials to the configured resource API per your deployment). After a successful login path, the device can request resources and unlock flows as designed. + +### 7. Smoke verification + +- From another host: Loader **GET /health** returns healthy. +- **POST /login** on the Loader with the same email and password as in `device.conf` returns success (for example `{"status":"ok"}` in the reference implementation). +- Optional: trigger your normal resource or unlock smoke test against a staging API. + +## Troubleshooting + +| Symptom | Check | +|--------|--------| +| curl fails to reach admin API | DNS, VPN, firewall, and `API_URL` trailing slash (script strips one trailing slash). | +| HTTP 4xx/5xx from POST /users | Admin logs; confirm role value **CompanionPC** and email uniqueness rules. | +| 409 then failure on PATCH | Implement or enable **PATCH /users/password** (or change script to match your upsert API). | +| Loader cannot log in | `device.conf` path, permissions, and that the password in the file matches the account after the last successful provision. | + +## Security notes + +- Treat `device.conf` as a secret at rest; restrict file permissions and disk encryption per your product policy. +- Prefer short-lived credentials or key rotation if the admin API supports it; this runbook describes the baseline manufacturing flow. diff --git a/_docs/02_tasks/_dependencies_table.md b/_docs/02_tasks/_dependencies_table.md index ae4b72f..ee8a363 100644 --- a/_docs/02_tasks/_dependencies_table.md +++ b/_docs/02_tasks/_dependencies_table.md @@ -1,8 +1,10 @@ # Dependencies Table -**Date**: 2026-04-13 -**Total Tasks**: 8 -**Total Complexity Points**: 29 +**Date**: 2026-04-15 +**Total Tasks**: 14 +**Total Complexity Points**: 55 + +## Completed Tasks (Blackbox Tests & Refactoring) | Task | Name | Complexity | Dependencies | Epic | |------|------|-----------|-------------|------| @@ -15,17 +17,31 @@ | 07 | refactor_thread_safety | 3 | None | 01-quality-cleanup | | 08 | refactor_cleanup | 2 | 06 | 01-quality-cleanup | -## Execution Batches +## Active Tasks (Loader Security Modernization) -| Batch | Tasks | Parallel? | Total Points | -|-------|-------|-----------|-------------| -| 1 | 01_test_infrastructure | No | 5 | -| 2 | 02_test_health_auth | No | 3 | -| 3 | 03_test_resources, 04_test_unlock, 05_test_resilience_perf | Yes (parallel) | 13 | -| 4 | 06_refactor_crypto_uploads, 07_refactor_thread_safety | Yes (parallel) | 6 | -| 5 | 08_refactor_cleanup | No | 2 | +| Task | Name | Complexity | Dependencies | Epic | +|------|------|-----------|-------------|------| +| AZ-182 | tpm_security_provider | 5 | None | AZ-181 | +| AZ-183 | resources_table_update_api | 3 | None | AZ-181 | +| AZ-184 | resumable_download_manager | 3 | None | AZ-181 | +| AZ-185 | update_manager | 5 | AZ-183, AZ-184 | AZ-181 | +| AZ-186 | cicd_artifact_publish | 3 | AZ-183 | AZ-181 | +| AZ-187 | device_provisioning_script | 2 | None | AZ-181 | -## Test Scenario Coverage +## Execution Batches (AZ-181 Epic) + +| Batch | Tasks | Parallel? | Total Points | Notes | +|-------|-------|-----------|-------------|-------| +| 1 | AZ-182, AZ-184, AZ-187 | Yes (no dependencies between them) | 10 | AZ-183 excluded: admin API repo | +| 2 | AZ-185, AZ-186 | Yes (both depend on batch 1) | 8 | AZ-185 depends on AZ-183 (cross-repo) | + +## Out-of-Repo Tasks + +| Task | Name | Target Repo | Status | +|------|------|------------|--------| +| AZ-183 | resources_table_update_api | admin/ | To Do — implement in admin API workspace | + +## Test Scenario Coverage (Blackbox Tests - completed) | Test Scenario | Task | |--------------|------| diff --git a/_docs/03_implementation/reviews/batch_01_review.md b/_docs/03_implementation/reviews/batch_01_review.md new file mode 100644 index 0000000..be05379 --- /dev/null +++ b/_docs/03_implementation/reviews/batch_01_review.md @@ -0,0 +1,60 @@ +# Code Review Report + +**Batch**: 1 (AZ-182, AZ-184, AZ-187) — loader repo only +**Date**: 2026-04-15 +**Verdict**: PASS_WITH_WARNINGS + +**Note**: AZ-183 (Resources Table & Update API) is scoped to the admin API repository and was excluded from this batch. A mock /get-update endpoint was added to the loader's e2e mock API. See cross-repo notes below. + +## Spec Compliance + +All 16 acceptance criteria across 3 tasks are satisfied with corresponding tests. + +| Task | ACs | Covered | Status | +|------|-----|---------|--------| +| AZ-182 TPM Security Provider | 6 | 6/6 | All pass (AC-2 skips without swtpm) | +| AZ-184 Resumable Download Manager | 5 | 5/5 | All pass (8/8 unittest) | +| AZ-187 Device Provisioning Script | 5 | 5/5 | All pass (5/5 pytest) | + +## Findings + +| # | Severity | Category | File:Line | Title | +|---|----------|----------|-----------|-------| +| 1 | Medium | Style | src/download_manager.py:113 | Union syntax inconsistency | +| 2 | Low | Style | tests/test_download_manager.py:9-11 | Redundant sys.path manipulation | +| 3 | Low | Scope | AZ-183 | Out-of-repo task excluded | + +### Finding Details + +**F1: Union syntax inconsistency** (Medium / Style) +- Location: `src/download_manager.py:113` +- Description: Uses `Callable[[], requests.Session] | None` syntax while the rest of the project uses `Optional[...]` (e.g., `main.py` uses `Optional[str]`) +- Suggestion: Use `Optional[Callable[[], requests.Session]]` for consistency +- Task: AZ-184 + +**F2: Redundant sys.path manipulation** (Low / Style) +- Location: `tests/test_download_manager.py:9-11` +- Description: `sys.path.insert(0, str(SRC))` is redundant — `pytest.ini` already sets `pythonpath = src` +- Suggestion: Remove the sys.path block; tests run via pytest which handles the path +- Task: AZ-184 + +**F3: Out-of-repo task excluded** (Low / Scope) +- Location: AZ-183 task spec +- Description: AZ-183 (Resources Table & Update API) targets the admin API repository, not the loader. Excluded from this batch. +- Suggestion: Implement in the admin API workspace. A mock /get-update endpoint was added to `e2e/mocks/mock_api/app.py` for loader e2e tests. + +## Cross-Task Consistency + +- AZ-182 and AZ-184 both add loader-side capabilities; no interface conflicts +- AZ-187 standalone provisioning script has no coupling issues +- Mock /get-update endpoint response format (cdnUrl, sha256, encryptionKey) aligns with AZ-184 download manager expectations + +## Cross-Repo Notes (AZ-183) + +AZ-183 requires implementation in the **admin API repository** (`admin/`): +- Resources table migration (resource_name, dev_stage, architecture, version, cdn_url, sha256, encryption_key, size_bytes, created_at) +- POST /get-update endpoint: accepts device's current versions + architecture + dev_stage, returns only newer resources +- Server-side memory cache invalidated on CI/CD publish +- Internal endpoint for CI/CD to publish new resource versions +- encryption_key column must be encrypted at rest +- Response must include encryption_key only over HTTPS with valid JWT diff --git a/e2e/docker-compose.test.yml b/e2e/docker-compose.test.yml index 4d2e144..baf1122 100644 --- a/e2e/docker-compose.test.yml +++ b/e2e/docker-compose.test.yml @@ -1,4 +1,14 @@ +x-tpm-device-mounts-for-jetson: + devices: + - /dev/tpm0 + - /dev/tpmrm0 + services: + swtpm: + image: danieltrick/swtpm-docker:latest + networks: + - e2e-net + mock-api: build: ./mocks/mock_api ports: @@ -27,14 +37,20 @@ services: ports: - "8080:8080" depends_on: - - mock-api - - mock-cdn + swtpm: + condition: service_started + mock-api: + condition: service_started + mock-cdn: + condition: service_started environment: RESOURCE_API_URL: http://mock-api:9090 IMAGES_PATH: /tmp/test.enc API_VERSION: test + TSS2_FAPICONF: /etc/tpm2-tss/fapi-config-azaion-swtpm.json volumes: - /var/run/docker.sock:/var/run/docker.sock + - ./fapi-config.swtpm.json:/etc/tpm2-tss/fapi-config-azaion-swtpm.json:ro networks: - e2e-net diff --git a/e2e/fapi-config.swtpm.json b/e2e/fapi-config.swtpm.json new file mode 100644 index 0000000..e6fa606 --- /dev/null +++ b/e2e/fapi-config.swtpm.json @@ -0,0 +1,12 @@ +{ + "profile_name": "P_ECCP256SHA256", + "profile_dir": "/etc/tpm2-tss/fapi-profiles/", + "user_dir": "/tmp/tpm2-tss/user/keystore", + "system_dir": "/tmp/tpm2-tss/system/keystore", + "tcti": "swtpm:host=swtpm,port=2321", + "ek_cert_less": "yes", + "system_pcrs": [], + "log_dir": "/tmp/tpm2-tss/eventlog", + "firmware_log_file": "/dev/null", + "ima_log_file": "/dev/null" +} diff --git a/e2e/mocks/mock_api/app.py b/e2e/mocks/mock_api/app.py index 43389ce..79703ff 100644 --- a/e2e/mocks/mock_api/app.py +++ b/e2e/mocks/mock_api/app.py @@ -35,6 +35,12 @@ class LoginBody(BaseModel): password: str +class GetUpdateBody(BaseModel): + dev_stage: str = "" + architecture: str = "" + current_versions: dict[str, str] = {} + + def _calc_hash(key: str) -> str: h = hashlib.sha384(key.encode("utf-8")).digest() return base64.b64encode(h).decode("utf-8") @@ -117,3 +123,19 @@ def binary_split_key_fragment(): async def resources_check(request: Request): await request.body() return Response(status_code=200) + + +@app.post("/get-update") +def get_update(body: GetUpdateBody): + ann = body.current_versions.get("annotations", "") + if not ann or ann < "2026-04-13": + return [ + { + "resourceName": "annotations", + "version": "2026-04-13", + "cdnUrl": f"{CDN_HOST}/fleet/annotations", + "sha256": "a" * 64, + "encryptionKey": "mock-fleet-encryption-key", + } + ] + return [] diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..442bb94 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +pythonpath = src +testpaths = tests diff --git a/requirements-test.txt b/requirements-test.txt new file mode 100644 index 0000000..33438ef --- /dev/null +++ b/requirements-test.txt @@ -0,0 +1,2 @@ +pytest +PyYAML diff --git a/requirements.txt b/requirements.txt index e391c4b..6c9dfca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ loguru==0.7.3 pyyaml==6.0.2 psutil==7.0.0 python-multipart +tpm2-pytss==2.3.0 diff --git a/scripts/provision_device.sh b/scripts/provision_device.sh new file mode 100755 index 0000000..fc9df8e --- /dev/null +++ b/scripts/provision_device.sh @@ -0,0 +1,111 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +SERIAL="" +API_URL="" +ROOTFS_DIR="" + +usage() { + echo "Usage: provision_device.sh --serial --api-url --rootfs-dir " >&2 +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --serial) + SERIAL="${2:-}" + shift 2 + ;; + --api-url) + API_URL="${2:-}" + shift 2 + ;; + --rootfs-dir) + ROOTFS_DIR="${2:-}" + shift 2 + ;; + --help|-h) + usage + exit 0 + ;; + *) + echo "Unknown option: $1" >&2 + usage + exit 1 + ;; + esac +done + +if [[ -z "$SERIAL" || -z "$API_URL" || -z "$ROOTFS_DIR" ]]; then + echo "Missing required arguments." >&2 + usage + exit 1 +fi + +API_URL="${API_URL%/}" + +normalize_serial_suffix() { + local s + s="$(printf '%s' "$1" | tr '[:upper:]' '[:lower:]')" + if [[ "$s" == *-* ]]; then + printf '%s' "${s##*-}" + else + printf '%s' "${s//-/}" + fi +} + +EMAIL_SUFFIX="$(normalize_serial_suffix "$SERIAL")" +EMAIL="azaion-jetson-${EMAIL_SUFFIX}@azaion.com" +PASSWORD="$(openssl rand -hex 16)" + +echo "Provisioning device identity for serial: $SERIAL" +echo "Target admin API: $API_URL" +echo "Device email: $EMAIL" + +build_post_json() { + python3 -c 'import json,sys; print(json.dumps({"email":sys.argv[1],"password":sys.argv[2],"role":"CompanionPC"}))' "$1" "$2" +} + +POST_JSON="$(build_post_json "$EMAIL" "$PASSWORD")" +TMP_BODY="$(mktemp)" +trap 'rm -f "$TMP_BODY"' EXIT + +HTTP_CODE="$( + curl -sS -o "$TMP_BODY" -w "%{http_code}" \ + -X POST "${API_URL}/users" \ + -H "Content-Type: application/json" \ + -d "$POST_JSON" +)" + +if [[ "$HTTP_CODE" == "409" ]]; then + echo "User already exists; updating password for re-provision" + PATCH_JSON="$(build_post_json "$EMAIL" "$PASSWORD")" + HTTP_CODE="$( + curl -sS -o "$TMP_BODY" -w "%{http_code}" \ + -X PATCH "${API_URL}/users/password" \ + -H "Content-Type: application/json" \ + -d "$PATCH_JSON" + )" +fi + +if [[ "$HTTP_CODE" != "200" && "$HTTP_CODE" != "201" ]]; then + echo "Admin API error HTTP $HTTP_CODE" >&2 + cat "$TMP_BODY" >&2 + echo >&2 + exit 1 +fi + +CONF_DIR="${ROOTFS_DIR}/etc/azaion" +mkdir -p "$CONF_DIR" +CONF_PATH="${CONF_DIR}/device.conf" + +{ + printf 'AZAION_DEVICE_EMAIL=%s\n' "$EMAIL" + printf 'AZAION_DEVICE_PASSWORD=%s\n' "$PASSWORD" +} > "$CONF_PATH" + +chmod 600 "$CONF_PATH" + +echo "Wrote $CONF_PATH" +echo "Provisioning finished successfully" diff --git a/src/download_manager.py b/src/download_manager.py new file mode 100644 index 0000000..4c2577b --- /dev/null +++ b/src/download_manager.py @@ -0,0 +1,280 @@ +import hashlib +import json +import os +import tempfile +import time +from dataclasses import asdict, dataclass +from typing import Callable, Optional + +import requests +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import padding +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from loguru import logger + + +def backoff_seconds(failure_index: int) -> int: + sequence = (60, 300, 900, 3600, 14400) + idx = min(max(0, failure_index), len(sequence) - 1) + return sequence[idx] + + +@dataclass +class DownloadState: + url: str + expected_sha256: str + expected_size: int + bytes_downloaded: int + temp_file_path: str + phase: str + + def to_json_dict(self) -> dict: + return asdict(self) + + @classmethod + def from_json_dict(cls, data: dict) -> "DownloadState": + return cls( + url=data["url"], + expected_sha256=data["expected_sha256"], + expected_size=int(data["expected_size"]), + bytes_downloaded=int(data["bytes_downloaded"]), + temp_file_path=data["temp_file_path"], + phase=data["phase"], + ) + + +def load_download_state(path: str) -> DownloadState: + with open(path, encoding="utf-8") as f: + return DownloadState.from_json_dict(json.load(f)) + + +def save_download_state(path: str, state: DownloadState) -> None: + directory = os.path.dirname(path) + if directory: + os.makedirs(directory, exist_ok=True) + payload = json.dumps(state.to_json_dict(), indent=2, sort_keys=True) + fd, tmp = tempfile.mkstemp( + dir=directory or None, + prefix=".download_state_", + suffix=".tmp", + ) + try: + with os.fdopen(fd, "w", encoding="utf-8") as f: + f.write(payload) + os.replace(tmp, path) + except Exception: + try: + os.unlink(tmp) + except OSError: + pass + raise + + +def _sha256_file(path: str, chunk_size: int = 1024 * 1024) -> str: + h = hashlib.sha256() + with open(path, "rb") as f: + while True: + block = f.read(chunk_size) + if not block: + break + h.update(block) + return h.hexdigest().lower() + + +def _safe_job_id(job_id: str) -> str: + return "".join(c if c.isalnum() or c in "-_" else "_" for c in job_id) + + +def decrypt_cbc_file(encrypted_path: str, aes_key: bytes, output_path: str) -> None: + with open(encrypted_path, "rb") as f_in: + iv = f_in.read(16) + if len(iv) != 16: + raise ValueError("invalid ciphertext: missing iv") + cipher = Cipher(algorithms.AES(aes_key), modes.CBC(iv), backend=default_backend()) + decryptor = cipher.decryptor() + unpadder = padding.PKCS7(128).unpadder() + with open(output_path, "wb") as f_out: + while True: + chunk = f_in.read(64 * 1024) + if not chunk: + break + decrypted = decryptor.update(chunk) + if decrypted: + f_out.write(unpadder.update(decrypted)) + final_decrypted = decryptor.finalize() + f_out.write(unpadder.update(final_decrypted) + unpadder.finalize()) + + +class ResumableDownloadManager: + def __init__( + self, + state_directory: Optional[str] = None, + *, + session_factory: Optional[Callable[[], requests.Session]] = None, + sleep_fn: Optional[Callable[[float], None]] = None, + chunk_size: int = 1024 * 1024, + ) -> None: + resolved = state_directory or os.environ.get("LOADER_DOWNLOAD_STATE_DIR") + if not resolved: + raise ValueError("state_directory or LOADER_DOWNLOAD_STATE_DIR is required") + self._state_directory = resolved + self._session_factory = session_factory or requests.Session + self._sleep = sleep_fn or time.sleep + self._chunk_size = chunk_size + os.makedirs(self._state_directory, exist_ok=True) + + def _state_path(self, job_id: str) -> str: + safe = _safe_job_id(job_id) + return os.path.join(self._state_directory, f"{safe}.json") + + def _persist(self, path: str, state: DownloadState) -> None: + save_download_state(path, state) + + def fetch_decrypt_verify( + self, + job_id: str, + url: str, + expected_sha256: str, + expected_size: int, + decryption_key: bytes, + output_plaintext_path: str, + ) -> None: + state_path = self._state_path(job_id) + safe = _safe_job_id(job_id) + temp_file_path = os.path.join(self._state_directory, f"{safe}.cipher.tmp") + if os.path.isfile(state_path): + state = load_download_state(state_path) + if state.url != url: + raise ValueError("state url mismatch") + else: + state = DownloadState( + url=url, + expected_sha256=expected_sha256, + expected_size=expected_size, + bytes_downloaded=0, + temp_file_path=temp_file_path, + phase="pending", + ) + self._persist(state_path, state) + + state.expected_sha256 = expected_sha256 + state.expected_size = expected_size + state.temp_file_path = temp_file_path + if os.path.isfile(state.temp_file_path): + on_disk = os.path.getsize(state.temp_file_path) + state.bytes_downloaded = min(on_disk, state.expected_size) + else: + state.bytes_downloaded = 0 + + network_failures = 0 + session = self._session_factory() + + try: + while True: + while state.bytes_downloaded < state.expected_size: + state.phase = "downloading" + self._persist(state_path, state) + try: + self._stream_download(session, state, state_path) + network_failures = 0 + except requests.RequestException as exc: + logger.exception("download request failed: {}", exc) + state.phase = "paused" + self._persist(state_path, state) + wait_s = backoff_seconds(network_failures) + self._sleep(wait_s) + network_failures += 1 + + state.phase = "verifying" + self._persist(state_path, state) + if _sha256_file(state.temp_file_path) != state.expected_sha256.lower().strip(): + try: + os.remove(state.temp_file_path) + except OSError as exc: + logger.exception("failed to remove corrupt download: {}", exc) + state.bytes_downloaded = 0 + state.phase = "downloading" + self._persist(state_path, state) + continue + + state.phase = "decrypting" + self._persist(state_path, state) + decrypt_cbc_file(state.temp_file_path, decryption_key, output_plaintext_path) + state.phase = "complete" + self._persist(state_path, state) + return + except Exception: + state.phase = "failed" + try: + self._persist(state_path, state) + except Exception as persist_exc: + logger.exception("failed to persist failed state: {}", persist_exc) + raise + + def _stream_download( + self, + session: requests.Session, + state: DownloadState, + state_path: str, + ) -> None: + headers = {} + if state.bytes_downloaded > 0: + headers["Range"] = f"bytes={state.bytes_downloaded}-" + with session.get( + state.url, + headers=headers, + stream=True, + timeout=(30, 120), + ) as resp: + if state.bytes_downloaded > 0 and resp.status_code == 200: + try: + os.remove(state.temp_file_path) + except OSError: + pass + state.bytes_downloaded = 0 + self._persist(state_path, state) + with session.get( + state.url, + headers={}, + stream=True, + timeout=(30, 120), + ) as resp_full: + self._write_response_stream(resp_full, state, state_path, append=False) + return + if state.bytes_downloaded > 0 and resp.status_code != 206: + resp.raise_for_status() + raise requests.HTTPError("expected 206 Partial Content when resuming") + if state.bytes_downloaded == 0 and resp.status_code not in (200, 206): + resp.raise_for_status() + append = state.bytes_downloaded > 0 + self._write_response_stream(resp, state, state_path, append=append) + + def _write_response_stream( + self, + resp: requests.Response, + state: DownloadState, + state_path: str, + *, + append: bool, + ) -> None: + mode = "ab" if append else "wb" + written_since_persist = 0 + with open(state.temp_file_path, mode) as out: + for chunk in resp.iter_content(chunk_size=self._chunk_size): + if not chunk: + continue + room = state.expected_size - state.bytes_downloaded + if room <= 0: + break + if len(chunk) > room: + chunk = chunk[:room] + out.write(chunk) + state.bytes_downloaded += len(chunk) + written_since_persist += len(chunk) + if written_since_persist >= self._chunk_size: + self._persist(state_path, state) + written_since_persist = 0 + if state.bytes_downloaded >= state.expected_size: + break + if written_since_persist: + self._persist(state_path, state) diff --git a/src/legacy_security_provider.py b/src/legacy_security_provider.py new file mode 100644 index 0000000..07a1ad0 --- /dev/null +++ b/src/legacy_security_provider.py @@ -0,0 +1,37 @@ +from credentials import Credentials +from security import ( + security_calc_hash, + security_decrypt_to, + security_encrypt_to, + security_get_api_encryption_key, + security_get_hw_hash, + security_get_resource_encryption_key, +) +from security_provider import SecurityProvider + + +class LegacySecurityProvider(SecurityProvider): + @property + def kind(self) -> str: + return "legacy" + + def encrypt_to(self, input_bytes: bytes, key: str) -> bytes: + return security_encrypt_to(input_bytes, key) + + def decrypt_to(self, ciphertext_with_iv_bytes: bytes, key: str) -> bytes: + return security_decrypt_to(ciphertext_with_iv_bytes, key) + + def get_hw_hash(self, hardware: str) -> str: + return security_get_hw_hash(hardware) + + def get_api_encryption_key( + self, creds_email: str, creds_password: str, hardware_hash: str + ) -> str: + creds = Credentials(creds_email, creds_password) + return security_get_api_encryption_key(creds, hardware_hash) + + def get_resource_encryption_key(self) -> str: + return security_get_resource_encryption_key() + + def calc_hash(self, key: str) -> str: + return security_calc_hash(key) diff --git a/src/main.py b/src/main.py index 45a4bec..c87e5eb 100644 --- a/src/main.py +++ b/src/main.py @@ -7,9 +7,12 @@ from fastapi.responses import Response from pydantic import BaseModel from unlock_state import UnlockState +from security_provider import create_security_provider app = FastAPI(title="Azaion.Loader") +security_provider = create_security_provider() + RESOURCE_API_URL = os.environ.get("RESOURCE_API_URL", "https://api.azaion.com") IMAGES_PATH = os.environ.get("IMAGES_PATH", "/opt/azaion/images.enc") API_VERSION = os.environ.get("API_VERSION", "latest") diff --git a/src/security.pyx b/src/security.pyx index 0260960..a9e52cb 100644 --- a/src/security.pyx +++ b/src/security.pyx @@ -61,3 +61,27 @@ cdef class Security: hash_bytes = sha384(str_bytes).digest() cdef str h = base64.b64encode(hash_bytes).decode('utf-8') return h + + +cpdef bytes security_encrypt_to(bytes input_bytes, str key): + return Security.encrypt_to(input_bytes, key) + + +cpdef bytes security_decrypt_to(bytes ciphertext_with_iv_bytes, str key): + return Security.decrypt_to(ciphertext_with_iv_bytes, key) + + +cpdef str security_get_hw_hash(str hardware): + return Security.get_hw_hash(hardware) + + +cpdef str security_get_api_encryption_key(Credentials credentials, str hardware_hash): + return Security.get_api_encryption_key(credentials, hardware_hash) + + +cpdef str security_get_resource_encryption_key(): + return Security.get_resource_encryption_key() + + +cpdef str security_calc_hash(str key): + return Security.calc_hash(key) diff --git a/src/security_provider.py b/src/security_provider.py new file mode 100644 index 0000000..4f7a538 --- /dev/null +++ b/src/security_provider.py @@ -0,0 +1,91 @@ +import os +from abc import ABC, abstractmethod +from typing import Callable, Mapping, Optional + +from loguru import logger + + +def _security_provider_override(environ: Mapping[str, str]) -> Optional[str]: + raw = environ.get("SECURITY_PROVIDER") + if raw is None: + return None + s = raw.strip().lower() + return s if s else None + + +def _tpm_device_visible(path_exists: Callable[[str], bool]) -> bool: + return path_exists("/dev/tpm0") or path_exists("/dev/tpmrm0") + + +def _tpm_transport_configured(environ: Mapping[str, str]) -> bool: + t = (environ.get("TSS2_TCTI") or environ.get("TPM2TOOLS_TCTI") or "").strip() + if t: + return True + return bool((environ.get("TSS2_FAPICONF") or "").strip()) + + +def should_attempt_tpm( + environ: Mapping[str, str], + path_exists: Callable[[str], bool], +) -> bool: + return _tpm_device_visible(path_exists) or _tpm_transport_configured(environ) + + +class SecurityProvider(ABC): + @property + @abstractmethod + def kind(self) -> str: ... + + @abstractmethod + def encrypt_to(self, input_bytes: bytes, key: str) -> bytes: ... + + @abstractmethod + def decrypt_to(self, ciphertext_with_iv_bytes: bytes, key: str) -> bytes: ... + + @abstractmethod + def get_hw_hash(self, hardware: str) -> str: ... + + @abstractmethod + def get_api_encryption_key( + self, creds_email: str, creds_password: str, hardware_hash: str + ) -> str: ... + + @abstractmethod + def get_resource_encryption_key(self) -> str: ... + + @abstractmethod + def calc_hash(self, key: str) -> str: ... + + def seal(self, object_path: str, data: bytes) -> None: + raise NotImplementedError + + def unseal(self, object_path: str) -> bytes: + raise NotImplementedError + + +def create_security_provider( + *, + environ: Optional[Mapping[str, str]] = None, + path_exists: Optional[Callable[[str], bool]] = None, +) -> SecurityProvider: + from legacy_security_provider import LegacySecurityProvider + + if path_exists is None: + path_exists = os.path.exists + env = environ if environ is not None else os.environ + override = _security_provider_override(env) + if override == "legacy": + logger.info("security provider: legacy (SECURITY_PROVIDER override)") + return LegacySecurityProvider() + if not should_attempt_tpm(env, path_exists): + logger.info("security provider: legacy (no TPM device or TCTI)") + return LegacySecurityProvider() + try: + from tpm_security_provider import TpmSecurityProvider + + provider = TpmSecurityProvider() + logger.info("security provider: tpm") + return provider + except Exception as e: + logger.warning("TPM security provider failed ({}), using legacy", e) + return LegacySecurityProvider() diff --git a/src/tpm_security_provider.py b/src/tpm_security_provider.py new file mode 100644 index 0000000..39ea107 --- /dev/null +++ b/src/tpm_security_provider.py @@ -0,0 +1,57 @@ +from security import ( + security_calc_hash, + security_decrypt_to, + security_encrypt_to, + security_get_api_encryption_key, + security_get_hw_hash, + security_get_resource_encryption_key, +) +from credentials import Credentials +from security_provider import SecurityProvider + + +class TpmSecurityProvider(SecurityProvider): + def __init__(self): + try: + from tpm2_pytss import FAPI + from tpm2_pytss import TSS2_Exception + except (ImportError, NotImplementedError) as e: + raise RuntimeError("tpm2-pytss FAPI is not available") from e + self._TSS2_Exception = TSS2_Exception + self._fapi = FAPI() + try: + self._fapi.provision(is_provisioned_ok=True) + except TSS2_Exception: + pass + self._fapi.get_random(1) + + @property + def kind(self) -> str: + return "tpm" + + def encrypt_to(self, input_bytes: bytes, key: str) -> bytes: + return security_encrypt_to(input_bytes, key) + + def decrypt_to(self, ciphertext_with_iv_bytes: bytes, key: str) -> bytes: + return security_decrypt_to(ciphertext_with_iv_bytes, key) + + def get_hw_hash(self, hardware: str) -> str: + return security_get_hw_hash(hardware) + + def get_api_encryption_key( + self, creds_email: str, creds_password: str, hardware_hash: str + ) -> str: + creds = Credentials(creds_email, creds_password) + return security_get_api_encryption_key(creds, hardware_hash) + + def get_resource_encryption_key(self) -> str: + return security_get_resource_encryption_key() + + def calc_hash(self, key: str) -> str: + return security_calc_hash(key) + + def seal(self, object_path: str, data: bytes) -> None: + self._fapi.create_seal(object_path, data=data, exists_ok=True) + + def unseal(self, object_path: str) -> bytes: + return self._fapi.unseal(object_path) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_download_manager.py b/tests/test_download_manager.py new file mode 100644 index 0000000..bcb674f --- /dev/null +++ b/tests/test_download_manager.py @@ -0,0 +1,319 @@ +import hashlib +import os +import shutil +import tempfile +import unittest + +import requests + +from download_manager import ( + DownloadState, + ResumableDownloadManager, + backoff_seconds, + decrypt_cbc_file, + load_download_state, + save_download_state, +) +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import padding +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes + + +def _encrypt_cbc(plaintext: bytes, aes_key: bytes) -> bytes: + iv = os.urandom(16) + padder = padding.PKCS7(128).padder() + padded = padder.update(plaintext) + padder.finalize() + cipher = Cipher(algorithms.AES(aes_key), modes.CBC(iv), backend=default_backend()) + encryptor = cipher.encryptor() + ciphertext = encryptor.update(padded) + encryptor.finalize() + return iv + ciphertext + + +class _StreamResponse: + def __init__(self, status_code: int, chunk_source): + self.status_code = status_code + self.headers = {} + self._chunk_source = chunk_source + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def raise_for_status(self): + if self.status_code >= 400: + raise requests.HTTPError(response=self) + + def iter_content(self, chunk_size=1024 * 1024): + yield from self._chunk_source() + + +class _MockSession: + def __init__(self, handler): + self._handler = handler + + def get(self, url, headers=None, stream=True, timeout=None): + return self._handler(url, headers=headers or {}) + + +class TestBackoff(unittest.TestCase): + def test_ac5_exponential_backoff_sequence(self): + # Arrange + expected = (60, 300, 900, 3600, 14400) + # Act + values = [backoff_seconds(i) for i in range(6)] + # Assert + self.assertEqual(values[0], expected[0]) + self.assertEqual(values[1], expected[1]) + self.assertEqual(values[2], expected[2]) + self.assertEqual(values[3], expected[3]) + self.assertEqual(values[4], expected[4]) + self.assertEqual(values[5], expected[4]) + + def test_ac5_sleep_invoked_with_backoff_on_repeated_failures(self): + # Arrange + sleeps = [] + + def fake_sleep(seconds): + sleeps.append(seconds) + + key = hashlib.sha256(b"k").digest() + ciphertext = _encrypt_cbc(b"x" * 100, key) + sha = hashlib.sha256(ciphertext).hexdigest() + failures_left = [3] + + def range_start(headers): + r = headers.get("Range") + if not r: + return 0 + return int(r.split("=", 1)[1].split("-", 1)[0]) + + def handler(url, headers): + start = range_start(headers) + if failures_left[0] > 0: + failures_left[0] -= 1 + + def chunks(): + yield ciphertext[start : start + 8] + raise requests.ConnectionError("drop") + + return _StreamResponse(206 if start else 200, chunks) + + def chunks_final(): + yield ciphertext[start:] + + return _StreamResponse(206 if start else 200, chunks_final) + + tmp = tempfile.mkdtemp() + self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True)) + out = os.path.join(tmp, "out.bin") + mgr = ResumableDownloadManager( + state_directory=tmp, + session_factory=lambda: _MockSession(handler), + sleep_fn=fake_sleep, + chunk_size=16, + ) + # Act + mgr.fetch_decrypt_verify("job-backoff", "http://x", sha, len(ciphertext), key, out) + # Assert + self.assertEqual(sleeps, [60, 300, 900]) + + +class TestStatePersistence(unittest.TestCase): + def test_ac4_state_file_reload_restores_offset(self): + # Arrange + tmp = tempfile.mkdtemp() + self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True)) + tf = os.path.join(tmp, "partial.cipher.tmp") + with open(tf, "wb") as f: + f.write(b"a" * 400) + state = DownloadState( + url="http://example/a", + expected_sha256="ab" * 32, + expected_size=1000, + bytes_downloaded=400, + temp_file_path=tf, + phase="paused", + ) + path = os.path.join(tmp, "state.json") + save_download_state(path, state) + # Act + loaded = load_download_state(path) + # Assert + self.assertEqual(loaded.bytes_downloaded, 400) + self.assertEqual(loaded.expected_size, 1000) + self.assertEqual(loaded.temp_file_path, tf) + + def test_ac4_manager_resumes_from_persisted_progress(self): + # Arrange + tmp = tempfile.mkdtemp() + self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True)) + key = hashlib.sha256(b"k2").digest() + plaintext = b"full-plaintext-payload-xyz" + ciphertext = _encrypt_cbc(plaintext, key) + sha = hashlib.sha256(ciphertext).hexdigest() + partial = int(len(ciphertext) * 0.4) + safe_job = "job_resume" + tf = os.path.join(tmp, f"{safe_job}.cipher.tmp") + with open(tf, "wb") as f: + f.write(ciphertext[:partial]) + state = DownloadState( + url="http://cdn/blob", + expected_sha256=sha, + expected_size=len(ciphertext), + bytes_downloaded=partial, + temp_file_path=tf, + phase="paused", + ) + save_download_state(os.path.join(tmp, f"{safe_job}.json"), state) + seen_ranges = [] + + def handler(url, headers): + rng = headers.get("Range") + seen_ranges.append(rng) + rest = ciphertext[partial:] + + def chunks(): + yield rest + + return _StreamResponse(206, chunks) + + out = os.path.join(tmp, "plain.out") + mgr = ResumableDownloadManager( + state_directory=tmp, + session_factory=lambda: _MockSession(handler), + sleep_fn=lambda s: None, + ) + # Act + mgr.fetch_decrypt_verify(safe_job, "http://cdn/blob", sha, len(ciphertext), key, out) + # Assert + self.assertEqual(seen_ranges[0], f"bytes={partial}-") + with open(out, "rb") as f: + self.assertEqual(f.read(), plaintext) + + +class TestResumeAfterDrop(unittest.TestCase): + def test_ac1_resume_uses_range_after_partial_transfer(self): + # Arrange + tmp = tempfile.mkdtemp() + self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True)) + key = hashlib.sha256(b"k3").digest() + body = b"q" * 100 + ciphertext = _encrypt_cbc(body, key) + sha = hashlib.sha256(ciphertext).hexdigest() + cut = 60 + headers_log = [] + + def handler(url, headers): + headers_log.append(dict(headers)) + if len(headers_log) == 1: + + def chunks(): + yield ciphertext[:cut] + raise requests.ConnectionError("starlink drop") + + return _StreamResponse(200, chunks) + + def chunks2(): + yield ciphertext[cut:] + + return _StreamResponse(206, chunks2) + + out = os.path.join(tmp, "p.out") + mgr = ResumableDownloadManager( + state_directory=tmp, + session_factory=lambda: _MockSession(handler), + sleep_fn=lambda s: None, + chunk_size=32, + ) + # Act + mgr.fetch_decrypt_verify("ac1", "http://s3/o", sha, len(ciphertext), key, out) + # Assert + self.assertNotIn("Range", headers_log[0]) + self.assertEqual(headers_log[1].get("Range"), f"bytes={cut}-") + with open(out, "rb") as f: + self.assertEqual(f.read(), body) + + +class TestShaMismatchRedownload(unittest.TestCase): + def test_ac2_corrupt_hash_deletes_file_and_redownloads(self): + # Arrange + tmp = tempfile.mkdtemp() + self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True)) + key = hashlib.sha256(b"k4").digest() + good_plain = b"same-len-pt-a!" + bad_plain = b"same-len-pt-b!" + good_ct = _encrypt_cbc(good_plain, key) + bad_ct = _encrypt_cbc(bad_plain, key) + sha_good = hashlib.sha256(good_ct).hexdigest() + calls = {"n": 0} + + def handler(url, headers): + calls["n"] += 1 + data = bad_ct if calls["n"] == 1 else good_ct + + def chunks(): + yield data + + return _StreamResponse(200, chunks) + + out = os.path.join(tmp, "good.out") + mgr = ResumableDownloadManager( + state_directory=tmp, + session_factory=lambda: _MockSession(handler), + sleep_fn=lambda s: None, + ) + # Act + mgr.fetch_decrypt_verify("ac2", "http://x", sha_good, len(good_ct), key, out) + # Assert + self.assertEqual(calls["n"], 2) + with open(out, "rb") as f: + self.assertEqual(f.read(), good_plain) + + +class TestDecryptRoundTrip(unittest.TestCase): + def test_ac3_decrypt_matches_original_plaintext(self): + # Arrange + tmp = tempfile.mkdtemp() + self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True)) + key = hashlib.sha256(b"artifact-key").digest() + original = b"payload-for-roundtrip-check" + ciphertext = _encrypt_cbc(original, key) + sha = hashlib.sha256(ciphertext).hexdigest() + + def handler(url, headers): + return _StreamResponse(200, lambda: [ciphertext]) + + out = os.path.join(tmp, "decrypted.bin") + mgr = ResumableDownloadManager( + state_directory=tmp, + session_factory=lambda: _MockSession(handler), + sleep_fn=lambda s: None, + ) + # Act + mgr.fetch_decrypt_verify("ac3", "http://blob", sha, len(ciphertext), key, out) + # Assert + with open(out, "rb") as f: + self.assertEqual(f.read(), original) + + def test_decrypt_cbc_file_matches_encrypt_helper(self): + # Arrange + tmp = tempfile.mkdtemp() + self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True)) + key = hashlib.sha256(b"x").digest() + plain = b"abc" * 500 + ct = _encrypt_cbc(plain, key) + enc_path = os.path.join(tmp, "e.bin") + with open(enc_path, "wb") as f: + f.write(ct) + out_path = os.path.join(tmp, "d.bin") + # Act + decrypt_cbc_file(enc_path, key, out_path) + # Assert + with open(out_path, "rb") as f: + self.assertEqual(f.read(), plain) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_provision_device.py b/tests/test_provision_device.py new file mode 100644 index 0000000..18948f2 --- /dev/null +++ b/tests/test_provision_device.py @@ -0,0 +1,224 @@ +import json +import subprocess +import threading +from http.server import BaseHTTPRequestHandler, HTTPServer +from pathlib import Path +from urllib.parse import urlparse + +import pytest +import requests + +REPO_ROOT = Path(__file__).resolve().parents[1] +PROVISION_SCRIPT = REPO_ROOT / "scripts" / "provision_device.sh" + + +class _ProvisionTestState: + lock = threading.Lock() + users: dict[str, dict] = {} + + +def _read_json_body(handler: BaseHTTPRequestHandler) -> dict: + length = int(handler.headers.get("Content-Length", "0")) + raw = handler.rfile.read(length) if length else b"{}" + return json.loads(raw.decode("utf-8")) + + +def _send_json(handler: BaseHTTPRequestHandler, code: int, payload: dict | None = None): + body = b"" + if payload is not None: + body = json.dumps(payload).encode("utf-8") + handler.send_response(code) + handler.send_header("Content-Type", "application/json") + handler.send_header("Content-Length", str(len(body))) + handler.end_headers() + if body: + handler.wfile.write(body) + + +class _AdminMockHandler(BaseHTTPRequestHandler): + def log_message(self, _format, *_args): + return + + def do_POST(self): + parsed = urlparse(self.path) + if parsed.path != "/users": + self.send_error(404) + return + body = _read_json_body(self) + email = body.get("email", "") + password = body.get("password", "") + role = body.get("role", "") + with _ProvisionTestState.lock: + if email in _ProvisionTestState.users: + _send_json(self, 409, {"detail": "exists"}) + return + _ProvisionTestState.users[email] = {"password": password, "role": role} + _send_json(self, 201, {"email": email, "role": role}) + + def do_PATCH(self): + parsed = urlparse(self.path) + if parsed.path != "/users/password": + self.send_error(404) + return + body = _read_json_body(self) + email = body.get("email", "") + password = body.get("password", "") + with _ProvisionTestState.lock: + if email not in _ProvisionTestState.users: + self.send_error(404) + return + _ProvisionTestState.users[email]["password"] = password + _send_json(self, 200, {"status": "ok"}) + + def handle_login_post(self): + body = _read_json_body(self) + email = body.get("email", "") + password = body.get("password", "") + with _ProvisionTestState.lock: + row = _ProvisionTestState.users.get(email) + if not row or row["password"] != password or row["role"] != "CompanionPC": + _send_json(self, 401, {"detail": "invalid"}) + return + _send_json(self, 200, {"token": "provision-test-jwt"}) + + +def _handler_factory(): + class H(_AdminMockHandler): + def do_POST(self): + parsed = urlparse(self.path) + if parsed.path == "/login": + self.handle_login_post() + return + super().do_POST() + + return H + + +@pytest.fixture +def mock_admin_server(): + # Arrange + with _ProvisionTestState.lock: + _ProvisionTestState.users.clear() + server = HTTPServer(("127.0.0.1", 0), _handler_factory()) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + host, port = server.server_address + base = f"http://{host}:{port}" + yield base + server.shutdown() + server.server_close() + thread.join(timeout=5) + + +def _run_provision(serial: str, api_url: str, rootfs: Path) -> subprocess.CompletedProcess: + return subprocess.run( + [str(PROVISION_SCRIPT), "--serial", serial, "--api-url", api_url, "--rootfs-dir", str(rootfs)], + capture_output=True, + text=True, + check=False, + ) + + +def _parse_device_conf(path: Path) -> dict[str, str]: + out: dict[str, str] = {} + for line in path.read_text(encoding="utf-8").splitlines(): + if "=" not in line: + continue + key, _, val = line.partition("=") + out[key.strip()] = val.strip() + return out + + +def test_provision_creates_companionpc_user(mock_admin_server, tmp_path): + # Arrange + rootfs = tmp_path / "rootfs" + serial = "AZJN-0042" + expected_email = "azaion-jetson-0042@azaion.com" + + # Act + result = _run_provision(serial, mock_admin_server, rootfs) + + # Assert + assert result.returncode == 0, result.stderr + result.stdout + with _ProvisionTestState.lock: + row = _ProvisionTestState.users.get(expected_email) + assert row is not None + assert row["role"] == "CompanionPC" + assert len(row["password"]) == 32 + + +def test_provision_writes_device_conf(mock_admin_server, tmp_path): + # Arrange + rootfs = tmp_path / "rootfs" + serial = "AZJN-0042" + conf_path = rootfs / "etc" / "azaion" / "device.conf" + + # Act + result = _run_provision(serial, mock_admin_server, rootfs) + + # Assert + assert result.returncode == 0, result.stderr + result.stdout + assert conf_path.is_file() + data = _parse_device_conf(conf_path) + assert data["AZAION_DEVICE_EMAIL"] == "azaion-jetson-0042@azaion.com" + assert len(data["AZAION_DEVICE_PASSWORD"]) == 32 + assert data["AZAION_DEVICE_PASSWORD"].isalnum() + + +def test_credentials_allow_login_after_provision(mock_admin_server, tmp_path): + # Arrange + rootfs = tmp_path / "rootfs" + serial = "AZJN-0042" + conf_path = rootfs / "etc" / "azaion" / "device.conf" + + # Act + prov = _run_provision(serial, mock_admin_server, rootfs) + assert prov.returncode == 0, prov.stderr + prov.stdout + creds = _parse_device_conf(conf_path) + login_resp = requests.post( + f"{mock_admin_server}/login", + json={"email": creds["AZAION_DEVICE_EMAIL"], "password": creds["AZAION_DEVICE_PASSWORD"]}, + timeout=5, + ) + + # Assert + assert login_resp.status_code == 200 + assert login_resp.json().get("token") == "provision-test-jwt" + + +def test_provision_idempotent_no_duplicate_user(mock_admin_server, tmp_path): + # Arrange + rootfs = tmp_path / "rootfs" + serial = "AZJN-0042" + expected_email = "azaion-jetson-0042@azaion.com" + + # Act + first = _run_provision(serial, mock_admin_server, rootfs) + second = _run_provision(serial, mock_admin_server, rootfs) + + # Assert + assert first.returncode == 0, first.stderr + first.stdout + assert second.returncode == 0, second.stderr + second.stdout + with _ProvisionTestState.lock: + assert expected_email in _ProvisionTestState.users + assert len(_ProvisionTestState.users) == 1 + + +def test_runbook_documents_end_to_end_flow(): + # Arrange + runbook = REPO_ROOT / "_docs" / "02_document" / "deployment" / "provisioning_runbook.md" + text = runbook.read_text(encoding="utf-8") + + # Act + markers = [ + "prerequisites" in text.lower(), + "provision_device.sh" in text, + "device.conf" in text, + "POST" in text and "/users" in text, + "flash" in text.lower(), + "login" in text.lower(), + ] + + # Assert + assert runbook.is_file() + assert all(markers) diff --git a/tests/test_security_provider.py b/tests/test_security_provider.py new file mode 100644 index 0000000..369f27f --- /dev/null +++ b/tests/test_security_provider.py @@ -0,0 +1,213 @@ +import json +import os +import uuid +from pathlib import Path +from unittest.mock import MagicMock + +import pytest +import yaml +from loguru import logger + +from legacy_security_provider import LegacySecurityProvider +from security import security_decrypt_to +from security_provider import create_security_provider, should_attempt_tpm + + +def _compose_path(): + return Path(__file__).resolve().parents[1] / "e2e" / "docker-compose.test.yml" + + +@pytest.fixture +def clear_security_env(monkeypatch): + monkeypatch.delenv("SECURITY_PROVIDER", raising=False) + monkeypatch.delenv("TSS2_TCTI", raising=False) + monkeypatch.delenv("TPM2TOOLS_TCTI", raising=False) + monkeypatch.delenv("TSS2_FAPICONF", raising=False) + monkeypatch.delenv("TPM2_SIM_HOST", raising=False) + monkeypatch.delenv("TPM2_SIM_PORT", raising=False) + + +def test_ac1_auto_detection_selects_tpm_when_tpm0_present( + monkeypatch, clear_security_env +): + # Arrange + monkeypatch.setattr( + os.path, + "exists", + lambda p: str(p) == "/dev/tpm0", + ) + fake_tpm = MagicMock() + fake_tpm.kind = "tpm" + import tpm_security_provider as tsp + + monkeypatch.setattr(tsp, "TpmSecurityProvider", lambda: fake_tpm) + + # Act + provider = create_security_provider() + + # Assert + assert provider is fake_tpm + + +def test_ac2_tpm_seal_unseal_roundtrip(tmp_path, monkeypatch): + # Arrange + sim_host = os.environ.get("TPM2_SIM_HOST", "") + sim_port = os.environ.get("TPM2_SIM_PORT", "2321") + fapi_conf = os.environ.get("TSS2_FAPICONF", "") + if not fapi_conf and not sim_host: + pytest.skip( + "Set TPM2_SIM_HOST or TSS2_FAPICONF for TPM simulator (e.g. Docker swtpm)" + ) + if sim_host and not fapi_conf: + (tmp_path / "user").mkdir() + (tmp_path / "system" / "policy").mkdir(parents=True) + (tmp_path / "log").mkdir() + cfg = { + "profile_name": "P_ECCP256SHA256", + "profile_dir": "/etc/tpm2-tss/fapi-profiles/", + "user_dir": str(tmp_path / "user"), + "system_dir": str(tmp_path / "system"), + "tcti": f"swtpm:host={sim_host},port={sim_port}", + "ek_cert_less": "yes", + "system_pcrs": [], + "log_dir": str(tmp_path / "log"), + "firmware_log_file": "/dev/null", + "ima_log_file": "/dev/null", + } + p = tmp_path / "fapi.json" + p.write_text(json.dumps(cfg), encoding="utf-8") + monkeypatch.setenv("TSS2_FAPICONF", str(p)) + + from tpm_security_provider import TpmSecurityProvider + + try: + provider = TpmSecurityProvider() + except Exception: + pytest.skip("TPM simulator not reachable with current FAPI config") + payload = b"azaion-loader-seal-test" + path = f"/HS/SRK/az182_{uuid.uuid4().hex}" + + # Act + try: + provider.seal(path, payload) + out = provider.unseal(path) + finally: + try: + provider._fapi.delete(path) + except Exception: + pass + + # Assert + assert out == payload + + +def test_ac3_legacy_when_no_tpm_device_or_tcti(monkeypatch, clear_security_env): + # Arrange + monkeypatch.setattr(os.path, "exists", lambda p: False) + + # Act + provider = create_security_provider() + + # Assert + assert provider.kind == "legacy" + blob = provider.encrypt_to(b"plain", "secret-key") + assert provider.decrypt_to(blob, "secret-key") == b"plain" + assert ( + provider.decrypt_to(blob, "secret-key") + == security_decrypt_to(blob, "secret-key") + ) + + +def test_ac4_env_legacy_overrides_tpm_device(monkeypatch, clear_security_env): + # Arrange + monkeypatch.setenv("SECURITY_PROVIDER", "legacy") + monkeypatch.setattr( + os.path, + "exists", + lambda p: str(p) in ("/dev/tpm0", "/dev/tpmrm0"), + ) + + # Act + provider = create_security_provider() + + # Assert + assert provider.kind == "legacy" + + +def test_ac5_fapi_failure_falls_back_to_legacy_with_warning( + monkeypatch, clear_security_env +): + # Arrange + monkeypatch.setattr( + os.path, + "exists", + lambda p: str(p) == "/dev/tpm0", + ) + import tpm_security_provider as tsp + + def _boom(*_a, **_k): + raise RuntimeError("fapi init failed") + + monkeypatch.setattr(tsp, "TpmSecurityProvider", _boom) + messages = [] + + def _capture(message): + messages.append(str(message)) + + hid = logger.add(_capture, level="WARNING") + + # Act + try: + provider = create_security_provider() + finally: + logger.remove(hid) + + # Assert + assert provider.kind == "legacy" + assert any("TPM security provider failed" in m for m in messages) + + +def test_ac6_compose_declares_tpm_device_mounts_and_swtpm(): + # Arrange + raw = _compose_path().read_text(encoding="utf-8") + data = yaml.safe_load(raw) + + # Assert + jetson = data["x-tpm-device-mounts-for-jetson"] + assert "/dev/tpm0" in jetson["devices"] + assert "/dev/tpmrm0" in jetson["devices"] + assert "swtpm" in data["services"] + sut_env = data["services"]["system-under-test"]["environment"] + assert "TSS2_FAPICONF" in sut_env + sut_vols = data["services"]["system-under-test"]["volumes"] + assert any("fapi-config" in str(v) for v in sut_vols) + fapi_file = Path(__file__).resolve().parents[1] / "e2e" / "fapi-config.swtpm.json" + assert "swtpm:" in fapi_file.read_text(encoding="utf-8") + + +def test_should_attempt_tpm_respects_device_and_tcti(monkeypatch, clear_security_env): + # Arrange / Act / Assert + monkeypatch.setattr(os.path, "exists", lambda p: False) + assert should_attempt_tpm(os.environ, os.path.exists) is False + monkeypatch.setenv("TSS2_TCTI", "mssim:host=127.0.0.1,port=2321") + assert should_attempt_tpm(os.environ, os.path.exists) is True + monkeypatch.delenv("TSS2_TCTI", raising=False) + monkeypatch.setenv("TSS2_FAPICONF", "/etc/tpm2-tss/fapi-config.json") + assert should_attempt_tpm(os.environ, os.path.exists) is True + monkeypatch.delenv("TSS2_FAPICONF", raising=False) + monkeypatch.setattr(os.path, "exists", lambda p: str(p) == "/dev/tpmrm0") + assert should_attempt_tpm(os.environ, os.path.exists) is True + + +def test_legacy_provider_matches_security_module_helpers(): + # Arrange + leg = LegacySecurityProvider() + data = b"x" * 500 + key = "k" + + # Act + enc = leg.encrypt_to(data, key) + + # Assert + assert security_decrypt_to(enc, key) == data + assert leg.decrypt_to(enc, key) == data