[AZ-182][AZ-184][AZ-187] Batch 1

Made-with: Cursor
This commit is contained in:
Oleksandr Bezdieniezhnykh
2026-04-15 07:23:47 +03:00
parent 765d3d32c1
commit d244799f02
22 changed files with 1622 additions and 16 deletions
+280
View File
@@ -0,0 +1,280 @@
import hashlib
import json
import os
import tempfile
import time
from dataclasses import asdict, dataclass
from typing import Callable, Optional
import requests
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from loguru import logger
def backoff_seconds(failure_index: int) -> int:
sequence = (60, 300, 900, 3600, 14400)
idx = min(max(0, failure_index), len(sequence) - 1)
return sequence[idx]
@dataclass
class DownloadState:
url: str
expected_sha256: str
expected_size: int
bytes_downloaded: int
temp_file_path: str
phase: str
def to_json_dict(self) -> dict:
return asdict(self)
@classmethod
def from_json_dict(cls, data: dict) -> "DownloadState":
return cls(
url=data["url"],
expected_sha256=data["expected_sha256"],
expected_size=int(data["expected_size"]),
bytes_downloaded=int(data["bytes_downloaded"]),
temp_file_path=data["temp_file_path"],
phase=data["phase"],
)
def load_download_state(path: str) -> DownloadState:
with open(path, encoding="utf-8") as f:
return DownloadState.from_json_dict(json.load(f))
def save_download_state(path: str, state: DownloadState) -> None:
directory = os.path.dirname(path)
if directory:
os.makedirs(directory, exist_ok=True)
payload = json.dumps(state.to_json_dict(), indent=2, sort_keys=True)
fd, tmp = tempfile.mkstemp(
dir=directory or None,
prefix=".download_state_",
suffix=".tmp",
)
try:
with os.fdopen(fd, "w", encoding="utf-8") as f:
f.write(payload)
os.replace(tmp, path)
except Exception:
try:
os.unlink(tmp)
except OSError:
pass
raise
def _sha256_file(path: str, chunk_size: int = 1024 * 1024) -> str:
h = hashlib.sha256()
with open(path, "rb") as f:
while True:
block = f.read(chunk_size)
if not block:
break
h.update(block)
return h.hexdigest().lower()
def _safe_job_id(job_id: str) -> str:
return "".join(c if c.isalnum() or c in "-_" else "_" for c in job_id)
def decrypt_cbc_file(encrypted_path: str, aes_key: bytes, output_path: str) -> None:
with open(encrypted_path, "rb") as f_in:
iv = f_in.read(16)
if len(iv) != 16:
raise ValueError("invalid ciphertext: missing iv")
cipher = Cipher(algorithms.AES(aes_key), modes.CBC(iv), backend=default_backend())
decryptor = cipher.decryptor()
unpadder = padding.PKCS7(128).unpadder()
with open(output_path, "wb") as f_out:
while True:
chunk = f_in.read(64 * 1024)
if not chunk:
break
decrypted = decryptor.update(chunk)
if decrypted:
f_out.write(unpadder.update(decrypted))
final_decrypted = decryptor.finalize()
f_out.write(unpadder.update(final_decrypted) + unpadder.finalize())
class ResumableDownloadManager:
def __init__(
self,
state_directory: Optional[str] = None,
*,
session_factory: Optional[Callable[[], requests.Session]] = None,
sleep_fn: Optional[Callable[[float], None]] = None,
chunk_size: int = 1024 * 1024,
) -> None:
resolved = state_directory or os.environ.get("LOADER_DOWNLOAD_STATE_DIR")
if not resolved:
raise ValueError("state_directory or LOADER_DOWNLOAD_STATE_DIR is required")
self._state_directory = resolved
self._session_factory = session_factory or requests.Session
self._sleep = sleep_fn or time.sleep
self._chunk_size = chunk_size
os.makedirs(self._state_directory, exist_ok=True)
def _state_path(self, job_id: str) -> str:
safe = _safe_job_id(job_id)
return os.path.join(self._state_directory, f"{safe}.json")
def _persist(self, path: str, state: DownloadState) -> None:
save_download_state(path, state)
def fetch_decrypt_verify(
self,
job_id: str,
url: str,
expected_sha256: str,
expected_size: int,
decryption_key: bytes,
output_plaintext_path: str,
) -> None:
state_path = self._state_path(job_id)
safe = _safe_job_id(job_id)
temp_file_path = os.path.join(self._state_directory, f"{safe}.cipher.tmp")
if os.path.isfile(state_path):
state = load_download_state(state_path)
if state.url != url:
raise ValueError("state url mismatch")
else:
state = DownloadState(
url=url,
expected_sha256=expected_sha256,
expected_size=expected_size,
bytes_downloaded=0,
temp_file_path=temp_file_path,
phase="pending",
)
self._persist(state_path, state)
state.expected_sha256 = expected_sha256
state.expected_size = expected_size
state.temp_file_path = temp_file_path
if os.path.isfile(state.temp_file_path):
on_disk = os.path.getsize(state.temp_file_path)
state.bytes_downloaded = min(on_disk, state.expected_size)
else:
state.bytes_downloaded = 0
network_failures = 0
session = self._session_factory()
try:
while True:
while state.bytes_downloaded < state.expected_size:
state.phase = "downloading"
self._persist(state_path, state)
try:
self._stream_download(session, state, state_path)
network_failures = 0
except requests.RequestException as exc:
logger.exception("download request failed: {}", exc)
state.phase = "paused"
self._persist(state_path, state)
wait_s = backoff_seconds(network_failures)
self._sleep(wait_s)
network_failures += 1
state.phase = "verifying"
self._persist(state_path, state)
if _sha256_file(state.temp_file_path) != state.expected_sha256.lower().strip():
try:
os.remove(state.temp_file_path)
except OSError as exc:
logger.exception("failed to remove corrupt download: {}", exc)
state.bytes_downloaded = 0
state.phase = "downloading"
self._persist(state_path, state)
continue
state.phase = "decrypting"
self._persist(state_path, state)
decrypt_cbc_file(state.temp_file_path, decryption_key, output_plaintext_path)
state.phase = "complete"
self._persist(state_path, state)
return
except Exception:
state.phase = "failed"
try:
self._persist(state_path, state)
except Exception as persist_exc:
logger.exception("failed to persist failed state: {}", persist_exc)
raise
def _stream_download(
self,
session: requests.Session,
state: DownloadState,
state_path: str,
) -> None:
headers = {}
if state.bytes_downloaded > 0:
headers["Range"] = f"bytes={state.bytes_downloaded}-"
with session.get(
state.url,
headers=headers,
stream=True,
timeout=(30, 120),
) as resp:
if state.bytes_downloaded > 0 and resp.status_code == 200:
try:
os.remove(state.temp_file_path)
except OSError:
pass
state.bytes_downloaded = 0
self._persist(state_path, state)
with session.get(
state.url,
headers={},
stream=True,
timeout=(30, 120),
) as resp_full:
self._write_response_stream(resp_full, state, state_path, append=False)
return
if state.bytes_downloaded > 0 and resp.status_code != 206:
resp.raise_for_status()
raise requests.HTTPError("expected 206 Partial Content when resuming")
if state.bytes_downloaded == 0 and resp.status_code not in (200, 206):
resp.raise_for_status()
append = state.bytes_downloaded > 0
self._write_response_stream(resp, state, state_path, append=append)
def _write_response_stream(
self,
resp: requests.Response,
state: DownloadState,
state_path: str,
*,
append: bool,
) -> None:
mode = "ab" if append else "wb"
written_since_persist = 0
with open(state.temp_file_path, mode) as out:
for chunk in resp.iter_content(chunk_size=self._chunk_size):
if not chunk:
continue
room = state.expected_size - state.bytes_downloaded
if room <= 0:
break
if len(chunk) > room:
chunk = chunk[:room]
out.write(chunk)
state.bytes_downloaded += len(chunk)
written_since_persist += len(chunk)
if written_since_persist >= self._chunk_size:
self._persist(state_path, state)
written_since_persist = 0
if state.bytes_downloaded >= state.expected_size:
break
if written_since_persist:
self._persist(state_path, state)
+37
View File
@@ -0,0 +1,37 @@
from credentials import Credentials
from security import (
security_calc_hash,
security_decrypt_to,
security_encrypt_to,
security_get_api_encryption_key,
security_get_hw_hash,
security_get_resource_encryption_key,
)
from security_provider import SecurityProvider
class LegacySecurityProvider(SecurityProvider):
@property
def kind(self) -> str:
return "legacy"
def encrypt_to(self, input_bytes: bytes, key: str) -> bytes:
return security_encrypt_to(input_bytes, key)
def decrypt_to(self, ciphertext_with_iv_bytes: bytes, key: str) -> bytes:
return security_decrypt_to(ciphertext_with_iv_bytes, key)
def get_hw_hash(self, hardware: str) -> str:
return security_get_hw_hash(hardware)
def get_api_encryption_key(
self, creds_email: str, creds_password: str, hardware_hash: str
) -> str:
creds = Credentials(creds_email, creds_password)
return security_get_api_encryption_key(creds, hardware_hash)
def get_resource_encryption_key(self) -> str:
return security_get_resource_encryption_key()
def calc_hash(self, key: str) -> str:
return security_calc_hash(key)
+3
View File
@@ -7,9 +7,12 @@ from fastapi.responses import Response
from pydantic import BaseModel
from unlock_state import UnlockState
from security_provider import create_security_provider
app = FastAPI(title="Azaion.Loader")
security_provider = create_security_provider()
RESOURCE_API_URL = os.environ.get("RESOURCE_API_URL", "https://api.azaion.com")
IMAGES_PATH = os.environ.get("IMAGES_PATH", "/opt/azaion/images.enc")
API_VERSION = os.environ.get("API_VERSION", "latest")
+24
View File
@@ -61,3 +61,27 @@ cdef class Security:
hash_bytes = sha384(str_bytes).digest()
cdef str h = base64.b64encode(hash_bytes).decode('utf-8')
return h
cpdef bytes security_encrypt_to(bytes input_bytes, str key):
return Security.encrypt_to(input_bytes, key)
cpdef bytes security_decrypt_to(bytes ciphertext_with_iv_bytes, str key):
return Security.decrypt_to(ciphertext_with_iv_bytes, key)
cpdef str security_get_hw_hash(str hardware):
return Security.get_hw_hash(hardware)
cpdef str security_get_api_encryption_key(Credentials credentials, str hardware_hash):
return Security.get_api_encryption_key(credentials, hardware_hash)
cpdef str security_get_resource_encryption_key():
return Security.get_resource_encryption_key()
cpdef str security_calc_hash(str key):
return Security.calc_hash(key)
+91
View File
@@ -0,0 +1,91 @@
import os
from abc import ABC, abstractmethod
from typing import Callable, Mapping, Optional
from loguru import logger
def _security_provider_override(environ: Mapping[str, str]) -> Optional[str]:
raw = environ.get("SECURITY_PROVIDER")
if raw is None:
return None
s = raw.strip().lower()
return s if s else None
def _tpm_device_visible(path_exists: Callable[[str], bool]) -> bool:
return path_exists("/dev/tpm0") or path_exists("/dev/tpmrm0")
def _tpm_transport_configured(environ: Mapping[str, str]) -> bool:
t = (environ.get("TSS2_TCTI") or environ.get("TPM2TOOLS_TCTI") or "").strip()
if t:
return True
return bool((environ.get("TSS2_FAPICONF") or "").strip())
def should_attempt_tpm(
environ: Mapping[str, str],
path_exists: Callable[[str], bool],
) -> bool:
return _tpm_device_visible(path_exists) or _tpm_transport_configured(environ)
class SecurityProvider(ABC):
@property
@abstractmethod
def kind(self) -> str: ...
@abstractmethod
def encrypt_to(self, input_bytes: bytes, key: str) -> bytes: ...
@abstractmethod
def decrypt_to(self, ciphertext_with_iv_bytes: bytes, key: str) -> bytes: ...
@abstractmethod
def get_hw_hash(self, hardware: str) -> str: ...
@abstractmethod
def get_api_encryption_key(
self, creds_email: str, creds_password: str, hardware_hash: str
) -> str: ...
@abstractmethod
def get_resource_encryption_key(self) -> str: ...
@abstractmethod
def calc_hash(self, key: str) -> str: ...
def seal(self, object_path: str, data: bytes) -> None:
raise NotImplementedError
def unseal(self, object_path: str) -> bytes:
raise NotImplementedError
def create_security_provider(
*,
environ: Optional[Mapping[str, str]] = None,
path_exists: Optional[Callable[[str], bool]] = None,
) -> SecurityProvider:
from legacy_security_provider import LegacySecurityProvider
if path_exists is None:
path_exists = os.path.exists
env = environ if environ is not None else os.environ
override = _security_provider_override(env)
if override == "legacy":
logger.info("security provider: legacy (SECURITY_PROVIDER override)")
return LegacySecurityProvider()
if not should_attempt_tpm(env, path_exists):
logger.info("security provider: legacy (no TPM device or TCTI)")
return LegacySecurityProvider()
try:
from tpm_security_provider import TpmSecurityProvider
provider = TpmSecurityProvider()
logger.info("security provider: tpm")
return provider
except Exception as e:
logger.warning("TPM security provider failed ({}), using legacy", e)
return LegacySecurityProvider()
+57
View File
@@ -0,0 +1,57 @@
from security import (
security_calc_hash,
security_decrypt_to,
security_encrypt_to,
security_get_api_encryption_key,
security_get_hw_hash,
security_get_resource_encryption_key,
)
from credentials import Credentials
from security_provider import SecurityProvider
class TpmSecurityProvider(SecurityProvider):
def __init__(self):
try:
from tpm2_pytss import FAPI
from tpm2_pytss import TSS2_Exception
except (ImportError, NotImplementedError) as e:
raise RuntimeError("tpm2-pytss FAPI is not available") from e
self._TSS2_Exception = TSS2_Exception
self._fapi = FAPI()
try:
self._fapi.provision(is_provisioned_ok=True)
except TSS2_Exception:
pass
self._fapi.get_random(1)
@property
def kind(self) -> str:
return "tpm"
def encrypt_to(self, input_bytes: bytes, key: str) -> bytes:
return security_encrypt_to(input_bytes, key)
def decrypt_to(self, ciphertext_with_iv_bytes: bytes, key: str) -> bytes:
return security_decrypt_to(ciphertext_with_iv_bytes, key)
def get_hw_hash(self, hardware: str) -> str:
return security_get_hw_hash(hardware)
def get_api_encryption_key(
self, creds_email: str, creds_password: str, hardware_hash: str
) -> str:
creds = Credentials(creds_email, creds_password)
return security_get_api_encryption_key(creds, hardware_hash)
def get_resource_encryption_key(self) -> str:
return security_get_resource_encryption_key()
def calc_hash(self, key: str) -> str:
return security_calc_hash(key)
def seal(self, object_path: str, data: bytes) -> None:
self._fapi.create_seal(object_path, data=data, exists_ok=True)
def unseal(self, object_path: str) -> bytes:
return self._fapi.unseal(object_path)