mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 22:36:32 +00:00
834f846dc8
- Added a new Cython extension for the engine factory to the setup configuration. - Updated the inference module to include additional logging for video batch processing and annotation callbacks. - Refactored test cases to standardize the detection endpoint responses and include channel IDs in headers for better event handling.
438 lines
12 KiB
Python
438 lines
12 KiB
Python
import asyncio
|
|
import os
|
|
import tempfile
|
|
import threading
|
|
import time
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
|
|
class TestStreamingBuffer:
|
|
def test_sequential_write_read(self):
|
|
# Arrange
|
|
from streaming_buffer import StreamingBuffer
|
|
|
|
buf = StreamingBuffer()
|
|
try:
|
|
buf.append(b"hello")
|
|
buf.append(b" world")
|
|
buf.close_writer()
|
|
# Act
|
|
result = buf.read(-1)
|
|
# Assert
|
|
assert result == b"hello world"
|
|
finally:
|
|
buf.close()
|
|
os.unlink(buf.path)
|
|
|
|
def test_read_blocks_until_data_available(self):
|
|
# Arrange
|
|
from streaming_buffer import StreamingBuffer
|
|
|
|
buf = StreamingBuffer()
|
|
results = []
|
|
|
|
def writer():
|
|
time.sleep(0.1)
|
|
buf.append(b"data")
|
|
buf.close_writer()
|
|
|
|
t = threading.Thread(target=writer)
|
|
t.start()
|
|
|
|
# Act
|
|
results.append(buf.read(4))
|
|
t.join(timeout=5)
|
|
|
|
# Assert
|
|
assert results == [b"data"]
|
|
buf.close()
|
|
os.unlink(buf.path)
|
|
|
|
def test_read_returns_empty_on_eof(self):
|
|
# Arrange
|
|
from streaming_buffer import StreamingBuffer
|
|
|
|
buf = StreamingBuffer()
|
|
buf.close_writer()
|
|
|
|
# Act
|
|
result = buf.read(1024)
|
|
|
|
# Assert
|
|
assert result == b""
|
|
buf.close()
|
|
os.unlink(buf.path)
|
|
|
|
def test_concurrent_write_read_chunked(self):
|
|
# Arrange
|
|
from streaming_buffer import StreamingBuffer
|
|
|
|
buf = StreamingBuffer()
|
|
chunks_written = [b"aaa", b"bbb", b"ccc"]
|
|
read_data = bytearray()
|
|
|
|
def writer():
|
|
for c in chunks_written:
|
|
time.sleep(0.02)
|
|
buf.append(c)
|
|
buf.close_writer()
|
|
|
|
def reader():
|
|
while True:
|
|
chunk = buf.read(1024)
|
|
if not chunk:
|
|
break
|
|
read_data.extend(chunk)
|
|
|
|
wt = threading.Thread(target=writer)
|
|
rt = threading.Thread(target=reader)
|
|
|
|
# Act
|
|
wt.start()
|
|
rt.start()
|
|
wt.join(timeout=5)
|
|
rt.join(timeout=5)
|
|
|
|
# Assert
|
|
assert bytes(read_data) == b"aaabbbccc"
|
|
buf.close()
|
|
os.unlink(buf.path)
|
|
|
|
def test_seek_set_and_reread(self):
|
|
# Arrange
|
|
from streaming_buffer import StreamingBuffer
|
|
|
|
buf = StreamingBuffer()
|
|
buf.append(b"0123456789")
|
|
buf.close_writer()
|
|
|
|
# Act
|
|
buf.read(5)
|
|
buf.seek(2, 0)
|
|
result = buf.read(3)
|
|
|
|
# Assert
|
|
assert result == b"234"
|
|
buf.close()
|
|
os.unlink(buf.path)
|
|
|
|
def test_seek_end_blocks_until_eof(self):
|
|
# Arrange
|
|
from streaming_buffer import StreamingBuffer
|
|
|
|
buf = StreamingBuffer()
|
|
positions = []
|
|
|
|
def writer():
|
|
time.sleep(0.1)
|
|
buf.append(b"abcdef")
|
|
buf.close_writer()
|
|
|
|
t = threading.Thread(target=writer)
|
|
t.start()
|
|
|
|
# Act
|
|
pos = buf.seek(0, 2)
|
|
positions.append(pos)
|
|
t.join(timeout=5)
|
|
|
|
# Assert
|
|
assert positions[0] == 6
|
|
buf.close()
|
|
os.unlink(buf.path)
|
|
|
|
def test_tell_tracks_position(self):
|
|
# Arrange
|
|
from streaming_buffer import StreamingBuffer
|
|
|
|
buf = StreamingBuffer()
|
|
buf.append(b"data")
|
|
buf.close_writer()
|
|
|
|
# Assert
|
|
assert buf.tell() == 0
|
|
buf.read(2)
|
|
assert buf.tell() == 2
|
|
buf.close()
|
|
os.unlink(buf.path)
|
|
|
|
def test_file_persisted_to_disk(self):
|
|
# Arrange
|
|
from streaming_buffer import StreamingBuffer
|
|
|
|
buf = StreamingBuffer()
|
|
payload = b"x" * 10000
|
|
|
|
# Act
|
|
buf.append(payload)
|
|
buf.close_writer()
|
|
|
|
# Assert
|
|
with open(buf.path, "rb") as f:
|
|
assert f.read() == payload
|
|
buf.close()
|
|
os.unlink(buf.path)
|
|
|
|
def test_written_property(self):
|
|
# Arrange
|
|
from streaming_buffer import StreamingBuffer
|
|
|
|
buf = StreamingBuffer()
|
|
buf.append(b"abc")
|
|
buf.append(b"defgh")
|
|
buf.close_writer()
|
|
|
|
# Assert
|
|
assert buf.written == 8
|
|
buf.close()
|
|
os.unlink(buf.path)
|
|
|
|
def test_seekable_readable(self):
|
|
# Arrange
|
|
from streaming_buffer import StreamingBuffer
|
|
|
|
buf = StreamingBuffer()
|
|
buf.close_writer()
|
|
|
|
# Assert
|
|
assert buf.seekable() is True
|
|
assert buf.readable() is True
|
|
assert buf.writable() is False
|
|
buf.close()
|
|
os.unlink(buf.path)
|
|
|
|
|
|
class TestMediaContentHashFromFile:
|
|
def test_small_file_matches_bytes_version(self):
|
|
# Arrange
|
|
from media_hash import compute_media_content_hash, compute_media_content_hash_from_file
|
|
|
|
data = b"hello world"
|
|
with tempfile.NamedTemporaryFile(delete=False) as f:
|
|
f.write(data)
|
|
path = f.name
|
|
|
|
# Act
|
|
hash_bytes = compute_media_content_hash(data)
|
|
hash_file = compute_media_content_hash_from_file(path)
|
|
|
|
# Assert
|
|
assert hash_file == hash_bytes
|
|
os.unlink(path)
|
|
|
|
def test_large_file_matches_bytes_version(self):
|
|
# Arrange
|
|
from media_hash import compute_media_content_hash, compute_media_content_hash_from_file
|
|
|
|
data = os.urandom(50_000)
|
|
with tempfile.NamedTemporaryFile(delete=False) as f:
|
|
f.write(data)
|
|
path = f.name
|
|
|
|
# Act
|
|
hash_bytes = compute_media_content_hash(data)
|
|
hash_file = compute_media_content_hash_from_file(path)
|
|
|
|
# Assert
|
|
assert hash_file == hash_bytes
|
|
os.unlink(path)
|
|
|
|
def test_virtual_flag(self):
|
|
# Arrange
|
|
from media_hash import compute_media_content_hash_from_file
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False) as f:
|
|
f.write(b"test")
|
|
path = f.name
|
|
|
|
# Act
|
|
normal = compute_media_content_hash_from_file(path, virtual=False)
|
|
virtual = compute_media_content_hash_from_file(path, virtual=True)
|
|
|
|
# Assert
|
|
assert virtual == f"V{normal}"
|
|
os.unlink(path)
|
|
|
|
def test_exact_boundary_3072_bytes(self):
|
|
# Arrange
|
|
from media_hash import compute_media_content_hash, compute_media_content_hash_from_file
|
|
|
|
data = os.urandom(3072)
|
|
with tempfile.NamedTemporaryFile(delete=False) as f:
|
|
f.write(data)
|
|
path = f.name
|
|
|
|
# Act
|
|
hash_bytes = compute_media_content_hash(data)
|
|
hash_file = compute_media_content_hash_from_file(path)
|
|
|
|
# Assert
|
|
assert hash_file == hash_bytes
|
|
os.unlink(path)
|
|
|
|
|
|
_TEST_JWT_SECRET = "az-test-secret-for-unit-tests-only-32b"
|
|
|
|
|
|
def _access_jwt(sub: str = "u1") -> str:
|
|
import jwt as pyjwt
|
|
return pyjwt.encode(
|
|
{"exp": int(time.time()) + 3600, "sub": sub},
|
|
_TEST_JWT_SECRET,
|
|
algorithm="HS256",
|
|
)
|
|
|
|
|
|
class _FakeInfStream:
|
|
is_engine_ready = True
|
|
|
|
def run_detect_video_stream(
|
|
self, readable, ai_cfg, media_name, on_annotation, status_callback=None
|
|
):
|
|
while True:
|
|
chunk = readable.read(4096)
|
|
if not chunk:
|
|
break
|
|
if status_callback:
|
|
status_callback(media_name, 0)
|
|
|
|
def run_detect_video(self, *a, **kw):
|
|
pass
|
|
|
|
def run_detect_image(self, *a, **kw):
|
|
pass
|
|
|
|
|
|
class TestDetectVideoEndpoint:
|
|
@pytest.fixture(autouse=True)
|
|
def reset_inference(self):
|
|
import main
|
|
main.inference = None
|
|
yield
|
|
main.inference = None
|
|
|
|
def test_streaming_upload_returns_started(self):
|
|
# Arrange
|
|
import main
|
|
from media_hash import compute_media_content_hash
|
|
|
|
video_body = b"fake-video-" * 200
|
|
content_hash = compute_media_content_hash(video_body)
|
|
mock_post = MagicMock()
|
|
mock_post.return_value.status_code = 201
|
|
mock_put = MagicMock()
|
|
mock_put.return_value.status_code = 204
|
|
token = _access_jwt()
|
|
|
|
with tempfile.TemporaryDirectory() as vd:
|
|
os.environ["VIDEOS_DIR"] = vd
|
|
from fastapi.testclient import TestClient
|
|
client = TestClient(main.app)
|
|
with (
|
|
patch.object(main, "JWT_SECRET", _TEST_JWT_SECRET),
|
|
patch.object(main, "get_inference", return_value=_FakeInfStream()),
|
|
patch.object(main.http_requests, "post", mock_post),
|
|
patch.object(main.http_requests, "put", mock_put),
|
|
):
|
|
# Act
|
|
r = client.post(
|
|
"/detect/video",
|
|
content=video_body,
|
|
headers={
|
|
"X-Filename": "test.mp4",
|
|
"X-Channel-Id": "test-channel",
|
|
"Authorization": f"Bearer {token}",
|
|
},
|
|
)
|
|
# Assert
|
|
assert r.status_code == 202
|
|
stored = os.path.join(vd, f"{content_hash}.mp4")
|
|
assert os.path.isfile(stored)
|
|
with open(stored, "rb") as f:
|
|
assert f.read() == video_body
|
|
|
|
def test_non_auth_cleanup(self):
|
|
# Arrange
|
|
import main
|
|
|
|
video_body = b"noauth-vid-" * 100
|
|
with tempfile.TemporaryDirectory() as vd:
|
|
os.environ["VIDEOS_DIR"] = vd
|
|
from fastapi.testclient import TestClient
|
|
client = TestClient(main.app)
|
|
with (
|
|
patch.object(main, "JWT_SECRET", ""),
|
|
patch.object(main, "get_inference", return_value=_FakeInfStream()),
|
|
):
|
|
# Act
|
|
r = client.post(
|
|
"/detect/video",
|
|
content=video_body,
|
|
headers={"X-Filename": "test.mp4", "X-Channel-Id": "test-channel"},
|
|
)
|
|
# Assert
|
|
assert r.status_code == 202
|
|
|
|
def test_rejects_non_video_extension(self):
|
|
# Arrange
|
|
import main
|
|
|
|
from fastapi.testclient import TestClient
|
|
client = TestClient(main.app)
|
|
|
|
# Act — patch JWT_SECRET to "" so auth does not block the extension check
|
|
with patch.object(main, "JWT_SECRET", ""):
|
|
r = client.post(
|
|
"/detect/video",
|
|
content=b"data",
|
|
headers={"X-Filename": "photo.jpg"},
|
|
)
|
|
|
|
# Assert
|
|
assert r.status_code == 400
|
|
|
|
def test_data_flows_through_streaming_buffer(self):
|
|
# Arrange
|
|
import main
|
|
from streaming_buffer import StreamingBuffer
|
|
|
|
received_chunks = []
|
|
|
|
class _CaptureInf(_FakeInfStream):
|
|
def run_detect_video_stream(
|
|
self, readable, ai_cfg, media_name, on_annotation, status_callback=None
|
|
):
|
|
while True:
|
|
chunk = readable.read(4096)
|
|
if not chunk:
|
|
break
|
|
received_chunks.append(chunk)
|
|
|
|
video_body = b"A" * 10000
|
|
with tempfile.TemporaryDirectory() as vd:
|
|
os.environ["VIDEOS_DIR"] = vd
|
|
from fastapi.testclient import TestClient
|
|
client = TestClient(main.app)
|
|
token = _access_jwt()
|
|
with (
|
|
patch.object(main, "JWT_SECRET", _TEST_JWT_SECRET),
|
|
patch.object(main, "get_inference", return_value=_CaptureInf()),
|
|
):
|
|
# Act
|
|
r = client.post(
|
|
"/detect/video",
|
|
content=video_body,
|
|
headers={
|
|
"X-Filename": "v.mp4",
|
|
"X-Channel-Id": "test-channel",
|
|
"Authorization": f"Bearer {token}",
|
|
},
|
|
)
|
|
|
|
# Assert
|
|
assert r.status_code == 202
|
|
all_received = b"".join(received_chunks)
|
|
assert all_received == video_body
|