mirror of
https://github.com/azaion/loader.git
synced 2026-04-22 10:06:32 +00:00
d244799f02
Made-with: Cursor
320 lines
10 KiB
Python
320 lines
10 KiB
Python
import hashlib
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
import unittest
|
|
|
|
import requests
|
|
|
|
from download_manager import (
|
|
DownloadState,
|
|
ResumableDownloadManager,
|
|
backoff_seconds,
|
|
decrypt_cbc_file,
|
|
load_download_state,
|
|
save_download_state,
|
|
)
|
|
from cryptography.hazmat.backends import default_backend
|
|
from cryptography.hazmat.primitives import padding
|
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
|
|
|
|
|
def _encrypt_cbc(plaintext: bytes, aes_key: bytes) -> bytes:
|
|
iv = os.urandom(16)
|
|
padder = padding.PKCS7(128).padder()
|
|
padded = padder.update(plaintext) + padder.finalize()
|
|
cipher = Cipher(algorithms.AES(aes_key), modes.CBC(iv), backend=default_backend())
|
|
encryptor = cipher.encryptor()
|
|
ciphertext = encryptor.update(padded) + encryptor.finalize()
|
|
return iv + ciphertext
|
|
|
|
|
|
class _StreamResponse:
|
|
def __init__(self, status_code: int, chunk_source):
|
|
self.status_code = status_code
|
|
self.headers = {}
|
|
self._chunk_source = chunk_source
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb):
|
|
return False
|
|
|
|
def raise_for_status(self):
|
|
if self.status_code >= 400:
|
|
raise requests.HTTPError(response=self)
|
|
|
|
def iter_content(self, chunk_size=1024 * 1024):
|
|
yield from self._chunk_source()
|
|
|
|
|
|
class _MockSession:
|
|
def __init__(self, handler):
|
|
self._handler = handler
|
|
|
|
def get(self, url, headers=None, stream=True, timeout=None):
|
|
return self._handler(url, headers=headers or {})
|
|
|
|
|
|
class TestBackoff(unittest.TestCase):
|
|
def test_ac5_exponential_backoff_sequence(self):
|
|
# Arrange
|
|
expected = (60, 300, 900, 3600, 14400)
|
|
# Act
|
|
values = [backoff_seconds(i) for i in range(6)]
|
|
# Assert
|
|
self.assertEqual(values[0], expected[0])
|
|
self.assertEqual(values[1], expected[1])
|
|
self.assertEqual(values[2], expected[2])
|
|
self.assertEqual(values[3], expected[3])
|
|
self.assertEqual(values[4], expected[4])
|
|
self.assertEqual(values[5], expected[4])
|
|
|
|
def test_ac5_sleep_invoked_with_backoff_on_repeated_failures(self):
|
|
# Arrange
|
|
sleeps = []
|
|
|
|
def fake_sleep(seconds):
|
|
sleeps.append(seconds)
|
|
|
|
key = hashlib.sha256(b"k").digest()
|
|
ciphertext = _encrypt_cbc(b"x" * 100, key)
|
|
sha = hashlib.sha256(ciphertext).hexdigest()
|
|
failures_left = [3]
|
|
|
|
def range_start(headers):
|
|
r = headers.get("Range")
|
|
if not r:
|
|
return 0
|
|
return int(r.split("=", 1)[1].split("-", 1)[0])
|
|
|
|
def handler(url, headers):
|
|
start = range_start(headers)
|
|
if failures_left[0] > 0:
|
|
failures_left[0] -= 1
|
|
|
|
def chunks():
|
|
yield ciphertext[start : start + 8]
|
|
raise requests.ConnectionError("drop")
|
|
|
|
return _StreamResponse(206 if start else 200, chunks)
|
|
|
|
def chunks_final():
|
|
yield ciphertext[start:]
|
|
|
|
return _StreamResponse(206 if start else 200, chunks_final)
|
|
|
|
tmp = tempfile.mkdtemp()
|
|
self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True))
|
|
out = os.path.join(tmp, "out.bin")
|
|
mgr = ResumableDownloadManager(
|
|
state_directory=tmp,
|
|
session_factory=lambda: _MockSession(handler),
|
|
sleep_fn=fake_sleep,
|
|
chunk_size=16,
|
|
)
|
|
# Act
|
|
mgr.fetch_decrypt_verify("job-backoff", "http://x", sha, len(ciphertext), key, out)
|
|
# Assert
|
|
self.assertEqual(sleeps, [60, 300, 900])
|
|
|
|
|
|
class TestStatePersistence(unittest.TestCase):
|
|
def test_ac4_state_file_reload_restores_offset(self):
|
|
# Arrange
|
|
tmp = tempfile.mkdtemp()
|
|
self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True))
|
|
tf = os.path.join(tmp, "partial.cipher.tmp")
|
|
with open(tf, "wb") as f:
|
|
f.write(b"a" * 400)
|
|
state = DownloadState(
|
|
url="http://example/a",
|
|
expected_sha256="ab" * 32,
|
|
expected_size=1000,
|
|
bytes_downloaded=400,
|
|
temp_file_path=tf,
|
|
phase="paused",
|
|
)
|
|
path = os.path.join(tmp, "state.json")
|
|
save_download_state(path, state)
|
|
# Act
|
|
loaded = load_download_state(path)
|
|
# Assert
|
|
self.assertEqual(loaded.bytes_downloaded, 400)
|
|
self.assertEqual(loaded.expected_size, 1000)
|
|
self.assertEqual(loaded.temp_file_path, tf)
|
|
|
|
def test_ac4_manager_resumes_from_persisted_progress(self):
|
|
# Arrange
|
|
tmp = tempfile.mkdtemp()
|
|
self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True))
|
|
key = hashlib.sha256(b"k2").digest()
|
|
plaintext = b"full-plaintext-payload-xyz"
|
|
ciphertext = _encrypt_cbc(plaintext, key)
|
|
sha = hashlib.sha256(ciphertext).hexdigest()
|
|
partial = int(len(ciphertext) * 0.4)
|
|
safe_job = "job_resume"
|
|
tf = os.path.join(tmp, f"{safe_job}.cipher.tmp")
|
|
with open(tf, "wb") as f:
|
|
f.write(ciphertext[:partial])
|
|
state = DownloadState(
|
|
url="http://cdn/blob",
|
|
expected_sha256=sha,
|
|
expected_size=len(ciphertext),
|
|
bytes_downloaded=partial,
|
|
temp_file_path=tf,
|
|
phase="paused",
|
|
)
|
|
save_download_state(os.path.join(tmp, f"{safe_job}.json"), state)
|
|
seen_ranges = []
|
|
|
|
def handler(url, headers):
|
|
rng = headers.get("Range")
|
|
seen_ranges.append(rng)
|
|
rest = ciphertext[partial:]
|
|
|
|
def chunks():
|
|
yield rest
|
|
|
|
return _StreamResponse(206, chunks)
|
|
|
|
out = os.path.join(tmp, "plain.out")
|
|
mgr = ResumableDownloadManager(
|
|
state_directory=tmp,
|
|
session_factory=lambda: _MockSession(handler),
|
|
sleep_fn=lambda s: None,
|
|
)
|
|
# Act
|
|
mgr.fetch_decrypt_verify(safe_job, "http://cdn/blob", sha, len(ciphertext), key, out)
|
|
# Assert
|
|
self.assertEqual(seen_ranges[0], f"bytes={partial}-")
|
|
with open(out, "rb") as f:
|
|
self.assertEqual(f.read(), plaintext)
|
|
|
|
|
|
class TestResumeAfterDrop(unittest.TestCase):
|
|
def test_ac1_resume_uses_range_after_partial_transfer(self):
|
|
# Arrange
|
|
tmp = tempfile.mkdtemp()
|
|
self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True))
|
|
key = hashlib.sha256(b"k3").digest()
|
|
body = b"q" * 100
|
|
ciphertext = _encrypt_cbc(body, key)
|
|
sha = hashlib.sha256(ciphertext).hexdigest()
|
|
cut = 60
|
|
headers_log = []
|
|
|
|
def handler(url, headers):
|
|
headers_log.append(dict(headers))
|
|
if len(headers_log) == 1:
|
|
|
|
def chunks():
|
|
yield ciphertext[:cut]
|
|
raise requests.ConnectionError("starlink drop")
|
|
|
|
return _StreamResponse(200, chunks)
|
|
|
|
def chunks2():
|
|
yield ciphertext[cut:]
|
|
|
|
return _StreamResponse(206, chunks2)
|
|
|
|
out = os.path.join(tmp, "p.out")
|
|
mgr = ResumableDownloadManager(
|
|
state_directory=tmp,
|
|
session_factory=lambda: _MockSession(handler),
|
|
sleep_fn=lambda s: None,
|
|
chunk_size=32,
|
|
)
|
|
# Act
|
|
mgr.fetch_decrypt_verify("ac1", "http://s3/o", sha, len(ciphertext), key, out)
|
|
# Assert
|
|
self.assertNotIn("Range", headers_log[0])
|
|
self.assertEqual(headers_log[1].get("Range"), f"bytes={cut}-")
|
|
with open(out, "rb") as f:
|
|
self.assertEqual(f.read(), body)
|
|
|
|
|
|
class TestShaMismatchRedownload(unittest.TestCase):
|
|
def test_ac2_corrupt_hash_deletes_file_and_redownloads(self):
|
|
# Arrange
|
|
tmp = tempfile.mkdtemp()
|
|
self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True))
|
|
key = hashlib.sha256(b"k4").digest()
|
|
good_plain = b"same-len-pt-a!"
|
|
bad_plain = b"same-len-pt-b!"
|
|
good_ct = _encrypt_cbc(good_plain, key)
|
|
bad_ct = _encrypt_cbc(bad_plain, key)
|
|
sha_good = hashlib.sha256(good_ct).hexdigest()
|
|
calls = {"n": 0}
|
|
|
|
def handler(url, headers):
|
|
calls["n"] += 1
|
|
data = bad_ct if calls["n"] == 1 else good_ct
|
|
|
|
def chunks():
|
|
yield data
|
|
|
|
return _StreamResponse(200, chunks)
|
|
|
|
out = os.path.join(tmp, "good.out")
|
|
mgr = ResumableDownloadManager(
|
|
state_directory=tmp,
|
|
session_factory=lambda: _MockSession(handler),
|
|
sleep_fn=lambda s: None,
|
|
)
|
|
# Act
|
|
mgr.fetch_decrypt_verify("ac2", "http://x", sha_good, len(good_ct), key, out)
|
|
# Assert
|
|
self.assertEqual(calls["n"], 2)
|
|
with open(out, "rb") as f:
|
|
self.assertEqual(f.read(), good_plain)
|
|
|
|
|
|
class TestDecryptRoundTrip(unittest.TestCase):
|
|
def test_ac3_decrypt_matches_original_plaintext(self):
|
|
# Arrange
|
|
tmp = tempfile.mkdtemp()
|
|
self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True))
|
|
key = hashlib.sha256(b"artifact-key").digest()
|
|
original = b"payload-for-roundtrip-check"
|
|
ciphertext = _encrypt_cbc(original, key)
|
|
sha = hashlib.sha256(ciphertext).hexdigest()
|
|
|
|
def handler(url, headers):
|
|
return _StreamResponse(200, lambda: [ciphertext])
|
|
|
|
out = os.path.join(tmp, "decrypted.bin")
|
|
mgr = ResumableDownloadManager(
|
|
state_directory=tmp,
|
|
session_factory=lambda: _MockSession(handler),
|
|
sleep_fn=lambda s: None,
|
|
)
|
|
# Act
|
|
mgr.fetch_decrypt_verify("ac3", "http://blob", sha, len(ciphertext), key, out)
|
|
# Assert
|
|
with open(out, "rb") as f:
|
|
self.assertEqual(f.read(), original)
|
|
|
|
def test_decrypt_cbc_file_matches_encrypt_helper(self):
|
|
# Arrange
|
|
tmp = tempfile.mkdtemp()
|
|
self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True))
|
|
key = hashlib.sha256(b"x").digest()
|
|
plain = b"abc" * 500
|
|
ct = _encrypt_cbc(plain, key)
|
|
enc_path = os.path.join(tmp, "e.bin")
|
|
with open(enc_path, "wb") as f:
|
|
f.write(ct)
|
|
out_path = os.path.join(tmp, "d.bin")
|
|
# Act
|
|
decrypt_cbc_file(enc_path, key, out_path)
|
|
# Assert
|
|
with open(out_path, "rb") as f:
|
|
self.assertEqual(f.read(), plain)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|