add ramdisk, load AI model to ramdisk and start recognition from it

rewrite zmq to DEALER and ROUTER
add GET_USER command to get CurrentUser from Python
all auth is on the python side
inference run and validate annotations on python
This commit is contained in:
Alex Bezdieniezhnykh
2025-01-29 17:45:26 +02:00
parent 82b3b526a7
commit 62623b7123
55 changed files with 945 additions and 895 deletions
+1 -1
View File
@@ -50,7 +50,7 @@ This is crucial for the build because build needs Python.h header and other file
pip install ultralytics pip install ultralytics
pip uninstall -y opencv-python pip uninstall -y opencv-python
pip install opencv-python cython msgpack cryptography rstream pika zmq pip install opencv-python cython msgpack cryptography rstream pika zmq pyjwt
``` ```
In case of fbgemm.dll error (Windows specific): In case of fbgemm.dll error (Windows specific):
+10
View File
@@ -0,0 +1,10 @@
cdef class AIRecognitionConfig:
cdef public double frame_recognition_seconds
cdef public double tracking_distance_confidence
cdef public double tracking_probability_increase
cdef public double tracking_intersection_threshold
cdef public int frame_period_recognition
cdef public bytes file_data
@staticmethod
cdef from_msgpack(bytes data)
+32
View File
@@ -0,0 +1,32 @@
from msgpack import unpackb
cdef class AIRecognitionConfig:
def __init__(self,
frame_recognition_seconds,
tracking_distance_confidence,
tracking_probability_increase,
tracking_intersection_threshold,
frame_period_recognition,
file_data
):
self.frame_recognition_seconds = frame_recognition_seconds
self.tracking_distance_confidence = tracking_distance_confidence
self.tracking_probability_increase = tracking_probability_increase
self.tracking_intersection_threshold = tracking_intersection_threshold
self.frame_period_recognition = frame_period_recognition
self.file_data = file_data
def __str__(self):
return (f'frame_seconds : {self.frame_recognition_seconds}, distance_confidence : {self.tracking_distance_confidence}, '
f'probability_increase : {self.tracking_probability_increase}, intersection_threshold : {self.tracking_intersection_threshold}, frame_period_recognition : {self.frame_period_recognition}')
@staticmethod
cdef from_msgpack(bytes data):
unpacked = unpackb(data, strict_map_key=False)
return AIRecognitionConfig(
unpacked.get("FrameRecognitionSeconds", 0.0),
unpacked.get("TrackingDistanceConfidence", 0.0),
unpacked.get("TrackingProbabilityIncrease", 0.0),
unpacked.get("TrackingIntersectionThreshold", 0.0),
unpacked.get("FramePeriodRecognition", 0),
unpacked.get("Data", b''))
+6 -4
View File
@@ -1,8 +1,10 @@
cdef class Detection: cdef class Detection:
cdef double x, y, w, h cdef public double x, y, w, h, confidence
cdef int cls cdef public int cls
cdef class Annotation: cdef class Annotation:
cdef bytes image cdef bytes image
cdef float time cdef long time
cdef list[Detection] detections cdef public list[Detection] detections
cdef bytes serialize(self)
+25 -3
View File
@@ -1,13 +1,35 @@
import msgpack
cdef class Detection: cdef class Detection:
def __init__(self, double x, double y, double w, double h, int cls): def __init__(self, double x, double y, double w, double h, int cls, double confidence):
self.x = x self.x = x
self.y = y self.y = y
self.w = w self.w = w
self.h = h self.h = h
self.cls = cls self.cls = cls
self.confidence = confidence
def __str__(self):
return f'{self.cls}: {self.x:.2f} {self.y:.2f} {self.w:.2f} {self.h:.2f}, prob: {(self.confidence*100):.1f}%'
cdef class Annotation: cdef class Annotation:
def __init__(self, bytes image_bytes, float time, list[Detection] detections): def __init__(self, bytes image_bytes, long time, list[Detection] detections):
self.image = image_bytes self.image = image_bytes
self.time = time self.time = time
self.detections = detections self.detections = detections if detections is not None else []
cdef bytes serialize(self):
return msgpack.packb({
"i": self.image, # "i" = image
"t": self.time, # "t" = time
"d": [ # "d" = detections
{
"x": det.x,
"y": det.y,
"w": det.w,
"h": det.h,
"c": det.cls,
"p": det.confidence
} for det in self.detections
]
})
+7 -1
View File
@@ -1,8 +1,14 @@
from user cimport User
cdef class ApiClient: cdef class ApiClient:
cdef str email, password, token, folder, token_file, api_url cdef str email, password, token, folder, token_file, api_url
cdef User user
cdef get_encryption_key(self, str hardware_hash) cdef get_encryption_key(self, str hardware_hash)
cdef login(self, str email, str password) cdef login(self)
cdef set_token(self, str token)
cdef get_user(self)
cdef load_bytes(self, str filename) cdef load_bytes(self, str filename)
cdef load_ai_model(self) cdef load_ai_model(self)
cdef load_queue_config(self) cdef load_queue_config(self)
+47 -11
View File
@@ -1,13 +1,14 @@
import io
import json import json
import os import os
from http import HTTPStatus from http import HTTPStatus
from uuid import UUID
import jwt
import requests import requests
cimport constants cimport constants
from hardware_service cimport HardwareService, HardwareInfo from hardware_service cimport HardwareService, HardwareInfo
from security cimport Security from security cimport Security
from io import BytesIO from io import BytesIO
from user cimport User, RoleEnum
cdef class ApiClient: cdef class ApiClient:
"""Handles API authentication and downloading of the AI model.""" """Handles API authentication and downloading of the AI model."""
@@ -15,9 +16,11 @@ cdef class ApiClient:
self.email = email self.email = email
self.password = password self.password = password
self.folder = folder self.folder = folder
self.user = None
if os.path.exists(<str>constants.TOKEN_FILE): if os.path.exists(<str>constants.TOKEN_FILE):
with open(<str>constants.TOKEN_FILE, "r") as file: with open(<str>constants.TOKEN_FILE, "r") as file:
self.token = file.read().strip() self.set_token(<str>file.read().strip())
else: else:
self.token = None self.token = None
@@ -25,21 +28,52 @@ cdef class ApiClient:
cdef str key = f'{self.email}-{self.password}-{hardware_hash}-#%@AzaionKey@%#---' cdef str key = f'{self.email}-{self.password}-{hardware_hash}-#%@AzaionKey@%#---'
return Security.calc_hash(key) return Security.calc_hash(key)
cdef login(self, str email, str password): cdef login(self):
response = requests.post(f"{constants.API_URL}/login", json={"email": email, "password": password}) response = requests.post(f"{constants.API_URL}/login",
json={"email": self.email, "password": self.password})
response.raise_for_status() response.raise_for_status()
self.token = response.json()["token"] token = response.json()["token"]
self.set_token(token)
with open(<str>constants.TOKEN_FILE, 'w') as file: with open(<str>constants.TOKEN_FILE, 'w') as file:
file.write(self.token) file.write(token)
cdef set_token(self, str token):
self.token = token
claims = jwt.decode(token, options={"verify_signature": False})
try:
id = str(UUID(claims.get("nameid", "")))
except ValueError:
raise ValueError("Invalid GUID format in claims")
email = claims.get("unique_name", "")
role_str = claims.get("role", "")
if role_str == "ApiAdmin":
role = RoleEnum.ApiAdmin
elif role_str == "Admin":
role = RoleEnum.Admin
elif role_str == "ResourceUploader":
role = RoleEnum.ResourceUploader
elif role_str == "Validator":
role = RoleEnum.Validator
elif role_str == "Operator":
role = RoleEnum.Operator
else:
role = RoleEnum.NONE
self.user = User(id, email, role)
cdef get_user(self):
if self.user is None:
self.login()
return self.user
cdef load_bytes(self, str filename): cdef load_bytes(self, str filename):
hardware_service = HardwareService() hardware_service = HardwareService()
cdef HardwareInfo hardware = hardware_service.get_hardware_info() cdef HardwareInfo hardware = hardware_service.get_hardware_info()
if self.token is None: if self.token is None:
self.login(self.email, self.password) self.login()
url = f"{constants.API_URL}/resources/get/{self.folder}" url = f"{constants.API_URL}/resources/get/{self.folder}"
headers = { headers = {
@@ -56,7 +90,7 @@ cdef class ApiClient:
response = requests.post(url, data=payload, headers=headers, stream=True) response = requests.post(url, data=payload, headers=headers, stream=True)
if response.status_code == HTTPStatus.UNAUTHORIZED or response.status_code == HTTPStatus.FORBIDDEN: if response.status_code == HTTPStatus.UNAUTHORIZED or response.status_code == HTTPStatus.FORBIDDEN:
self.login(self.email, self.password) self.login()
headers = { headers = {
"Authorization": f"Bearer {self.token}", "Authorization": f"Bearer {self.token}",
"Content-Type": "application/json" "Content-Type": "application/json"
@@ -69,7 +103,9 @@ cdef class ApiClient:
key = self.get_encryption_key(hardware.hash) key = self.get_encryption_key(hardware.hash)
stream = BytesIO(response.raw.read()) stream = BytesIO(response.raw.read())
return Security.decrypt_to(stream, key) data = Security.decrypt_to(stream, key)
print(f'loaded file: {filename}, {len(data)} bytes')
return data
cdef load_ai_model(self): cdef load_ai_model(self):
return self.load_bytes(constants.AI_MODEL_FILE) return self.load_bytes(constants.AI_MODEL_FILE)
+3 -3
View File
@@ -1,6 +1,4 @@
cdef str SOCKET_HOST # Host for the socket server cdef int ZMQ_PORT = 5127 # Port for the zmq
cdef int SOCKET_PORT # Port for the socket server
cdef int SOCKET_BUFFER_SIZE # Buffer size for socket communication
cdef int QUEUE_MAXSIZE # Maximum size of the command queue cdef int QUEUE_MAXSIZE # Maximum size of the command queue
cdef str COMMANDS_QUEUE # Name of the commands queue in rabbit cdef str COMMANDS_QUEUE # Name of the commands queue in rabbit
@@ -10,3 +8,5 @@ cdef str API_URL # Base URL for the external API
cdef str TOKEN_FILE # Name of the token file where temporary token would be stored cdef str TOKEN_FILE # Name of the token file where temporary token would be stored
cdef str QUEUE_CONFIG_FILENAME # queue config filename to load from api cdef str QUEUE_CONFIG_FILENAME # queue config filename to load from api
cdef str AI_MODEL_FILE # AI Model file cdef str AI_MODEL_FILE # AI Model file
cdef bytes DONE_SIGNAL
+3 -3
View File
@@ -1,6 +1,4 @@
cdef str SOCKET_HOST = "127.0.0.1" # Host for the socket server cdef int ZMQ_PORT = 5127 # Port for the zmq
cdef int SOCKET_PORT = 9127 # Port for the socket server
cdef int SOCKET_BUFFER_SIZE = 4096 # Buffer size for socket communication
cdef int QUEUE_MAXSIZE = 1000 # Maximum size of the command queue cdef int QUEUE_MAXSIZE = 1000 # Maximum size of the command queue
cdef str COMMANDS_QUEUE = "azaion-commands" cdef str COMMANDS_QUEUE = "azaion-commands"
@@ -10,3 +8,5 @@ cdef str API_URL = "https://api.azaion.com" # Base URL for the external API
cdef str TOKEN_FILE = "token" cdef str TOKEN_FILE = "token"
cdef str QUEUE_CONFIG_FILENAME = "secured-config.json" cdef str QUEUE_CONFIG_FILENAME = "secured-config.json"
cdef str AI_MODEL_FILE = "azaion.pt" cdef str AI_MODEL_FILE = "azaion.pt"
cdef bytes DONE_SIGNAL = b"DONE"
+3
View File
@@ -10,5 +10,8 @@ def start_server():
except Exception as e: except Exception as e:
processor.stop() processor.stop()
def on_annotation(self, cmd, annotation):
print('on_annotation hit!')
if __name__ == "__main__": if __name__ == "__main__":
start_server() start_server()
+17
View File
@@ -0,0 +1,17 @@
from remote_command cimport RemoteCommand
from annotation cimport Annotation
from ai_config cimport AIRecognitionConfig
cdef class Inference:
cdef object model
cdef object on_annotation
cdef Annotation _previous_annotation
cdef AIRecognitionConfig ai_config
cdef bint is_video(self, str filepath)
cdef run_inference(self, RemoteCommand cmd, int batch_size=?)
cdef _process_video(self, RemoteCommand cmd, int batch_size)
cdef _process_image(self, RemoteCommand cmd)
cdef frame_to_annotation(self, long time, frame, boxes: object)
cdef bint is_valid_annotation(self, Annotation annotation)
+76 -17
View File
@@ -1,30 +1,38 @@
import ai_config
import msgpack
from ultralytics import YOLO from ultralytics import YOLO
import mimetypes import mimetypes
import cv2 import cv2
from ultralytics.engine.results import Boxes from ultralytics.engine.results import Boxes
from remote_command cimport RemoteCommand from remote_command cimport RemoteCommand
from annotation cimport Detection, Annotation from annotation cimport Detection, Annotation
from secure_model cimport SecureModelLoader
from ai_config cimport AIRecognitionConfig
cdef class Inference: cdef class Inference:
def __init__(self, model_bytes, on_annotations): def __init__(self, model_bytes, on_annotation):
self.model = YOLO(model_bytes) loader = SecureModelLoader()
self.on_annotations = on_annotations model_path = loader.load_model(model_bytes)
self.model = YOLO(<str>model_path)
self.on_annotation = on_annotation
cdef bint is_video(self, str filepath): cdef bint is_video(self, str filepath):
mime_type, _ = mimetypes.guess_type(<str>filepath) mime_type, _ = mimetypes.guess_type(<str>filepath)
return mime_type and mime_type.startswith("video") return mime_type and mime_type.startswith("video")
cdef run_inference(self, RemoteCommand cmd, int batch_size=8, int frame_skip=4): cdef run_inference(self, RemoteCommand cmd, int batch_size=8):
print('run inference..')
if self.is_video(cmd.filename): if self.is_video(cmd.filename):
return self._process_video(cmd, batch_size, frame_skip) return self._process_video(cmd, batch_size)
else: else:
return self._process_image(cmd) return self._process_image(cmd)
cdef _process_video(self, RemoteCommand cmd, int batch_size, int frame_skip): cdef _process_video(self, RemoteCommand cmd, int batch_size):
frame_count = 0 frame_count = 0
batch_frame = [] batch_frame = []
annotations = []
v_input = cv2.VideoCapture(<str>cmd.filename) v_input = cv2.VideoCapture(<str>cmd.filename)
self.ai_config = AIRecognitionConfig.from_msgpack(cmd.data)
while v_input.isOpened(): while v_input.isOpened():
ret, frame = v_input.read() ret, frame = v_input.read()
@@ -33,7 +41,7 @@ cdef class Inference:
break break
frame_count += 1 frame_count += 1
if frame_count % frame_skip == 0: if frame_count % self.ai_config.frame_period_recognition == 0:
batch_frame.append((frame, ms)) batch_frame.append((frame, ms))
if len(batch_frame) == batch_size: if len(batch_frame) == batch_size:
@@ -41,10 +49,11 @@ cdef class Inference:
results = self.model.track(frames, persist=True) results = self.model.track(frames, persist=True)
for frame, res in zip(batch_frame, results): for frame, res in zip(batch_frame, results):
annotation = self.process_detections(int(frame[1]), frame[0], res.boxes) annotation = self.frame_to_annotation(int(frame[1]), frame[0], res.boxes)
if len(annotation.detections) > 0:
annotations.append(annotation) if self.is_valid_annotation(<Annotation>annotation):
self.on_annotations(cmd, annotations) self._previous_annotation = annotation
self.on_annotation(cmd, annotation)
batch_frame.clear() batch_frame.clear()
v_input.release() v_input.release()
@@ -52,15 +61,65 @@ cdef class Inference:
cdef _process_image(self, RemoteCommand cmd): cdef _process_image(self, RemoteCommand cmd):
frame = cv2.imread(<str>cmd.filename) frame = cv2.imread(<str>cmd.filename)
res = self.model.track(frame) res = self.model.track(frame)
annotation = self.process_detections(0, frame, res[0].boxes) annotation = self.frame_to_annotation(0, frame, res[0].boxes)
self.on_annotations(cmd, [annotation]) self.on_annotation(cmd, annotation)
cdef process_detections(self, float time, frame, boxes: Boxes): cdef frame_to_annotation(self, long time, frame, boxes: Boxes):
detections = [] detections = []
for box in boxes: for box in boxes:
b = box.xywhn[0].cpu().numpy() b = box.xywhn[0].cpu().numpy()
cls = int(box.cls[0].cpu().numpy().item()) cls = int(box.cls[0].cpu().numpy().item())
detections.append(Detection(<double>b[0], <double>b[1], <double>b[2], <double>b[3], cls)) confidence = box.conf[0].cpu().numpy().item()
_, encoded_image = cv2.imencode('.jpg', frame[0]) det = Detection(<double> b[0], <double> b[1], <double> b[2], <double> b[3], cls, confidence)
detections.append(det)
_, encoded_image = cv2.imencode('.jpg', frame)
image_bytes = encoded_image.tobytes() image_bytes = encoded_image.tobytes()
return Annotation(image_bytes, time, detections) return Annotation(image_bytes, time, detections)
cdef bint is_valid_annotation(self, Annotation annotation):
# No detections, invalid
if not annotation.detections:
return False
# First valid annotation, always accept
if self._previous_annotation is None:
return True
# Enough time has passed since last annotation
if annotation.time >= self._previous_annotation.time + <long>(self.ai_config.frame_recognition_seconds * 1000):
return True
# More objects detected than before
if len(annotation.detections) > len(self._previous_annotation.detections):
return True
cdef:
Detection current_det, prev_det
double dx, dy, distance_sq, min_distance_sq
Detection closest_det
# Check each detection against previous frame
for current_det in annotation.detections:
min_distance_sq = 1e18 # Initialize with large value
closest_det = None
# Find closest detection in previous frame
for prev_det in self._previous_annotation.detections:
dx = current_det.x - prev_det.x
dy = current_det.y - prev_det.y
distance_sq = dx * dx + dy * dy
if distance_sq < min_distance_sq:
min_distance_sq = distance_sq
closest_det = prev_det
# Check if beyond tracking distance
if min_distance_sq > self.ai_config.tracking_distance_confidence:
return True
# Check probability increase
if current_det.confidence >= closest_det.confidence + self.ai_config.tracking_probability_increase:
return True
# No validation criteria met
return False
+16 -13
View File
@@ -1,12 +1,13 @@
import traceback
from queue import Queue from queue import Queue
cimport constants cimport constants
import msgpack
from api_client cimport ApiClient from api_client cimport ApiClient
from annotation cimport Annotation from annotation cimport Annotation
from inference import Inference from inference cimport Inference
from remote_command cimport RemoteCommand, CommandType from remote_command cimport RemoteCommand, CommandType
from remote_command_handler cimport RemoteCommandHandler from remote_command_handler cimport RemoteCommandHandler
from user cimport User
import argparse import argparse
cdef class ParsedArguments: cdef class ParsedArguments:
@@ -36,11 +37,10 @@ cdef class CommandProcessor:
while self.running: while self.running:
try: try:
command = self.command_queue.get() command = self.command_queue.get()
print(f'command is : {command}')
model = self.api_client.load_ai_model() model = self.api_client.load_ai_model()
Inference(model, self.on_annotations).run_inference(command) Inference(model, self.on_annotation).run_inference(command)
except Exception as e: except Exception as e:
print(f"Error processing queue: {e}") traceback.print_exc()
cdef on_command(self, RemoteCommand command): cdef on_command(self, RemoteCommand command):
try: try:
@@ -48,17 +48,20 @@ cdef class CommandProcessor:
self.command_queue.put(command) self.command_queue.put(command)
elif command.command_type == CommandType.LOAD: elif command.command_type == CommandType.LOAD:
response = self.api_client.load_bytes(command.filename) response = self.api_client.load_bytes(command.filename)
print(f'loaded file: {command.filename}, {len(response)} bytes') self.remote_handler.send(command.client_id, response)
self.remote_handler.send(response) elif command.command_type == CommandType.GET_USER:
print(f'{len(response)} bytes was sent.') self.get_user(command, self.api_client.get_user())
else:
pass
except Exception as e: except Exception as e:
print(f"Error handling client: {e}") print(f"Error handling client: {e}")
cdef on_annotations(self, RemoteCommand cmd, annotations: [Annotation]): cdef get_user(self, RemoteCommand command, User user):
data = msgpack.packb(annotations) self.remote_handler.send(command.client_id, user.serialize())
self.remote_handler.send(data)
print(f'{len(data)} bytes was sent.') cdef on_annotation(self, RemoteCommand cmd, Annotation annotation):
data = annotation.serialize()
self.remote_handler.send(cmd.client_id, data)
def stop(self): def stop(self):
self.running = False self.running = False
+2
View File
@@ -1,8 +1,10 @@
cdef enum CommandType: cdef enum CommandType:
INFERENCE = 1 INFERENCE = 1
LOAD = 2 LOAD = 2
GET_USER = 3
cdef class RemoteCommand: cdef class RemoteCommand:
cdef public bytes client_id
cdef CommandType command_type cdef CommandType command_type
cdef str filename cdef str filename
cdef bytes data cdef bytes data
+3 -1
View File
@@ -10,8 +10,10 @@ cdef class RemoteCommand:
command_type_names = { command_type_names = {
1: "INFERENCE", 1: "INFERENCE",
2: "LOAD", 2: "LOAD",
3: "GET_USER"
} }
return f'{command_type_names[self.command_type]}: {self.filename}' data_str = f'. Data: {len(self.data)} bytes' if self.data else ''
return f'{command_type_names[self.command_type]}: {self.filename}{data_str}'
@staticmethod @staticmethod
cdef from_msgpack(bytes data): cdef from_msgpack(bytes data):
+9 -10
View File
@@ -1,16 +1,15 @@
cdef class RemoteCommandHandler: cdef class RemoteCommandHandler:
cdef object _on_command
cdef object _context cdef object _context
cdef object _socket cdef object _router
cdef object _dealer
cdef object _shutdown_event cdef object _shutdown_event
cdef object _pull_socket cdef object _on_command
cdef object _pull_thread
cdef object _push_socket cdef object _proxy_thread
cdef object _push_queue cdef object _workers
cdef object _push_thread
cdef start(self) cdef start(self)
cdef _pull_loop(self) cdef _proxy_loop(self)
cdef _push_loop(self) cdef _worker_loop(self)
cdef send(self, bytes message_bytes) cdef send(self, bytes client_id, bytes data)
cdef close(self) cdef close(self)
+44 -56
View File
@@ -1,9 +1,7 @@
from queue import Queue
import zmq import zmq
import json
from threading import Thread, Event from threading import Thread, Event
from remote_command cimport RemoteCommand from remote_command cimport RemoteCommand
cimport constants
cdef class RemoteCommandHandler: cdef class RemoteCommandHandler:
def __init__(self, object on_command): def __init__(self, object on_command):
@@ -11,68 +9,58 @@ cdef class RemoteCommandHandler:
self._context = zmq.Context.instance() self._context = zmq.Context.instance()
self._shutdown_event = Event() self._shutdown_event = Event()
self._pull_socket = self._context.socket(zmq.PULL) self._router = self._context.socket(zmq.ROUTER)
self._pull_socket.setsockopt(zmq.LINGER, 0) self._router.setsockopt(zmq.LINGER, 0)
self._pull_socket.bind("tcp://*:5127") self._router.bind(f'tcp://*:{constants.ZMQ_PORT}')
self._pull_thread = Thread(target=self._pull_loop, daemon=True)
self._push_queue = Queue() self._dealer = self._context.socket(zmq.DEALER)
self._dealer.setsockopt(zmq.LINGER, 0)
self._dealer.bind("inproc://backend")
self._push_socket = self._context.socket(zmq.PUSH) self._proxy_thread = Thread(target=self._proxy_loop, daemon=True)
self._push_socket.setsockopt(zmq.LINGER, 0)
self._push_socket.bind("tcp://*:5128") self._workers = []
self._push_thread = Thread(target=self._push_loop, daemon=True) for _ in range(4): # 4 worker threads
worker = Thread(target=self._worker_loop, daemon=True)
self._workers.append(worker)
cdef start(self): cdef start(self):
self._pull_thread.start() self._proxy_thread.start()
self._push_thread.start() for worker in self._workers:
worker.start()
cdef _pull_loop(self): cdef _proxy_loop(self):
while not self._shutdown_event.is_set(): zmq.proxy(self._router, self._dealer)
print('wait for the command...')
message = self._pull_socket.recv()
cmd = RemoteCommand.from_msgpack(<bytes>message)
print(f'received: {cmd}')
self._on_command(cmd)
cdef _push_loop(self): cdef _worker_loop(self):
worker_socket = self._context.socket(zmq.DEALER)
worker_socket.setsockopt(zmq.LINGER, 0)
worker_socket.connect("inproc://backend")
poller = zmq.Poller()
poller.register(worker_socket, zmq.POLLIN)
print('started receiver loop...')
while not self._shutdown_event.is_set(): while not self._shutdown_event.is_set():
try: try:
response = self._push_queue.get(timeout=1) # Timeout to check shutdown flag socks = dict(poller.poll(500))
self._push_socket.send(response) if worker_socket in socks:
except: client_id, message = worker_socket.recv_multipart()
continue cmd = RemoteCommand.from_msgpack(<bytes> message)
cmd.client_id = client_id
print(f'Received [{cmd}] from the client {client_id}')
self._on_command(cmd)
except Exception as e:
print(f"Worker error: {e}")
import traceback
traceback.print_exc()
cdef send(self, bytes message_bytes): cdef send(self, bytes client_id, bytes data):
print(f'about to send {len(message_bytes)}') with self._context.socket(zmq.DEALER) as socket:
try: socket.connect("inproc://backend")
self._push_queue.put(message_bytes) socket.send_multipart([client_id, data])
except Exception as e: print(f'{len(data)} bytes was sent to client {client_id}')
print(e)
cdef close(self): cdef close(self):
self._shutdown_event.set() self._shutdown_event.set()
self._pull_socket.close() self._router.close()
self._push_socket.close() self._dealer.close()
self._context.term() self._context.term()
cdef class QueueConfig:
cdef str host,
cdef int port, command_port
cdef str producer_user, producer_pw, consumer_user, consumer_pw
@staticmethod
cdef QueueConfig from_json(str json_string):
s = str(json_string).strip()
cdef dict config_dict = json.loads(s)["QueueConfig"]
cdef QueueConfig config = QueueConfig()
config.host = config_dict["Host"]
config.port = config_dict["Port"]
config.command_port = config_dict["CommandsPort"]
config.producer_user = config_dict["ProducerUsername"]
config.producer_pw = config_dict["ProducerPassword"]
config.consumer_user = config_dict["ConsumerUsername"]
config.consumer_pw = config_dict["ConsumerPassword"]
return config
+12
View File
@@ -0,0 +1,12 @@
cdef class SecureModelLoader:
cdef:
bytes _model_bytes
str _ramdisk_path
str _temp_file_path
int _disk_size_mb
cpdef str load_model(self, bytes model_bytes)
cdef str _get_ramdisk_path(self)
cdef void _create_ramdisk(self)
cdef void _store_model(self)
cdef void _cleanup(self)
+104
View File
@@ -0,0 +1,104 @@
import os
import platform
import tempfile
from pathlib import Path
from libc.stdio cimport FILE, fopen, fclose, remove
from libc.stdlib cimport free
from libc.string cimport strdup
cdef class SecureModelLoader:
def __cinit__(self, int disk_size_mb=512):
self._disk_size_mb = disk_size_mb
self._ramdisk_path = None
self._temp_file_path = None
cpdef str load_model(self, bytes model_bytes):
"""Public method to load YOLO model securely."""
self._model_bytes = model_bytes
self._create_ramdisk()
self._store_model()
return self._temp_file_path
cdef str _get_ramdisk_path(self):
"""Determine the RAM disk path based on the OS."""
if platform.system() == "Windows":
return "R:\\"
elif platform.system() == "Linux":
return "/mnt/ramdisk"
elif platform.system() == "Darwin":
return "/Volumes/RAMDisk"
else:
raise RuntimeError("Unsupported OS for RAM disk")
cdef void _create_ramdisk(self):
"""Create a RAM disk securely based on the OS."""
system = platform.system()
if system == "Windows":
# Create RAM disk via PowerShell
command = f'powershell -Command "subst R: {tempfile.gettempdir()}"'
if os.system(command) != 0:
raise RuntimeError("Failed to create RAM disk on Windows")
self._ramdisk_path = "R:\\"
elif system == "Linux":
# Use tmpfs for RAM disk
self._ramdisk_path = "/mnt/ramdisk"
if not Path(self._ramdisk_path).exists():
os.mkdir(self._ramdisk_path)
if os.system(f"mount -t tmpfs -o size={self._disk_size_mb}M tmpfs {self._ramdisk_path}") != 0:
raise RuntimeError("Failed to create RAM disk on Linux")
elif system == "Darwin":
# Use hdiutil for macOS RAM disk
block_size = 2048 # 512-byte blocks * 2048 = 1MB
num_blocks = self._disk_size_mb * block_size
result = os.popen(f"hdiutil attach -nomount ram://{num_blocks}").read().strip()
if result:
self._ramdisk_path = "/Volumes/RAMDisk"
os.system(f"diskutil eraseVolume HFS+ RAMDisk {result}")
else:
raise RuntimeError("Failed to create RAM disk on macOS")
cdef void _store_model(self):
"""Write model securely to the RAM disk."""
cdef char* temp_path
cdef FILE* cfile
with tempfile.NamedTemporaryFile(
dir=self._ramdisk_path, suffix='.pt', delete=False
) as tmp_file:
tmp_file.write(self._model_bytes)
self._temp_file_path = tmp_file.name
encoded_path = self._temp_file_path.encode('utf-8')
temp_path = strdup(encoded_path)
with nogil:
cfile = fopen(temp_path, "rb")
if cfile == NULL:
raise IOError(f"Could not open {self._temp_file_path}")
fclose(cfile)
cdef void _cleanup(self):
"""Remove the model file and unmount RAM disk securely."""
cdef char* c_path
if self._temp_file_path:
c_path = strdup(os.fsencode(self._temp_file_path))
with nogil:
remove(c_path)
free(c_path)
self._temp_file_path = None
# Unmount RAM disk based on OS
if self._ramdisk_path:
if platform.system() == "Windows":
os.system("subst R: /D")
elif platform.system() == "Linux":
os.system(f"umount {self._ramdisk_path}")
elif platform.system() == "Darwin":
os.system("hdiutil detach /Volumes/RAMDisk")
self._ramdisk_path = None
def __dealloc__(self):
"""Ensure cleanup when the object is deleted."""
self._cleanup()
+10 -1
View File
@@ -8,7 +8,10 @@ extensions = [
Extension('hardware_service', ['hardware_service.pyx'], extra_compile_args=["-g"], extra_link_args=["-g"]), Extension('hardware_service', ['hardware_service.pyx'], extra_compile_args=["-g"], extra_link_args=["-g"]),
Extension('remote_command', ['remote_command.pyx']), Extension('remote_command', ['remote_command.pyx']),
Extension('remote_command_handler', ['remote_command_handler.pyx']), Extension('remote_command_handler', ['remote_command_handler.pyx']),
Extension('user', ['user.pyx']),
Extension('api_client', ['api_client.pyx']), Extension('api_client', ['api_client.pyx']),
Extension('secure_model', ['secure_model.pyx']),
Extension('ai_config', ['ai_config.pyx']),
Extension('inference', ['inference.pyx']), Extension('inference', ['inference.pyx']),
Extension('main', ['main.pyx']), Extension('main', ['main.pyx']),
@@ -21,8 +24,14 @@ setup(
compiler_directives={ compiler_directives={
"language_level": 3, "language_level": 3,
"emit_code_comments" : False, "emit_code_comments" : False,
"binding": True "binding": True,
'boundscheck': False,
'wraparound': False
} }
), ),
install_requires=[
'ultralytics>=8.0.0',
'pywin32; platform_system=="Windows"'
],
zip_safe=False zip_safe=False
) )
+1
View File
@@ -0,0 +1 @@
eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1laWQiOiJkOTBhMzZjYS1lMjM3LTRmYmQtOWM3Yy0xMjcwNDBhYzg1NTYiLCJ1bmlxdWVfbmFtZSI6ImFkbWluQGF6YWlvbi5jb20iLCJyb2xlIjoiQXBpQWRtaW4iLCJuYmYiOjE3MzgxNjM2MzYsImV4cCI6MTczODE3ODAzNiwiaWF0IjoxNzM4MTYzNjM2LCJpc3MiOiJBemFpb25BcGkiLCJhdWQiOiJBbm5vdGF0b3JzL09yYW5nZVBpL0FkbWlucyJ9.7VVws5mwGqx--sGopOuZE9iu3dzt1UdVPXeje2KZTYk
+15
View File
@@ -0,0 +1,15 @@
cdef enum RoleEnum:
NONE = 0
Operator = 10
Validator = 20
CompanionPC = 30
Admin = 40
ResourceUploader = 50
ApiAdmin = 1000
cdef class User:
cdef public str id
cdef public str email
cdef public RoleEnum role
cdef bytes serialize(self)
+15
View File
@@ -0,0 +1,15 @@
import msgpack
cdef class User:
def __init__(self, str id, str email, RoleEnum role):
self.id = id
self.email = email
self.role = role
cdef bytes serialize(self):
return msgpack.packb({
"i": self.id,
"e": self.email,
"r": self.role
})
+1 -1
View File
@@ -480,7 +480,7 @@
Grid.Column="10" Grid.Column="10"
Padding="2" Width="25" Padding="2" Width="25"
Height="25" Height="25"
ToolTip="Розпізнати за допомогою AI. Клавіша: [A]" Background="Black" BorderBrush="Black" ToolTip="Розпізнати за допомогою AI. Клавіша: [R]" Background="Black" BorderBrush="Black"
Click="AutoDetect"> Click="AutoDetect">
<Path Stretch="Fill" Fill="LightGray" Data="M144.317 85.269h223.368c15.381 0 29.391 6.325 39.567 16.494l.025-.024c10.163 10.164 16.477 24.193 16.477 <Path Stretch="Fill" Fill="LightGray" Data="M144.317 85.269h223.368c15.381 0 29.391 6.325 39.567 16.494l.025-.024c10.163 10.164 16.477 24.193 16.477
39.599v189.728c0 15.401-6.326 29.425-16.485 39.584-10.159 10.159-24.183 16.484-39.584 16.484H144.317c-15.4 39.599v189.728c0 15.401-6.326 29.425-16.485 39.584-10.159 10.159-24.183 16.484-39.584 16.484H144.317c-15.4
+148 -154
View File
@@ -6,9 +6,7 @@ using System.Windows.Controls;
using System.Windows.Controls.Primitives; using System.Windows.Controls.Primitives;
using System.Windows.Input; using System.Windows.Input;
using System.Windows.Media; using System.Windows.Media;
using System.Windows.Media.Imaging;
using Azaion.Annotator.DTO; using Azaion.Annotator.DTO;
using Azaion.Annotator.Extensions;
using Azaion.Common.Database; using Azaion.Common.Database;
using Azaion.Common.DTO; using Azaion.Common.DTO;
using Azaion.Common.DTO.Config; using Azaion.Common.DTO.Config;
@@ -39,10 +37,9 @@ public partial class Annotator
private readonly IConfigUpdater _configUpdater; private readonly IConfigUpdater _configUpdater;
private readonly HelpWindow _helpWindow; private readonly HelpWindow _helpWindow;
private readonly ILogger<Annotator> _logger; private readonly ILogger<Annotator> _logger;
private readonly VLCFrameExtractor _vlcFrameExtractor;
private readonly IAIDetector _aiDetector;
private readonly AnnotationService _annotationService; private readonly AnnotationService _annotationService;
private readonly IDbFactory _dbFactory; private readonly IDbFactory _dbFactory;
private readonly IInferenceService _inferenceService;
private readonly CancellationTokenSource _ctSource = new(); private readonly CancellationTokenSource _ctSource = new();
private ObservableCollection<DetectionClass> AnnotationClasses { get; set; } = new(); private ObservableCollection<DetectionClass> AnnotationClasses { get; set; } = new();
@@ -67,10 +64,9 @@ public partial class Annotator
FormState formState, FormState formState,
HelpWindow helpWindow, HelpWindow helpWindow,
ILogger<Annotator> logger, ILogger<Annotator> logger,
VLCFrameExtractor vlcFrameExtractor,
IAIDetector aiDetector,
AnnotationService annotationService, AnnotationService annotationService,
IDbFactory dbFactory) IDbFactory dbFactory,
IInferenceService inferenceService)
{ {
InitializeComponent(); InitializeComponent();
_appConfig = appConfig.Value; _appConfig = appConfig.Value;
@@ -81,10 +77,9 @@ public partial class Annotator
_formState = formState; _formState = formState;
_helpWindow = helpWindow; _helpWindow = helpWindow;
_logger = logger; _logger = logger;
_vlcFrameExtractor = vlcFrameExtractor;
_aiDetector = aiDetector;
_annotationService = annotationService; _annotationService = annotationService;
_dbFactory = dbFactory; _dbFactory = dbFactory;
_inferenceService = inferenceService;
Loaded += OnLoaded; Loaded += OnLoaded;
Closed += OnFormClosed; Closed += OnFormClosed;
@@ -304,11 +299,16 @@ public partial class Annotator
var annotations = await _dbFactory.Run(async db => var annotations = await _dbFactory.Run(async db =>
await db.Annotations.LoadWith(x => x.Detections) await db.Annotations.LoadWith(x => x.Detections)
.Where(x => x.Name.Contains(_formState.VideoName)) .Where(x => x.OriginalMediaName == _formState.VideoName)
.ToListAsync(token: _ctSource.Token)); .ToListAsync(token: _ctSource.Token));
TimedAnnotations.Clear();
_formState.AnnotationResults.Clear();
foreach (var ann in annotations) foreach (var ann in annotations)
AddAnnotation(ann); {
TimedAnnotations.Add(ann.Time.Subtract(_thresholdBefore), ann.Time.Add(_thresholdAfter), ann);
_formState.AnnotationResults.Add(new AnnotationResult(_appConfig.AnnotationConfig.DetectionClassesDict, ann));
}
} }
//Add manually //Add manually
@@ -435,8 +435,6 @@ public partial class Annotator
_appConfig.DirectoriesConfig.VideosDirectory = dlg.FileName; _appConfig.DirectoriesConfig.VideosDirectory = dlg.FileName;
TbFolder.Text = dlg.FileName; TbFolder.Text = dlg.FileName;
await ReloadFiles();
await SaveUserSettings();
} }
private void TbFilter_OnTextChanged(object sender, TextChangedEventArgs e) private void TbFilter_OnTextChanged(object sender, TextChangedEventArgs e)
@@ -487,11 +485,8 @@ public partial class Annotator
if (LvFiles.SelectedIndex == -1) if (LvFiles.SelectedIndex == -1)
LvFiles.SelectedIndex = 0; LvFiles.SelectedIndex = 0;
await _mediator.Publish(new AnnotatorControlEvent(PlaybackControlEnum.Play)); var mct = new CancellationTokenSource();
_mediaPlayer.Stop(); var token = mct.Token;
var manualCancellationSource = new CancellationTokenSource();
var token = manualCancellationSource.Token;
_autoDetectDialog = new AutodetectDialog _autoDetectDialog = new AutodetectDialog
{ {
@@ -500,7 +495,7 @@ public partial class Annotator
}; };
_autoDetectDialog.Closing += (_, _) => _autoDetectDialog.Closing += (_, _) =>
{ {
manualCancellationSource.Cancel(); mct.Cancel();
_mediaPlayer.SeekTo(TimeSpan.Zero); _mediaPlayer.SeekTo(TimeSpan.Zero);
Editor.RemoveAllAnns(); Editor.RemoveAllAnns();
}; };
@@ -515,16 +510,17 @@ public partial class Annotator
var mediaInfo = Dispatcher.Invoke(() => (MediaFileInfo)LvFiles.SelectedItem); var mediaInfo = Dispatcher.Invoke(() => (MediaFileInfo)LvFiles.SelectedItem);
while (mediaInfo != null) while (mediaInfo != null)
{ {
_formState.CurrentMedia = mediaInfo; await Dispatcher.Invoke(async () =>
await Dispatcher.Invoke(async () => await ReloadAnnotations());
if (mediaInfo.MediaType == MediaTypes.Image)
{ {
await DetectImage(mediaInfo, manualCancellationSource, token); await _mediator.Publish(new AnnotatorControlEvent(PlaybackControlEnum.Play), token);
await Task.Delay(70, token); await ReloadAnnotations();
} });
else
await DetectVideo(mediaInfo, manualCancellationSource, token); await _inferenceService.RunInference(mediaInfo.Path, async (annotationImage, ct) =>
{
annotationImage.OriginalMediaName = mediaInfo.FName;
await ProcessDetection(annotationImage, ct);
}, token);
mediaInfo = Dispatcher.Invoke(() => mediaInfo = Dispatcher.Invoke(() =>
{ {
@@ -533,6 +529,7 @@ public partial class Annotator
LvFiles.SelectedIndex += 1; LvFiles.SelectedIndex += 1;
return (MediaFileInfo)LvFiles.SelectedItem; return (MediaFileInfo)LvFiles.SelectedItem;
}); });
LvFiles.Items.Refresh();
} }
Dispatcher.Invoke(() => Dispatcher.Invoke(() =>
{ {
@@ -546,143 +543,140 @@ public partial class Annotator
Dispatcher.Invoke(() => Editor.ResetBackground()); Dispatcher.Invoke(() => Editor.ResetBackground());
} }
private async Task DetectImage(MediaFileInfo mediaInfo, CancellationTokenSource manualCancellationSource, CancellationToken token) // private async Task DetectImage(MediaFileInfo mediaInfo, CancellationTokenSource manualCancellationSource, CancellationToken token)
// {
// try
// {
// var fName = Path.GetFileNameWithoutExtension(mediaInfo.Path);
// var stream = new FileStream(mediaInfo.Path, FileMode.Open);
// var detections = await _aiDetector.Detect(fName, stream, token);
// await ProcessDetection((TimeSpan.FromMilliseconds(0), stream), Path.GetExtension(mediaInfo.Path), detections, token);
// if (detections.Count != 0)
// mediaInfo.HasAnnotations = true;
// }
// catch (Exception e)
// {
// _logger.LogError(e, e.Message);
// await manualCancellationSource.CancelAsync();
// }
// }
// private async Task DetectVideo(MediaFileInfo mediaInfo, CancellationTokenSource manualCancellationSource, CancellationToken token)
// {
// var prevSeekTime = 0.0;
// await foreach (var timeframe in _vlcFrameExtractor.ExtractFrames(mediaInfo.Path, token))
// {
// Console.WriteLine($"Detect time: {timeframe.Time}");
// try
// {
// var fName = _formState.GetTimeName(timeframe.Time);
// var detections = await _aiDetector.Detect(fName, timeframe.Stream, token);
//
// var isValid = IsValidDetection(timeframe.Time, detections);
// Console.WriteLine($"Detection time: {timeframe.Time}");
//
// var log = string.Join(Environment.NewLine, detections.Select(det =>
// $"{_appConfig.AnnotationConfig.DetectionClassesDict[det.ClassNumber].Name}: " +
// $"xy=({det.CenterX:F2},{det.CenterY:F2}), " +
// $"size=({det.Width:F2}, {det.Height:F2}), " +
// $"prob: {det.Probability:F1}%"));
//
// log = $"Detection time: {timeframe.Time}, Valid: {isValid}. {Environment.NewLine} {log}";
// Dispatcher.Invoke(() => _autoDetectDialog.Log(log));
//
// if (timeframe.Time.TotalMilliseconds > prevSeekTime + 250)
// {
// Dispatcher.Invoke(() => SeekTo(timeframe.Time));
// prevSeekTime = timeframe.Time.TotalMilliseconds;
// if (!isValid) //Show frame anyway
// {
// Dispatcher.Invoke(() =>
// {
// Editor.RemoveAllAnns();
// Editor.Background = new ImageBrush
// {
// ImageSource = timeframe.Stream.OpenImage()
// };
// });
// }
// }
//
// if (!isValid)
// continue;
//
// mediaInfo.HasAnnotations = true;
// await ProcessDetection(timeframe, ".jpg", detections, token);
// await timeframe.Stream.DisposeAsync();
// }
// catch (Exception ex)
// {
// _logger.LogError(ex, ex.Message);
// await manualCancellationSource.CancelAsync();
// }
// }
// }
// private bool IsValidDetection(TimeSpan time, List<Detection> detections)
// {
// // No AI detection, forbid
// if (detections.Count == 0)
// return false;
//
// // Very first detection, allow
// if (!_previousDetection.HasValue)
// return true;
//
// var prev = _previousDetection.Value;
//
// // Time between detections is >= than Frame Recognition Seconds, allow
// if (time >= prev.Time.Add(TimeSpan.FromSeconds(_appConfig.AIRecognitionConfig.FrameRecognitionSeconds)))
// return true;
//
// // Detection is earlier than previous + FrameRecognitionSeconds.
// // Look to the detections more in detail
//
// // More detected objects, allow
// if (detections.Count > prev.Detections.Count)
// return true;
//
// foreach (var det in detections)
// {
// var point = new Point(det.CenterX, det.CenterY);
// var closestObject = prev.Detections
// .Select(p => new
// {
// Point = p,
// Distance = point.SqrDistance(new Point(p.CenterX, p.CenterY))
// })
// .OrderBy(x => x.Distance)
// .First();
//
// // Closest object is farther than Tracking distance confidence, hence it's a different object, allow
// if (closestObject.Distance > _appConfig.AIRecognitionConfig.TrackingDistanceConfidence)
// return true;
//
// // Since closest object within distance confidence, then it is tracking of the same object. Then if recognition probability for the object > increase from previous
// if (det.Probability >= closestObject.Point.Probability + _appConfig.AIRecognitionConfig.TrackingProbabilityIncrease)
// return true;
// }
//
// return false;
// }
private async Task ProcessDetection(AnnotationImage annotationImage, CancellationToken token = default)
{ {
try
{
var fName = Path.GetFileNameWithoutExtension(mediaInfo.Path);
var stream = new FileStream(mediaInfo.Path, FileMode.Open);
var detections = await _aiDetector.Detect(fName, stream, token);
await ProcessDetection((TimeSpan.FromMilliseconds(0), stream), Path.GetExtension(mediaInfo.Path), detections, token);
if (detections.Count != 0)
mediaInfo.HasAnnotations = true;
}
catch (Exception e)
{
_logger.LogError(e, e.Message);
await manualCancellationSource.CancelAsync();
}
}
private async Task DetectVideo(MediaFileInfo mediaInfo, CancellationTokenSource manualCancellationSource, CancellationToken token)
{
var prevSeekTime = 0.0;
await foreach (var timeframe in _vlcFrameExtractor.ExtractFrames(mediaInfo.Path, token))
{
Console.WriteLine($"Detect time: {timeframe.Time}");
try
{
var fName = _formState.GetTimeName(timeframe.Time);
var detections = await _aiDetector.Detect(fName, timeframe.Stream, token);
var isValid = IsValidDetection(timeframe.Time, detections);
Console.WriteLine($"Detection time: {timeframe.Time}");
var log = string.Join(Environment.NewLine, detections.Select(det =>
$"{_appConfig.AnnotationConfig.DetectionClassesDict[det.ClassNumber].Name}: " +
$"xy=({det.CenterX:F2},{det.CenterY:F2}), " +
$"size=({det.Width:F2}, {det.Height:F2}), " +
$"prob: {det.Probability:F1}%"));
log = $"Detection time: {timeframe.Time}, Valid: {isValid}. {Environment.NewLine} {log}";
Dispatcher.Invoke(() => _autoDetectDialog.Log(log));
if (timeframe.Time.TotalMilliseconds > prevSeekTime + 250)
{
Dispatcher.Invoke(() => SeekTo(timeframe.Time));
prevSeekTime = timeframe.Time.TotalMilliseconds;
if (!isValid) //Show frame anyway
{
Dispatcher.Invoke(() =>
{
Editor.RemoveAllAnns();
Editor.Background = new ImageBrush
{
ImageSource = timeframe.Stream.OpenImage()
};
});
}
}
if (!isValid)
continue;
mediaInfo.HasAnnotations = true;
await ProcessDetection(timeframe, ".jpg", detections, token);
await timeframe.Stream.DisposeAsync();
}
catch (Exception ex)
{
_logger.LogError(ex, ex.Message);
await manualCancellationSource.CancelAsync();
}
}
}
private bool IsValidDetection(TimeSpan time, List<Detection> detections)
{
// No AI detection, forbid
if (detections.Count == 0)
return false;
// Very first detection, allow
if (!_previousDetection.HasValue)
return true;
var prev = _previousDetection.Value;
// Time between detections is >= than Frame Recognition Seconds, allow
if (time >= prev.Time.Add(TimeSpan.FromSeconds(_appConfig.AIRecognitionConfig.FrameRecognitionSeconds)))
return true;
// Detection is earlier than previous + FrameRecognitionSeconds.
// Look to the detections more in detail
// More detected objects, allow
if (detections.Count > prev.Detections.Count)
return true;
foreach (var det in detections)
{
var point = new Point(det.CenterX, det.CenterY);
var closestObject = prev.Detections
.Select(p => new
{
Point = p,
Distance = point.SqrDistance(new Point(p.CenterX, p.CenterY))
})
.OrderBy(x => x.Distance)
.First();
// Closest object is farther than Tracking distance confidence, hence it's a different object, allow
if (closestObject.Distance > _appConfig.AIRecognitionConfig.TrackingDistanceConfidence)
return true;
// Since closest object within distance confidence, then it is tracking of the same object. Then if recognition probability for the object > increase from previous
if (det.Probability >= closestObject.Point.Probability + _appConfig.AIRecognitionConfig.TrackingProbabilityIncrease)
return true;
}
return false;
}
private async Task ProcessDetection((TimeSpan Time, Stream Stream) timeframe, string imageExtension, List<Detection> detections, CancellationToken token = default)
{
_previousDetection = (timeframe.Time, detections);
await Dispatcher.Invoke(async () => await Dispatcher.Invoke(async () =>
{ {
try try
{ {
var fName = _formState.GetTimeName(timeframe.Time); var annotation = await _annotationService.SaveAnnotation(annotationImage, token);
var annotation = await _annotationService.SaveAnnotation(fName, imageExtension, detections, SourceEnum.AI, timeframe.Stream, token);
Editor.Background = new ImageBrush { ImageSource = await annotation.ImagePath.OpenImage() }; Editor.Background = new ImageBrush { ImageSource = await annotation.ImagePath.OpenImage() };
Editor.RemoveAllAnns(); Editor.RemoveAllAnns();
ShowAnnotations(annotation, true); ShowAnnotations(annotation, true);
AddAnnotation(annotation); AddAnnotation(annotation);
var log = string.Join(Environment.NewLine, detections.Select(det => var log = string.Join(Environment.NewLine, annotation.Detections.Select(det =>
$"{_appConfig.AnnotationConfig.DetectionClassesDict[det.ClassNumber].Name}: " + $"{_appConfig.AnnotationConfig.DetectionClassesDict[det.ClassNumber].Name}: " +
$"xy=({det.CenterX:F2},{det.CenterY:F2}), " + $"xy=({det.CenterX:F2},{det.CenterY:F2}), " +
$"size=({det.Width:F2}, {det.Height:F2}), " + $"size=({det.Width:F2}, {det.Height:F2}), " +
+4 -4
View File
@@ -2,11 +2,11 @@
using System.Windows; using System.Windows;
using System.Windows.Input; using System.Windows.Input;
using Azaion.Annotator.DTO; using Azaion.Annotator.DTO;
using Azaion.Common;
using Azaion.Common.DTO; using Azaion.Common.DTO;
using Azaion.Common.DTO.Config; using Azaion.Common.DTO.Config;
using Azaion.Common.DTO.Queue; using Azaion.Common.DTO.Queue;
using Azaion.Common.Events; using Azaion.Common.Events;
using Azaion.Common.Extensions;
using Azaion.Common.Services; using Azaion.Common.Services;
using LibVLCSharp.Shared; using LibVLCSharp.Shared;
using MediatR; using MediatR;
@@ -79,7 +79,7 @@ public class AnnotatorEventHandler(
if (_keysControlEnumDict.TryGetValue(key, out var value)) if (_keysControlEnumDict.TryGetValue(key, out var value))
await ControlPlayback(value, cancellationToken); await ControlPlayback(value, cancellationToken);
if (key == Key.A) if (key == Key.R)
mainWindow.AutoDetect(null!, null!); mainWindow.AutoDetect(null!, null!);
#region Volume #region Volume
@@ -228,7 +228,7 @@ public class AnnotatorEventHandler(
return; return;
var time = formState.BackgroundTime ?? TimeSpan.FromMilliseconds(mediaPlayer.Time); var time = formState.BackgroundTime ?? TimeSpan.FromMilliseconds(mediaPlayer.Time);
var fName = formState.GetTimeName(time); var fName = formState.VideoName.ToTimeName(time);
var currentDetections = mainWindow.Editor.CurrentDetections var currentDetections = mainWindow.Editor.CurrentDetections
.Select(x => new Detection(fName, x.GetLabel(mainWindow.Editor.RenderSize, formState.BackgroundTime.HasValue ? mainWindow.Editor.RenderSize : formState.CurrentVideoSize))) .Select(x => new Detection(fName, x.GetLabel(mainWindow.Editor.RenderSize, formState.BackgroundTime.HasValue ? mainWindow.Editor.RenderSize : formState.CurrentVideoSize)))
@@ -267,7 +267,7 @@ public class AnnotatorEventHandler(
File.Copy(formState.CurrentMedia.Path, imgPath, overwrite: true); File.Copy(formState.CurrentMedia.Path, imgPath, overwrite: true);
NextMedia(); NextMedia();
} }
var annotation = await annotationService.SaveAnnotation(fName, imageExtension, currentDetections, SourceEnum.Manual, token: cancellationToken); var annotation = await annotationService.SaveAnnotation(formState.VideoName, time, imageExtension, currentDetections, SourceEnum.Manual, token: cancellationToken);
mainWindow.AddAnnotation(annotation); mainWindow.AddAnnotation(annotation);
} }
@@ -1,9 +0,0 @@
using System.Windows;
namespace Azaion.Annotator.Extensions;
public static class PointExtensions
{
public static double SqrDistance(this Point p1, Point p2) =>
(p2.X - p1.X) * (p2.X - p1.X) + (p2.Y - p1.Y) * (p2.Y - p1.Y);
}
@@ -1,130 +0,0 @@
using System.Collections.Concurrent;
using System.IO;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using Azaion.Common.DTO.Config;
using LibVLCSharp.Shared;
using Microsoft.Extensions.Options;
using SkiaSharp;
namespace Azaion.Annotator.Extensions;
public class VLCFrameExtractor(LibVLC libVLC, IOptions<AIRecognitionConfig> config)
{
private const uint RGBA_BYTES = 4;
private const int PLAYBACK_RATE = 4;
private uint _pitch; // Number of bytes per "line", aligned to x32.
private uint _lines; // Number of lines in the buffer, aligned to x32.
private uint _width; // Thumbnail width
private uint _height; // Thumbnail height
private MediaPlayer _mediaPlayer = null!;
private TimeSpan _lastFrameTimestamp;
private long _lastFrame;
private static uint Align32(uint size)
{
if (size % 32 == 0)
return size;
return (size / 32 + 1) * 32;// Align on the next multiple of 32
}
private static SKBitmap? _currentBitmap;
private static readonly ConcurrentQueue<FrameInfo> FramesQueue = new();
private static long _frameCounter;
public async IAsyncEnumerable<(TimeSpan Time, Stream Stream)> ExtractFrames(string mediaPath,
[EnumeratorCancellation] CancellationToken manualCancellationToken = default)
{
var videoFinishedCancellationSource = new CancellationTokenSource();
_mediaPlayer = new MediaPlayer(libVLC);
_mediaPlayer.Stopped += (s, e) => videoFinishedCancellationSource.CancelAfter(1);
using var media = new Media(libVLC, mediaPath);
await media.Parse(cancellationToken: videoFinishedCancellationSource.Token);
var videoTrack = media.Tracks.FirstOrDefault(x => x.Data.Video.Width != 0);
_width = videoTrack.Data.Video.Width;
_height = videoTrack.Data.Video.Height;
_pitch = Align32(_width * RGBA_BYTES);
_lines = Align32(_height);
_mediaPlayer.SetRate(PLAYBACK_RATE);
media.AddOption(":no-audio");
_mediaPlayer.SetVideoFormat("RV32", _width, _height, _pitch);
_mediaPlayer.SetVideoCallbacks(Lock, null, Display);
_mediaPlayer.Play(media);
_frameCounter = 0;
var surface = SKSurface.Create(new SKImageInfo((int) _width, (int) _height));
var videoFinishedCT = videoFinishedCancellationSource.Token;
while ( !(FramesQueue.IsEmpty && videoFinishedCT.IsCancellationRequested || manualCancellationToken.IsCancellationRequested))
{
if (FramesQueue.TryDequeue(out var frameInfo))
{
if (frameInfo.Bitmap == null)
continue;
surface.Canvas.DrawBitmap(frameInfo.Bitmap, 0, 0); // Effectively crops the original bitmap to get only the visible area
using var outputImage = surface.Snapshot();
using var data = outputImage.Encode(SKEncodedImageFormat.Jpeg, 85);
var ms = new MemoryStream();
data.SaveTo(ms);
yield return (frameInfo.Time, ms);
frameInfo.Bitmap?.Dispose();
}
else
{
await Task.Delay(TimeSpan.FromSeconds(1), videoFinishedCT);
}
}
FramesQueue.Clear(); //clear queue in case of manual stop
_mediaPlayer.Stop();
_mediaPlayer.Dispose();
}
private IntPtr Lock(IntPtr opaque, IntPtr planes)
{
_currentBitmap = new SKBitmap(new SKImageInfo((int)(_pitch / RGBA_BYTES), (int)_lines, SKColorType.Bgra8888));
Marshal.WriteIntPtr(planes, _currentBitmap.GetPixels());
return IntPtr.Zero;
}
private void Display(IntPtr opaque, IntPtr picture)
{
var playerTime = TimeSpan.FromMilliseconds(_mediaPlayer.Time);
if (_lastFrameTimestamp != playerTime)
{
_lastFrame = _frameCounter;
_lastFrameTimestamp = playerTime;
}
if (_frameCounter > 20 && _frameCounter % config.Value.FramePeriodRecognition == 0)
{
var msToAdd = (_frameCounter - _lastFrame) * (_lastFrame == 0 ? 0 : _lastFrameTimestamp.TotalMilliseconds / _lastFrame);
var time = _lastFrameTimestamp.Add(TimeSpan.FromMilliseconds(msToAdd));
FramesQueue.Enqueue(new FrameInfo(time, _currentBitmap));
}
else
{
_currentBitmap?.Dispose();
}
_currentBitmap = null;
_frameCounter++;
}
}
public class FrameInfo(TimeSpan time, SKBitmap? bitmap)
{
public TimeSpan Time { get; set; } = time;
public SKBitmap? Bitmap { get; set; } = bitmap;
}
-87
View File
@@ -1,87 +0,0 @@
using System.Diagnostics;
using System.IO;
using Azaion.Annotator.Extensions;
using Azaion.Common.DTO;
using Azaion.Common.DTO.Config;
using Azaion.CommonSecurity.Services;
using Compunet.YoloV8;
using Microsoft.Extensions.Options;
using SixLabors.ImageSharp;
using SixLabors.ImageSharp.PixelFormats;
using Detection = Azaion.Common.DTO.Detection;
namespace Azaion.Annotator;
public interface IAIDetector
{
Task<List<Detection>> Detect(string fName, Stream imageStream, CancellationToken cancellationToken = default);
}
public class YOLODetector(IOptions<AIRecognitionConfig> recognitionConfig, IResourceLoader resourceLoader) : IAIDetector, IDisposable
{
private readonly AIRecognitionConfig _recognitionConfig = recognitionConfig.Value;
private YoloPredictor? _predictor;
private const string YOLO_MODEL = "azaion.onnx";
public async Task<List<Detection>> Detect(string fName, Stream imageStream, CancellationToken cancellationToken)
{
if (_predictor == null)
{
await using var stream = await resourceLoader.Load(YOLO_MODEL, cancellationToken);
_predictor = new YoloPredictor(stream.ToArray());
}
imageStream.Seek(0, SeekOrigin.Begin);
using var image = Image.Load<Rgb24>(imageStream);
var result = await _predictor.DetectAsync(image);
var imageSize = new System.Windows.Size(image.Width, image.Height);
var detections = result.Select(d =>
{
var label = new YoloLabel(new CanvasLabel(d.Name.Id, d.Bounds.X, d.Bounds.Y, d.Bounds.Width, d.Bounds.Height), imageSize, imageSize);
return new Detection(fName, label, (double?)d.Confidence * 100);
}).ToList();
return FilterOverlapping(detections);
}
private List<Detection> FilterOverlapping(List<Detection> detections)
{
var k = _recognitionConfig.TrackingIntersectionThreshold;
var filteredDetections = new List<Detection>();
for (var i = 0; i < detections.Count; i++)
{
var detectionSelected = false;
for (var j = i + 1; j < detections.Count; j++)
{
var intersect = detections[i].ToRectangle();
intersect.Intersect(detections[j].ToRectangle());
var maxArea = Math.Max(detections[i].ToRectangle().Area(), detections[j].ToRectangle().Area());
if (!(intersect.Area() > k * maxArea))
continue;
if (detections[i].Probability > detections[j].Probability)
{
filteredDetections.Add(detections[i]);
detections.RemoveAt(j);
}
else
{
filteredDetections.Add(detections[j]);
detections.RemoveAt(i);
}
detectionSelected = true;
break;
}
if (!detectionSelected)
filteredDetections.Add(detections[i]);
}
return filteredDetections;
}
public void Dispose() => _predictor?.Dispose();
}
@@ -1,10 +1,15 @@
using MessagePack;
namespace Azaion.Common.DTO.Config; namespace Azaion.Common.DTO.Config;
[MessagePackObject]
public class AIRecognitionConfig public class AIRecognitionConfig
{ {
public double FrameRecognitionSeconds { get; set; } [Key("FrameRecognitionSeconds")] public double FrameRecognitionSeconds { get; set; }
public double TrackingDistanceConfidence { get; set; }
public double TrackingProbabilityIncrease { get; set; } [Key("TrackingDistanceConfidence")] public double TrackingDistanceConfidence { get; set; }
public double TrackingIntersectionThreshold { get; set; } [Key("TrackingProbabilityIncrease")] public double TrackingProbabilityIncrease { get; set; }
public int FramePeriodRecognition { get; set; } [Key("TrackingIntersectionThreshold")] public double TrackingIntersectionThreshold { get; set; }
[Key("FramePeriodRecognition")] public int FramePeriodRecognition { get; set; }
[Key("Data")] public byte[] Data { get; set; }
} }
-9
View File
@@ -8,8 +8,6 @@ namespace Azaion.Common.DTO.Config;
public class AppConfig public class AppConfig
{ {
public ApiConfig ApiConfig { get; set; } = null!;
public QueueConfig QueueConfig { get; set; } = null!; public QueueConfig QueueConfig { get; set; } = null!;
public DirectoriesConfig DirectoriesConfig { get; set; } = null!; public DirectoriesConfig DirectoriesConfig { get; set; } = null!;
@@ -39,13 +37,6 @@ public class ConfigUpdater : IConfigUpdater
var appConfig = new AppConfig var appConfig = new AppConfig
{ {
ApiConfig = new ApiConfig
{
Url = SecurityConstants.DEFAULT_API_URL,
RetryCount = SecurityConstants.DEFAULT_API_RETRY_COUNT,
TimeoutSeconds = SecurityConstants.DEFAULT_API_TIMEOUT_SECONDS
},
AnnotationConfig = new AnnotationConfig AnnotationConfig = new AnnotationConfig
{ {
AnnotationClasses = Constants.DefaultAnnotationClasses, AnnotationClasses = Constants.DefaultAnnotationClasses,
-2
View File
@@ -16,6 +16,4 @@ public class FormState
public int CurrentVolume { get; set; } = 100; public int CurrentVolume { get; set; } = 100;
public ObservableCollection<AnnotationResult> AnnotationResults { get; set; } = []; public ObservableCollection<AnnotationResult> AnnotationResults { get; set; } = [];
public WindowEnum ActiveWindow { get; set; } public WindowEnum ActiveWindow { get; set; }
public string GetTimeName(TimeSpan? ts) => $"{VideoName}_{ts:hmmssf}";
} }
+13 -11
View File
@@ -1,18 +1,18 @@
using System.Drawing; using System.Drawing;
using System.Globalization; using System.Globalization;
using System.IO; using System.IO;
using MessagePack;
using Newtonsoft.Json; using Newtonsoft.Json;
using Size = System.Windows.Size; using Size = System.Windows.Size;
namespace Azaion.Common.DTO; namespace Azaion.Common.DTO;
[MessagePackObject]
public abstract class Label public abstract class Label
{ {
[JsonProperty(PropertyName = "cl")] public int ClassNumber { get; set; } [JsonProperty(PropertyName = "cl")][Key("c")] public int ClassNumber { get; set; }
protected Label() protected Label() { }
{
}
protected Label(int classNumber) protected Label(int classNumber)
{ {
@@ -79,15 +79,16 @@ public class CanvasLabel : Label
} }
} }
[MessagePackObject]
public class YoloLabel : Label public class YoloLabel : Label
{ {
[JsonProperty(PropertyName = "x")] public double CenterX { get; set; } [JsonProperty(PropertyName = "x")][Key("x")] public double CenterX { get; set; }
[JsonProperty(PropertyName = "y")] public double CenterY { get; set; } [JsonProperty(PropertyName = "y")][Key("y")] public double CenterY { get; set; }
[JsonProperty(PropertyName = "w")] public double Width { get; set; } [JsonProperty(PropertyName = "w")][Key("w")] public double Width { get; set; }
[JsonProperty(PropertyName = "h")] public double Height { get; set; } [JsonProperty(PropertyName = "h")][Key("h")] public double Height { get; set; }
public YoloLabel() public YoloLabel()
{ {
@@ -184,12 +185,13 @@ public class YoloLabel : Label
public override string ToString() => $"{ClassNumber} {CenterX:F5} {CenterY:F5} {Width:F5} {Height:F5}".Replace(',', '.'); public override string ToString() => $"{ClassNumber} {CenterX:F5} {CenterY:F5} {Width:F5} {Height:F5}".Replace(',', '.');
} }
[MessagePackObject]
public class Detection : YoloLabel public class Detection : YoloLabel
{ {
public string AnnotationName { get; set; } = null!; [IgnoreMember]public string AnnotationName { get; set; } = null!;
public double? Probability { get; set; } [Key("p")] public double? Probability { get; set; }
//For db //For db & serialization
public Detection(){} public Detection(){}
public Detection(string annotationName, YoloLabel label, double? probability = null) public Detection(string annotationName, YoloLabel label, double? probability = null)
+4 -2
View File
@@ -1,4 +1,6 @@
namespace Azaion.Common.DTO; using Azaion.Common.Extensions;
namespace Azaion.Common.DTO;
public class MediaFileInfo public class MediaFileInfo
{ {
@@ -9,5 +11,5 @@ public class MediaFileInfo
public bool HasAnnotations { get; set; } public bool HasAnnotations { get; set; }
public MediaTypes MediaType { get; set; } public MediaTypes MediaType { get; set; }
public string FName => System.IO.Path.GetFileNameWithoutExtension(Name).Replace(" ", ""); public string FName => Name.ToFName();
} }
@@ -7,15 +7,17 @@ using MessagePack;
[MessagePackObject] [MessagePackObject]
public class AnnotationCreatedMessage public class AnnotationCreatedMessage
{ {
[Key(0)] public DateTime CreatedDate { get; set; } [Key(0)] public DateTime CreatedDate { get; set; }
[Key(1)] public string Name { get; set; } = null!; [Key(1)] public string Name { get; set; } = null!;
[Key(2)] public string ImageExtension { get; set; } = null!; [Key(2)] public string OriginalMediaName { get; set; } = null!;
[Key(3)] public string Detections { get; set; } = null!; [Key(3)] public TimeSpan Time { get; set; }
[Key(4)] public byte[] Image { get; set; } = null!; [Key(4)] public string ImageExtension { get; set; } = null!;
[Key(5)] public RoleEnum CreatedRole { get; set; } [Key(5)] public string Detections { get; set; } = null!;
[Key(6)] public string CreatedEmail { get; set; } = null!; [Key(6)] public byte[] Image { get; set; } = null!;
[Key(7)] public SourceEnum Source { get; set; } [Key(7)] public RoleEnum CreatedRole { get; set; }
[Key(8)] public AnnotationStatus Status { get; set; } [Key(8)] public string CreatedEmail { get; set; } = null!;
[Key(9)] public SourceEnum Source { get; set; }
[Key(10)] public AnnotationStatus Status { get; set; }
} }
[MessagePackObject] [MessagePackObject]
+24 -39
View File
@@ -3,9 +3,11 @@ using Azaion.Common.DTO;
using Azaion.Common.DTO.Config; using Azaion.Common.DTO.Config;
using Azaion.Common.DTO.Queue; using Azaion.Common.DTO.Queue;
using Azaion.CommonSecurity.DTO; using Azaion.CommonSecurity.DTO;
using MessagePack;
namespace Azaion.Common.Database; namespace Azaion.Common.Database;
[MessagePackObject]
public class Annotation public class Annotation
{ {
private static string _labelsDir = null!; private static string _labelsDir = null!;
@@ -19,53 +21,36 @@ public class Annotation
_thumbDir = config.ThumbnailsDirectory; _thumbDir = config.ThumbnailsDirectory;
} }
public string Name { get; set; } = null!; [IgnoreMember]public string Name { get; set; } = null!;
public string ImageExtension { get; set; } = null!; [IgnoreMember]public string OriginalMediaName { get; set; } = null!;
public DateTime CreatedDate { get; set; } [IgnoreMember]public TimeSpan Time { get; set; }
public string CreatedEmail { get; set; } = null!; [IgnoreMember]public string ImageExtension { get; set; } = null!;
public RoleEnum CreatedRole { get; set; } [IgnoreMember]public DateTime CreatedDate { get; set; }
public SourceEnum Source { get; set; } [IgnoreMember]public string CreatedEmail { get; set; } = null!;
public AnnotationStatus AnnotationStatus { get; set; } [IgnoreMember]public RoleEnum CreatedRole { get; set; }
[IgnoreMember]public SourceEnum Source { get; set; }
[IgnoreMember]public AnnotationStatus AnnotationStatus { get; set; }
public IEnumerable<Detection> Detections { get; set; } = null!; [Key("d")] public IEnumerable<Detection> Detections { get; set; } = null!;
[Key("t")] public long Milliseconds { get; set; }
public double Lat { get; set; } [Key("lat")]public double Lat { get; set; }
public double Lon { get; set; } [Key("lon")]public double Lon { get; set; }
#region Calculated #region Calculated
public List<int> Classes => Detections.Select(x => x.ClassNumber).ToList(); [IgnoreMember]public List<int> Classes => Detections.Select(x => x.ClassNumber).ToList();
public string ImagePath => Path.Combine(_imagesDir, $"{Name}{ImageExtension}"); [IgnoreMember]public string ImagePath => Path.Combine(_imagesDir, $"{Name}{ImageExtension}");
public string LabelPath => Path.Combine(_labelsDir, $"{Name}.txt"); [IgnoreMember]public string LabelPath => Path.Combine(_labelsDir, $"{Name}.txt");
public string ThumbPath => Path.Combine(_thumbDir, $"{Name}{Constants.THUMBNAIL_PREFIX}.jpg"); [IgnoreMember]public string ThumbPath => Path.Combine(_thumbDir, $"{Name}{Constants.THUMBNAIL_PREFIX}.jpg");
public string OriginalMediaName => $"{Name[..^7]}";
private TimeSpan? _time;
public TimeSpan Time
{
get
{
if (_time.HasValue)
return _time.Value;
var timeStr = Name.Split("_").LastOrDefault();
//For some reason, TimeSpan.ParseExact doesn't work on every platform.
if (!string.IsNullOrEmpty(timeStr) &&
timeStr.Length == 6 &&
int.TryParse(timeStr[..1], out var hours) &&
int.TryParse(timeStr[1..3], out var minutes) &&
int.TryParse(timeStr[3..5], out var seconds) &&
int.TryParse(timeStr[5..], out var milliseconds))
return new TimeSpan(0, hours, minutes, seconds, milliseconds * 100);
_time = TimeSpan.FromSeconds(0);
return _time.Value;
}
}
#endregion Calculated #endregion Calculated
} }
[MessagePackObject]
public class AnnotationImage : Annotation
{
[Key("i")] public byte[] Image { get; set; }
}
public enum AnnotationStatus public enum AnnotationStatus
{ {
+9 -6
View File
@@ -117,16 +117,19 @@ public static class AnnotationsDbSchemaHolder
MappingSchema = new MappingSchema(); MappingSchema = new MappingSchema();
var builder = new FluentMappingBuilder(MappingSchema); var builder = new FluentMappingBuilder(MappingSchema);
builder.Entity<Annotation>() var annotationBuilder = builder.Entity<Annotation>();
.HasTableName(Constants.ANNOTATIONS_TABLENAME) annotationBuilder.HasTableName(Constants.ANNOTATIONS_TABLENAME)
.HasPrimaryKey(x => x.Name) .HasPrimaryKey(x => x.Name)
.Ignore(x => x.Time) .Association(a => a.Detections, (a, d) => a.Name == d.AnnotationName)
.Property(x => x.Time).HasDataType(DataType.Int64).HasConversion(ts => ts.Ticks, t => new TimeSpan(t));
annotationBuilder
.Ignore(x => x.Milliseconds)
.Ignore(x => x.Classes)
.Ignore(x => x.Classes) .Ignore(x => x.Classes)
.Ignore(x => x.ImagePath) .Ignore(x => x.ImagePath)
.Ignore(x => x.LabelPath) .Ignore(x => x.LabelPath)
.Ignore(x => x.ThumbPath) .Ignore(x => x.ThumbPath);
.Ignore(x => x.OriginalMediaName)
.Association(a => a.Detections, (a, d) => a.Name == d.AnnotationName);
builder.Entity<Detection>() builder.Entity<Detection>()
.HasTableName(Constants.DETECTIONS_TABLENAME); .HasTableName(Constants.DETECTIONS_TABLENAME);
+1 -1
View File
@@ -24,7 +24,7 @@ public class ParallelExt
return Task.CompletedTask; return Task.CompletedTask;
} }
}; };
var threadsCount = (int)(Environment.ProcessorCount * parallelOptions.CpuUtilPercent / 100.0); var threadsCount = (int)Math.Round(Environment.ProcessorCount * parallelOptions.CpuUtilPercent / 100.0);
var processedCount = 0; var processedCount = 0;
var chunkSize = Math.Max(1, (int)(source.Count / (decimal)threadsCount)); var chunkSize = Math.Max(1, (int)(source.Count / (decimal)threadsCount));
@@ -0,0 +1,12 @@
using System.IO;
namespace Azaion.Common.Extensions;
public static class StringExtensions
{
public static string ToFName(this string path) =>
Path.GetFileNameWithoutExtension(path).Replace(" ", "");
public static string ToTimeName(this string fName, TimeSpan? ts) =>
$"{fName}_{ts:hmmssf}";
}
+29 -21
View File
@@ -22,29 +22,31 @@ namespace Azaion.Common.Services;
public class AnnotationService : INotificationHandler<AnnotationsDeletedEvent> public class AnnotationService : INotificationHandler<AnnotationsDeletedEvent>
{ {
private readonly AzaionApiClient _apiClient;
private readonly IDbFactory _dbFactory; private readonly IDbFactory _dbFactory;
private readonly FailsafeAnnotationsProducer _producer; private readonly FailsafeAnnotationsProducer _producer;
private readonly IGalleryService _galleryService; private readonly IGalleryService _galleryService;
private readonly IMediator _mediator; private readonly IMediator _mediator;
private readonly IHardwareService _hardwareService; private readonly IHardwareService _hardwareService;
private readonly IAuthProvider _authProvider;
private readonly QueueConfig _queueConfig; private readonly QueueConfig _queueConfig;
private Consumer _consumer = null!; private Consumer _consumer = null!;
public AnnotationService(AzaionApiClient apiClient, public AnnotationService(
IResourceLoader resourceLoader,
IDbFactory dbFactory, IDbFactory dbFactory,
FailsafeAnnotationsProducer producer, FailsafeAnnotationsProducer producer,
IOptions<QueueConfig> queueConfig, IOptions<QueueConfig> queueConfig,
IGalleryService galleryService, IGalleryService galleryService,
IMediator mediator, IMediator mediator,
IHardwareService hardwareService) IHardwareService hardwareService,
IAuthProvider authProvider)
{ {
_apiClient = apiClient;
_dbFactory = dbFactory; _dbFactory = dbFactory;
_producer = producer; _producer = producer;
_galleryService = galleryService; _galleryService = galleryService;
_mediator = mediator; _mediator = mediator;
_hardwareService = hardwareService; _hardwareService = hardwareService;
_authProvider = authProvider;
_queueConfig = queueConfig.Value; _queueConfig = queueConfig.Value;
Task.Run(async () => await Init()).Wait(); Task.Run(async () => await Init()).Wait();
@@ -73,7 +75,8 @@ public class AnnotationService : INotificationHandler<AnnotationsDeletedEvent>
await SaveAnnotationInner( await SaveAnnotationInner(
msg.CreatedDate, msg.CreatedDate,
msg.Name, msg.OriginalMediaName,
msg.Time,
msg.ImageExtension, msg.ImageExtension,
JsonConvert.DeserializeObject<List<Detection>>(msg.Detections) ?? [], JsonConvert.DeserializeObject<List<Detection>>(msg.Detections) ?? [],
msg.Source, msg.Source,
@@ -98,36 +101,39 @@ public class AnnotationService : INotificationHandler<AnnotationsDeletedEvent>
}); });
} }
//AI / Manual //AI
public async Task<Annotation> SaveAnnotation(string fName, string imageExtension, List<Detection> detections, SourceEnum source, Stream? stream = null, CancellationToken token = default) => public async Task<Annotation> SaveAnnotation(AnnotationImage a, CancellationToken cancellationToken = default)
await SaveAnnotationInner(DateTime.UtcNow, fName, imageExtension, detections, source, stream, _apiClient.User.Role, _apiClient.User.Email, generateThumbnail: true, token); {
a.Time = TimeSpan.FromMilliseconds(a.Milliseconds);
a.Name = a.OriginalMediaName.ToTimeName(a.Time);
return await SaveAnnotationInner(DateTime.Now, a.OriginalMediaName, a.Time, ".jpg", a.Detections.ToList(),
a.Source, new MemoryStream(a.Image), a.CreatedRole, a.CreatedEmail, generateThumbnail: true, cancellationToken);
}
//Manual //Manual
public async Task<Annotation> SaveAnnotation(string originalMediaName, TimeSpan time, string imageExtension, List<Detection> detections, SourceEnum source, Stream? stream = null, CancellationToken token = default) =>
await SaveAnnotationInner(DateTime.UtcNow, originalMediaName, time, imageExtension, detections, source, stream,
_authProvider.CurrentUser.Role, _authProvider.CurrentUser.Email, generateThumbnail: true, token);
//Manual Validate existing
public async Task ValidateAnnotation(Annotation annotation, CancellationToken token = default) => public async Task ValidateAnnotation(Annotation annotation, CancellationToken token = default) =>
await SaveAnnotationInner(DateTime.UtcNow, annotation.Name, annotation.ImageExtension, annotation.Detections.ToList(), SourceEnum.Manual, null, _apiClient.User.Role, _apiClient.User.Email, await SaveAnnotationInner(DateTime.UtcNow, annotation.OriginalMediaName, annotation.Time, annotation.ImageExtension, annotation.Detections.ToList(), SourceEnum.Manual, null,
generateThumbnail: false, token); _authProvider.CurrentUser.Role, _authProvider.CurrentUser.Email, generateThumbnail: false, token);
// //Queue (only from operators) private async Task<Annotation> SaveAnnotationInner(DateTime createdDate, string originalMediaName, TimeSpan time, string imageExtension, List<Detection> detections, SourceEnum source, Stream? stream,
// public async Task Consume(AnnotationCreatedMessage message, CancellationToken cancellationToken = default)
// {
//
// }
private async Task<Annotation> SaveAnnotationInner(DateTime createdDate, string fName, string imageExtension, List<Detection> detections, SourceEnum source, Stream? stream,
RoleEnum userRole, RoleEnum userRole,
string createdEmail, string createdEmail,
bool generateThumbnail = false, bool generateThumbnail = false,
CancellationToken token = default) CancellationToken token = default)
{ {
//Flow for roles:
// Operator or (AI from any role) -> Created
// Validator, Admin & Manual -> Validated
AnnotationStatus status; AnnotationStatus status;
var fName = originalMediaName.ToTimeName(time);
var annotation = await _dbFactory.Run(async db => var annotation = await _dbFactory.Run(async db =>
{ {
var ann = await db.Annotations.FirstOrDefaultAsync(x => x.Name == fName, token: token); var ann = await db.Annotations.FirstOrDefaultAsync(x => x.Name == fName, token: token);
// Manual Save from Validators -> Validated
// otherwise Created
status = userRole.IsValidator() && source == SourceEnum.Manual status = userRole.IsValidator() && source == SourceEnum.Manual
? AnnotationStatus.Validated ? AnnotationStatus.Validated
: AnnotationStatus.Created; : AnnotationStatus.Created;
@@ -149,6 +155,8 @@ public class AnnotationService : INotificationHandler<AnnotationsDeletedEvent>
{ {
CreatedDate = createdDate, CreatedDate = createdDate,
Name = fName, Name = fName,
OriginalMediaName = originalMediaName,
Time = time,
ImageExtension = imageExtension, ImageExtension = imageExtension,
CreatedEmail = createdEmail, CreatedEmail = createdEmail,
CreatedRole = userRole, CreatedRole = userRole,
+3 -2
View File
@@ -76,7 +76,7 @@ public class FailsafeAnnotationsProducer
await _annotationConfirmProducer.Send(validatedMessages, CompressionType.Gzip); await _annotationConfirmProducer.Send(validatedMessages, CompressionType.Gzip);
await _dbFactory.Run(async db => await _dbFactory.Run(async db =>
await db.AnnotationsQueue.DeleteAsync(aq => messagesChunk.Any(x => aq.Name == x.Name), token: cancellationToken)); await db.AnnotationsQueue.DeleteAsync(aq => messagesChunk.Any(x => aq.Name == x.OriginalMediaName), token: cancellationToken));
sent = true; sent = true;
} }
catch (Exception e) catch (Exception e)
@@ -106,7 +106,8 @@ public class FailsafeAnnotationsProducer
var annCreateMessage = new AnnotationCreatedMessage var annCreateMessage = new AnnotationCreatedMessage
{ {
Name = annotation.Name, Name = annotation.Name,
OriginalMediaName = annotation.OriginalMediaName,
Time = annotation.Time,
CreatedRole = annotation.CreatedRole, CreatedRole = annotation.CreatedRole,
CreatedEmail = annotation.CreatedEmail, CreatedEmail = annotation.CreatedEmail,
CreatedDate = annotation.CreatedDate, CreatedDate = annotation.CreatedDate,
+45 -13
View File
@@ -8,6 +8,7 @@ using Azaion.Common.Database;
using Azaion.Common.DTO; using Azaion.Common.DTO;
using Azaion.Common.DTO.Config; using Azaion.Common.DTO.Config;
using Azaion.Common.DTO.Queue; using Azaion.Common.DTO.Queue;
using Azaion.Common.Extensions;
using Azaion.CommonSecurity.DTO; using Azaion.CommonSecurity.DTO;
using LinqToDB; using LinqToDB;
using LinqToDB.Data; using LinqToDB.Data;
@@ -75,7 +76,6 @@ public class GalleryService(
var missedAnnotations = new ConcurrentBag<Annotation>(); var missedAnnotations = new ConcurrentBag<Annotation>();
try try
{ {
var prefixLen = Constants.THUMBNAIL_PREFIX.Length; var prefixLen = Constants.THUMBNAIL_PREFIX.Length;
var thumbnails = ThumbnailsDirectory.GetFiles() var thumbnails = ThumbnailsDirectory.GetFiles()
@@ -105,9 +105,37 @@ public class GalleryService(
return; return;
var detections = (await YoloLabel.ReadFromFile(labelName, cancellationToken)).Select(x => new Detection(fName, x)).ToList(); var detections = (await YoloLabel.ReadFromFile(labelName, cancellationToken)).Select(x => new Detection(fName, x)).ToList();
//get names and time
var fileName = Path.GetFileNameWithoutExtension(file.Name);
var strings = fileName.Split("_");
var timeStr = strings.LastOrDefault();
string originalMediaName;
TimeSpan time;
//For some reason, TimeSpan.ParseExact doesn't work on every platform.
if (!string.IsNullOrEmpty(timeStr) &&
timeStr.Length == 6 &&
int.TryParse(timeStr[..1], out var hours) &&
int.TryParse(timeStr[1..3], out var minutes) &&
int.TryParse(timeStr[3..5], out var seconds) &&
int.TryParse(timeStr[5..], out var milliseconds))
{
time = new TimeSpan(0, hours, minutes, seconds, milliseconds * 100);
originalMediaName = fileName[..^7];
}
else
{
originalMediaName = fileName;
time = TimeSpan.FromSeconds(0);
}
var annotation = new Annotation var annotation = new Annotation
{ {
Name = fName, Time = time,
OriginalMediaName = originalMediaName,
Name = file.Name.ToFName(),
ImageExtension = Path.GetExtension(file.Name), ImageExtension = Path.GetExtension(file.Name),
Detections = detections, Detections = detections,
CreatedDate = File.GetCreationTimeUtc(file.FullName), CreatedDate = File.GetCreationTimeUtc(file.FullName),
@@ -129,18 +157,22 @@ public class GalleryService(
logger.LogError(e, $"Failed to generate thumbnail for {file.Name}! Error: {e.Message}"); logger.LogError(e, $"Failed to generate thumbnail for {file.Name}! Error: {e.Message}");
} }
}, },
new ParallelOptions new ParallelOptions
{
ProgressFn = async num =>
{ {
ProgressFn = async num => Console.WriteLine($"Processed {num} item by Thread {Environment.CurrentManagedThreadId}");
{ ProcessedThumbnailsPercentage = imagesCount == 0 ? 0 : Math.Min(100, num * 100 / (double)imagesCount);
Console.WriteLine($"Processed {num} item by Thread {Environment.CurrentManagedThreadId}"); ThumbnailsUpdate?.Invoke(ProcessedThumbnailsPercentage);
ProcessedThumbnailsPercentage = imagesCount == 0 ? 0 : Math.Min(100, num * 100 / (double)imagesCount); await Task.CompletedTask;
ThumbnailsUpdate?.Invoke(ProcessedThumbnailsPercentage); },
await Task.CompletedTask; CpuUtilPercent = 100,
}, ProgressUpdateInterval = 200
CpuUtilPercent = 100, });
ProgressUpdateInterval = 200 }
}); catch (Exception e)
{
logger.LogError(e, $"Failed to refresh thumbnails! Error: {e.Message}");
} }
finally finally
{ {
@@ -0,0 +1,53 @@
using System.Text;
using Azaion.Common.Database;
using Azaion.Common.DTO.Config;
using Azaion.CommonSecurity;
using Azaion.CommonSecurity.DTO.Commands;
using Azaion.CommonSecurity.Services;
using MessagePack;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using NetMQ;
using NetMQ.Sockets;
namespace Azaion.Common.Services;
public interface IInferenceService
{
Task RunInference(string mediaPath, Func<AnnotationImage, CancellationToken, Task> processAnnotation, CancellationToken ct = default);
}
public class PythonInferenceService(ILogger<PythonInferenceService> logger, IOptions<AIRecognitionConfig> aiConfigOptions) : IInferenceService
{
public async Task RunInference(string mediaPath, Func<AnnotationImage, CancellationToken, Task> processAnnotation, CancellationToken ct = default)
{
using var dealer = new DealerSocket();
var clientId = Guid.NewGuid();
dealer.Options.Identity = Encoding.UTF8.GetBytes(clientId.ToString("N"));
dealer.Connect($"tcp://{SecurityConstants.ZMQ_HOST}:{SecurityConstants.ZMQ_PORT}");
var data = MessagePackSerializer.Serialize(aiConfigOptions.Value);
dealer.SendFrame(MessagePackSerializer.Serialize(new RemoteCommand(CommandType.Inference, mediaPath, data)));
while (true)
{
byte[] bytes = [];
try
{
var annotationStream = dealer.Get<AnnotationImage>(out bytes);
if (annotationStream == null)
{
if (bytes.Length == 4 && Encoding.UTF8.GetString(bytes) == "DONE")
break;
continue;
}
await processAnnotation(annotationStream, ct);
}
catch (Exception e)
{
logger.LogError(e, e.Message);
}
}
}
}
-9
View File
@@ -1,9 +0,0 @@
namespace Azaion.CommonSecurity.DTO;
public class ApiConfig
{
public string Url { get; set; } = null!;
public int RetryCount {get;set;}
public double TimeoutSeconds { get; set; }
public string ResourcesFolder { get; set; } = "";
}
@@ -1,24 +0,0 @@
using MessagePack;
namespace Azaion.CommonSecurity.DTO.Commands;
[MessagePackObject]
public class FileCommand
{
[Key("CommandType")]
public CommandType CommandType { get; set; }
[Key("Filename")]
public string Filename { get; set; }
[Key("Data")]
public byte[] Data { get; set; }
}
public enum CommandType
{
None = 0,
Inference = 1,
Load = 2
}
@@ -0,0 +1,24 @@
using MessagePack;
namespace Azaion.CommonSecurity.DTO.Commands;
[MessagePackObject]
public class RemoteCommand(CommandType commandType, string? filename = null, byte[]? data = null)
{
[Key("CommandType")]
public CommandType CommandType { get; set; } = commandType;
[Key("Filename")]
public string? Filename { get; set; } = filename;
[Key("Data")]
public byte[]? Data { get; set; } = data;
}
public enum CommandType
{
None = 0,
Inference = 1,
Load = 2,
GetUser = 3
}
@@ -1,6 +0,0 @@
namespace Azaion.CommonSecurity.DTO;
public class SecureAppConfig
{
public ApiConfig ApiConfig { get; set; } = null!;
}
+5 -15
View File
@@ -1,21 +1,11 @@
using System.Security.Claims; using MessagePack;
namespace Azaion.CommonSecurity.DTO; namespace Azaion.CommonSecurity.DTO;
[MessagePackObject]
public class User public class User
{ {
public Guid Id { get; set; } [Key("i")]public string Id { get; set; }
public string Email { get; set; } [Key("e")]public string Email { get; set; }
public RoleEnum Role { get; set; } [Key("r")]public RoleEnum Role { get; set; }
public User(IEnumerable<Claim> claims)
{
var claimDict = claims.ToDictionary(x => x.Type, x => x.Value);
Id = Guid.Parse(claimDict[SecurityConstants.CLAIM_NAME_ID]);
Email = claimDict[SecurityConstants.CLAIM_EMAIL];
if (!Enum.TryParse(claimDict[SecurityConstants.CLAIM_ROLE], out RoleEnum role))
role = RoleEnum.None;
Role = role;
}
} }
+2 -3
View File
@@ -19,9 +19,8 @@ public class SecurityConstants
#endregion ApiConfig #endregion ApiConfig
#region SocketClient #region SocketClient
public const string SOCKET_HOST = "127.0.0.1"; public const string ZMQ_HOST = "127.0.0.1";
public const int SOCKET_SEND_PORT = 5127; public const int ZMQ_PORT = 5127;
public const int SOCKET_RECEIVE_PORT = 5128;
#endregion SocketClient #endregion SocketClient
} }
@@ -1,127 +0,0 @@
using System.IdentityModel.Tokens.Jwt;
using System.Net;
using System.Net.Http.Headers;
using System.Security;
using System.Text;
using Azaion.CommonSecurity.DTO;
using Newtonsoft.Json;
namespace Azaion.CommonSecurity.Services;
public class AzaionApiClient(HttpClient httpClient) : IDisposable
{
const string JSON_MEDIA = "application/json";
private static ApiConfig _apiConfig = null!;
private string Email { get; set; } = null!;
private SecureString Password { get; set; } = new();
private string JwtToken { get; set; } = null!;
public User User { get; set; } = null!;
public static AzaionApiClient Create(ApiCredentials credentials)
{
try
{
if (!File.Exists(SecurityConstants.CONFIG_PATH))
throw new FileNotFoundException(SecurityConstants.CONFIG_PATH);
var configStr = File.ReadAllText(SecurityConstants.CONFIG_PATH);
_apiConfig = JsonConvert.DeserializeObject<SecureAppConfig>(configStr)!.ApiConfig;
}
catch (Exception e)
{
Console.WriteLine(e);
_apiConfig = new ApiConfig
{
Url = SecurityConstants.DEFAULT_API_URL,
RetryCount = SecurityConstants.DEFAULT_API_RETRY_COUNT ,
TimeoutSeconds = SecurityConstants.DEFAULT_API_TIMEOUT_SECONDS
};
}
var api = new AzaionApiClient(new HttpClient
{
BaseAddress = new Uri(_apiConfig.Url),
Timeout = TimeSpan.FromSeconds(_apiConfig.TimeoutSeconds)
});
api.EnterCredentials(credentials);
return api;
}
public void EnterCredentials(ApiCredentials credentials)
{
if (string.IsNullOrWhiteSpace(credentials.Email) || string.IsNullOrWhiteSpace(credentials.Password))
throw new Exception("Email or password is empty!");
Email = credentials.Email;
Password = credentials.Password.ToSecureString();
}
public async Task<Stream> GetResource(string fileName, string password, HardwareInfo hardware)
{
var response = await Send(httpClient, new HttpRequestMessage(HttpMethod.Post, $"/resources/get/{_apiConfig.ResourcesFolder}")
{
Content = new StringContent(JsonConvert.SerializeObject(new { fileName, password, hardware }), Encoding.UTF8, JSON_MEDIA)
});
return await response.Content.ReadAsStreamAsync();
}
private async Task Authorize()
{
if (string.IsNullOrEmpty(Email) || Password.Length == 0)
throw new Exception("Email or password is empty! Please do EnterCredentials first!");
var payload = new
{
email = Email,
password = Password.ToRealString()
};
var response = await httpClient.PostAsync(
"login",
new StringContent(JsonConvert.SerializeObject(payload), Encoding.UTF8, JSON_MEDIA));
if (!response.IsSuccessStatusCode)
throw new Exception($"EnterCredentials failed: {response.StatusCode}");
var responseData = await response.Content.ReadAsStringAsync();
var result = JsonConvert.DeserializeObject<LoginResponse>(responseData);
if (string.IsNullOrEmpty(result?.Token))
throw new Exception("JWT Token not found in response");
var handler = new JwtSecurityTokenHandler();
var token = handler.ReadJwtToken(result.Token);
User = new User(token.Claims);
JwtToken = result.Token;
}
private async Task<HttpResponseMessage> Send(HttpClient client, HttpRequestMessage request)
{
if (string.IsNullOrEmpty(JwtToken))
await Authorize();
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", JwtToken);
var response = await client.SendAsync(request);
if (response.StatusCode == HttpStatusCode.Unauthorized)
{
await Authorize();
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", JwtToken);
response = await client.SendAsync(request);
}
if (response.IsSuccessStatusCode)
return response;
var result = await response.Content.ReadAsStringAsync();
throw new Exception($"Failed: {response.StatusCode}! Result: {result}");
}
public void Dispose()
{
httpClient.Dispose();
Password.Dispose();
}
}
@@ -0,0 +1,60 @@
using System.Text;
using Azaion.CommonSecurity.DTO;
using Azaion.CommonSecurity.DTO.Commands;
using MessagePack;
using NetMQ;
using NetMQ.Sockets;
namespace Azaion.CommonSecurity.Services;
public interface IResourceLoader
{
Task<MemoryStream> LoadFile(string fileName, CancellationToken ct = default);
}
public interface IAuthProvider
{
User CurrentUser { get; }
}
public class PythonResourceLoader : IResourceLoader, IAuthProvider
{
private readonly DealerSocket _dealer = new();
private readonly Guid _clientId = Guid.NewGuid();
public User CurrentUser { get; }
public PythonResourceLoader(ApiCredentials credentials)
{
//Run python by credentials
_dealer.Options.Identity = Encoding.UTF8.GetBytes(_clientId.ToString("N"));
_dealer.Connect($"tcp://{SecurityConstants.ZMQ_HOST}:{SecurityConstants.ZMQ_PORT}");
_dealer.SendFrame(MessagePackSerializer.Serialize(new RemoteCommand(CommandType.GetUser)));
var user = _dealer.Get<User>(out _);
if (user == null)
throw new Exception("Can't get user from Auth provider");
CurrentUser = user;
}
public async Task<MemoryStream> LoadFile(string fileName, CancellationToken ct = default)
{
try
{
_dealer.SendFrame(MessagePackSerializer.Serialize(new RemoteCommand(CommandType.Load, fileName)));
if (!_dealer.TryReceiveFrameBytes(TimeSpan.FromMilliseconds(1000), out var bytes))
throw new Exception($"Unable to receive {fileName}");
return await Task.FromResult(new MemoryStream(bytes));
}
catch (Exception ex)
{
throw new Exception($"Failed to load fil0e '{fileName}': {ex.Message}", ex);
}
}
}
@@ -1,63 +0,0 @@
using Azaion.CommonSecurity.DTO;
using Azaion.CommonSecurity.DTO.Commands;
using MessagePack;
using NetMQ;
using NetMQ.Sockets;
namespace Azaion.CommonSecurity.Services;
public interface IResourceLoader
{
Task<MemoryStream> Load(string fileName, CancellationToken cancellationToken = default);
}
public class PythonResourceLoader : IResourceLoader
{
private readonly PushSocket _pushSocket = new();
private readonly PullSocket _pullSocket = new();
public PythonResourceLoader(ApiCredentials credentials)
{
//Run python by credentials
_pushSocket.Connect($"tcp://{SecurityConstants.SOCKET_HOST}:{SecurityConstants.SOCKET_SEND_PORT}");
_pullSocket.Connect($"tcp://{SecurityConstants.SOCKET_HOST}:{SecurityConstants.SOCKET_RECEIVE_PORT}");
}
public async Task<MemoryStream> Load(string fileName, CancellationToken cancellationToken = default)
{
try
{
var b = MessagePackSerializer.Serialize(new FileCommand
{
CommandType = CommandType.Load,
Filename = fileName
});
_pushSocket.SendFrame(b);
var bytes = _pullSocket.ReceiveFrameBytes(out bool more);
return new MemoryStream(bytes);
}
catch (Exception ex)
{
throw new Exception($"Failed to load fil0e '{fileName}': {ex.Message}", ex);
}
}
}
public class ResourceLoader(AzaionApiClient api, ApiCredentials credentials) : IResourceLoader
{
public async Task<MemoryStream> Load(string fileName, CancellationToken cancellationToken = default)
{
var hardwareService = new HardwareService();
var hardwareInfo = hardwareService.GetHardware();
var encryptedStream = Task.Run(() => api.GetResource(fileName, credentials.Password, hardwareInfo), cancellationToken).Result;
var key = Security.MakeEncryptionKey(credentials.Email, credentials.Password, hardwareInfo.Hash);
var stream = new MemoryStream();
await encryptedStream.DecryptTo(stream, key, cancellationToken);
stream.Seek(0, SeekOrigin.Begin);
return stream;
}
}
+16
View File
@@ -0,0 +1,16 @@
using MessagePack;
using NetMQ;
using NetMQ.Sockets;
namespace Azaion.CommonSecurity;
public static class ZeroMqExtensions
{
public static T? Get<T>(this DealerSocket dealer, out byte[] message)
{
if (!dealer.TryReceiveFrameBytes(TimeSpan.FromMinutes(2), out var bytes))
throw new Exception($"Unable to get {typeof(T).Name}");
message = bytes;
return MessagePackSerializer.Deserialize<T>(bytes);
}
}
@@ -58,13 +58,12 @@ public class DatasetExplorerEventHandler(
if (datasetExplorer.ThumbnailLoading) if (datasetExplorer.ThumbnailLoading)
return; return;
var fName = Path.GetFileNameWithoutExtension(datasetExplorer.CurrentAnnotation!.Annotation.ImagePath); var a = datasetExplorer.CurrentAnnotation!.Annotation;
var extension = Path.GetExtension(fName);
var detections = datasetExplorer.ExplorerEditor.CurrentDetections var detections = datasetExplorer.ExplorerEditor.CurrentDetections
.Select(x => new Detection(fName, x.GetLabel(datasetExplorer.ExplorerEditor.RenderSize))) .Select(x => new Detection(a.Name, x.GetLabel(datasetExplorer.ExplorerEditor.RenderSize)))
.ToList(); .ToList();
await annotationService.SaveAnnotation(fName, extension, detections, SourceEnum.Manual, token: cancellationToken); await annotationService.SaveAnnotation(a.OriginalMediaName, a.Time, a.ImageExtension, detections, SourceEnum.Manual, token: cancellationToken);
datasetExplorer.SwitchTab(toEditor: false); datasetExplorer.SwitchTab(toEditor: false);
break; break;
case PlaybackControlEnum.RemoveSelectedAnns: case PlaybackControlEnum.RemoveSelectedAnns:
+7 -15
View File
@@ -1,7 +1,5 @@
using System.IO; using System.IO;
using System.Reflection; using System.Reflection;
using System.Text;
using System.Text.Json;
using System.Windows; using System.Windows;
using System.Windows.Threading; using System.Windows.Threading;
using Azaion.Annotator; using Azaion.Annotator;
@@ -13,7 +11,6 @@ using Azaion.Common.Events;
using Azaion.Common.Extensions; using Azaion.Common.Extensions;
using Azaion.Common.Services; using Azaion.Common.Services;
using Azaion.CommonSecurity; using Azaion.CommonSecurity;
using Azaion.CommonSecurity.DTO;
using Azaion.CommonSecurity.Services; using Azaion.CommonSecurity.Services;
using Azaion.Dataset; using Azaion.Dataset;
using LibVLCSharp.Shared; using LibVLCSharp.Shared;
@@ -23,7 +20,6 @@ using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
using Newtonsoft.Json;
using Serilog; using Serilog;
using KeyEventArgs = System.Windows.Input.KeyEventArgs; using KeyEventArgs = System.Windows.Input.KeyEventArgs;
@@ -36,8 +32,7 @@ public partial class App
private IMediator _mediator = null!; private IMediator _mediator = null!;
private FormState _formState = null!; private FormState _formState = null!;
private AzaionApiClient _apiClient = null!; private PythonResourceLoader _resourceLoader = null!;
private IResourceLoader _resourceLoader = null!;
private Stream _securedConfig = null!; private Stream _securedConfig = null!;
private void OnDispatcherUnhandledException(object sender, DispatcherUnhandledExceptionEventArgs e) private void OnDispatcherUnhandledException(object sender, DispatcherUnhandledExceptionEventArgs e)
@@ -64,9 +59,8 @@ public partial class App
var login = new Login(); var login = new Login();
login.CredentialsEntered += async (s, args) => login.CredentialsEntered += async (s, args) =>
{ {
_apiClient = AzaionApiClient.Create(args);
_resourceLoader = new PythonResourceLoader(args); _resourceLoader = new PythonResourceLoader(args);
_securedConfig = await _resourceLoader.Load("secured-config.json"); _securedConfig = await _resourceLoader.LoadFile("secured-config.json");
AppDomain.CurrentDomain.AssemblyResolve += (_, a) => AppDomain.CurrentDomain.AssemblyResolve += (_, a) =>
{ {
@@ -75,7 +69,7 @@ public partial class App
{ {
try try
{ {
var stream = _resourceLoader.Load($"{assemblyName}.dll").GetAwaiter().GetResult(); var stream = _resourceLoader.LoadFile($"{assemblyName}.dll").GetAwaiter().GetResult();
return Assembly.Load(stream.ToArray()); return Assembly.Load(stream.ToArray());
} }
catch (Exception e) catch (Exception e)
@@ -124,11 +118,11 @@ public partial class App
services.AddSingleton<MainSuite>(); services.AddSingleton<MainSuite>();
services.AddSingleton<IHardwareService, HardwareService>(); services.AddSingleton<IHardwareService, HardwareService>();
services.AddSingleton(_apiClient); services.AddSingleton<IResourceLoader>(_resourceLoader);
services.AddSingleton(_resourceLoader); services.AddSingleton<IAuthProvider>(_resourceLoader);
services.AddSingleton<IInferenceService, PythonInferenceService>();
services.Configure<AppConfig>(context.Configuration); services.Configure<AppConfig>(context.Configuration);
services.ConfigureSection<ApiConfig>(context.Configuration);
services.ConfigureSection<QueueConfig>(context.Configuration); services.ConfigureSection<QueueConfig>(context.Configuration);
services.ConfigureSection<DirectoriesConfig>(context.Configuration); services.ConfigureSection<DirectoriesConfig>(context.Configuration);
services.ConfigureSection<AnnotationConfig>(context.Configuration); services.ConfigureSection<AnnotationConfig>(context.Configuration);
@@ -139,7 +133,6 @@ public partial class App
services.AddSingleton<Annotator.Annotator>(); services.AddSingleton<Annotator.Annotator>();
services.AddSingleton<DatasetExplorer>(); services.AddSingleton<DatasetExplorer>();
services.AddSingleton<HelpWindow>(); services.AddSingleton<HelpWindow>();
services.AddSingleton<IAIDetector, YOLODetector>();
services.AddMediatR(c => c.RegisterServicesFromAssemblies( services.AddMediatR(c => c.RegisterServicesFromAssemblies(
typeof(Annotator.Annotator).Assembly, typeof(Annotator.Annotator).Assembly,
typeof(DatasetExplorer).Assembly, typeof(DatasetExplorer).Assembly,
@@ -152,10 +145,9 @@ public partial class App
return new MediaPlayer(libVLC); return new MediaPlayer(libVLC);
}); });
services.AddSingleton<AnnotatorEventHandler>(); services.AddSingleton<AnnotatorEventHandler>();
services.AddSingleton<VLCFrameExtractor>();
services.AddSingleton<IDbFactory, DbFactory>(); services.AddSingleton<IDbFactory, DbFactory>();
services.AddSingleton<FailsafeAnnotationsProducer>(); services.AddSingleton<FailsafeAnnotationsProducer>();
services.AddSingleton<AnnotationService>(); services.AddSingleton<AnnotationService>();