add ramdisk, load AI model to ramdisk and start recognition from it

rewrite zmq to DEALER and ROUTER
add GET_USER command to get CurrentUser from Python
all auth is on the python side
inference run and validate annotations on python
This commit is contained in:
Alex Bezdieniezhnykh
2025-01-29 17:45:26 +02:00
parent 82b3b526a7
commit 62623b7123
55 changed files with 945 additions and 895 deletions
+1 -1
View File
@@ -50,7 +50,7 @@ This is crucial for the build because build needs Python.h header and other file
pip install ultralytics
pip uninstall -y opencv-python
pip install opencv-python cython msgpack cryptography rstream pika zmq
pip install opencv-python cython msgpack cryptography rstream pika zmq pyjwt
```
In case of fbgemm.dll error (Windows specific):
+10
View File
@@ -0,0 +1,10 @@
cdef class AIRecognitionConfig:
cdef public double frame_recognition_seconds
cdef public double tracking_distance_confidence
cdef public double tracking_probability_increase
cdef public double tracking_intersection_threshold
cdef public int frame_period_recognition
cdef public bytes file_data
@staticmethod
cdef from_msgpack(bytes data)
+32
View File
@@ -0,0 +1,32 @@
from msgpack import unpackb
cdef class AIRecognitionConfig:
def __init__(self,
frame_recognition_seconds,
tracking_distance_confidence,
tracking_probability_increase,
tracking_intersection_threshold,
frame_period_recognition,
file_data
):
self.frame_recognition_seconds = frame_recognition_seconds
self.tracking_distance_confidence = tracking_distance_confidence
self.tracking_probability_increase = tracking_probability_increase
self.tracking_intersection_threshold = tracking_intersection_threshold
self.frame_period_recognition = frame_period_recognition
self.file_data = file_data
def __str__(self):
return (f'frame_seconds : {self.frame_recognition_seconds}, distance_confidence : {self.tracking_distance_confidence}, '
f'probability_increase : {self.tracking_probability_increase}, intersection_threshold : {self.tracking_intersection_threshold}, frame_period_recognition : {self.frame_period_recognition}')
@staticmethod
cdef from_msgpack(bytes data):
unpacked = unpackb(data, strict_map_key=False)
return AIRecognitionConfig(
unpacked.get("FrameRecognitionSeconds", 0.0),
unpacked.get("TrackingDistanceConfidence", 0.0),
unpacked.get("TrackingProbabilityIncrease", 0.0),
unpacked.get("TrackingIntersectionThreshold", 0.0),
unpacked.get("FramePeriodRecognition", 0),
unpacked.get("Data", b''))
+6 -4
View File
@@ -1,8 +1,10 @@
cdef class Detection:
cdef double x, y, w, h
cdef int cls
cdef public double x, y, w, h, confidence
cdef public int cls
cdef class Annotation:
cdef bytes image
cdef float time
cdef list[Detection] detections
cdef long time
cdef public list[Detection] detections
cdef bytes serialize(self)
+25 -3
View File
@@ -1,13 +1,35 @@
import msgpack
cdef class Detection:
def __init__(self, double x, double y, double w, double h, int cls):
def __init__(self, double x, double y, double w, double h, int cls, double confidence):
self.x = x
self.y = y
self.w = w
self.h = h
self.cls = cls
self.confidence = confidence
def __str__(self):
return f'{self.cls}: {self.x:.2f} {self.y:.2f} {self.w:.2f} {self.h:.2f}, prob: {(self.confidence*100):.1f}%'
cdef class Annotation:
def __init__(self, bytes image_bytes, float time, list[Detection] detections):
def __init__(self, bytes image_bytes, long time, list[Detection] detections):
self.image = image_bytes
self.time = time
self.detections = detections
self.detections = detections if detections is not None else []
cdef bytes serialize(self):
return msgpack.packb({
"i": self.image, # "i" = image
"t": self.time, # "t" = time
"d": [ # "d" = detections
{
"x": det.x,
"y": det.y,
"w": det.w,
"h": det.h,
"c": det.cls,
"p": det.confidence
} for det in self.detections
]
})
+7 -1
View File
@@ -1,8 +1,14 @@
from user cimport User
cdef class ApiClient:
cdef str email, password, token, folder, token_file, api_url
cdef User user
cdef get_encryption_key(self, str hardware_hash)
cdef login(self, str email, str password)
cdef login(self)
cdef set_token(self, str token)
cdef get_user(self)
cdef load_bytes(self, str filename)
cdef load_ai_model(self)
cdef load_queue_config(self)
+47 -11
View File
@@ -1,13 +1,14 @@
import io
import json
import os
from http import HTTPStatus
from uuid import UUID
import jwt
import requests
cimport constants
from hardware_service cimport HardwareService, HardwareInfo
from security cimport Security
from io import BytesIO
from user cimport User, RoleEnum
cdef class ApiClient:
"""Handles API authentication and downloading of the AI model."""
@@ -15,9 +16,11 @@ cdef class ApiClient:
self.email = email
self.password = password
self.folder = folder
self.user = None
if os.path.exists(<str>constants.TOKEN_FILE):
with open(<str>constants.TOKEN_FILE, "r") as file:
self.token = file.read().strip()
self.set_token(<str>file.read().strip())
else:
self.token = None
@@ -25,21 +28,52 @@ cdef class ApiClient:
cdef str key = f'{self.email}-{self.password}-{hardware_hash}-#%@AzaionKey@%#---'
return Security.calc_hash(key)
cdef login(self, str email, str password):
response = requests.post(f"{constants.API_URL}/login", json={"email": email, "password": password})
cdef login(self):
response = requests.post(f"{constants.API_URL}/login",
json={"email": self.email, "password": self.password})
response.raise_for_status()
self.token = response.json()["token"]
token = response.json()["token"]
self.set_token(token)
with open(<str>constants.TOKEN_FILE, 'w') as file:
file.write(self.token)
file.write(token)
cdef set_token(self, str token):
self.token = token
claims = jwt.decode(token, options={"verify_signature": False})
try:
id = str(UUID(claims.get("nameid", "")))
except ValueError:
raise ValueError("Invalid GUID format in claims")
email = claims.get("unique_name", "")
role_str = claims.get("role", "")
if role_str == "ApiAdmin":
role = RoleEnum.ApiAdmin
elif role_str == "Admin":
role = RoleEnum.Admin
elif role_str == "ResourceUploader":
role = RoleEnum.ResourceUploader
elif role_str == "Validator":
role = RoleEnum.Validator
elif role_str == "Operator":
role = RoleEnum.Operator
else:
role = RoleEnum.NONE
self.user = User(id, email, role)
cdef get_user(self):
if self.user is None:
self.login()
return self.user
cdef load_bytes(self, str filename):
hardware_service = HardwareService()
cdef HardwareInfo hardware = hardware_service.get_hardware_info()
if self.token is None:
self.login(self.email, self.password)
self.login()
url = f"{constants.API_URL}/resources/get/{self.folder}"
headers = {
@@ -56,7 +90,7 @@ cdef class ApiClient:
response = requests.post(url, data=payload, headers=headers, stream=True)
if response.status_code == HTTPStatus.UNAUTHORIZED or response.status_code == HTTPStatus.FORBIDDEN:
self.login(self.email, self.password)
self.login()
headers = {
"Authorization": f"Bearer {self.token}",
"Content-Type": "application/json"
@@ -69,7 +103,9 @@ cdef class ApiClient:
key = self.get_encryption_key(hardware.hash)
stream = BytesIO(response.raw.read())
return Security.decrypt_to(stream, key)
data = Security.decrypt_to(stream, key)
print(f'loaded file: {filename}, {len(data)} bytes')
return data
cdef load_ai_model(self):
return self.load_bytes(constants.AI_MODEL_FILE)
+3 -3
View File
@@ -1,6 +1,4 @@
cdef str SOCKET_HOST # Host for the socket server
cdef int SOCKET_PORT # Port for the socket server
cdef int SOCKET_BUFFER_SIZE # Buffer size for socket communication
cdef int ZMQ_PORT = 5127 # Port for the zmq
cdef int QUEUE_MAXSIZE # Maximum size of the command queue
cdef str COMMANDS_QUEUE # Name of the commands queue in rabbit
@@ -10,3 +8,5 @@ cdef str API_URL # Base URL for the external API
cdef str TOKEN_FILE # Name of the token file where temporary token would be stored
cdef str QUEUE_CONFIG_FILENAME # queue config filename to load from api
cdef str AI_MODEL_FILE # AI Model file
cdef bytes DONE_SIGNAL
+3 -3
View File
@@ -1,6 +1,4 @@
cdef str SOCKET_HOST = "127.0.0.1" # Host for the socket server
cdef int SOCKET_PORT = 9127 # Port for the socket server
cdef int SOCKET_BUFFER_SIZE = 4096 # Buffer size for socket communication
cdef int ZMQ_PORT = 5127 # Port for the zmq
cdef int QUEUE_MAXSIZE = 1000 # Maximum size of the command queue
cdef str COMMANDS_QUEUE = "azaion-commands"
@@ -10,3 +8,5 @@ cdef str API_URL = "https://api.azaion.com" # Base URL for the external API
cdef str TOKEN_FILE = "token"
cdef str QUEUE_CONFIG_FILENAME = "secured-config.json"
cdef str AI_MODEL_FILE = "azaion.pt"
cdef bytes DONE_SIGNAL = b"DONE"
+3
View File
@@ -10,5 +10,8 @@ def start_server():
except Exception as e:
processor.stop()
def on_annotation(self, cmd, annotation):
print('on_annotation hit!')
if __name__ == "__main__":
start_server()
+17
View File
@@ -0,0 +1,17 @@
from remote_command cimport RemoteCommand
from annotation cimport Annotation
from ai_config cimport AIRecognitionConfig
cdef class Inference:
cdef object model
cdef object on_annotation
cdef Annotation _previous_annotation
cdef AIRecognitionConfig ai_config
cdef bint is_video(self, str filepath)
cdef run_inference(self, RemoteCommand cmd, int batch_size=?)
cdef _process_video(self, RemoteCommand cmd, int batch_size)
cdef _process_image(self, RemoteCommand cmd)
cdef frame_to_annotation(self, long time, frame, boxes: object)
cdef bint is_valid_annotation(self, Annotation annotation)
+76 -17
View File
@@ -1,30 +1,38 @@
import ai_config
import msgpack
from ultralytics import YOLO
import mimetypes
import cv2
from ultralytics.engine.results import Boxes
from remote_command cimport RemoteCommand
from annotation cimport Detection, Annotation
from secure_model cimport SecureModelLoader
from ai_config cimport AIRecognitionConfig
cdef class Inference:
def __init__(self, model_bytes, on_annotations):
self.model = YOLO(model_bytes)
self.on_annotations = on_annotations
def __init__(self, model_bytes, on_annotation):
loader = SecureModelLoader()
model_path = loader.load_model(model_bytes)
self.model = YOLO(<str>model_path)
self.on_annotation = on_annotation
cdef bint is_video(self, str filepath):
mime_type, _ = mimetypes.guess_type(<str>filepath)
return mime_type and mime_type.startswith("video")
cdef run_inference(self, RemoteCommand cmd, int batch_size=8, int frame_skip=4):
cdef run_inference(self, RemoteCommand cmd, int batch_size=8):
print('run inference..')
if self.is_video(cmd.filename):
return self._process_video(cmd, batch_size, frame_skip)
return self._process_video(cmd, batch_size)
else:
return self._process_image(cmd)
cdef _process_video(self, RemoteCommand cmd, int batch_size, int frame_skip):
cdef _process_video(self, RemoteCommand cmd, int batch_size):
frame_count = 0
batch_frame = []
annotations = []
v_input = cv2.VideoCapture(<str>cmd.filename)
self.ai_config = AIRecognitionConfig.from_msgpack(cmd.data)
while v_input.isOpened():
ret, frame = v_input.read()
@@ -33,7 +41,7 @@ cdef class Inference:
break
frame_count += 1
if frame_count % frame_skip == 0:
if frame_count % self.ai_config.frame_period_recognition == 0:
batch_frame.append((frame, ms))
if len(batch_frame) == batch_size:
@@ -41,10 +49,11 @@ cdef class Inference:
results = self.model.track(frames, persist=True)
for frame, res in zip(batch_frame, results):
annotation = self.process_detections(int(frame[1]), frame[0], res.boxes)
if len(annotation.detections) > 0:
annotations.append(annotation)
self.on_annotations(cmd, annotations)
annotation = self.frame_to_annotation(int(frame[1]), frame[0], res.boxes)
if self.is_valid_annotation(<Annotation>annotation):
self._previous_annotation = annotation
self.on_annotation(cmd, annotation)
batch_frame.clear()
v_input.release()
@@ -52,15 +61,65 @@ cdef class Inference:
cdef _process_image(self, RemoteCommand cmd):
frame = cv2.imread(<str>cmd.filename)
res = self.model.track(frame)
annotation = self.process_detections(0, frame, res[0].boxes)
self.on_annotations(cmd, [annotation])
annotation = self.frame_to_annotation(0, frame, res[0].boxes)
self.on_annotation(cmd, annotation)
cdef process_detections(self, float time, frame, boxes: Boxes):
cdef frame_to_annotation(self, long time, frame, boxes: Boxes):
detections = []
for box in boxes:
b = box.xywhn[0].cpu().numpy()
cls = int(box.cls[0].cpu().numpy().item())
detections.append(Detection(<double>b[0], <double>b[1], <double>b[2], <double>b[3], cls))
_, encoded_image = cv2.imencode('.jpg', frame[0])
confidence = box.conf[0].cpu().numpy().item()
det = Detection(<double> b[0], <double> b[1], <double> b[2], <double> b[3], cls, confidence)
detections.append(det)
_, encoded_image = cv2.imencode('.jpg', frame)
image_bytes = encoded_image.tobytes()
return Annotation(image_bytes, time, detections)
cdef bint is_valid_annotation(self, Annotation annotation):
# No detections, invalid
if not annotation.detections:
return False
# First valid annotation, always accept
if self._previous_annotation is None:
return True
# Enough time has passed since last annotation
if annotation.time >= self._previous_annotation.time + <long>(self.ai_config.frame_recognition_seconds * 1000):
return True
# More objects detected than before
if len(annotation.detections) > len(self._previous_annotation.detections):
return True
cdef:
Detection current_det, prev_det
double dx, dy, distance_sq, min_distance_sq
Detection closest_det
# Check each detection against previous frame
for current_det in annotation.detections:
min_distance_sq = 1e18 # Initialize with large value
closest_det = None
# Find closest detection in previous frame
for prev_det in self._previous_annotation.detections:
dx = current_det.x - prev_det.x
dy = current_det.y - prev_det.y
distance_sq = dx * dx + dy * dy
if distance_sq < min_distance_sq:
min_distance_sq = distance_sq
closest_det = prev_det
# Check if beyond tracking distance
if min_distance_sq > self.ai_config.tracking_distance_confidence:
return True
# Check probability increase
if current_det.confidence >= closest_det.confidence + self.ai_config.tracking_probability_increase:
return True
# No validation criteria met
return False
+16 -13
View File
@@ -1,12 +1,13 @@
import traceback
from queue import Queue
cimport constants
import msgpack
from api_client cimport ApiClient
from annotation cimport Annotation
from inference import Inference
from inference cimport Inference
from remote_command cimport RemoteCommand, CommandType
from remote_command_handler cimport RemoteCommandHandler
from user cimport User
import argparse
cdef class ParsedArguments:
@@ -36,11 +37,10 @@ cdef class CommandProcessor:
while self.running:
try:
command = self.command_queue.get()
print(f'command is : {command}')
model = self.api_client.load_ai_model()
Inference(model, self.on_annotations).run_inference(command)
Inference(model, self.on_annotation).run_inference(command)
except Exception as e:
print(f"Error processing queue: {e}")
traceback.print_exc()
cdef on_command(self, RemoteCommand command):
try:
@@ -48,17 +48,20 @@ cdef class CommandProcessor:
self.command_queue.put(command)
elif command.command_type == CommandType.LOAD:
response = self.api_client.load_bytes(command.filename)
print(f'loaded file: {command.filename}, {len(response)} bytes')
self.remote_handler.send(response)
print(f'{len(response)} bytes was sent.')
self.remote_handler.send(command.client_id, response)
elif command.command_type == CommandType.GET_USER:
self.get_user(command, self.api_client.get_user())
else:
pass
except Exception as e:
print(f"Error handling client: {e}")
cdef on_annotations(self, RemoteCommand cmd, annotations: [Annotation]):
data = msgpack.packb(annotations)
self.remote_handler.send(data)
print(f'{len(data)} bytes was sent.')
cdef get_user(self, RemoteCommand command, User user):
self.remote_handler.send(command.client_id, user.serialize())
cdef on_annotation(self, RemoteCommand cmd, Annotation annotation):
data = annotation.serialize()
self.remote_handler.send(cmd.client_id, data)
def stop(self):
self.running = False
+2
View File
@@ -1,8 +1,10 @@
cdef enum CommandType:
INFERENCE = 1
LOAD = 2
GET_USER = 3
cdef class RemoteCommand:
cdef public bytes client_id
cdef CommandType command_type
cdef str filename
cdef bytes data
+3 -1
View File
@@ -10,8 +10,10 @@ cdef class RemoteCommand:
command_type_names = {
1: "INFERENCE",
2: "LOAD",
3: "GET_USER"
}
return f'{command_type_names[self.command_type]}: {self.filename}'
data_str = f'. Data: {len(self.data)} bytes' if self.data else ''
return f'{command_type_names[self.command_type]}: {self.filename}{data_str}'
@staticmethod
cdef from_msgpack(bytes data):
+9 -10
View File
@@ -1,16 +1,15 @@
cdef class RemoteCommandHandler:
cdef object _on_command
cdef object _context
cdef object _socket
cdef object _router
cdef object _dealer
cdef object _shutdown_event
cdef object _pull_socket
cdef object _pull_thread
cdef object _push_socket
cdef object _push_queue
cdef object _push_thread
cdef object _on_command
cdef object _proxy_thread
cdef object _workers
cdef start(self)
cdef _pull_loop(self)
cdef _push_loop(self)
cdef send(self, bytes message_bytes)
cdef _proxy_loop(self)
cdef _worker_loop(self)
cdef send(self, bytes client_id, bytes data)
cdef close(self)
+44 -56
View File
@@ -1,9 +1,7 @@
from queue import Queue
import zmq
import json
from threading import Thread, Event
from remote_command cimport RemoteCommand
cimport constants
cdef class RemoteCommandHandler:
def __init__(self, object on_command):
@@ -11,68 +9,58 @@ cdef class RemoteCommandHandler:
self._context = zmq.Context.instance()
self._shutdown_event = Event()
self._pull_socket = self._context.socket(zmq.PULL)
self._pull_socket.setsockopt(zmq.LINGER, 0)
self._pull_socket.bind("tcp://*:5127")
self._pull_thread = Thread(target=self._pull_loop, daemon=True)
self._router = self._context.socket(zmq.ROUTER)
self._router.setsockopt(zmq.LINGER, 0)
self._router.bind(f'tcp://*:{constants.ZMQ_PORT}')
self._push_queue = Queue()
self._dealer = self._context.socket(zmq.DEALER)
self._dealer.setsockopt(zmq.LINGER, 0)
self._dealer.bind("inproc://backend")
self._push_socket = self._context.socket(zmq.PUSH)
self._push_socket.setsockopt(zmq.LINGER, 0)
self._push_socket.bind("tcp://*:5128")
self._push_thread = Thread(target=self._push_loop, daemon=True)
self._proxy_thread = Thread(target=self._proxy_loop, daemon=True)
self._workers = []
for _ in range(4): # 4 worker threads
worker = Thread(target=self._worker_loop, daemon=True)
self._workers.append(worker)
cdef start(self):
self._pull_thread.start()
self._push_thread.start()
self._proxy_thread.start()
for worker in self._workers:
worker.start()
cdef _pull_loop(self):
cdef _proxy_loop(self):
zmq.proxy(self._router, self._dealer)
cdef _worker_loop(self):
worker_socket = self._context.socket(zmq.DEALER)
worker_socket.setsockopt(zmq.LINGER, 0)
worker_socket.connect("inproc://backend")
poller = zmq.Poller()
poller.register(worker_socket, zmq.POLLIN)
print('started receiver loop...')
while not self._shutdown_event.is_set():
print('wait for the command...')
message = self._pull_socket.recv()
cmd = RemoteCommand.from_msgpack(<bytes>message)
print(f'received: {cmd}')
try:
socks = dict(poller.poll(500))
if worker_socket in socks:
client_id, message = worker_socket.recv_multipart()
cmd = RemoteCommand.from_msgpack(<bytes> message)
cmd.client_id = client_id
print(f'Received [{cmd}] from the client {client_id}')
self._on_command(cmd)
cdef _push_loop(self):
while not self._shutdown_event.is_set():
try:
response = self._push_queue.get(timeout=1) # Timeout to check shutdown flag
self._push_socket.send(response)
except:
continue
cdef send(self, bytes message_bytes):
print(f'about to send {len(message_bytes)}')
try:
self._push_queue.put(message_bytes)
except Exception as e:
print(e)
print(f"Worker error: {e}")
import traceback
traceback.print_exc()
cdef send(self, bytes client_id, bytes data):
with self._context.socket(zmq.DEALER) as socket:
socket.connect("inproc://backend")
socket.send_multipart([client_id, data])
print(f'{len(data)} bytes was sent to client {client_id}')
cdef close(self):
self._shutdown_event.set()
self._pull_socket.close()
self._push_socket.close()
self._router.close()
self._dealer.close()
self._context.term()
cdef class QueueConfig:
cdef str host,
cdef int port, command_port
cdef str producer_user, producer_pw, consumer_user, consumer_pw
@staticmethod
cdef QueueConfig from_json(str json_string):
s = str(json_string).strip()
cdef dict config_dict = json.loads(s)["QueueConfig"]
cdef QueueConfig config = QueueConfig()
config.host = config_dict["Host"]
config.port = config_dict["Port"]
config.command_port = config_dict["CommandsPort"]
config.producer_user = config_dict["ProducerUsername"]
config.producer_pw = config_dict["ProducerPassword"]
config.consumer_user = config_dict["ConsumerUsername"]
config.consumer_pw = config_dict["ConsumerPassword"]
return config
+12
View File
@@ -0,0 +1,12 @@
cdef class SecureModelLoader:
cdef:
bytes _model_bytes
str _ramdisk_path
str _temp_file_path
int _disk_size_mb
cpdef str load_model(self, bytes model_bytes)
cdef str _get_ramdisk_path(self)
cdef void _create_ramdisk(self)
cdef void _store_model(self)
cdef void _cleanup(self)
+104
View File
@@ -0,0 +1,104 @@
import os
import platform
import tempfile
from pathlib import Path
from libc.stdio cimport FILE, fopen, fclose, remove
from libc.stdlib cimport free
from libc.string cimport strdup
cdef class SecureModelLoader:
def __cinit__(self, int disk_size_mb=512):
self._disk_size_mb = disk_size_mb
self._ramdisk_path = None
self._temp_file_path = None
cpdef str load_model(self, bytes model_bytes):
"""Public method to load YOLO model securely."""
self._model_bytes = model_bytes
self._create_ramdisk()
self._store_model()
return self._temp_file_path
cdef str _get_ramdisk_path(self):
"""Determine the RAM disk path based on the OS."""
if platform.system() == "Windows":
return "R:\\"
elif platform.system() == "Linux":
return "/mnt/ramdisk"
elif platform.system() == "Darwin":
return "/Volumes/RAMDisk"
else:
raise RuntimeError("Unsupported OS for RAM disk")
cdef void _create_ramdisk(self):
"""Create a RAM disk securely based on the OS."""
system = platform.system()
if system == "Windows":
# Create RAM disk via PowerShell
command = f'powershell -Command "subst R: {tempfile.gettempdir()}"'
if os.system(command) != 0:
raise RuntimeError("Failed to create RAM disk on Windows")
self._ramdisk_path = "R:\\"
elif system == "Linux":
# Use tmpfs for RAM disk
self._ramdisk_path = "/mnt/ramdisk"
if not Path(self._ramdisk_path).exists():
os.mkdir(self._ramdisk_path)
if os.system(f"mount -t tmpfs -o size={self._disk_size_mb}M tmpfs {self._ramdisk_path}") != 0:
raise RuntimeError("Failed to create RAM disk on Linux")
elif system == "Darwin":
# Use hdiutil for macOS RAM disk
block_size = 2048 # 512-byte blocks * 2048 = 1MB
num_blocks = self._disk_size_mb * block_size
result = os.popen(f"hdiutil attach -nomount ram://{num_blocks}").read().strip()
if result:
self._ramdisk_path = "/Volumes/RAMDisk"
os.system(f"diskutil eraseVolume HFS+ RAMDisk {result}")
else:
raise RuntimeError("Failed to create RAM disk on macOS")
cdef void _store_model(self):
"""Write model securely to the RAM disk."""
cdef char* temp_path
cdef FILE* cfile
with tempfile.NamedTemporaryFile(
dir=self._ramdisk_path, suffix='.pt', delete=False
) as tmp_file:
tmp_file.write(self._model_bytes)
self._temp_file_path = tmp_file.name
encoded_path = self._temp_file_path.encode('utf-8')
temp_path = strdup(encoded_path)
with nogil:
cfile = fopen(temp_path, "rb")
if cfile == NULL:
raise IOError(f"Could not open {self._temp_file_path}")
fclose(cfile)
cdef void _cleanup(self):
"""Remove the model file and unmount RAM disk securely."""
cdef char* c_path
if self._temp_file_path:
c_path = strdup(os.fsencode(self._temp_file_path))
with nogil:
remove(c_path)
free(c_path)
self._temp_file_path = None
# Unmount RAM disk based on OS
if self._ramdisk_path:
if platform.system() == "Windows":
os.system("subst R: /D")
elif platform.system() == "Linux":
os.system(f"umount {self._ramdisk_path}")
elif platform.system() == "Darwin":
os.system("hdiutil detach /Volumes/RAMDisk")
self._ramdisk_path = None
def __dealloc__(self):
"""Ensure cleanup when the object is deleted."""
self._cleanup()
+10 -1
View File
@@ -8,7 +8,10 @@ extensions = [
Extension('hardware_service', ['hardware_service.pyx'], extra_compile_args=["-g"], extra_link_args=["-g"]),
Extension('remote_command', ['remote_command.pyx']),
Extension('remote_command_handler', ['remote_command_handler.pyx']),
Extension('user', ['user.pyx']),
Extension('api_client', ['api_client.pyx']),
Extension('secure_model', ['secure_model.pyx']),
Extension('ai_config', ['ai_config.pyx']),
Extension('inference', ['inference.pyx']),
Extension('main', ['main.pyx']),
@@ -21,8 +24,14 @@ setup(
compiler_directives={
"language_level": 3,
"emit_code_comments" : False,
"binding": True
"binding": True,
'boundscheck': False,
'wraparound': False
}
),
install_requires=[
'ultralytics>=8.0.0',
'pywin32; platform_system=="Windows"'
],
zip_safe=False
)
+1
View File
@@ -0,0 +1 @@
eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1laWQiOiJkOTBhMzZjYS1lMjM3LTRmYmQtOWM3Yy0xMjcwNDBhYzg1NTYiLCJ1bmlxdWVfbmFtZSI6ImFkbWluQGF6YWlvbi5jb20iLCJyb2xlIjoiQXBpQWRtaW4iLCJuYmYiOjE3MzgxNjM2MzYsImV4cCI6MTczODE3ODAzNiwiaWF0IjoxNzM4MTYzNjM2LCJpc3MiOiJBemFpb25BcGkiLCJhdWQiOiJBbm5vdGF0b3JzL09yYW5nZVBpL0FkbWlucyJ9.7VVws5mwGqx--sGopOuZE9iu3dzt1UdVPXeje2KZTYk
+15
View File
@@ -0,0 +1,15 @@
cdef enum RoleEnum:
NONE = 0
Operator = 10
Validator = 20
CompanionPC = 30
Admin = 40
ResourceUploader = 50
ApiAdmin = 1000
cdef class User:
cdef public str id
cdef public str email
cdef public RoleEnum role
cdef bytes serialize(self)
+15
View File
@@ -0,0 +1,15 @@
import msgpack
cdef class User:
def __init__(self, str id, str email, RoleEnum role):
self.id = id
self.email = email
self.role = role
cdef bytes serialize(self):
return msgpack.packb({
"i": self.id,
"e": self.email,
"r": self.role
})
+1 -1
View File
@@ -480,7 +480,7 @@
Grid.Column="10"
Padding="2" Width="25"
Height="25"
ToolTip="Розпізнати за допомогою AI. Клавіша: [A]" Background="Black" BorderBrush="Black"
ToolTip="Розпізнати за допомогою AI. Клавіша: [R]" Background="Black" BorderBrush="Black"
Click="AutoDetect">
<Path Stretch="Fill" Fill="LightGray" Data="M144.317 85.269h223.368c15.381 0 29.391 6.325 39.567 16.494l.025-.024c10.163 10.164 16.477 24.193 16.477
39.599v189.728c0 15.401-6.326 29.425-16.485 39.584-10.159 10.159-24.183 16.484-39.584 16.484H144.317c-15.4
+148 -154
View File
@@ -6,9 +6,7 @@ using System.Windows.Controls;
using System.Windows.Controls.Primitives;
using System.Windows.Input;
using System.Windows.Media;
using System.Windows.Media.Imaging;
using Azaion.Annotator.DTO;
using Azaion.Annotator.Extensions;
using Azaion.Common.Database;
using Azaion.Common.DTO;
using Azaion.Common.DTO.Config;
@@ -39,10 +37,9 @@ public partial class Annotator
private readonly IConfigUpdater _configUpdater;
private readonly HelpWindow _helpWindow;
private readonly ILogger<Annotator> _logger;
private readonly VLCFrameExtractor _vlcFrameExtractor;
private readonly IAIDetector _aiDetector;
private readonly AnnotationService _annotationService;
private readonly IDbFactory _dbFactory;
private readonly IInferenceService _inferenceService;
private readonly CancellationTokenSource _ctSource = new();
private ObservableCollection<DetectionClass> AnnotationClasses { get; set; } = new();
@@ -67,10 +64,9 @@ public partial class Annotator
FormState formState,
HelpWindow helpWindow,
ILogger<Annotator> logger,
VLCFrameExtractor vlcFrameExtractor,
IAIDetector aiDetector,
AnnotationService annotationService,
IDbFactory dbFactory)
IDbFactory dbFactory,
IInferenceService inferenceService)
{
InitializeComponent();
_appConfig = appConfig.Value;
@@ -81,10 +77,9 @@ public partial class Annotator
_formState = formState;
_helpWindow = helpWindow;
_logger = logger;
_vlcFrameExtractor = vlcFrameExtractor;
_aiDetector = aiDetector;
_annotationService = annotationService;
_dbFactory = dbFactory;
_inferenceService = inferenceService;
Loaded += OnLoaded;
Closed += OnFormClosed;
@@ -304,11 +299,16 @@ public partial class Annotator
var annotations = await _dbFactory.Run(async db =>
await db.Annotations.LoadWith(x => x.Detections)
.Where(x => x.Name.Contains(_formState.VideoName))
.Where(x => x.OriginalMediaName == _formState.VideoName)
.ToListAsync(token: _ctSource.Token));
TimedAnnotations.Clear();
_formState.AnnotationResults.Clear();
foreach (var ann in annotations)
AddAnnotation(ann);
{
TimedAnnotations.Add(ann.Time.Subtract(_thresholdBefore), ann.Time.Add(_thresholdAfter), ann);
_formState.AnnotationResults.Add(new AnnotationResult(_appConfig.AnnotationConfig.DetectionClassesDict, ann));
}
}
//Add manually
@@ -435,8 +435,6 @@ public partial class Annotator
_appConfig.DirectoriesConfig.VideosDirectory = dlg.FileName;
TbFolder.Text = dlg.FileName;
await ReloadFiles();
await SaveUserSettings();
}
private void TbFilter_OnTextChanged(object sender, TextChangedEventArgs e)
@@ -487,11 +485,8 @@ public partial class Annotator
if (LvFiles.SelectedIndex == -1)
LvFiles.SelectedIndex = 0;
await _mediator.Publish(new AnnotatorControlEvent(PlaybackControlEnum.Play));
_mediaPlayer.Stop();
var manualCancellationSource = new CancellationTokenSource();
var token = manualCancellationSource.Token;
var mct = new CancellationTokenSource();
var token = mct.Token;
_autoDetectDialog = new AutodetectDialog
{
@@ -500,7 +495,7 @@ public partial class Annotator
};
_autoDetectDialog.Closing += (_, _) =>
{
manualCancellationSource.Cancel();
mct.Cancel();
_mediaPlayer.SeekTo(TimeSpan.Zero);
Editor.RemoveAllAnns();
};
@@ -515,16 +510,17 @@ public partial class Annotator
var mediaInfo = Dispatcher.Invoke(() => (MediaFileInfo)LvFiles.SelectedItem);
while (mediaInfo != null)
{
_formState.CurrentMedia = mediaInfo;
await Dispatcher.Invoke(async () => await ReloadAnnotations());
if (mediaInfo.MediaType == MediaTypes.Image)
await Dispatcher.Invoke(async () =>
{
await DetectImage(mediaInfo, manualCancellationSource, token);
await Task.Delay(70, token);
}
else
await DetectVideo(mediaInfo, manualCancellationSource, token);
await _mediator.Publish(new AnnotatorControlEvent(PlaybackControlEnum.Play), token);
await ReloadAnnotations();
});
await _inferenceService.RunInference(mediaInfo.Path, async (annotationImage, ct) =>
{
annotationImage.OriginalMediaName = mediaInfo.FName;
await ProcessDetection(annotationImage, ct);
}, token);
mediaInfo = Dispatcher.Invoke(() =>
{
@@ -533,6 +529,7 @@ public partial class Annotator
LvFiles.SelectedIndex += 1;
return (MediaFileInfo)LvFiles.SelectedItem;
});
LvFiles.Items.Refresh();
}
Dispatcher.Invoke(() =>
{
@@ -546,143 +543,140 @@ public partial class Annotator
Dispatcher.Invoke(() => Editor.ResetBackground());
}
private async Task DetectImage(MediaFileInfo mediaInfo, CancellationTokenSource manualCancellationSource, CancellationToken token)
// private async Task DetectImage(MediaFileInfo mediaInfo, CancellationTokenSource manualCancellationSource, CancellationToken token)
// {
// try
// {
// var fName = Path.GetFileNameWithoutExtension(mediaInfo.Path);
// var stream = new FileStream(mediaInfo.Path, FileMode.Open);
// var detections = await _aiDetector.Detect(fName, stream, token);
// await ProcessDetection((TimeSpan.FromMilliseconds(0), stream), Path.GetExtension(mediaInfo.Path), detections, token);
// if (detections.Count != 0)
// mediaInfo.HasAnnotations = true;
// }
// catch (Exception e)
// {
// _logger.LogError(e, e.Message);
// await manualCancellationSource.CancelAsync();
// }
// }
// private async Task DetectVideo(MediaFileInfo mediaInfo, CancellationTokenSource manualCancellationSource, CancellationToken token)
// {
// var prevSeekTime = 0.0;
// await foreach (var timeframe in _vlcFrameExtractor.ExtractFrames(mediaInfo.Path, token))
// {
// Console.WriteLine($"Detect time: {timeframe.Time}");
// try
// {
// var fName = _formState.GetTimeName(timeframe.Time);
// var detections = await _aiDetector.Detect(fName, timeframe.Stream, token);
//
// var isValid = IsValidDetection(timeframe.Time, detections);
// Console.WriteLine($"Detection time: {timeframe.Time}");
//
// var log = string.Join(Environment.NewLine, detections.Select(det =>
// $"{_appConfig.AnnotationConfig.DetectionClassesDict[det.ClassNumber].Name}: " +
// $"xy=({det.CenterX:F2},{det.CenterY:F2}), " +
// $"size=({det.Width:F2}, {det.Height:F2}), " +
// $"prob: {det.Probability:F1}%"));
//
// log = $"Detection time: {timeframe.Time}, Valid: {isValid}. {Environment.NewLine} {log}";
// Dispatcher.Invoke(() => _autoDetectDialog.Log(log));
//
// if (timeframe.Time.TotalMilliseconds > prevSeekTime + 250)
// {
// Dispatcher.Invoke(() => SeekTo(timeframe.Time));
// prevSeekTime = timeframe.Time.TotalMilliseconds;
// if (!isValid) //Show frame anyway
// {
// Dispatcher.Invoke(() =>
// {
// Editor.RemoveAllAnns();
// Editor.Background = new ImageBrush
// {
// ImageSource = timeframe.Stream.OpenImage()
// };
// });
// }
// }
//
// if (!isValid)
// continue;
//
// mediaInfo.HasAnnotations = true;
// await ProcessDetection(timeframe, ".jpg", detections, token);
// await timeframe.Stream.DisposeAsync();
// }
// catch (Exception ex)
// {
// _logger.LogError(ex, ex.Message);
// await manualCancellationSource.CancelAsync();
// }
// }
// }
// private bool IsValidDetection(TimeSpan time, List<Detection> detections)
// {
// // No AI detection, forbid
// if (detections.Count == 0)
// return false;
//
// // Very first detection, allow
// if (!_previousDetection.HasValue)
// return true;
//
// var prev = _previousDetection.Value;
//
// // Time between detections is >= than Frame Recognition Seconds, allow
// if (time >= prev.Time.Add(TimeSpan.FromSeconds(_appConfig.AIRecognitionConfig.FrameRecognitionSeconds)))
// return true;
//
// // Detection is earlier than previous + FrameRecognitionSeconds.
// // Look to the detections more in detail
//
// // More detected objects, allow
// if (detections.Count > prev.Detections.Count)
// return true;
//
// foreach (var det in detections)
// {
// var point = new Point(det.CenterX, det.CenterY);
// var closestObject = prev.Detections
// .Select(p => new
// {
// Point = p,
// Distance = point.SqrDistance(new Point(p.CenterX, p.CenterY))
// })
// .OrderBy(x => x.Distance)
// .First();
//
// // Closest object is farther than Tracking distance confidence, hence it's a different object, allow
// if (closestObject.Distance > _appConfig.AIRecognitionConfig.TrackingDistanceConfidence)
// return true;
//
// // Since closest object within distance confidence, then it is tracking of the same object. Then if recognition probability for the object > increase from previous
// if (det.Probability >= closestObject.Point.Probability + _appConfig.AIRecognitionConfig.TrackingProbabilityIncrease)
// return true;
// }
//
// return false;
// }
private async Task ProcessDetection(AnnotationImage annotationImage, CancellationToken token = default)
{
try
{
var fName = Path.GetFileNameWithoutExtension(mediaInfo.Path);
var stream = new FileStream(mediaInfo.Path, FileMode.Open);
var detections = await _aiDetector.Detect(fName, stream, token);
await ProcessDetection((TimeSpan.FromMilliseconds(0), stream), Path.GetExtension(mediaInfo.Path), detections, token);
if (detections.Count != 0)
mediaInfo.HasAnnotations = true;
}
catch (Exception e)
{
_logger.LogError(e, e.Message);
await manualCancellationSource.CancelAsync();
}
}
private async Task DetectVideo(MediaFileInfo mediaInfo, CancellationTokenSource manualCancellationSource, CancellationToken token)
{
var prevSeekTime = 0.0;
await foreach (var timeframe in _vlcFrameExtractor.ExtractFrames(mediaInfo.Path, token))
{
Console.WriteLine($"Detect time: {timeframe.Time}");
try
{
var fName = _formState.GetTimeName(timeframe.Time);
var detections = await _aiDetector.Detect(fName, timeframe.Stream, token);
var isValid = IsValidDetection(timeframe.Time, detections);
Console.WriteLine($"Detection time: {timeframe.Time}");
var log = string.Join(Environment.NewLine, detections.Select(det =>
$"{_appConfig.AnnotationConfig.DetectionClassesDict[det.ClassNumber].Name}: " +
$"xy=({det.CenterX:F2},{det.CenterY:F2}), " +
$"size=({det.Width:F2}, {det.Height:F2}), " +
$"prob: {det.Probability:F1}%"));
log = $"Detection time: {timeframe.Time}, Valid: {isValid}. {Environment.NewLine} {log}";
Dispatcher.Invoke(() => _autoDetectDialog.Log(log));
if (timeframe.Time.TotalMilliseconds > prevSeekTime + 250)
{
Dispatcher.Invoke(() => SeekTo(timeframe.Time));
prevSeekTime = timeframe.Time.TotalMilliseconds;
if (!isValid) //Show frame anyway
{
Dispatcher.Invoke(() =>
{
Editor.RemoveAllAnns();
Editor.Background = new ImageBrush
{
ImageSource = timeframe.Stream.OpenImage()
};
});
}
}
if (!isValid)
continue;
mediaInfo.HasAnnotations = true;
await ProcessDetection(timeframe, ".jpg", detections, token);
await timeframe.Stream.DisposeAsync();
}
catch (Exception ex)
{
_logger.LogError(ex, ex.Message);
await manualCancellationSource.CancelAsync();
}
}
}
private bool IsValidDetection(TimeSpan time, List<Detection> detections)
{
// No AI detection, forbid
if (detections.Count == 0)
return false;
// Very first detection, allow
if (!_previousDetection.HasValue)
return true;
var prev = _previousDetection.Value;
// Time between detections is >= than Frame Recognition Seconds, allow
if (time >= prev.Time.Add(TimeSpan.FromSeconds(_appConfig.AIRecognitionConfig.FrameRecognitionSeconds)))
return true;
// Detection is earlier than previous + FrameRecognitionSeconds.
// Look to the detections more in detail
// More detected objects, allow
if (detections.Count > prev.Detections.Count)
return true;
foreach (var det in detections)
{
var point = new Point(det.CenterX, det.CenterY);
var closestObject = prev.Detections
.Select(p => new
{
Point = p,
Distance = point.SqrDistance(new Point(p.CenterX, p.CenterY))
})
.OrderBy(x => x.Distance)
.First();
// Closest object is farther than Tracking distance confidence, hence it's a different object, allow
if (closestObject.Distance > _appConfig.AIRecognitionConfig.TrackingDistanceConfidence)
return true;
// Since closest object within distance confidence, then it is tracking of the same object. Then if recognition probability for the object > increase from previous
if (det.Probability >= closestObject.Point.Probability + _appConfig.AIRecognitionConfig.TrackingProbabilityIncrease)
return true;
}
return false;
}
private async Task ProcessDetection((TimeSpan Time, Stream Stream) timeframe, string imageExtension, List<Detection> detections, CancellationToken token = default)
{
_previousDetection = (timeframe.Time, detections);
await Dispatcher.Invoke(async () =>
{
try
{
var fName = _formState.GetTimeName(timeframe.Time);
var annotation = await _annotationService.SaveAnnotation(fName, imageExtension, detections, SourceEnum.AI, timeframe.Stream, token);
var annotation = await _annotationService.SaveAnnotation(annotationImage, token);
Editor.Background = new ImageBrush { ImageSource = await annotation.ImagePath.OpenImage() };
Editor.RemoveAllAnns();
ShowAnnotations(annotation, true);
AddAnnotation(annotation);
var log = string.Join(Environment.NewLine, detections.Select(det =>
var log = string.Join(Environment.NewLine, annotation.Detections.Select(det =>
$"{_appConfig.AnnotationConfig.DetectionClassesDict[det.ClassNumber].Name}: " +
$"xy=({det.CenterX:F2},{det.CenterY:F2}), " +
$"size=({det.Width:F2}, {det.Height:F2}), " +
+4 -4
View File
@@ -2,11 +2,11 @@
using System.Windows;
using System.Windows.Input;
using Azaion.Annotator.DTO;
using Azaion.Common;
using Azaion.Common.DTO;
using Azaion.Common.DTO.Config;
using Azaion.Common.DTO.Queue;
using Azaion.Common.Events;
using Azaion.Common.Extensions;
using Azaion.Common.Services;
using LibVLCSharp.Shared;
using MediatR;
@@ -79,7 +79,7 @@ public class AnnotatorEventHandler(
if (_keysControlEnumDict.TryGetValue(key, out var value))
await ControlPlayback(value, cancellationToken);
if (key == Key.A)
if (key == Key.R)
mainWindow.AutoDetect(null!, null!);
#region Volume
@@ -228,7 +228,7 @@ public class AnnotatorEventHandler(
return;
var time = formState.BackgroundTime ?? TimeSpan.FromMilliseconds(mediaPlayer.Time);
var fName = formState.GetTimeName(time);
var fName = formState.VideoName.ToTimeName(time);
var currentDetections = mainWindow.Editor.CurrentDetections
.Select(x => new Detection(fName, x.GetLabel(mainWindow.Editor.RenderSize, formState.BackgroundTime.HasValue ? mainWindow.Editor.RenderSize : formState.CurrentVideoSize)))
@@ -267,7 +267,7 @@ public class AnnotatorEventHandler(
File.Copy(formState.CurrentMedia.Path, imgPath, overwrite: true);
NextMedia();
}
var annotation = await annotationService.SaveAnnotation(fName, imageExtension, currentDetections, SourceEnum.Manual, token: cancellationToken);
var annotation = await annotationService.SaveAnnotation(formState.VideoName, time, imageExtension, currentDetections, SourceEnum.Manual, token: cancellationToken);
mainWindow.AddAnnotation(annotation);
}
@@ -1,9 +0,0 @@
using System.Windows;
namespace Azaion.Annotator.Extensions;
public static class PointExtensions
{
public static double SqrDistance(this Point p1, Point p2) =>
(p2.X - p1.X) * (p2.X - p1.X) + (p2.Y - p1.Y) * (p2.Y - p1.Y);
}
@@ -1,130 +0,0 @@
using System.Collections.Concurrent;
using System.IO;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using Azaion.Common.DTO.Config;
using LibVLCSharp.Shared;
using Microsoft.Extensions.Options;
using SkiaSharp;
namespace Azaion.Annotator.Extensions;
public class VLCFrameExtractor(LibVLC libVLC, IOptions<AIRecognitionConfig> config)
{
private const uint RGBA_BYTES = 4;
private const int PLAYBACK_RATE = 4;
private uint _pitch; // Number of bytes per "line", aligned to x32.
private uint _lines; // Number of lines in the buffer, aligned to x32.
private uint _width; // Thumbnail width
private uint _height; // Thumbnail height
private MediaPlayer _mediaPlayer = null!;
private TimeSpan _lastFrameTimestamp;
private long _lastFrame;
private static uint Align32(uint size)
{
if (size % 32 == 0)
return size;
return (size / 32 + 1) * 32;// Align on the next multiple of 32
}
private static SKBitmap? _currentBitmap;
private static readonly ConcurrentQueue<FrameInfo> FramesQueue = new();
private static long _frameCounter;
public async IAsyncEnumerable<(TimeSpan Time, Stream Stream)> ExtractFrames(string mediaPath,
[EnumeratorCancellation] CancellationToken manualCancellationToken = default)
{
var videoFinishedCancellationSource = new CancellationTokenSource();
_mediaPlayer = new MediaPlayer(libVLC);
_mediaPlayer.Stopped += (s, e) => videoFinishedCancellationSource.CancelAfter(1);
using var media = new Media(libVLC, mediaPath);
await media.Parse(cancellationToken: videoFinishedCancellationSource.Token);
var videoTrack = media.Tracks.FirstOrDefault(x => x.Data.Video.Width != 0);
_width = videoTrack.Data.Video.Width;
_height = videoTrack.Data.Video.Height;
_pitch = Align32(_width * RGBA_BYTES);
_lines = Align32(_height);
_mediaPlayer.SetRate(PLAYBACK_RATE);
media.AddOption(":no-audio");
_mediaPlayer.SetVideoFormat("RV32", _width, _height, _pitch);
_mediaPlayer.SetVideoCallbacks(Lock, null, Display);
_mediaPlayer.Play(media);
_frameCounter = 0;
var surface = SKSurface.Create(new SKImageInfo((int) _width, (int) _height));
var videoFinishedCT = videoFinishedCancellationSource.Token;
while ( !(FramesQueue.IsEmpty && videoFinishedCT.IsCancellationRequested || manualCancellationToken.IsCancellationRequested))
{
if (FramesQueue.TryDequeue(out var frameInfo))
{
if (frameInfo.Bitmap == null)
continue;
surface.Canvas.DrawBitmap(frameInfo.Bitmap, 0, 0); // Effectively crops the original bitmap to get only the visible area
using var outputImage = surface.Snapshot();
using var data = outputImage.Encode(SKEncodedImageFormat.Jpeg, 85);
var ms = new MemoryStream();
data.SaveTo(ms);
yield return (frameInfo.Time, ms);
frameInfo.Bitmap?.Dispose();
}
else
{
await Task.Delay(TimeSpan.FromSeconds(1), videoFinishedCT);
}
}
FramesQueue.Clear(); //clear queue in case of manual stop
_mediaPlayer.Stop();
_mediaPlayer.Dispose();
}
private IntPtr Lock(IntPtr opaque, IntPtr planes)
{
_currentBitmap = new SKBitmap(new SKImageInfo((int)(_pitch / RGBA_BYTES), (int)_lines, SKColorType.Bgra8888));
Marshal.WriteIntPtr(planes, _currentBitmap.GetPixels());
return IntPtr.Zero;
}
private void Display(IntPtr opaque, IntPtr picture)
{
var playerTime = TimeSpan.FromMilliseconds(_mediaPlayer.Time);
if (_lastFrameTimestamp != playerTime)
{
_lastFrame = _frameCounter;
_lastFrameTimestamp = playerTime;
}
if (_frameCounter > 20 && _frameCounter % config.Value.FramePeriodRecognition == 0)
{
var msToAdd = (_frameCounter - _lastFrame) * (_lastFrame == 0 ? 0 : _lastFrameTimestamp.TotalMilliseconds / _lastFrame);
var time = _lastFrameTimestamp.Add(TimeSpan.FromMilliseconds(msToAdd));
FramesQueue.Enqueue(new FrameInfo(time, _currentBitmap));
}
else
{
_currentBitmap?.Dispose();
}
_currentBitmap = null;
_frameCounter++;
}
}
public class FrameInfo(TimeSpan time, SKBitmap? bitmap)
{
public TimeSpan Time { get; set; } = time;
public SKBitmap? Bitmap { get; set; } = bitmap;
}
-87
View File
@@ -1,87 +0,0 @@
using System.Diagnostics;
using System.IO;
using Azaion.Annotator.Extensions;
using Azaion.Common.DTO;
using Azaion.Common.DTO.Config;
using Azaion.CommonSecurity.Services;
using Compunet.YoloV8;
using Microsoft.Extensions.Options;
using SixLabors.ImageSharp;
using SixLabors.ImageSharp.PixelFormats;
using Detection = Azaion.Common.DTO.Detection;
namespace Azaion.Annotator;
public interface IAIDetector
{
Task<List<Detection>> Detect(string fName, Stream imageStream, CancellationToken cancellationToken = default);
}
public class YOLODetector(IOptions<AIRecognitionConfig> recognitionConfig, IResourceLoader resourceLoader) : IAIDetector, IDisposable
{
private readonly AIRecognitionConfig _recognitionConfig = recognitionConfig.Value;
private YoloPredictor? _predictor;
private const string YOLO_MODEL = "azaion.onnx";
public async Task<List<Detection>> Detect(string fName, Stream imageStream, CancellationToken cancellationToken)
{
if (_predictor == null)
{
await using var stream = await resourceLoader.Load(YOLO_MODEL, cancellationToken);
_predictor = new YoloPredictor(stream.ToArray());
}
imageStream.Seek(0, SeekOrigin.Begin);
using var image = Image.Load<Rgb24>(imageStream);
var result = await _predictor.DetectAsync(image);
var imageSize = new System.Windows.Size(image.Width, image.Height);
var detections = result.Select(d =>
{
var label = new YoloLabel(new CanvasLabel(d.Name.Id, d.Bounds.X, d.Bounds.Y, d.Bounds.Width, d.Bounds.Height), imageSize, imageSize);
return new Detection(fName, label, (double?)d.Confidence * 100);
}).ToList();
return FilterOverlapping(detections);
}
private List<Detection> FilterOverlapping(List<Detection> detections)
{
var k = _recognitionConfig.TrackingIntersectionThreshold;
var filteredDetections = new List<Detection>();
for (var i = 0; i < detections.Count; i++)
{
var detectionSelected = false;
for (var j = i + 1; j < detections.Count; j++)
{
var intersect = detections[i].ToRectangle();
intersect.Intersect(detections[j].ToRectangle());
var maxArea = Math.Max(detections[i].ToRectangle().Area(), detections[j].ToRectangle().Area());
if (!(intersect.Area() > k * maxArea))
continue;
if (detections[i].Probability > detections[j].Probability)
{
filteredDetections.Add(detections[i]);
detections.RemoveAt(j);
}
else
{
filteredDetections.Add(detections[j]);
detections.RemoveAt(i);
}
detectionSelected = true;
break;
}
if (!detectionSelected)
filteredDetections.Add(detections[i]);
}
return filteredDetections;
}
public void Dispose() => _predictor?.Dispose();
}
@@ -1,10 +1,15 @@
using MessagePack;
namespace Azaion.Common.DTO.Config;
[MessagePackObject]
public class AIRecognitionConfig
{
public double FrameRecognitionSeconds { get; set; }
public double TrackingDistanceConfidence { get; set; }
public double TrackingProbabilityIncrease { get; set; }
public double TrackingIntersectionThreshold { get; set; }
public int FramePeriodRecognition { get; set; }
[Key("FrameRecognitionSeconds")] public double FrameRecognitionSeconds { get; set; }
[Key("TrackingDistanceConfidence")] public double TrackingDistanceConfidence { get; set; }
[Key("TrackingProbabilityIncrease")] public double TrackingProbabilityIncrease { get; set; }
[Key("TrackingIntersectionThreshold")] public double TrackingIntersectionThreshold { get; set; }
[Key("FramePeriodRecognition")] public int FramePeriodRecognition { get; set; }
[Key("Data")] public byte[] Data { get; set; }
}
-9
View File
@@ -8,8 +8,6 @@ namespace Azaion.Common.DTO.Config;
public class AppConfig
{
public ApiConfig ApiConfig { get; set; } = null!;
public QueueConfig QueueConfig { get; set; } = null!;
public DirectoriesConfig DirectoriesConfig { get; set; } = null!;
@@ -39,13 +37,6 @@ public class ConfigUpdater : IConfigUpdater
var appConfig = new AppConfig
{
ApiConfig = new ApiConfig
{
Url = SecurityConstants.DEFAULT_API_URL,
RetryCount = SecurityConstants.DEFAULT_API_RETRY_COUNT,
TimeoutSeconds = SecurityConstants.DEFAULT_API_TIMEOUT_SECONDS
},
AnnotationConfig = new AnnotationConfig
{
AnnotationClasses = Constants.DefaultAnnotationClasses,
-2
View File
@@ -16,6 +16,4 @@ public class FormState
public int CurrentVolume { get; set; } = 100;
public ObservableCollection<AnnotationResult> AnnotationResults { get; set; } = [];
public WindowEnum ActiveWindow { get; set; }
public string GetTimeName(TimeSpan? ts) => $"{VideoName}_{ts:hmmssf}";
}
+13 -11
View File
@@ -1,18 +1,18 @@
using System.Drawing;
using System.Globalization;
using System.IO;
using MessagePack;
using Newtonsoft.Json;
using Size = System.Windows.Size;
namespace Azaion.Common.DTO;
[MessagePackObject]
public abstract class Label
{
[JsonProperty(PropertyName = "cl")] public int ClassNumber { get; set; }
[JsonProperty(PropertyName = "cl")][Key("c")] public int ClassNumber { get; set; }
protected Label()
{
}
protected Label() { }
protected Label(int classNumber)
{
@@ -79,15 +79,16 @@ public class CanvasLabel : Label
}
}
[MessagePackObject]
public class YoloLabel : Label
{
[JsonProperty(PropertyName = "x")] public double CenterX { get; set; }
[JsonProperty(PropertyName = "x")][Key("x")] public double CenterX { get; set; }
[JsonProperty(PropertyName = "y")] public double CenterY { get; set; }
[JsonProperty(PropertyName = "y")][Key("y")] public double CenterY { get; set; }
[JsonProperty(PropertyName = "w")] public double Width { get; set; }
[JsonProperty(PropertyName = "w")][Key("w")] public double Width { get; set; }
[JsonProperty(PropertyName = "h")] public double Height { get; set; }
[JsonProperty(PropertyName = "h")][Key("h")] public double Height { get; set; }
public YoloLabel()
{
@@ -184,12 +185,13 @@ public class YoloLabel : Label
public override string ToString() => $"{ClassNumber} {CenterX:F5} {CenterY:F5} {Width:F5} {Height:F5}".Replace(',', '.');
}
[MessagePackObject]
public class Detection : YoloLabel
{
public string AnnotationName { get; set; } = null!;
public double? Probability { get; set; }
[IgnoreMember]public string AnnotationName { get; set; } = null!;
[Key("p")] public double? Probability { get; set; }
//For db
//For db & serialization
public Detection(){}
public Detection(string annotationName, YoloLabel label, double? probability = null)
+4 -2
View File
@@ -1,4 +1,6 @@
namespace Azaion.Common.DTO;
using Azaion.Common.Extensions;
namespace Azaion.Common.DTO;
public class MediaFileInfo
{
@@ -9,5 +11,5 @@ public class MediaFileInfo
public bool HasAnnotations { get; set; }
public MediaTypes MediaType { get; set; }
public string FName => System.IO.Path.GetFileNameWithoutExtension(Name).Replace(" ", "");
public string FName => Name.ToFName();
}
@@ -9,13 +9,15 @@ public class AnnotationCreatedMessage
{
[Key(0)] public DateTime CreatedDate { get; set; }
[Key(1)] public string Name { get; set; } = null!;
[Key(2)] public string ImageExtension { get; set; } = null!;
[Key(3)] public string Detections { get; set; } = null!;
[Key(4)] public byte[] Image { get; set; } = null!;
[Key(5)] public RoleEnum CreatedRole { get; set; }
[Key(6)] public string CreatedEmail { get; set; } = null!;
[Key(7)] public SourceEnum Source { get; set; }
[Key(8)] public AnnotationStatus Status { get; set; }
[Key(2)] public string OriginalMediaName { get; set; } = null!;
[Key(3)] public TimeSpan Time { get; set; }
[Key(4)] public string ImageExtension { get; set; } = null!;
[Key(5)] public string Detections { get; set; } = null!;
[Key(6)] public byte[] Image { get; set; } = null!;
[Key(7)] public RoleEnum CreatedRole { get; set; }
[Key(8)] public string CreatedEmail { get; set; } = null!;
[Key(9)] public SourceEnum Source { get; set; }
[Key(10)] public AnnotationStatus Status { get; set; }
}
[MessagePackObject]
+24 -39
View File
@@ -3,9 +3,11 @@ using Azaion.Common.DTO;
using Azaion.Common.DTO.Config;
using Azaion.Common.DTO.Queue;
using Azaion.CommonSecurity.DTO;
using MessagePack;
namespace Azaion.Common.Database;
[MessagePackObject]
public class Annotation
{
private static string _labelsDir = null!;
@@ -19,53 +21,36 @@ public class Annotation
_thumbDir = config.ThumbnailsDirectory;
}
public string Name { get; set; } = null!;
public string ImageExtension { get; set; } = null!;
public DateTime CreatedDate { get; set; }
public string CreatedEmail { get; set; } = null!;
public RoleEnum CreatedRole { get; set; }
public SourceEnum Source { get; set; }
public AnnotationStatus AnnotationStatus { get; set; }
[IgnoreMember]public string Name { get; set; } = null!;
[IgnoreMember]public string OriginalMediaName { get; set; } = null!;
[IgnoreMember]public TimeSpan Time { get; set; }
[IgnoreMember]public string ImageExtension { get; set; } = null!;
[IgnoreMember]public DateTime CreatedDate { get; set; }
[IgnoreMember]public string CreatedEmail { get; set; } = null!;
[IgnoreMember]public RoleEnum CreatedRole { get; set; }
[IgnoreMember]public SourceEnum Source { get; set; }
[IgnoreMember]public AnnotationStatus AnnotationStatus { get; set; }
public IEnumerable<Detection> Detections { get; set; } = null!;
[Key("d")] public IEnumerable<Detection> Detections { get; set; } = null!;
[Key("t")] public long Milliseconds { get; set; }
public double Lat { get; set; }
public double Lon { get; set; }
[Key("lat")]public double Lat { get; set; }
[Key("lon")]public double Lon { get; set; }
#region Calculated
public List<int> Classes => Detections.Select(x => x.ClassNumber).ToList();
public string ImagePath => Path.Combine(_imagesDir, $"{Name}{ImageExtension}");
public string LabelPath => Path.Combine(_labelsDir, $"{Name}.txt");
public string ThumbPath => Path.Combine(_thumbDir, $"{Name}{Constants.THUMBNAIL_PREFIX}.jpg");
public string OriginalMediaName => $"{Name[..^7]}";
[IgnoreMember]public List<int> Classes => Detections.Select(x => x.ClassNumber).ToList();
[IgnoreMember]public string ImagePath => Path.Combine(_imagesDir, $"{Name}{ImageExtension}");
[IgnoreMember]public string LabelPath => Path.Combine(_labelsDir, $"{Name}.txt");
[IgnoreMember]public string ThumbPath => Path.Combine(_thumbDir, $"{Name}{Constants.THUMBNAIL_PREFIX}.jpg");
private TimeSpan? _time;
public TimeSpan Time
{
get
{
if (_time.HasValue)
return _time.Value;
var timeStr = Name.Split("_").LastOrDefault();
//For some reason, TimeSpan.ParseExact doesn't work on every platform.
if (!string.IsNullOrEmpty(timeStr) &&
timeStr.Length == 6 &&
int.TryParse(timeStr[..1], out var hours) &&
int.TryParse(timeStr[1..3], out var minutes) &&
int.TryParse(timeStr[3..5], out var seconds) &&
int.TryParse(timeStr[5..], out var milliseconds))
return new TimeSpan(0, hours, minutes, seconds, milliseconds * 100);
_time = TimeSpan.FromSeconds(0);
return _time.Value;
}
}
#endregion Calculated
}
[MessagePackObject]
public class AnnotationImage : Annotation
{
[Key("i")] public byte[] Image { get; set; }
}
public enum AnnotationStatus
{
+9 -6
View File
@@ -117,16 +117,19 @@ public static class AnnotationsDbSchemaHolder
MappingSchema = new MappingSchema();
var builder = new FluentMappingBuilder(MappingSchema);
builder.Entity<Annotation>()
.HasTableName(Constants.ANNOTATIONS_TABLENAME)
var annotationBuilder = builder.Entity<Annotation>();
annotationBuilder.HasTableName(Constants.ANNOTATIONS_TABLENAME)
.HasPrimaryKey(x => x.Name)
.Ignore(x => x.Time)
.Association(a => a.Detections, (a, d) => a.Name == d.AnnotationName)
.Property(x => x.Time).HasDataType(DataType.Int64).HasConversion(ts => ts.Ticks, t => new TimeSpan(t));
annotationBuilder
.Ignore(x => x.Milliseconds)
.Ignore(x => x.Classes)
.Ignore(x => x.Classes)
.Ignore(x => x.ImagePath)
.Ignore(x => x.LabelPath)
.Ignore(x => x.ThumbPath)
.Ignore(x => x.OriginalMediaName)
.Association(a => a.Detections, (a, d) => a.Name == d.AnnotationName);
.Ignore(x => x.ThumbPath);
builder.Entity<Detection>()
.HasTableName(Constants.DETECTIONS_TABLENAME);
+1 -1
View File
@@ -24,7 +24,7 @@ public class ParallelExt
return Task.CompletedTask;
}
};
var threadsCount = (int)(Environment.ProcessorCount * parallelOptions.CpuUtilPercent / 100.0);
var threadsCount = (int)Math.Round(Environment.ProcessorCount * parallelOptions.CpuUtilPercent / 100.0);
var processedCount = 0;
var chunkSize = Math.Max(1, (int)(source.Count / (decimal)threadsCount));
@@ -0,0 +1,12 @@
using System.IO;
namespace Azaion.Common.Extensions;
public static class StringExtensions
{
public static string ToFName(this string path) =>
Path.GetFileNameWithoutExtension(path).Replace(" ", "");
public static string ToTimeName(this string fName, TimeSpan? ts) =>
$"{fName}_{ts:hmmssf}";
}
+29 -21
View File
@@ -22,29 +22,31 @@ namespace Azaion.Common.Services;
public class AnnotationService : INotificationHandler<AnnotationsDeletedEvent>
{
private readonly AzaionApiClient _apiClient;
private readonly IDbFactory _dbFactory;
private readonly FailsafeAnnotationsProducer _producer;
private readonly IGalleryService _galleryService;
private readonly IMediator _mediator;
private readonly IHardwareService _hardwareService;
private readonly IAuthProvider _authProvider;
private readonly QueueConfig _queueConfig;
private Consumer _consumer = null!;
public AnnotationService(AzaionApiClient apiClient,
public AnnotationService(
IResourceLoader resourceLoader,
IDbFactory dbFactory,
FailsafeAnnotationsProducer producer,
IOptions<QueueConfig> queueConfig,
IGalleryService galleryService,
IMediator mediator,
IHardwareService hardwareService)
IHardwareService hardwareService,
IAuthProvider authProvider)
{
_apiClient = apiClient;
_dbFactory = dbFactory;
_producer = producer;
_galleryService = galleryService;
_mediator = mediator;
_hardwareService = hardwareService;
_authProvider = authProvider;
_queueConfig = queueConfig.Value;
Task.Run(async () => await Init()).Wait();
@@ -73,7 +75,8 @@ public class AnnotationService : INotificationHandler<AnnotationsDeletedEvent>
await SaveAnnotationInner(
msg.CreatedDate,
msg.Name,
msg.OriginalMediaName,
msg.Time,
msg.ImageExtension,
JsonConvert.DeserializeObject<List<Detection>>(msg.Detections) ?? [],
msg.Source,
@@ -98,36 +101,39 @@ public class AnnotationService : INotificationHandler<AnnotationsDeletedEvent>
});
}
//AI / Manual
public async Task<Annotation> SaveAnnotation(string fName, string imageExtension, List<Detection> detections, SourceEnum source, Stream? stream = null, CancellationToken token = default) =>
await SaveAnnotationInner(DateTime.UtcNow, fName, imageExtension, detections, source, stream, _apiClient.User.Role, _apiClient.User.Email, generateThumbnail: true, token);
//AI
public async Task<Annotation> SaveAnnotation(AnnotationImage a, CancellationToken cancellationToken = default)
{
a.Time = TimeSpan.FromMilliseconds(a.Milliseconds);
a.Name = a.OriginalMediaName.ToTimeName(a.Time);
return await SaveAnnotationInner(DateTime.Now, a.OriginalMediaName, a.Time, ".jpg", a.Detections.ToList(),
a.Source, new MemoryStream(a.Image), a.CreatedRole, a.CreatedEmail, generateThumbnail: true, cancellationToken);
}
//Manual
public async Task<Annotation> SaveAnnotation(string originalMediaName, TimeSpan time, string imageExtension, List<Detection> detections, SourceEnum source, Stream? stream = null, CancellationToken token = default) =>
await SaveAnnotationInner(DateTime.UtcNow, originalMediaName, time, imageExtension, detections, source, stream,
_authProvider.CurrentUser.Role, _authProvider.CurrentUser.Email, generateThumbnail: true, token);
//Manual Validate existing
public async Task ValidateAnnotation(Annotation annotation, CancellationToken token = default) =>
await SaveAnnotationInner(DateTime.UtcNow, annotation.Name, annotation.ImageExtension, annotation.Detections.ToList(), SourceEnum.Manual, null, _apiClient.User.Role, _apiClient.User.Email,
generateThumbnail: false, token);
await SaveAnnotationInner(DateTime.UtcNow, annotation.OriginalMediaName, annotation.Time, annotation.ImageExtension, annotation.Detections.ToList(), SourceEnum.Manual, null,
_authProvider.CurrentUser.Role, _authProvider.CurrentUser.Email, generateThumbnail: false, token);
// //Queue (only from operators)
// public async Task Consume(AnnotationCreatedMessage message, CancellationToken cancellationToken = default)
// {
//
// }
private async Task<Annotation> SaveAnnotationInner(DateTime createdDate, string fName, string imageExtension, List<Detection> detections, SourceEnum source, Stream? stream,
private async Task<Annotation> SaveAnnotationInner(DateTime createdDate, string originalMediaName, TimeSpan time, string imageExtension, List<Detection> detections, SourceEnum source, Stream? stream,
RoleEnum userRole,
string createdEmail,
bool generateThumbnail = false,
CancellationToken token = default)
{
//Flow for roles:
// Operator or (AI from any role) -> Created
// Validator, Admin & Manual -> Validated
AnnotationStatus status;
var fName = originalMediaName.ToTimeName(time);
var annotation = await _dbFactory.Run(async db =>
{
var ann = await db.Annotations.FirstOrDefaultAsync(x => x.Name == fName, token: token);
// Manual Save from Validators -> Validated
// otherwise Created
status = userRole.IsValidator() && source == SourceEnum.Manual
? AnnotationStatus.Validated
: AnnotationStatus.Created;
@@ -149,6 +155,8 @@ public class AnnotationService : INotificationHandler<AnnotationsDeletedEvent>
{
CreatedDate = createdDate,
Name = fName,
OriginalMediaName = originalMediaName,
Time = time,
ImageExtension = imageExtension,
CreatedEmail = createdEmail,
CreatedRole = userRole,
+3 -2
View File
@@ -76,7 +76,7 @@ public class FailsafeAnnotationsProducer
await _annotationConfirmProducer.Send(validatedMessages, CompressionType.Gzip);
await _dbFactory.Run(async db =>
await db.AnnotationsQueue.DeleteAsync(aq => messagesChunk.Any(x => aq.Name == x.Name), token: cancellationToken));
await db.AnnotationsQueue.DeleteAsync(aq => messagesChunk.Any(x => aq.Name == x.OriginalMediaName), token: cancellationToken));
sent = true;
}
catch (Exception e)
@@ -106,7 +106,8 @@ public class FailsafeAnnotationsProducer
var annCreateMessage = new AnnotationCreatedMessage
{
Name = annotation.Name,
OriginalMediaName = annotation.OriginalMediaName,
Time = annotation.Time,
CreatedRole = annotation.CreatedRole,
CreatedEmail = annotation.CreatedEmail,
CreatedDate = annotation.CreatedDate,
+34 -2
View File
@@ -8,6 +8,7 @@ using Azaion.Common.Database;
using Azaion.Common.DTO;
using Azaion.Common.DTO.Config;
using Azaion.Common.DTO.Queue;
using Azaion.Common.Extensions;
using Azaion.CommonSecurity.DTO;
using LinqToDB;
using LinqToDB.Data;
@@ -75,7 +76,6 @@ public class GalleryService(
var missedAnnotations = new ConcurrentBag<Annotation>();
try
{
var prefixLen = Constants.THUMBNAIL_PREFIX.Length;
var thumbnails = ThumbnailsDirectory.GetFiles()
@@ -105,9 +105,37 @@ public class GalleryService(
return;
var detections = (await YoloLabel.ReadFromFile(labelName, cancellationToken)).Select(x => new Detection(fName, x)).ToList();
//get names and time
var fileName = Path.GetFileNameWithoutExtension(file.Name);
var strings = fileName.Split("_");
var timeStr = strings.LastOrDefault();
string originalMediaName;
TimeSpan time;
//For some reason, TimeSpan.ParseExact doesn't work on every platform.
if (!string.IsNullOrEmpty(timeStr) &&
timeStr.Length == 6 &&
int.TryParse(timeStr[..1], out var hours) &&
int.TryParse(timeStr[1..3], out var minutes) &&
int.TryParse(timeStr[3..5], out var seconds) &&
int.TryParse(timeStr[5..], out var milliseconds))
{
time = new TimeSpan(0, hours, minutes, seconds, milliseconds * 100);
originalMediaName = fileName[..^7];
}
else
{
originalMediaName = fileName;
time = TimeSpan.FromSeconds(0);
}
var annotation = new Annotation
{
Name = fName,
Time = time,
OriginalMediaName = originalMediaName,
Name = file.Name.ToFName(),
ImageExtension = Path.GetExtension(file.Name),
Detections = detections,
CreatedDate = File.GetCreationTimeUtc(file.FullName),
@@ -142,6 +170,10 @@ public class GalleryService(
ProgressUpdateInterval = 200
});
}
catch (Exception e)
{
logger.LogError(e, $"Failed to refresh thumbnails! Error: {e.Message}");
}
finally
{
var copyOptions = new BulkCopyOptions
@@ -0,0 +1,53 @@
using System.Text;
using Azaion.Common.Database;
using Azaion.Common.DTO.Config;
using Azaion.CommonSecurity;
using Azaion.CommonSecurity.DTO.Commands;
using Azaion.CommonSecurity.Services;
using MessagePack;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using NetMQ;
using NetMQ.Sockets;
namespace Azaion.Common.Services;
public interface IInferenceService
{
Task RunInference(string mediaPath, Func<AnnotationImage, CancellationToken, Task> processAnnotation, CancellationToken ct = default);
}
public class PythonInferenceService(ILogger<PythonInferenceService> logger, IOptions<AIRecognitionConfig> aiConfigOptions) : IInferenceService
{
public async Task RunInference(string mediaPath, Func<AnnotationImage, CancellationToken, Task> processAnnotation, CancellationToken ct = default)
{
using var dealer = new DealerSocket();
var clientId = Guid.NewGuid();
dealer.Options.Identity = Encoding.UTF8.GetBytes(clientId.ToString("N"));
dealer.Connect($"tcp://{SecurityConstants.ZMQ_HOST}:{SecurityConstants.ZMQ_PORT}");
var data = MessagePackSerializer.Serialize(aiConfigOptions.Value);
dealer.SendFrame(MessagePackSerializer.Serialize(new RemoteCommand(CommandType.Inference, mediaPath, data)));
while (true)
{
byte[] bytes = [];
try
{
var annotationStream = dealer.Get<AnnotationImage>(out bytes);
if (annotationStream == null)
{
if (bytes.Length == 4 && Encoding.UTF8.GetString(bytes) == "DONE")
break;
continue;
}
await processAnnotation(annotationStream, ct);
}
catch (Exception e)
{
logger.LogError(e, e.Message);
}
}
}
}
-9
View File
@@ -1,9 +0,0 @@
namespace Azaion.CommonSecurity.DTO;
public class ApiConfig
{
public string Url { get; set; } = null!;
public int RetryCount {get;set;}
public double TimeoutSeconds { get; set; }
public string ResourcesFolder { get; set; } = "";
}
@@ -1,24 +0,0 @@
using MessagePack;
namespace Azaion.CommonSecurity.DTO.Commands;
[MessagePackObject]
public class FileCommand
{
[Key("CommandType")]
public CommandType CommandType { get; set; }
[Key("Filename")]
public string Filename { get; set; }
[Key("Data")]
public byte[] Data { get; set; }
}
public enum CommandType
{
None = 0,
Inference = 1,
Load = 2
}
@@ -0,0 +1,24 @@
using MessagePack;
namespace Azaion.CommonSecurity.DTO.Commands;
[MessagePackObject]
public class RemoteCommand(CommandType commandType, string? filename = null, byte[]? data = null)
{
[Key("CommandType")]
public CommandType CommandType { get; set; } = commandType;
[Key("Filename")]
public string? Filename { get; set; } = filename;
[Key("Data")]
public byte[]? Data { get; set; } = data;
}
public enum CommandType
{
None = 0,
Inference = 1,
Load = 2,
GetUser = 3
}
@@ -1,6 +0,0 @@
namespace Azaion.CommonSecurity.DTO;
public class SecureAppConfig
{
public ApiConfig ApiConfig { get; set; } = null!;
}
+5 -15
View File
@@ -1,21 +1,11 @@
using System.Security.Claims;
using MessagePack;
namespace Azaion.CommonSecurity.DTO;
[MessagePackObject]
public class User
{
public Guid Id { get; set; }
public string Email { get; set; }
public RoleEnum Role { get; set; }
public User(IEnumerable<Claim> claims)
{
var claimDict = claims.ToDictionary(x => x.Type, x => x.Value);
Id = Guid.Parse(claimDict[SecurityConstants.CLAIM_NAME_ID]);
Email = claimDict[SecurityConstants.CLAIM_EMAIL];
if (!Enum.TryParse(claimDict[SecurityConstants.CLAIM_ROLE], out RoleEnum role))
role = RoleEnum.None;
Role = role;
}
[Key("i")]public string Id { get; set; }
[Key("e")]public string Email { get; set; }
[Key("r")]public RoleEnum Role { get; set; }
}
+2 -3
View File
@@ -19,9 +19,8 @@ public class SecurityConstants
#endregion ApiConfig
#region SocketClient
public const string SOCKET_HOST = "127.0.0.1";
public const int SOCKET_SEND_PORT = 5127;
public const int SOCKET_RECEIVE_PORT = 5128;
public const string ZMQ_HOST = "127.0.0.1";
public const int ZMQ_PORT = 5127;
#endregion SocketClient
}
@@ -1,127 +0,0 @@
using System.IdentityModel.Tokens.Jwt;
using System.Net;
using System.Net.Http.Headers;
using System.Security;
using System.Text;
using Azaion.CommonSecurity.DTO;
using Newtonsoft.Json;
namespace Azaion.CommonSecurity.Services;
public class AzaionApiClient(HttpClient httpClient) : IDisposable
{
const string JSON_MEDIA = "application/json";
private static ApiConfig _apiConfig = null!;
private string Email { get; set; } = null!;
private SecureString Password { get; set; } = new();
private string JwtToken { get; set; } = null!;
public User User { get; set; } = null!;
public static AzaionApiClient Create(ApiCredentials credentials)
{
try
{
if (!File.Exists(SecurityConstants.CONFIG_PATH))
throw new FileNotFoundException(SecurityConstants.CONFIG_PATH);
var configStr = File.ReadAllText(SecurityConstants.CONFIG_PATH);
_apiConfig = JsonConvert.DeserializeObject<SecureAppConfig>(configStr)!.ApiConfig;
}
catch (Exception e)
{
Console.WriteLine(e);
_apiConfig = new ApiConfig
{
Url = SecurityConstants.DEFAULT_API_URL,
RetryCount = SecurityConstants.DEFAULT_API_RETRY_COUNT ,
TimeoutSeconds = SecurityConstants.DEFAULT_API_TIMEOUT_SECONDS
};
}
var api = new AzaionApiClient(new HttpClient
{
BaseAddress = new Uri(_apiConfig.Url),
Timeout = TimeSpan.FromSeconds(_apiConfig.TimeoutSeconds)
});
api.EnterCredentials(credentials);
return api;
}
public void EnterCredentials(ApiCredentials credentials)
{
if (string.IsNullOrWhiteSpace(credentials.Email) || string.IsNullOrWhiteSpace(credentials.Password))
throw new Exception("Email or password is empty!");
Email = credentials.Email;
Password = credentials.Password.ToSecureString();
}
public async Task<Stream> GetResource(string fileName, string password, HardwareInfo hardware)
{
var response = await Send(httpClient, new HttpRequestMessage(HttpMethod.Post, $"/resources/get/{_apiConfig.ResourcesFolder}")
{
Content = new StringContent(JsonConvert.SerializeObject(new { fileName, password, hardware }), Encoding.UTF8, JSON_MEDIA)
});
return await response.Content.ReadAsStreamAsync();
}
private async Task Authorize()
{
if (string.IsNullOrEmpty(Email) || Password.Length == 0)
throw new Exception("Email or password is empty! Please do EnterCredentials first!");
var payload = new
{
email = Email,
password = Password.ToRealString()
};
var response = await httpClient.PostAsync(
"login",
new StringContent(JsonConvert.SerializeObject(payload), Encoding.UTF8, JSON_MEDIA));
if (!response.IsSuccessStatusCode)
throw new Exception($"EnterCredentials failed: {response.StatusCode}");
var responseData = await response.Content.ReadAsStringAsync();
var result = JsonConvert.DeserializeObject<LoginResponse>(responseData);
if (string.IsNullOrEmpty(result?.Token))
throw new Exception("JWT Token not found in response");
var handler = new JwtSecurityTokenHandler();
var token = handler.ReadJwtToken(result.Token);
User = new User(token.Claims);
JwtToken = result.Token;
}
private async Task<HttpResponseMessage> Send(HttpClient client, HttpRequestMessage request)
{
if (string.IsNullOrEmpty(JwtToken))
await Authorize();
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", JwtToken);
var response = await client.SendAsync(request);
if (response.StatusCode == HttpStatusCode.Unauthorized)
{
await Authorize();
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", JwtToken);
response = await client.SendAsync(request);
}
if (response.IsSuccessStatusCode)
return response;
var result = await response.Content.ReadAsStringAsync();
throw new Exception($"Failed: {response.StatusCode}! Result: {result}");
}
public void Dispose()
{
httpClient.Dispose();
Password.Dispose();
}
}
@@ -0,0 +1,60 @@
using System.Text;
using Azaion.CommonSecurity.DTO;
using Azaion.CommonSecurity.DTO.Commands;
using MessagePack;
using NetMQ;
using NetMQ.Sockets;
namespace Azaion.CommonSecurity.Services;
public interface IResourceLoader
{
Task<MemoryStream> LoadFile(string fileName, CancellationToken ct = default);
}
public interface IAuthProvider
{
User CurrentUser { get; }
}
public class PythonResourceLoader : IResourceLoader, IAuthProvider
{
private readonly DealerSocket _dealer = new();
private readonly Guid _clientId = Guid.NewGuid();
public User CurrentUser { get; }
public PythonResourceLoader(ApiCredentials credentials)
{
//Run python by credentials
_dealer.Options.Identity = Encoding.UTF8.GetBytes(_clientId.ToString("N"));
_dealer.Connect($"tcp://{SecurityConstants.ZMQ_HOST}:{SecurityConstants.ZMQ_PORT}");
_dealer.SendFrame(MessagePackSerializer.Serialize(new RemoteCommand(CommandType.GetUser)));
var user = _dealer.Get<User>(out _);
if (user == null)
throw new Exception("Can't get user from Auth provider");
CurrentUser = user;
}
public async Task<MemoryStream> LoadFile(string fileName, CancellationToken ct = default)
{
try
{
_dealer.SendFrame(MessagePackSerializer.Serialize(new RemoteCommand(CommandType.Load, fileName)));
if (!_dealer.TryReceiveFrameBytes(TimeSpan.FromMilliseconds(1000), out var bytes))
throw new Exception($"Unable to receive {fileName}");
return await Task.FromResult(new MemoryStream(bytes));
}
catch (Exception ex)
{
throw new Exception($"Failed to load fil0e '{fileName}': {ex.Message}", ex);
}
}
}
@@ -1,63 +0,0 @@
using Azaion.CommonSecurity.DTO;
using Azaion.CommonSecurity.DTO.Commands;
using MessagePack;
using NetMQ;
using NetMQ.Sockets;
namespace Azaion.CommonSecurity.Services;
public interface IResourceLoader
{
Task<MemoryStream> Load(string fileName, CancellationToken cancellationToken = default);
}
public class PythonResourceLoader : IResourceLoader
{
private readonly PushSocket _pushSocket = new();
private readonly PullSocket _pullSocket = new();
public PythonResourceLoader(ApiCredentials credentials)
{
//Run python by credentials
_pushSocket.Connect($"tcp://{SecurityConstants.SOCKET_HOST}:{SecurityConstants.SOCKET_SEND_PORT}");
_pullSocket.Connect($"tcp://{SecurityConstants.SOCKET_HOST}:{SecurityConstants.SOCKET_RECEIVE_PORT}");
}
public async Task<MemoryStream> Load(string fileName, CancellationToken cancellationToken = default)
{
try
{
var b = MessagePackSerializer.Serialize(new FileCommand
{
CommandType = CommandType.Load,
Filename = fileName
});
_pushSocket.SendFrame(b);
var bytes = _pullSocket.ReceiveFrameBytes(out bool more);
return new MemoryStream(bytes);
}
catch (Exception ex)
{
throw new Exception($"Failed to load fil0e '{fileName}': {ex.Message}", ex);
}
}
}
public class ResourceLoader(AzaionApiClient api, ApiCredentials credentials) : IResourceLoader
{
public async Task<MemoryStream> Load(string fileName, CancellationToken cancellationToken = default)
{
var hardwareService = new HardwareService();
var hardwareInfo = hardwareService.GetHardware();
var encryptedStream = Task.Run(() => api.GetResource(fileName, credentials.Password, hardwareInfo), cancellationToken).Result;
var key = Security.MakeEncryptionKey(credentials.Email, credentials.Password, hardwareInfo.Hash);
var stream = new MemoryStream();
await encryptedStream.DecryptTo(stream, key, cancellationToken);
stream.Seek(0, SeekOrigin.Begin);
return stream;
}
}
+16
View File
@@ -0,0 +1,16 @@
using MessagePack;
using NetMQ;
using NetMQ.Sockets;
namespace Azaion.CommonSecurity;
public static class ZeroMqExtensions
{
public static T? Get<T>(this DealerSocket dealer, out byte[] message)
{
if (!dealer.TryReceiveFrameBytes(TimeSpan.FromMinutes(2), out var bytes))
throw new Exception($"Unable to get {typeof(T).Name}");
message = bytes;
return MessagePackSerializer.Deserialize<T>(bytes);
}
}
@@ -58,13 +58,12 @@ public class DatasetExplorerEventHandler(
if (datasetExplorer.ThumbnailLoading)
return;
var fName = Path.GetFileNameWithoutExtension(datasetExplorer.CurrentAnnotation!.Annotation.ImagePath);
var extension = Path.GetExtension(fName);
var a = datasetExplorer.CurrentAnnotation!.Annotation;
var detections = datasetExplorer.ExplorerEditor.CurrentDetections
.Select(x => new Detection(fName, x.GetLabel(datasetExplorer.ExplorerEditor.RenderSize)))
.Select(x => new Detection(a.Name, x.GetLabel(datasetExplorer.ExplorerEditor.RenderSize)))
.ToList();
await annotationService.SaveAnnotation(fName, extension, detections, SourceEnum.Manual, token: cancellationToken);
await annotationService.SaveAnnotation(a.OriginalMediaName, a.Time, a.ImageExtension, detections, SourceEnum.Manual, token: cancellationToken);
datasetExplorer.SwitchTab(toEditor: false);
break;
case PlaybackControlEnum.RemoveSelectedAnns:
+6 -14
View File
@@ -1,7 +1,5 @@
using System.IO;
using System.Reflection;
using System.Text;
using System.Text.Json;
using System.Windows;
using System.Windows.Threading;
using Azaion.Annotator;
@@ -13,7 +11,6 @@ using Azaion.Common.Events;
using Azaion.Common.Extensions;
using Azaion.Common.Services;
using Azaion.CommonSecurity;
using Azaion.CommonSecurity.DTO;
using Azaion.CommonSecurity.Services;
using Azaion.Dataset;
using LibVLCSharp.Shared;
@@ -23,7 +20,6 @@ using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Newtonsoft.Json;
using Serilog;
using KeyEventArgs = System.Windows.Input.KeyEventArgs;
@@ -36,8 +32,7 @@ public partial class App
private IMediator _mediator = null!;
private FormState _formState = null!;
private AzaionApiClient _apiClient = null!;
private IResourceLoader _resourceLoader = null!;
private PythonResourceLoader _resourceLoader = null!;
private Stream _securedConfig = null!;
private void OnDispatcherUnhandledException(object sender, DispatcherUnhandledExceptionEventArgs e)
@@ -64,9 +59,8 @@ public partial class App
var login = new Login();
login.CredentialsEntered += async (s, args) =>
{
_apiClient = AzaionApiClient.Create(args);
_resourceLoader = new PythonResourceLoader(args);
_securedConfig = await _resourceLoader.Load("secured-config.json");
_securedConfig = await _resourceLoader.LoadFile("secured-config.json");
AppDomain.CurrentDomain.AssemblyResolve += (_, a) =>
{
@@ -75,7 +69,7 @@ public partial class App
{
try
{
var stream = _resourceLoader.Load($"{assemblyName}.dll").GetAwaiter().GetResult();
var stream = _resourceLoader.LoadFile($"{assemblyName}.dll").GetAwaiter().GetResult();
return Assembly.Load(stream.ToArray());
}
catch (Exception e)
@@ -124,11 +118,11 @@ public partial class App
services.AddSingleton<MainSuite>();
services.AddSingleton<IHardwareService, HardwareService>();
services.AddSingleton(_apiClient);
services.AddSingleton(_resourceLoader);
services.AddSingleton<IResourceLoader>(_resourceLoader);
services.AddSingleton<IAuthProvider>(_resourceLoader);
services.AddSingleton<IInferenceService, PythonInferenceService>();
services.Configure<AppConfig>(context.Configuration);
services.ConfigureSection<ApiConfig>(context.Configuration);
services.ConfigureSection<QueueConfig>(context.Configuration);
services.ConfigureSection<DirectoriesConfig>(context.Configuration);
services.ConfigureSection<AnnotationConfig>(context.Configuration);
@@ -139,7 +133,6 @@ public partial class App
services.AddSingleton<Annotator.Annotator>();
services.AddSingleton<DatasetExplorer>();
services.AddSingleton<HelpWindow>();
services.AddSingleton<IAIDetector, YOLODetector>();
services.AddMediatR(c => c.RegisterServicesFromAssemblies(
typeof(Annotator.Annotator).Assembly,
typeof(DatasetExplorer).Assembly,
@@ -152,7 +145,6 @@ public partial class App
return new MediaPlayer(libVLC);
});
services.AddSingleton<AnnotatorEventHandler>();
services.AddSingleton<VLCFrameExtractor>();
services.AddSingleton<IDbFactory, DbFactory>();
services.AddSingleton<FailsafeAnnotationsProducer>();