mirror of
https://github.com/azaion/loader.git
synced 2026-04-22 10:46:32 +00:00
[AZ-182][AZ-184][AZ-187] Batch 1
Made-with: Cursor
This commit is contained in:
@@ -0,0 +1,319 @@
|
||||
import hashlib
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from download_manager import (
|
||||
DownloadState,
|
||||
ResumableDownloadManager,
|
||||
backoff_seconds,
|
||||
decrypt_cbc_file,
|
||||
load_download_state,
|
||||
save_download_state,
|
||||
)
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import padding
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
|
||||
|
||||
def _encrypt_cbc(plaintext: bytes, aes_key: bytes) -> bytes:
|
||||
iv = os.urandom(16)
|
||||
padder = padding.PKCS7(128).padder()
|
||||
padded = padder.update(plaintext) + padder.finalize()
|
||||
cipher = Cipher(algorithms.AES(aes_key), modes.CBC(iv), backend=default_backend())
|
||||
encryptor = cipher.encryptor()
|
||||
ciphertext = encryptor.update(padded) + encryptor.finalize()
|
||||
return iv + ciphertext
|
||||
|
||||
|
||||
class _StreamResponse:
|
||||
def __init__(self, status_code: int, chunk_source):
|
||||
self.status_code = status_code
|
||||
self.headers = {}
|
||||
self._chunk_source = chunk_source
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def raise_for_status(self):
|
||||
if self.status_code >= 400:
|
||||
raise requests.HTTPError(response=self)
|
||||
|
||||
def iter_content(self, chunk_size=1024 * 1024):
|
||||
yield from self._chunk_source()
|
||||
|
||||
|
||||
class _MockSession:
|
||||
def __init__(self, handler):
|
||||
self._handler = handler
|
||||
|
||||
def get(self, url, headers=None, stream=True, timeout=None):
|
||||
return self._handler(url, headers=headers or {})
|
||||
|
||||
|
||||
class TestBackoff(unittest.TestCase):
|
||||
def test_ac5_exponential_backoff_sequence(self):
|
||||
# Arrange
|
||||
expected = (60, 300, 900, 3600, 14400)
|
||||
# Act
|
||||
values = [backoff_seconds(i) for i in range(6)]
|
||||
# Assert
|
||||
self.assertEqual(values[0], expected[0])
|
||||
self.assertEqual(values[1], expected[1])
|
||||
self.assertEqual(values[2], expected[2])
|
||||
self.assertEqual(values[3], expected[3])
|
||||
self.assertEqual(values[4], expected[4])
|
||||
self.assertEqual(values[5], expected[4])
|
||||
|
||||
def test_ac5_sleep_invoked_with_backoff_on_repeated_failures(self):
|
||||
# Arrange
|
||||
sleeps = []
|
||||
|
||||
def fake_sleep(seconds):
|
||||
sleeps.append(seconds)
|
||||
|
||||
key = hashlib.sha256(b"k").digest()
|
||||
ciphertext = _encrypt_cbc(b"x" * 100, key)
|
||||
sha = hashlib.sha256(ciphertext).hexdigest()
|
||||
failures_left = [3]
|
||||
|
||||
def range_start(headers):
|
||||
r = headers.get("Range")
|
||||
if not r:
|
||||
return 0
|
||||
return int(r.split("=", 1)[1].split("-", 1)[0])
|
||||
|
||||
def handler(url, headers):
|
||||
start = range_start(headers)
|
||||
if failures_left[0] > 0:
|
||||
failures_left[0] -= 1
|
||||
|
||||
def chunks():
|
||||
yield ciphertext[start : start + 8]
|
||||
raise requests.ConnectionError("drop")
|
||||
|
||||
return _StreamResponse(206 if start else 200, chunks)
|
||||
|
||||
def chunks_final():
|
||||
yield ciphertext[start:]
|
||||
|
||||
return _StreamResponse(206 if start else 200, chunks_final)
|
||||
|
||||
tmp = tempfile.mkdtemp()
|
||||
self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True))
|
||||
out = os.path.join(tmp, "out.bin")
|
||||
mgr = ResumableDownloadManager(
|
||||
state_directory=tmp,
|
||||
session_factory=lambda: _MockSession(handler),
|
||||
sleep_fn=fake_sleep,
|
||||
chunk_size=16,
|
||||
)
|
||||
# Act
|
||||
mgr.fetch_decrypt_verify("job-backoff", "http://x", sha, len(ciphertext), key, out)
|
||||
# Assert
|
||||
self.assertEqual(sleeps, [60, 300, 900])
|
||||
|
||||
|
||||
class TestStatePersistence(unittest.TestCase):
|
||||
def test_ac4_state_file_reload_restores_offset(self):
|
||||
# Arrange
|
||||
tmp = tempfile.mkdtemp()
|
||||
self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True))
|
||||
tf = os.path.join(tmp, "partial.cipher.tmp")
|
||||
with open(tf, "wb") as f:
|
||||
f.write(b"a" * 400)
|
||||
state = DownloadState(
|
||||
url="http://example/a",
|
||||
expected_sha256="ab" * 32,
|
||||
expected_size=1000,
|
||||
bytes_downloaded=400,
|
||||
temp_file_path=tf,
|
||||
phase="paused",
|
||||
)
|
||||
path = os.path.join(tmp, "state.json")
|
||||
save_download_state(path, state)
|
||||
# Act
|
||||
loaded = load_download_state(path)
|
||||
# Assert
|
||||
self.assertEqual(loaded.bytes_downloaded, 400)
|
||||
self.assertEqual(loaded.expected_size, 1000)
|
||||
self.assertEqual(loaded.temp_file_path, tf)
|
||||
|
||||
def test_ac4_manager_resumes_from_persisted_progress(self):
|
||||
# Arrange
|
||||
tmp = tempfile.mkdtemp()
|
||||
self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True))
|
||||
key = hashlib.sha256(b"k2").digest()
|
||||
plaintext = b"full-plaintext-payload-xyz"
|
||||
ciphertext = _encrypt_cbc(plaintext, key)
|
||||
sha = hashlib.sha256(ciphertext).hexdigest()
|
||||
partial = int(len(ciphertext) * 0.4)
|
||||
safe_job = "job_resume"
|
||||
tf = os.path.join(tmp, f"{safe_job}.cipher.tmp")
|
||||
with open(tf, "wb") as f:
|
||||
f.write(ciphertext[:partial])
|
||||
state = DownloadState(
|
||||
url="http://cdn/blob",
|
||||
expected_sha256=sha,
|
||||
expected_size=len(ciphertext),
|
||||
bytes_downloaded=partial,
|
||||
temp_file_path=tf,
|
||||
phase="paused",
|
||||
)
|
||||
save_download_state(os.path.join(tmp, f"{safe_job}.json"), state)
|
||||
seen_ranges = []
|
||||
|
||||
def handler(url, headers):
|
||||
rng = headers.get("Range")
|
||||
seen_ranges.append(rng)
|
||||
rest = ciphertext[partial:]
|
||||
|
||||
def chunks():
|
||||
yield rest
|
||||
|
||||
return _StreamResponse(206, chunks)
|
||||
|
||||
out = os.path.join(tmp, "plain.out")
|
||||
mgr = ResumableDownloadManager(
|
||||
state_directory=tmp,
|
||||
session_factory=lambda: _MockSession(handler),
|
||||
sleep_fn=lambda s: None,
|
||||
)
|
||||
# Act
|
||||
mgr.fetch_decrypt_verify(safe_job, "http://cdn/blob", sha, len(ciphertext), key, out)
|
||||
# Assert
|
||||
self.assertEqual(seen_ranges[0], f"bytes={partial}-")
|
||||
with open(out, "rb") as f:
|
||||
self.assertEqual(f.read(), plaintext)
|
||||
|
||||
|
||||
class TestResumeAfterDrop(unittest.TestCase):
|
||||
def test_ac1_resume_uses_range_after_partial_transfer(self):
|
||||
# Arrange
|
||||
tmp = tempfile.mkdtemp()
|
||||
self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True))
|
||||
key = hashlib.sha256(b"k3").digest()
|
||||
body = b"q" * 100
|
||||
ciphertext = _encrypt_cbc(body, key)
|
||||
sha = hashlib.sha256(ciphertext).hexdigest()
|
||||
cut = 60
|
||||
headers_log = []
|
||||
|
||||
def handler(url, headers):
|
||||
headers_log.append(dict(headers))
|
||||
if len(headers_log) == 1:
|
||||
|
||||
def chunks():
|
||||
yield ciphertext[:cut]
|
||||
raise requests.ConnectionError("starlink drop")
|
||||
|
||||
return _StreamResponse(200, chunks)
|
||||
|
||||
def chunks2():
|
||||
yield ciphertext[cut:]
|
||||
|
||||
return _StreamResponse(206, chunks2)
|
||||
|
||||
out = os.path.join(tmp, "p.out")
|
||||
mgr = ResumableDownloadManager(
|
||||
state_directory=tmp,
|
||||
session_factory=lambda: _MockSession(handler),
|
||||
sleep_fn=lambda s: None,
|
||||
chunk_size=32,
|
||||
)
|
||||
# Act
|
||||
mgr.fetch_decrypt_verify("ac1", "http://s3/o", sha, len(ciphertext), key, out)
|
||||
# Assert
|
||||
self.assertNotIn("Range", headers_log[0])
|
||||
self.assertEqual(headers_log[1].get("Range"), f"bytes={cut}-")
|
||||
with open(out, "rb") as f:
|
||||
self.assertEqual(f.read(), body)
|
||||
|
||||
|
||||
class TestShaMismatchRedownload(unittest.TestCase):
|
||||
def test_ac2_corrupt_hash_deletes_file_and_redownloads(self):
|
||||
# Arrange
|
||||
tmp = tempfile.mkdtemp()
|
||||
self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True))
|
||||
key = hashlib.sha256(b"k4").digest()
|
||||
good_plain = b"same-len-pt-a!"
|
||||
bad_plain = b"same-len-pt-b!"
|
||||
good_ct = _encrypt_cbc(good_plain, key)
|
||||
bad_ct = _encrypt_cbc(bad_plain, key)
|
||||
sha_good = hashlib.sha256(good_ct).hexdigest()
|
||||
calls = {"n": 0}
|
||||
|
||||
def handler(url, headers):
|
||||
calls["n"] += 1
|
||||
data = bad_ct if calls["n"] == 1 else good_ct
|
||||
|
||||
def chunks():
|
||||
yield data
|
||||
|
||||
return _StreamResponse(200, chunks)
|
||||
|
||||
out = os.path.join(tmp, "good.out")
|
||||
mgr = ResumableDownloadManager(
|
||||
state_directory=tmp,
|
||||
session_factory=lambda: _MockSession(handler),
|
||||
sleep_fn=lambda s: None,
|
||||
)
|
||||
# Act
|
||||
mgr.fetch_decrypt_verify("ac2", "http://x", sha_good, len(good_ct), key, out)
|
||||
# Assert
|
||||
self.assertEqual(calls["n"], 2)
|
||||
with open(out, "rb") as f:
|
||||
self.assertEqual(f.read(), good_plain)
|
||||
|
||||
|
||||
class TestDecryptRoundTrip(unittest.TestCase):
|
||||
def test_ac3_decrypt_matches_original_plaintext(self):
|
||||
# Arrange
|
||||
tmp = tempfile.mkdtemp()
|
||||
self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True))
|
||||
key = hashlib.sha256(b"artifact-key").digest()
|
||||
original = b"payload-for-roundtrip-check"
|
||||
ciphertext = _encrypt_cbc(original, key)
|
||||
sha = hashlib.sha256(ciphertext).hexdigest()
|
||||
|
||||
def handler(url, headers):
|
||||
return _StreamResponse(200, lambda: [ciphertext])
|
||||
|
||||
out = os.path.join(tmp, "decrypted.bin")
|
||||
mgr = ResumableDownloadManager(
|
||||
state_directory=tmp,
|
||||
session_factory=lambda: _MockSession(handler),
|
||||
sleep_fn=lambda s: None,
|
||||
)
|
||||
# Act
|
||||
mgr.fetch_decrypt_verify("ac3", "http://blob", sha, len(ciphertext), key, out)
|
||||
# Assert
|
||||
with open(out, "rb") as f:
|
||||
self.assertEqual(f.read(), original)
|
||||
|
||||
def test_decrypt_cbc_file_matches_encrypt_helper(self):
|
||||
# Arrange
|
||||
tmp = tempfile.mkdtemp()
|
||||
self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True))
|
||||
key = hashlib.sha256(b"x").digest()
|
||||
plain = b"abc" * 500
|
||||
ct = _encrypt_cbc(plain, key)
|
||||
enc_path = os.path.join(tmp, "e.bin")
|
||||
with open(enc_path, "wb") as f:
|
||||
f.write(ct)
|
||||
out_path = os.path.join(tmp, "d.bin")
|
||||
# Act
|
||||
decrypt_cbc_file(enc_path, key, out_path)
|
||||
# Assert
|
||||
with open(out_path, "rb") as f:
|
||||
self.assertEqual(f.read(), plain)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,224 @@
|
||||
import json
|
||||
import subprocess
|
||||
import threading
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
PROVISION_SCRIPT = REPO_ROOT / "scripts" / "provision_device.sh"
|
||||
|
||||
|
||||
class _ProvisionTestState:
|
||||
lock = threading.Lock()
|
||||
users: dict[str, dict] = {}
|
||||
|
||||
|
||||
def _read_json_body(handler: BaseHTTPRequestHandler) -> dict:
|
||||
length = int(handler.headers.get("Content-Length", "0"))
|
||||
raw = handler.rfile.read(length) if length else b"{}"
|
||||
return json.loads(raw.decode("utf-8"))
|
||||
|
||||
|
||||
def _send_json(handler: BaseHTTPRequestHandler, code: int, payload: dict | None = None):
|
||||
body = b""
|
||||
if payload is not None:
|
||||
body = json.dumps(payload).encode("utf-8")
|
||||
handler.send_response(code)
|
||||
handler.send_header("Content-Type", "application/json")
|
||||
handler.send_header("Content-Length", str(len(body)))
|
||||
handler.end_headers()
|
||||
if body:
|
||||
handler.wfile.write(body)
|
||||
|
||||
|
||||
class _AdminMockHandler(BaseHTTPRequestHandler):
|
||||
def log_message(self, _format, *_args):
|
||||
return
|
||||
|
||||
def do_POST(self):
|
||||
parsed = urlparse(self.path)
|
||||
if parsed.path != "/users":
|
||||
self.send_error(404)
|
||||
return
|
||||
body = _read_json_body(self)
|
||||
email = body.get("email", "")
|
||||
password = body.get("password", "")
|
||||
role = body.get("role", "")
|
||||
with _ProvisionTestState.lock:
|
||||
if email in _ProvisionTestState.users:
|
||||
_send_json(self, 409, {"detail": "exists"})
|
||||
return
|
||||
_ProvisionTestState.users[email] = {"password": password, "role": role}
|
||||
_send_json(self, 201, {"email": email, "role": role})
|
||||
|
||||
def do_PATCH(self):
|
||||
parsed = urlparse(self.path)
|
||||
if parsed.path != "/users/password":
|
||||
self.send_error(404)
|
||||
return
|
||||
body = _read_json_body(self)
|
||||
email = body.get("email", "")
|
||||
password = body.get("password", "")
|
||||
with _ProvisionTestState.lock:
|
||||
if email not in _ProvisionTestState.users:
|
||||
self.send_error(404)
|
||||
return
|
||||
_ProvisionTestState.users[email]["password"] = password
|
||||
_send_json(self, 200, {"status": "ok"})
|
||||
|
||||
def handle_login_post(self):
|
||||
body = _read_json_body(self)
|
||||
email = body.get("email", "")
|
||||
password = body.get("password", "")
|
||||
with _ProvisionTestState.lock:
|
||||
row = _ProvisionTestState.users.get(email)
|
||||
if not row or row["password"] != password or row["role"] != "CompanionPC":
|
||||
_send_json(self, 401, {"detail": "invalid"})
|
||||
return
|
||||
_send_json(self, 200, {"token": "provision-test-jwt"})
|
||||
|
||||
|
||||
def _handler_factory():
|
||||
class H(_AdminMockHandler):
|
||||
def do_POST(self):
|
||||
parsed = urlparse(self.path)
|
||||
if parsed.path == "/login":
|
||||
self.handle_login_post()
|
||||
return
|
||||
super().do_POST()
|
||||
|
||||
return H
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_admin_server():
|
||||
# Arrange
|
||||
with _ProvisionTestState.lock:
|
||||
_ProvisionTestState.users.clear()
|
||||
server = HTTPServer(("127.0.0.1", 0), _handler_factory())
|
||||
thread = threading.Thread(target=server.serve_forever, daemon=True)
|
||||
thread.start()
|
||||
host, port = server.server_address
|
||||
base = f"http://{host}:{port}"
|
||||
yield base
|
||||
server.shutdown()
|
||||
server.server_close()
|
||||
thread.join(timeout=5)
|
||||
|
||||
|
||||
def _run_provision(serial: str, api_url: str, rootfs: Path) -> subprocess.CompletedProcess:
|
||||
return subprocess.run(
|
||||
[str(PROVISION_SCRIPT), "--serial", serial, "--api-url", api_url, "--rootfs-dir", str(rootfs)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
|
||||
|
||||
def _parse_device_conf(path: Path) -> dict[str, str]:
|
||||
out: dict[str, str] = {}
|
||||
for line in path.read_text(encoding="utf-8").splitlines():
|
||||
if "=" not in line:
|
||||
continue
|
||||
key, _, val = line.partition("=")
|
||||
out[key.strip()] = val.strip()
|
||||
return out
|
||||
|
||||
|
||||
def test_provision_creates_companionpc_user(mock_admin_server, tmp_path):
|
||||
# Arrange
|
||||
rootfs = tmp_path / "rootfs"
|
||||
serial = "AZJN-0042"
|
||||
expected_email = "azaion-jetson-0042@azaion.com"
|
||||
|
||||
# Act
|
||||
result = _run_provision(serial, mock_admin_server, rootfs)
|
||||
|
||||
# Assert
|
||||
assert result.returncode == 0, result.stderr + result.stdout
|
||||
with _ProvisionTestState.lock:
|
||||
row = _ProvisionTestState.users.get(expected_email)
|
||||
assert row is not None
|
||||
assert row["role"] == "CompanionPC"
|
||||
assert len(row["password"]) == 32
|
||||
|
||||
|
||||
def test_provision_writes_device_conf(mock_admin_server, tmp_path):
|
||||
# Arrange
|
||||
rootfs = tmp_path / "rootfs"
|
||||
serial = "AZJN-0042"
|
||||
conf_path = rootfs / "etc" / "azaion" / "device.conf"
|
||||
|
||||
# Act
|
||||
result = _run_provision(serial, mock_admin_server, rootfs)
|
||||
|
||||
# Assert
|
||||
assert result.returncode == 0, result.stderr + result.stdout
|
||||
assert conf_path.is_file()
|
||||
data = _parse_device_conf(conf_path)
|
||||
assert data["AZAION_DEVICE_EMAIL"] == "azaion-jetson-0042@azaion.com"
|
||||
assert len(data["AZAION_DEVICE_PASSWORD"]) == 32
|
||||
assert data["AZAION_DEVICE_PASSWORD"].isalnum()
|
||||
|
||||
|
||||
def test_credentials_allow_login_after_provision(mock_admin_server, tmp_path):
|
||||
# Arrange
|
||||
rootfs = tmp_path / "rootfs"
|
||||
serial = "AZJN-0042"
|
||||
conf_path = rootfs / "etc" / "azaion" / "device.conf"
|
||||
|
||||
# Act
|
||||
prov = _run_provision(serial, mock_admin_server, rootfs)
|
||||
assert prov.returncode == 0, prov.stderr + prov.stdout
|
||||
creds = _parse_device_conf(conf_path)
|
||||
login_resp = requests.post(
|
||||
f"{mock_admin_server}/login",
|
||||
json={"email": creds["AZAION_DEVICE_EMAIL"], "password": creds["AZAION_DEVICE_PASSWORD"]},
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert login_resp.status_code == 200
|
||||
assert login_resp.json().get("token") == "provision-test-jwt"
|
||||
|
||||
|
||||
def test_provision_idempotent_no_duplicate_user(mock_admin_server, tmp_path):
|
||||
# Arrange
|
||||
rootfs = tmp_path / "rootfs"
|
||||
serial = "AZJN-0042"
|
||||
expected_email = "azaion-jetson-0042@azaion.com"
|
||||
|
||||
# Act
|
||||
first = _run_provision(serial, mock_admin_server, rootfs)
|
||||
second = _run_provision(serial, mock_admin_server, rootfs)
|
||||
|
||||
# Assert
|
||||
assert first.returncode == 0, first.stderr + first.stdout
|
||||
assert second.returncode == 0, second.stderr + second.stdout
|
||||
with _ProvisionTestState.lock:
|
||||
assert expected_email in _ProvisionTestState.users
|
||||
assert len(_ProvisionTestState.users) == 1
|
||||
|
||||
|
||||
def test_runbook_documents_end_to_end_flow():
|
||||
# Arrange
|
||||
runbook = REPO_ROOT / "_docs" / "02_document" / "deployment" / "provisioning_runbook.md"
|
||||
text = runbook.read_text(encoding="utf-8")
|
||||
|
||||
# Act
|
||||
markers = [
|
||||
"prerequisites" in text.lower(),
|
||||
"provision_device.sh" in text,
|
||||
"device.conf" in text,
|
||||
"POST" in text and "/users" in text,
|
||||
"flash" in text.lower(),
|
||||
"login" in text.lower(),
|
||||
]
|
||||
|
||||
# Assert
|
||||
assert runbook.is_file()
|
||||
assert all(markers)
|
||||
@@ -0,0 +1,213 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from loguru import logger
|
||||
|
||||
from legacy_security_provider import LegacySecurityProvider
|
||||
from security import security_decrypt_to
|
||||
from security_provider import create_security_provider, should_attempt_tpm
|
||||
|
||||
|
||||
def _compose_path():
|
||||
return Path(__file__).resolve().parents[1] / "e2e" / "docker-compose.test.yml"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clear_security_env(monkeypatch):
|
||||
monkeypatch.delenv("SECURITY_PROVIDER", raising=False)
|
||||
monkeypatch.delenv("TSS2_TCTI", raising=False)
|
||||
monkeypatch.delenv("TPM2TOOLS_TCTI", raising=False)
|
||||
monkeypatch.delenv("TSS2_FAPICONF", raising=False)
|
||||
monkeypatch.delenv("TPM2_SIM_HOST", raising=False)
|
||||
monkeypatch.delenv("TPM2_SIM_PORT", raising=False)
|
||||
|
||||
|
||||
def test_ac1_auto_detection_selects_tpm_when_tpm0_present(
|
||||
monkeypatch, clear_security_env
|
||||
):
|
||||
# Arrange
|
||||
monkeypatch.setattr(
|
||||
os.path,
|
||||
"exists",
|
||||
lambda p: str(p) == "/dev/tpm0",
|
||||
)
|
||||
fake_tpm = MagicMock()
|
||||
fake_tpm.kind = "tpm"
|
||||
import tpm_security_provider as tsp
|
||||
|
||||
monkeypatch.setattr(tsp, "TpmSecurityProvider", lambda: fake_tpm)
|
||||
|
||||
# Act
|
||||
provider = create_security_provider()
|
||||
|
||||
# Assert
|
||||
assert provider is fake_tpm
|
||||
|
||||
|
||||
def test_ac2_tpm_seal_unseal_roundtrip(tmp_path, monkeypatch):
|
||||
# Arrange
|
||||
sim_host = os.environ.get("TPM2_SIM_HOST", "")
|
||||
sim_port = os.environ.get("TPM2_SIM_PORT", "2321")
|
||||
fapi_conf = os.environ.get("TSS2_FAPICONF", "")
|
||||
if not fapi_conf and not sim_host:
|
||||
pytest.skip(
|
||||
"Set TPM2_SIM_HOST or TSS2_FAPICONF for TPM simulator (e.g. Docker swtpm)"
|
||||
)
|
||||
if sim_host and not fapi_conf:
|
||||
(tmp_path / "user").mkdir()
|
||||
(tmp_path / "system" / "policy").mkdir(parents=True)
|
||||
(tmp_path / "log").mkdir()
|
||||
cfg = {
|
||||
"profile_name": "P_ECCP256SHA256",
|
||||
"profile_dir": "/etc/tpm2-tss/fapi-profiles/",
|
||||
"user_dir": str(tmp_path / "user"),
|
||||
"system_dir": str(tmp_path / "system"),
|
||||
"tcti": f"swtpm:host={sim_host},port={sim_port}",
|
||||
"ek_cert_less": "yes",
|
||||
"system_pcrs": [],
|
||||
"log_dir": str(tmp_path / "log"),
|
||||
"firmware_log_file": "/dev/null",
|
||||
"ima_log_file": "/dev/null",
|
||||
}
|
||||
p = tmp_path / "fapi.json"
|
||||
p.write_text(json.dumps(cfg), encoding="utf-8")
|
||||
monkeypatch.setenv("TSS2_FAPICONF", str(p))
|
||||
|
||||
from tpm_security_provider import TpmSecurityProvider
|
||||
|
||||
try:
|
||||
provider = TpmSecurityProvider()
|
||||
except Exception:
|
||||
pytest.skip("TPM simulator not reachable with current FAPI config")
|
||||
payload = b"azaion-loader-seal-test"
|
||||
path = f"/HS/SRK/az182_{uuid.uuid4().hex}"
|
||||
|
||||
# Act
|
||||
try:
|
||||
provider.seal(path, payload)
|
||||
out = provider.unseal(path)
|
||||
finally:
|
||||
try:
|
||||
provider._fapi.delete(path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Assert
|
||||
assert out == payload
|
||||
|
||||
|
||||
def test_ac3_legacy_when_no_tpm_device_or_tcti(monkeypatch, clear_security_env):
|
||||
# Arrange
|
||||
monkeypatch.setattr(os.path, "exists", lambda p: False)
|
||||
|
||||
# Act
|
||||
provider = create_security_provider()
|
||||
|
||||
# Assert
|
||||
assert provider.kind == "legacy"
|
||||
blob = provider.encrypt_to(b"plain", "secret-key")
|
||||
assert provider.decrypt_to(blob, "secret-key") == b"plain"
|
||||
assert (
|
||||
provider.decrypt_to(blob, "secret-key")
|
||||
== security_decrypt_to(blob, "secret-key")
|
||||
)
|
||||
|
||||
|
||||
def test_ac4_env_legacy_overrides_tpm_device(monkeypatch, clear_security_env):
|
||||
# Arrange
|
||||
monkeypatch.setenv("SECURITY_PROVIDER", "legacy")
|
||||
monkeypatch.setattr(
|
||||
os.path,
|
||||
"exists",
|
||||
lambda p: str(p) in ("/dev/tpm0", "/dev/tpmrm0"),
|
||||
)
|
||||
|
||||
# Act
|
||||
provider = create_security_provider()
|
||||
|
||||
# Assert
|
||||
assert provider.kind == "legacy"
|
||||
|
||||
|
||||
def test_ac5_fapi_failure_falls_back_to_legacy_with_warning(
|
||||
monkeypatch, clear_security_env
|
||||
):
|
||||
# Arrange
|
||||
monkeypatch.setattr(
|
||||
os.path,
|
||||
"exists",
|
||||
lambda p: str(p) == "/dev/tpm0",
|
||||
)
|
||||
import tpm_security_provider as tsp
|
||||
|
||||
def _boom(*_a, **_k):
|
||||
raise RuntimeError("fapi init failed")
|
||||
|
||||
monkeypatch.setattr(tsp, "TpmSecurityProvider", _boom)
|
||||
messages = []
|
||||
|
||||
def _capture(message):
|
||||
messages.append(str(message))
|
||||
|
||||
hid = logger.add(_capture, level="WARNING")
|
||||
|
||||
# Act
|
||||
try:
|
||||
provider = create_security_provider()
|
||||
finally:
|
||||
logger.remove(hid)
|
||||
|
||||
# Assert
|
||||
assert provider.kind == "legacy"
|
||||
assert any("TPM security provider failed" in m for m in messages)
|
||||
|
||||
|
||||
def test_ac6_compose_declares_tpm_device_mounts_and_swtpm():
|
||||
# Arrange
|
||||
raw = _compose_path().read_text(encoding="utf-8")
|
||||
data = yaml.safe_load(raw)
|
||||
|
||||
# Assert
|
||||
jetson = data["x-tpm-device-mounts-for-jetson"]
|
||||
assert "/dev/tpm0" in jetson["devices"]
|
||||
assert "/dev/tpmrm0" in jetson["devices"]
|
||||
assert "swtpm" in data["services"]
|
||||
sut_env = data["services"]["system-under-test"]["environment"]
|
||||
assert "TSS2_FAPICONF" in sut_env
|
||||
sut_vols = data["services"]["system-under-test"]["volumes"]
|
||||
assert any("fapi-config" in str(v) for v in sut_vols)
|
||||
fapi_file = Path(__file__).resolve().parents[1] / "e2e" / "fapi-config.swtpm.json"
|
||||
assert "swtpm:" in fapi_file.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def test_should_attempt_tpm_respects_device_and_tcti(monkeypatch, clear_security_env):
|
||||
# Arrange / Act / Assert
|
||||
monkeypatch.setattr(os.path, "exists", lambda p: False)
|
||||
assert should_attempt_tpm(os.environ, os.path.exists) is False
|
||||
monkeypatch.setenv("TSS2_TCTI", "mssim:host=127.0.0.1,port=2321")
|
||||
assert should_attempt_tpm(os.environ, os.path.exists) is True
|
||||
monkeypatch.delenv("TSS2_TCTI", raising=False)
|
||||
monkeypatch.setenv("TSS2_FAPICONF", "/etc/tpm2-tss/fapi-config.json")
|
||||
assert should_attempt_tpm(os.environ, os.path.exists) is True
|
||||
monkeypatch.delenv("TSS2_FAPICONF", raising=False)
|
||||
monkeypatch.setattr(os.path, "exists", lambda p: str(p) == "/dev/tpmrm0")
|
||||
assert should_attempt_tpm(os.environ, os.path.exists) is True
|
||||
|
||||
|
||||
def test_legacy_provider_matches_security_module_helpers():
|
||||
# Arrange
|
||||
leg = LegacySecurityProvider()
|
||||
data = b"x" * 500
|
||||
key = "k"
|
||||
|
||||
# Act
|
||||
enc = leg.encrypt_to(data, key)
|
||||
|
||||
# Assert
|
||||
assert security_decrypt_to(enc, key) == data
|
||||
assert leg.decrypt_to(enc, key) == data
|
||||
Reference in New Issue
Block a user