mirror of
https://github.com/azaion/annotations.git
synced 2026-04-22 09:56:31 +00:00
add ramdisk, load AI model to ramdisk and start recognition from it
rewrite zmq to DEALER and ROUTER add GET_USER command to get CurrentUser from Python all auth is on the python side inference run and validate annotations on python
This commit is contained in:
+1
-1
@@ -50,7 +50,7 @@ This is crucial for the build because build needs Python.h header and other file
|
||||
pip install ultralytics
|
||||
|
||||
pip uninstall -y opencv-python
|
||||
pip install opencv-python cython msgpack cryptography rstream pika zmq
|
||||
pip install opencv-python cython msgpack cryptography rstream pika zmq pyjwt
|
||||
```
|
||||
In case of fbgemm.dll error (Windows specific):
|
||||
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
cdef class AIRecognitionConfig:
|
||||
cdef public double frame_recognition_seconds
|
||||
cdef public double tracking_distance_confidence
|
||||
cdef public double tracking_probability_increase
|
||||
cdef public double tracking_intersection_threshold
|
||||
cdef public int frame_period_recognition
|
||||
cdef public bytes file_data
|
||||
|
||||
@staticmethod
|
||||
cdef from_msgpack(bytes data)
|
||||
@@ -0,0 +1,32 @@
|
||||
from msgpack import unpackb
|
||||
|
||||
cdef class AIRecognitionConfig:
|
||||
def __init__(self,
|
||||
frame_recognition_seconds,
|
||||
tracking_distance_confidence,
|
||||
tracking_probability_increase,
|
||||
tracking_intersection_threshold,
|
||||
frame_period_recognition,
|
||||
file_data
|
||||
):
|
||||
self.frame_recognition_seconds = frame_recognition_seconds
|
||||
self.tracking_distance_confidence = tracking_distance_confidence
|
||||
self.tracking_probability_increase = tracking_probability_increase
|
||||
self.tracking_intersection_threshold = tracking_intersection_threshold
|
||||
self.frame_period_recognition = frame_period_recognition
|
||||
self.file_data = file_data
|
||||
|
||||
def __str__(self):
|
||||
return (f'frame_seconds : {self.frame_recognition_seconds}, distance_confidence : {self.tracking_distance_confidence}, '
|
||||
f'probability_increase : {self.tracking_probability_increase}, intersection_threshold : {self.tracking_intersection_threshold}, frame_period_recognition : {self.frame_period_recognition}')
|
||||
|
||||
@staticmethod
|
||||
cdef from_msgpack(bytes data):
|
||||
unpacked = unpackb(data, strict_map_key=False)
|
||||
return AIRecognitionConfig(
|
||||
unpacked.get("FrameRecognitionSeconds", 0.0),
|
||||
unpacked.get("TrackingDistanceConfidence", 0.0),
|
||||
unpacked.get("TrackingProbabilityIncrease", 0.0),
|
||||
unpacked.get("TrackingIntersectionThreshold", 0.0),
|
||||
unpacked.get("FramePeriodRecognition", 0),
|
||||
unpacked.get("Data", b''))
|
||||
@@ -1,8 +1,10 @@
|
||||
cdef class Detection:
|
||||
cdef double x, y, w, h
|
||||
cdef int cls
|
||||
cdef public double x, y, w, h, confidence
|
||||
cdef public int cls
|
||||
|
||||
cdef class Annotation:
|
||||
cdef bytes image
|
||||
cdef float time
|
||||
cdef list[Detection] detections
|
||||
cdef long time
|
||||
cdef public list[Detection] detections
|
||||
cdef bytes serialize(self)
|
||||
|
||||
|
||||
@@ -1,13 +1,35 @@
|
||||
import msgpack
|
||||
|
||||
cdef class Detection:
|
||||
def __init__(self, double x, double y, double w, double h, int cls):
|
||||
def __init__(self, double x, double y, double w, double h, int cls, double confidence):
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.w = w
|
||||
self.h = h
|
||||
self.cls = cls
|
||||
self.confidence = confidence
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.cls}: {self.x:.2f} {self.y:.2f} {self.w:.2f} {self.h:.2f}, prob: {(self.confidence*100):.1f}%'
|
||||
|
||||
cdef class Annotation:
|
||||
def __init__(self, bytes image_bytes, float time, list[Detection] detections):
|
||||
def __init__(self, bytes image_bytes, long time, list[Detection] detections):
|
||||
self.image = image_bytes
|
||||
self.time = time
|
||||
self.detections = detections
|
||||
self.detections = detections if detections is not None else []
|
||||
|
||||
cdef bytes serialize(self):
|
||||
return msgpack.packb({
|
||||
"i": self.image, # "i" = image
|
||||
"t": self.time, # "t" = time
|
||||
"d": [ # "d" = detections
|
||||
{
|
||||
"x": det.x,
|
||||
"y": det.y,
|
||||
"w": det.w,
|
||||
"h": det.h,
|
||||
"c": det.cls,
|
||||
"p": det.confidence
|
||||
} for det in self.detections
|
||||
]
|
||||
})
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
from user cimport User
|
||||
|
||||
cdef class ApiClient:
|
||||
cdef str email, password, token, folder, token_file, api_url
|
||||
cdef User user
|
||||
|
||||
cdef get_encryption_key(self, str hardware_hash)
|
||||
cdef login(self, str email, str password)
|
||||
cdef login(self)
|
||||
cdef set_token(self, str token)
|
||||
cdef get_user(self)
|
||||
|
||||
cdef load_bytes(self, str filename)
|
||||
cdef load_ai_model(self)
|
||||
cdef load_queue_config(self)
|
||||
|
||||
+47
-11
@@ -1,13 +1,14 @@
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
from http import HTTPStatus
|
||||
|
||||
from uuid import UUID
|
||||
import jwt
|
||||
import requests
|
||||
cimport constants
|
||||
from hardware_service cimport HardwareService, HardwareInfo
|
||||
from security cimport Security
|
||||
from io import BytesIO
|
||||
from user cimport User, RoleEnum
|
||||
|
||||
cdef class ApiClient:
|
||||
"""Handles API authentication and downloading of the AI model."""
|
||||
@@ -15,9 +16,11 @@ cdef class ApiClient:
|
||||
self.email = email
|
||||
self.password = password
|
||||
self.folder = folder
|
||||
self.user = None
|
||||
|
||||
if os.path.exists(<str>constants.TOKEN_FILE):
|
||||
with open(<str>constants.TOKEN_FILE, "r") as file:
|
||||
self.token = file.read().strip()
|
||||
self.set_token(<str>file.read().strip())
|
||||
else:
|
||||
self.token = None
|
||||
|
||||
@@ -25,21 +28,52 @@ cdef class ApiClient:
|
||||
cdef str key = f'{self.email}-{self.password}-{hardware_hash}-#%@AzaionKey@%#---'
|
||||
return Security.calc_hash(key)
|
||||
|
||||
cdef login(self, str email, str password):
|
||||
response = requests.post(f"{constants.API_URL}/login", json={"email": email, "password": password})
|
||||
cdef login(self):
|
||||
response = requests.post(f"{constants.API_URL}/login",
|
||||
json={"email": self.email, "password": self.password})
|
||||
response.raise_for_status()
|
||||
self.token = response.json()["token"]
|
||||
|
||||
token = response.json()["token"]
|
||||
self.set_token(token)
|
||||
with open(<str>constants.TOKEN_FILE, 'w') as file:
|
||||
file.write(self.token)
|
||||
file.write(token)
|
||||
|
||||
cdef set_token(self, str token):
|
||||
self.token = token
|
||||
claims = jwt.decode(token, options={"verify_signature": False})
|
||||
|
||||
try:
|
||||
id = str(UUID(claims.get("nameid", "")))
|
||||
except ValueError:
|
||||
raise ValueError("Invalid GUID format in claims")
|
||||
|
||||
email = claims.get("unique_name", "")
|
||||
|
||||
role_str = claims.get("role", "")
|
||||
if role_str == "ApiAdmin":
|
||||
role = RoleEnum.ApiAdmin
|
||||
elif role_str == "Admin":
|
||||
role = RoleEnum.Admin
|
||||
elif role_str == "ResourceUploader":
|
||||
role = RoleEnum.ResourceUploader
|
||||
elif role_str == "Validator":
|
||||
role = RoleEnum.Validator
|
||||
elif role_str == "Operator":
|
||||
role = RoleEnum.Operator
|
||||
else:
|
||||
role = RoleEnum.NONE
|
||||
self.user = User(id, email, role)
|
||||
|
||||
cdef get_user(self):
|
||||
if self.user is None:
|
||||
self.login()
|
||||
return self.user
|
||||
|
||||
cdef load_bytes(self, str filename):
|
||||
hardware_service = HardwareService()
|
||||
cdef HardwareInfo hardware = hardware_service.get_hardware_info()
|
||||
|
||||
if self.token is None:
|
||||
self.login(self.email, self.password)
|
||||
self.login()
|
||||
|
||||
url = f"{constants.API_URL}/resources/get/{self.folder}"
|
||||
headers = {
|
||||
@@ -56,7 +90,7 @@ cdef class ApiClient:
|
||||
response = requests.post(url, data=payload, headers=headers, stream=True)
|
||||
|
||||
if response.status_code == HTTPStatus.UNAUTHORIZED or response.status_code == HTTPStatus.FORBIDDEN:
|
||||
self.login(self.email, self.password)
|
||||
self.login()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.token}",
|
||||
"Content-Type": "application/json"
|
||||
@@ -69,7 +103,9 @@ cdef class ApiClient:
|
||||
key = self.get_encryption_key(hardware.hash)
|
||||
|
||||
stream = BytesIO(response.raw.read())
|
||||
return Security.decrypt_to(stream, key)
|
||||
data = Security.decrypt_to(stream, key)
|
||||
print(f'loaded file: {filename}, {len(data)} bytes')
|
||||
return data
|
||||
|
||||
cdef load_ai_model(self):
|
||||
return self.load_bytes(constants.AI_MODEL_FILE)
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
cdef str SOCKET_HOST # Host for the socket server
|
||||
cdef int SOCKET_PORT # Port for the socket server
|
||||
cdef int SOCKET_BUFFER_SIZE # Buffer size for socket communication
|
||||
cdef int ZMQ_PORT = 5127 # Port for the zmq
|
||||
|
||||
cdef int QUEUE_MAXSIZE # Maximum size of the command queue
|
||||
cdef str COMMANDS_QUEUE # Name of the commands queue in rabbit
|
||||
@@ -10,3 +8,5 @@ cdef str API_URL # Base URL for the external API
|
||||
cdef str TOKEN_FILE # Name of the token file where temporary token would be stored
|
||||
cdef str QUEUE_CONFIG_FILENAME # queue config filename to load from api
|
||||
cdef str AI_MODEL_FILE # AI Model file
|
||||
|
||||
cdef bytes DONE_SIGNAL
|
||||
@@ -1,6 +1,4 @@
|
||||
cdef str SOCKET_HOST = "127.0.0.1" # Host for the socket server
|
||||
cdef int SOCKET_PORT = 9127 # Port for the socket server
|
||||
cdef int SOCKET_BUFFER_SIZE = 4096 # Buffer size for socket communication
|
||||
cdef int ZMQ_PORT = 5127 # Port for the zmq
|
||||
|
||||
cdef int QUEUE_MAXSIZE = 1000 # Maximum size of the command queue
|
||||
cdef str COMMANDS_QUEUE = "azaion-commands"
|
||||
@@ -10,3 +8,5 @@ cdef str API_URL = "https://api.azaion.com" # Base URL for the external API
|
||||
cdef str TOKEN_FILE = "token"
|
||||
cdef str QUEUE_CONFIG_FILENAME = "secured-config.json"
|
||||
cdef str AI_MODEL_FILE = "azaion.pt"
|
||||
|
||||
cdef bytes DONE_SIGNAL = b"DONE"
|
||||
@@ -10,5 +10,8 @@ def start_server():
|
||||
except Exception as e:
|
||||
processor.stop()
|
||||
|
||||
def on_annotation(self, cmd, annotation):
|
||||
print('on_annotation hit!')
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_server()
|
||||
@@ -0,0 +1,17 @@
|
||||
from remote_command cimport RemoteCommand
|
||||
from annotation cimport Annotation
|
||||
from ai_config cimport AIRecognitionConfig
|
||||
|
||||
cdef class Inference:
|
||||
cdef object model
|
||||
cdef object on_annotation
|
||||
cdef Annotation _previous_annotation
|
||||
cdef AIRecognitionConfig ai_config
|
||||
|
||||
cdef bint is_video(self, str filepath)
|
||||
cdef run_inference(self, RemoteCommand cmd, int batch_size=?)
|
||||
cdef _process_video(self, RemoteCommand cmd, int batch_size)
|
||||
cdef _process_image(self, RemoteCommand cmd)
|
||||
|
||||
cdef frame_to_annotation(self, long time, frame, boxes: object)
|
||||
cdef bint is_valid_annotation(self, Annotation annotation)
|
||||
+76
-17
@@ -1,30 +1,38 @@
|
||||
import ai_config
|
||||
import msgpack
|
||||
from ultralytics import YOLO
|
||||
import mimetypes
|
||||
import cv2
|
||||
from ultralytics.engine.results import Boxes
|
||||
from remote_command cimport RemoteCommand
|
||||
from annotation cimport Detection, Annotation
|
||||
from secure_model cimport SecureModelLoader
|
||||
from ai_config cimport AIRecognitionConfig
|
||||
|
||||
cdef class Inference:
|
||||
def __init__(self, model_bytes, on_annotations):
|
||||
self.model = YOLO(model_bytes)
|
||||
self.on_annotations = on_annotations
|
||||
def __init__(self, model_bytes, on_annotation):
|
||||
loader = SecureModelLoader()
|
||||
model_path = loader.load_model(model_bytes)
|
||||
self.model = YOLO(<str>model_path)
|
||||
self.on_annotation = on_annotation
|
||||
|
||||
cdef bint is_video(self, str filepath):
|
||||
mime_type, _ = mimetypes.guess_type(<str>filepath)
|
||||
return mime_type and mime_type.startswith("video")
|
||||
|
||||
cdef run_inference(self, RemoteCommand cmd, int batch_size=8, int frame_skip=4):
|
||||
cdef run_inference(self, RemoteCommand cmd, int batch_size=8):
|
||||
print('run inference..')
|
||||
|
||||
if self.is_video(cmd.filename):
|
||||
return self._process_video(cmd, batch_size, frame_skip)
|
||||
return self._process_video(cmd, batch_size)
|
||||
else:
|
||||
return self._process_image(cmd)
|
||||
|
||||
cdef _process_video(self, RemoteCommand cmd, int batch_size, int frame_skip):
|
||||
cdef _process_video(self, RemoteCommand cmd, int batch_size):
|
||||
frame_count = 0
|
||||
batch_frame = []
|
||||
annotations = []
|
||||
v_input = cv2.VideoCapture(<str>cmd.filename)
|
||||
self.ai_config = AIRecognitionConfig.from_msgpack(cmd.data)
|
||||
|
||||
while v_input.isOpened():
|
||||
ret, frame = v_input.read()
|
||||
@@ -33,7 +41,7 @@ cdef class Inference:
|
||||
break
|
||||
|
||||
frame_count += 1
|
||||
if frame_count % frame_skip == 0:
|
||||
if frame_count % self.ai_config.frame_period_recognition == 0:
|
||||
batch_frame.append((frame, ms))
|
||||
|
||||
if len(batch_frame) == batch_size:
|
||||
@@ -41,10 +49,11 @@ cdef class Inference:
|
||||
results = self.model.track(frames, persist=True)
|
||||
|
||||
for frame, res in zip(batch_frame, results):
|
||||
annotation = self.process_detections(int(frame[1]), frame[0], res.boxes)
|
||||
if len(annotation.detections) > 0:
|
||||
annotations.append(annotation)
|
||||
self.on_annotations(cmd, annotations)
|
||||
annotation = self.frame_to_annotation(int(frame[1]), frame[0], res.boxes)
|
||||
|
||||
if self.is_valid_annotation(<Annotation>annotation):
|
||||
self._previous_annotation = annotation
|
||||
self.on_annotation(cmd, annotation)
|
||||
batch_frame.clear()
|
||||
|
||||
v_input.release()
|
||||
@@ -52,15 +61,65 @@ cdef class Inference:
|
||||
cdef _process_image(self, RemoteCommand cmd):
|
||||
frame = cv2.imread(<str>cmd.filename)
|
||||
res = self.model.track(frame)
|
||||
annotation = self.process_detections(0, frame, res[0].boxes)
|
||||
self.on_annotations(cmd, [annotation])
|
||||
annotation = self.frame_to_annotation(0, frame, res[0].boxes)
|
||||
self.on_annotation(cmd, annotation)
|
||||
|
||||
cdef process_detections(self, float time, frame, boxes: Boxes):
|
||||
cdef frame_to_annotation(self, long time, frame, boxes: Boxes):
|
||||
detections = []
|
||||
for box in boxes:
|
||||
b = box.xywhn[0].cpu().numpy()
|
||||
cls = int(box.cls[0].cpu().numpy().item())
|
||||
detections.append(Detection(<double>b[0], <double>b[1], <double>b[2], <double>b[3], cls))
|
||||
_, encoded_image = cv2.imencode('.jpg', frame[0])
|
||||
confidence = box.conf[0].cpu().numpy().item()
|
||||
det = Detection(<double> b[0], <double> b[1], <double> b[2], <double> b[3], cls, confidence)
|
||||
detections.append(det)
|
||||
_, encoded_image = cv2.imencode('.jpg', frame)
|
||||
image_bytes = encoded_image.tobytes()
|
||||
return Annotation(image_bytes, time, detections)
|
||||
|
||||
cdef bint is_valid_annotation(self, Annotation annotation):
|
||||
# No detections, invalid
|
||||
if not annotation.detections:
|
||||
return False
|
||||
|
||||
# First valid annotation, always accept
|
||||
if self._previous_annotation is None:
|
||||
return True
|
||||
|
||||
# Enough time has passed since last annotation
|
||||
if annotation.time >= self._previous_annotation.time + <long>(self.ai_config.frame_recognition_seconds * 1000):
|
||||
return True
|
||||
|
||||
# More objects detected than before
|
||||
if len(annotation.detections) > len(self._previous_annotation.detections):
|
||||
return True
|
||||
|
||||
cdef:
|
||||
Detection current_det, prev_det
|
||||
double dx, dy, distance_sq, min_distance_sq
|
||||
Detection closest_det
|
||||
|
||||
# Check each detection against previous frame
|
||||
for current_det in annotation.detections:
|
||||
min_distance_sq = 1e18 # Initialize with large value
|
||||
closest_det = None
|
||||
|
||||
# Find closest detection in previous frame
|
||||
for prev_det in self._previous_annotation.detections:
|
||||
dx = current_det.x - prev_det.x
|
||||
dy = current_det.y - prev_det.y
|
||||
distance_sq = dx * dx + dy * dy
|
||||
|
||||
if distance_sq < min_distance_sq:
|
||||
min_distance_sq = distance_sq
|
||||
closest_det = prev_det
|
||||
|
||||
# Check if beyond tracking distance
|
||||
if min_distance_sq > self.ai_config.tracking_distance_confidence:
|
||||
return True
|
||||
|
||||
# Check probability increase
|
||||
if current_det.confidence >= closest_det.confidence + self.ai_config.tracking_probability_increase:
|
||||
return True
|
||||
|
||||
# No validation criteria met
|
||||
return False
|
||||
|
||||
+16
-13
@@ -1,12 +1,13 @@
|
||||
import traceback
|
||||
from queue import Queue
|
||||
cimport constants
|
||||
import msgpack
|
||||
|
||||
from api_client cimport ApiClient
|
||||
from annotation cimport Annotation
|
||||
from inference import Inference
|
||||
from inference cimport Inference
|
||||
from remote_command cimport RemoteCommand, CommandType
|
||||
from remote_command_handler cimport RemoteCommandHandler
|
||||
from user cimport User
|
||||
import argparse
|
||||
|
||||
cdef class ParsedArguments:
|
||||
@@ -36,11 +37,10 @@ cdef class CommandProcessor:
|
||||
while self.running:
|
||||
try:
|
||||
command = self.command_queue.get()
|
||||
print(f'command is : {command}')
|
||||
model = self.api_client.load_ai_model()
|
||||
Inference(model, self.on_annotations).run_inference(command)
|
||||
Inference(model, self.on_annotation).run_inference(command)
|
||||
except Exception as e:
|
||||
print(f"Error processing queue: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
cdef on_command(self, RemoteCommand command):
|
||||
try:
|
||||
@@ -48,17 +48,20 @@ cdef class CommandProcessor:
|
||||
self.command_queue.put(command)
|
||||
elif command.command_type == CommandType.LOAD:
|
||||
response = self.api_client.load_bytes(command.filename)
|
||||
print(f'loaded file: {command.filename}, {len(response)} bytes')
|
||||
self.remote_handler.send(response)
|
||||
print(f'{len(response)} bytes was sent.')
|
||||
|
||||
self.remote_handler.send(command.client_id, response)
|
||||
elif command.command_type == CommandType.GET_USER:
|
||||
self.get_user(command, self.api_client.get_user())
|
||||
else:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"Error handling client: {e}")
|
||||
|
||||
cdef on_annotations(self, RemoteCommand cmd, annotations: [Annotation]):
|
||||
data = msgpack.packb(annotations)
|
||||
self.remote_handler.send(data)
|
||||
print(f'{len(data)} bytes was sent.')
|
||||
cdef get_user(self, RemoteCommand command, User user):
|
||||
self.remote_handler.send(command.client_id, user.serialize())
|
||||
|
||||
cdef on_annotation(self, RemoteCommand cmd, Annotation annotation):
|
||||
data = annotation.serialize()
|
||||
self.remote_handler.send(cmd.client_id, data)
|
||||
|
||||
def stop(self):
|
||||
self.running = False
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
cdef enum CommandType:
|
||||
INFERENCE = 1
|
||||
LOAD = 2
|
||||
GET_USER = 3
|
||||
|
||||
cdef class RemoteCommand:
|
||||
cdef public bytes client_id
|
||||
cdef CommandType command_type
|
||||
cdef str filename
|
||||
cdef bytes data
|
||||
|
||||
@@ -10,8 +10,10 @@ cdef class RemoteCommand:
|
||||
command_type_names = {
|
||||
1: "INFERENCE",
|
||||
2: "LOAD",
|
||||
3: "GET_USER"
|
||||
}
|
||||
return f'{command_type_names[self.command_type]}: {self.filename}'
|
||||
data_str = f'. Data: {len(self.data)} bytes' if self.data else ''
|
||||
return f'{command_type_names[self.command_type]}: {self.filename}{data_str}'
|
||||
|
||||
@staticmethod
|
||||
cdef from_msgpack(bytes data):
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
cdef class RemoteCommandHandler:
|
||||
cdef object _on_command
|
||||
cdef object _context
|
||||
cdef object _socket
|
||||
cdef object _router
|
||||
cdef object _dealer
|
||||
cdef object _shutdown_event
|
||||
cdef object _pull_socket
|
||||
cdef object _pull_thread
|
||||
cdef object _push_socket
|
||||
cdef object _push_queue
|
||||
cdef object _push_thread
|
||||
cdef object _on_command
|
||||
|
||||
cdef object _proxy_thread
|
||||
cdef object _workers
|
||||
|
||||
cdef start(self)
|
||||
cdef _pull_loop(self)
|
||||
cdef _push_loop(self)
|
||||
cdef send(self, bytes message_bytes)
|
||||
cdef _proxy_loop(self)
|
||||
cdef _worker_loop(self)
|
||||
cdef send(self, bytes client_id, bytes data)
|
||||
cdef close(self)
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from queue import Queue
|
||||
|
||||
import zmq
|
||||
import json
|
||||
from threading import Thread, Event
|
||||
from remote_command cimport RemoteCommand
|
||||
cimport constants
|
||||
|
||||
cdef class RemoteCommandHandler:
|
||||
def __init__(self, object on_command):
|
||||
@@ -11,68 +9,58 @@ cdef class RemoteCommandHandler:
|
||||
self._context = zmq.Context.instance()
|
||||
self._shutdown_event = Event()
|
||||
|
||||
self._pull_socket = self._context.socket(zmq.PULL)
|
||||
self._pull_socket.setsockopt(zmq.LINGER, 0)
|
||||
self._pull_socket.bind("tcp://*:5127")
|
||||
self._pull_thread = Thread(target=self._pull_loop, daemon=True)
|
||||
self._router = self._context.socket(zmq.ROUTER)
|
||||
self._router.setsockopt(zmq.LINGER, 0)
|
||||
self._router.bind(f'tcp://*:{constants.ZMQ_PORT}')
|
||||
|
||||
self._push_queue = Queue()
|
||||
self._dealer = self._context.socket(zmq.DEALER)
|
||||
self._dealer.setsockopt(zmq.LINGER, 0)
|
||||
self._dealer.bind("inproc://backend")
|
||||
|
||||
self._push_socket = self._context.socket(zmq.PUSH)
|
||||
self._push_socket.setsockopt(zmq.LINGER, 0)
|
||||
self._push_socket.bind("tcp://*:5128")
|
||||
self._push_thread = Thread(target=self._push_loop, daemon=True)
|
||||
self._proxy_thread = Thread(target=self._proxy_loop, daemon=True)
|
||||
|
||||
self._workers = []
|
||||
for _ in range(4): # 4 worker threads
|
||||
worker = Thread(target=self._worker_loop, daemon=True)
|
||||
self._workers.append(worker)
|
||||
|
||||
cdef start(self):
|
||||
self._pull_thread.start()
|
||||
self._push_thread.start()
|
||||
self._proxy_thread.start()
|
||||
for worker in self._workers:
|
||||
worker.start()
|
||||
|
||||
cdef _pull_loop(self):
|
||||
while not self._shutdown_event.is_set():
|
||||
print('wait for the command...')
|
||||
message = self._pull_socket.recv()
|
||||
cmd = RemoteCommand.from_msgpack(<bytes>message)
|
||||
print(f'received: {cmd}')
|
||||
self._on_command(cmd)
|
||||
cdef _proxy_loop(self):
|
||||
zmq.proxy(self._router, self._dealer)
|
||||
|
||||
cdef _push_loop(self):
|
||||
cdef _worker_loop(self):
|
||||
worker_socket = self._context.socket(zmq.DEALER)
|
||||
worker_socket.setsockopt(zmq.LINGER, 0)
|
||||
worker_socket.connect("inproc://backend")
|
||||
poller = zmq.Poller()
|
||||
poller.register(worker_socket, zmq.POLLIN)
|
||||
print('started receiver loop...')
|
||||
while not self._shutdown_event.is_set():
|
||||
try:
|
||||
response = self._push_queue.get(timeout=1) # Timeout to check shutdown flag
|
||||
self._push_socket.send(response)
|
||||
except:
|
||||
continue
|
||||
socks = dict(poller.poll(500))
|
||||
if worker_socket in socks:
|
||||
client_id, message = worker_socket.recv_multipart()
|
||||
cmd = RemoteCommand.from_msgpack(<bytes> message)
|
||||
cmd.client_id = client_id
|
||||
print(f'Received [{cmd}] from the client {client_id}')
|
||||
self._on_command(cmd)
|
||||
except Exception as e:
|
||||
print(f"Worker error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
cdef send(self, bytes message_bytes):
|
||||
print(f'about to send {len(message_bytes)}')
|
||||
try:
|
||||
self._push_queue.put(message_bytes)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
cdef send(self, bytes client_id, bytes data):
|
||||
with self._context.socket(zmq.DEALER) as socket:
|
||||
socket.connect("inproc://backend")
|
||||
socket.send_multipart([client_id, data])
|
||||
print(f'{len(data)} bytes was sent to client {client_id}')
|
||||
|
||||
cdef close(self):
|
||||
self._shutdown_event.set()
|
||||
self._pull_socket.close()
|
||||
self._push_socket.close()
|
||||
self._context.term()
|
||||
|
||||
|
||||
cdef class QueueConfig:
|
||||
cdef str host,
|
||||
cdef int port, command_port
|
||||
cdef str producer_user, producer_pw, consumer_user, consumer_pw
|
||||
|
||||
@staticmethod
|
||||
cdef QueueConfig from_json(str json_string):
|
||||
s = str(json_string).strip()
|
||||
cdef dict config_dict = json.loads(s)["QueueConfig"]
|
||||
cdef QueueConfig config = QueueConfig()
|
||||
|
||||
config.host = config_dict["Host"]
|
||||
config.port = config_dict["Port"]
|
||||
config.command_port = config_dict["CommandsPort"]
|
||||
config.producer_user = config_dict["ProducerUsername"]
|
||||
config.producer_pw = config_dict["ProducerPassword"]
|
||||
config.consumer_user = config_dict["ConsumerUsername"]
|
||||
config.consumer_pw = config_dict["ConsumerPassword"]
|
||||
return config
|
||||
self._router.close()
|
||||
self._dealer.close()
|
||||
self._context.term()
|
||||
@@ -0,0 +1,12 @@
|
||||
cdef class SecureModelLoader:
|
||||
cdef:
|
||||
bytes _model_bytes
|
||||
str _ramdisk_path
|
||||
str _temp_file_path
|
||||
int _disk_size_mb
|
||||
|
||||
cpdef str load_model(self, bytes model_bytes)
|
||||
cdef str _get_ramdisk_path(self)
|
||||
cdef void _create_ramdisk(self)
|
||||
cdef void _store_model(self)
|
||||
cdef void _cleanup(self)
|
||||
@@ -0,0 +1,104 @@
|
||||
import os
|
||||
import platform
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from libc.stdio cimport FILE, fopen, fclose, remove
|
||||
from libc.stdlib cimport free
|
||||
from libc.string cimport strdup
|
||||
|
||||
cdef class SecureModelLoader:
|
||||
def __cinit__(self, int disk_size_mb=512):
|
||||
self._disk_size_mb = disk_size_mb
|
||||
self._ramdisk_path = None
|
||||
self._temp_file_path = None
|
||||
|
||||
cpdef str load_model(self, bytes model_bytes):
|
||||
"""Public method to load YOLO model securely."""
|
||||
self._model_bytes = model_bytes
|
||||
self._create_ramdisk()
|
||||
self._store_model()
|
||||
return self._temp_file_path
|
||||
|
||||
cdef str _get_ramdisk_path(self):
|
||||
"""Determine the RAM disk path based on the OS."""
|
||||
if platform.system() == "Windows":
|
||||
return "R:\\"
|
||||
elif platform.system() == "Linux":
|
||||
return "/mnt/ramdisk"
|
||||
elif platform.system() == "Darwin":
|
||||
return "/Volumes/RAMDisk"
|
||||
else:
|
||||
raise RuntimeError("Unsupported OS for RAM disk")
|
||||
|
||||
cdef void _create_ramdisk(self):
|
||||
"""Create a RAM disk securely based on the OS."""
|
||||
system = platform.system()
|
||||
|
||||
if system == "Windows":
|
||||
# Create RAM disk via PowerShell
|
||||
command = f'powershell -Command "subst R: {tempfile.gettempdir()}"'
|
||||
if os.system(command) != 0:
|
||||
raise RuntimeError("Failed to create RAM disk on Windows")
|
||||
self._ramdisk_path = "R:\\"
|
||||
|
||||
elif system == "Linux":
|
||||
# Use tmpfs for RAM disk
|
||||
self._ramdisk_path = "/mnt/ramdisk"
|
||||
if not Path(self._ramdisk_path).exists():
|
||||
os.mkdir(self._ramdisk_path)
|
||||
if os.system(f"mount -t tmpfs -o size={self._disk_size_mb}M tmpfs {self._ramdisk_path}") != 0:
|
||||
raise RuntimeError("Failed to create RAM disk on Linux")
|
||||
|
||||
elif system == "Darwin":
|
||||
# Use hdiutil for macOS RAM disk
|
||||
block_size = 2048 # 512-byte blocks * 2048 = 1MB
|
||||
num_blocks = self._disk_size_mb * block_size
|
||||
result = os.popen(f"hdiutil attach -nomount ram://{num_blocks}").read().strip()
|
||||
if result:
|
||||
self._ramdisk_path = "/Volumes/RAMDisk"
|
||||
os.system(f"diskutil eraseVolume HFS+ RAMDisk {result}")
|
||||
else:
|
||||
raise RuntimeError("Failed to create RAM disk on macOS")
|
||||
|
||||
cdef void _store_model(self):
|
||||
"""Write model securely to the RAM disk."""
|
||||
cdef char* temp_path
|
||||
cdef FILE* cfile
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
dir=self._ramdisk_path, suffix='.pt', delete=False
|
||||
) as tmp_file:
|
||||
tmp_file.write(self._model_bytes)
|
||||
self._temp_file_path = tmp_file.name
|
||||
|
||||
encoded_path = self._temp_file_path.encode('utf-8')
|
||||
temp_path = strdup(encoded_path)
|
||||
with nogil:
|
||||
cfile = fopen(temp_path, "rb")
|
||||
if cfile == NULL:
|
||||
raise IOError(f"Could not open {self._temp_file_path}")
|
||||
fclose(cfile)
|
||||
|
||||
cdef void _cleanup(self):
|
||||
"""Remove the model file and unmount RAM disk securely."""
|
||||
cdef char* c_path
|
||||
if self._temp_file_path:
|
||||
c_path = strdup(os.fsencode(self._temp_file_path))
|
||||
with nogil:
|
||||
remove(c_path)
|
||||
free(c_path)
|
||||
self._temp_file_path = None
|
||||
|
||||
# Unmount RAM disk based on OS
|
||||
if self._ramdisk_path:
|
||||
if platform.system() == "Windows":
|
||||
os.system("subst R: /D")
|
||||
elif platform.system() == "Linux":
|
||||
os.system(f"umount {self._ramdisk_path}")
|
||||
elif platform.system() == "Darwin":
|
||||
os.system("hdiutil detach /Volumes/RAMDisk")
|
||||
self._ramdisk_path = None
|
||||
|
||||
def __dealloc__(self):
|
||||
"""Ensure cleanup when the object is deleted."""
|
||||
self._cleanup()
|
||||
+10
-1
@@ -8,7 +8,10 @@ extensions = [
|
||||
Extension('hardware_service', ['hardware_service.pyx'], extra_compile_args=["-g"], extra_link_args=["-g"]),
|
||||
Extension('remote_command', ['remote_command.pyx']),
|
||||
Extension('remote_command_handler', ['remote_command_handler.pyx']),
|
||||
Extension('user', ['user.pyx']),
|
||||
Extension('api_client', ['api_client.pyx']),
|
||||
Extension('secure_model', ['secure_model.pyx']),
|
||||
Extension('ai_config', ['ai_config.pyx']),
|
||||
Extension('inference', ['inference.pyx']),
|
||||
|
||||
Extension('main', ['main.pyx']),
|
||||
@@ -21,8 +24,14 @@ setup(
|
||||
compiler_directives={
|
||||
"language_level": 3,
|
||||
"emit_code_comments" : False,
|
||||
"binding": True
|
||||
"binding": True,
|
||||
'boundscheck': False,
|
||||
'wraparound': False
|
||||
}
|
||||
),
|
||||
install_requires=[
|
||||
'ultralytics>=8.0.0',
|
||||
'pywin32; platform_system=="Windows"'
|
||||
],
|
||||
zip_safe=False
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1laWQiOiJkOTBhMzZjYS1lMjM3LTRmYmQtOWM3Yy0xMjcwNDBhYzg1NTYiLCJ1bmlxdWVfbmFtZSI6ImFkbWluQGF6YWlvbi5jb20iLCJyb2xlIjoiQXBpQWRtaW4iLCJuYmYiOjE3MzgxNjM2MzYsImV4cCI6MTczODE3ODAzNiwiaWF0IjoxNzM4MTYzNjM2LCJpc3MiOiJBemFpb25BcGkiLCJhdWQiOiJBbm5vdGF0b3JzL09yYW5nZVBpL0FkbWlucyJ9.7VVws5mwGqx--sGopOuZE9iu3dzt1UdVPXeje2KZTYk
|
||||
@@ -0,0 +1,15 @@
|
||||
cdef enum RoleEnum:
|
||||
NONE = 0
|
||||
Operator = 10
|
||||
Validator = 20
|
||||
CompanionPC = 30
|
||||
Admin = 40
|
||||
ResourceUploader = 50
|
||||
ApiAdmin = 1000
|
||||
|
||||
cdef class User:
|
||||
cdef public str id
|
||||
cdef public str email
|
||||
cdef public RoleEnum role
|
||||
|
||||
cdef bytes serialize(self)
|
||||
@@ -0,0 +1,15 @@
|
||||
import msgpack
|
||||
|
||||
cdef class User:
|
||||
|
||||
def __init__(self, str id, str email, RoleEnum role):
|
||||
self.id = id
|
||||
self.email = email
|
||||
self.role = role
|
||||
|
||||
cdef bytes serialize(self):
|
||||
return msgpack.packb({
|
||||
"i": self.id,
|
||||
"e": self.email,
|
||||
"r": self.role
|
||||
})
|
||||
@@ -480,7 +480,7 @@
|
||||
Grid.Column="10"
|
||||
Padding="2" Width="25"
|
||||
Height="25"
|
||||
ToolTip="Розпізнати за допомогою AI. Клавіша: [A]" Background="Black" BorderBrush="Black"
|
||||
ToolTip="Розпізнати за допомогою AI. Клавіша: [R]" Background="Black" BorderBrush="Black"
|
||||
Click="AutoDetect">
|
||||
<Path Stretch="Fill" Fill="LightGray" Data="M144.317 85.269h223.368c15.381 0 29.391 6.325 39.567 16.494l.025-.024c10.163 10.164 16.477 24.193 16.477
|
||||
39.599v189.728c0 15.401-6.326 29.425-16.485 39.584-10.159 10.159-24.183 16.484-39.584 16.484H144.317c-15.4
|
||||
|
||||
+148
-154
@@ -6,9 +6,7 @@ using System.Windows.Controls;
|
||||
using System.Windows.Controls.Primitives;
|
||||
using System.Windows.Input;
|
||||
using System.Windows.Media;
|
||||
using System.Windows.Media.Imaging;
|
||||
using Azaion.Annotator.DTO;
|
||||
using Azaion.Annotator.Extensions;
|
||||
using Azaion.Common.Database;
|
||||
using Azaion.Common.DTO;
|
||||
using Azaion.Common.DTO.Config;
|
||||
@@ -39,10 +37,9 @@ public partial class Annotator
|
||||
private readonly IConfigUpdater _configUpdater;
|
||||
private readonly HelpWindow _helpWindow;
|
||||
private readonly ILogger<Annotator> _logger;
|
||||
private readonly VLCFrameExtractor _vlcFrameExtractor;
|
||||
private readonly IAIDetector _aiDetector;
|
||||
private readonly AnnotationService _annotationService;
|
||||
private readonly IDbFactory _dbFactory;
|
||||
private readonly IInferenceService _inferenceService;
|
||||
private readonly CancellationTokenSource _ctSource = new();
|
||||
|
||||
private ObservableCollection<DetectionClass> AnnotationClasses { get; set; } = new();
|
||||
@@ -67,10 +64,9 @@ public partial class Annotator
|
||||
FormState formState,
|
||||
HelpWindow helpWindow,
|
||||
ILogger<Annotator> logger,
|
||||
VLCFrameExtractor vlcFrameExtractor,
|
||||
IAIDetector aiDetector,
|
||||
AnnotationService annotationService,
|
||||
IDbFactory dbFactory)
|
||||
IDbFactory dbFactory,
|
||||
IInferenceService inferenceService)
|
||||
{
|
||||
InitializeComponent();
|
||||
_appConfig = appConfig.Value;
|
||||
@@ -81,10 +77,9 @@ public partial class Annotator
|
||||
_formState = formState;
|
||||
_helpWindow = helpWindow;
|
||||
_logger = logger;
|
||||
_vlcFrameExtractor = vlcFrameExtractor;
|
||||
_aiDetector = aiDetector;
|
||||
_annotationService = annotationService;
|
||||
_dbFactory = dbFactory;
|
||||
_inferenceService = inferenceService;
|
||||
|
||||
Loaded += OnLoaded;
|
||||
Closed += OnFormClosed;
|
||||
@@ -304,11 +299,16 @@ public partial class Annotator
|
||||
|
||||
var annotations = await _dbFactory.Run(async db =>
|
||||
await db.Annotations.LoadWith(x => x.Detections)
|
||||
.Where(x => x.Name.Contains(_formState.VideoName))
|
||||
.Where(x => x.OriginalMediaName == _formState.VideoName)
|
||||
.ToListAsync(token: _ctSource.Token));
|
||||
|
||||
TimedAnnotations.Clear();
|
||||
_formState.AnnotationResults.Clear();
|
||||
foreach (var ann in annotations)
|
||||
AddAnnotation(ann);
|
||||
{
|
||||
TimedAnnotations.Add(ann.Time.Subtract(_thresholdBefore), ann.Time.Add(_thresholdAfter), ann);
|
||||
_formState.AnnotationResults.Add(new AnnotationResult(_appConfig.AnnotationConfig.DetectionClassesDict, ann));
|
||||
}
|
||||
}
|
||||
|
||||
//Add manually
|
||||
@@ -435,8 +435,6 @@ public partial class Annotator
|
||||
|
||||
_appConfig.DirectoriesConfig.VideosDirectory = dlg.FileName;
|
||||
TbFolder.Text = dlg.FileName;
|
||||
await ReloadFiles();
|
||||
await SaveUserSettings();
|
||||
}
|
||||
|
||||
private void TbFilter_OnTextChanged(object sender, TextChangedEventArgs e)
|
||||
@@ -487,11 +485,8 @@ public partial class Annotator
|
||||
if (LvFiles.SelectedIndex == -1)
|
||||
LvFiles.SelectedIndex = 0;
|
||||
|
||||
await _mediator.Publish(new AnnotatorControlEvent(PlaybackControlEnum.Play));
|
||||
_mediaPlayer.Stop();
|
||||
|
||||
var manualCancellationSource = new CancellationTokenSource();
|
||||
var token = manualCancellationSource.Token;
|
||||
var mct = new CancellationTokenSource();
|
||||
var token = mct.Token;
|
||||
|
||||
_autoDetectDialog = new AutodetectDialog
|
||||
{
|
||||
@@ -500,7 +495,7 @@ public partial class Annotator
|
||||
};
|
||||
_autoDetectDialog.Closing += (_, _) =>
|
||||
{
|
||||
manualCancellationSource.Cancel();
|
||||
mct.Cancel();
|
||||
_mediaPlayer.SeekTo(TimeSpan.Zero);
|
||||
Editor.RemoveAllAnns();
|
||||
};
|
||||
@@ -515,16 +510,17 @@ public partial class Annotator
|
||||
var mediaInfo = Dispatcher.Invoke(() => (MediaFileInfo)LvFiles.SelectedItem);
|
||||
while (mediaInfo != null)
|
||||
{
|
||||
_formState.CurrentMedia = mediaInfo;
|
||||
await Dispatcher.Invoke(async () => await ReloadAnnotations());
|
||||
|
||||
if (mediaInfo.MediaType == MediaTypes.Image)
|
||||
await Dispatcher.Invoke(async () =>
|
||||
{
|
||||
await DetectImage(mediaInfo, manualCancellationSource, token);
|
||||
await Task.Delay(70, token);
|
||||
}
|
||||
else
|
||||
await DetectVideo(mediaInfo, manualCancellationSource, token);
|
||||
await _mediator.Publish(new AnnotatorControlEvent(PlaybackControlEnum.Play), token);
|
||||
await ReloadAnnotations();
|
||||
});
|
||||
|
||||
await _inferenceService.RunInference(mediaInfo.Path, async (annotationImage, ct) =>
|
||||
{
|
||||
annotationImage.OriginalMediaName = mediaInfo.FName;
|
||||
await ProcessDetection(annotationImage, ct);
|
||||
}, token);
|
||||
|
||||
mediaInfo = Dispatcher.Invoke(() =>
|
||||
{
|
||||
@@ -533,6 +529,7 @@ public partial class Annotator
|
||||
LvFiles.SelectedIndex += 1;
|
||||
return (MediaFileInfo)LvFiles.SelectedItem;
|
||||
});
|
||||
LvFiles.Items.Refresh();
|
||||
}
|
||||
Dispatcher.Invoke(() =>
|
||||
{
|
||||
@@ -546,143 +543,140 @@ public partial class Annotator
|
||||
Dispatcher.Invoke(() => Editor.ResetBackground());
|
||||
}
|
||||
|
||||
private async Task DetectImage(MediaFileInfo mediaInfo, CancellationTokenSource manualCancellationSource, CancellationToken token)
|
||||
// private async Task DetectImage(MediaFileInfo mediaInfo, CancellationTokenSource manualCancellationSource, CancellationToken token)
|
||||
// {
|
||||
// try
|
||||
// {
|
||||
// var fName = Path.GetFileNameWithoutExtension(mediaInfo.Path);
|
||||
// var stream = new FileStream(mediaInfo.Path, FileMode.Open);
|
||||
// var detections = await _aiDetector.Detect(fName, stream, token);
|
||||
// await ProcessDetection((TimeSpan.FromMilliseconds(0), stream), Path.GetExtension(mediaInfo.Path), detections, token);
|
||||
// if (detections.Count != 0)
|
||||
// mediaInfo.HasAnnotations = true;
|
||||
// }
|
||||
// catch (Exception e)
|
||||
// {
|
||||
// _logger.LogError(e, e.Message);
|
||||
// await manualCancellationSource.CancelAsync();
|
||||
// }
|
||||
// }
|
||||
|
||||
// private async Task DetectVideo(MediaFileInfo mediaInfo, CancellationTokenSource manualCancellationSource, CancellationToken token)
|
||||
// {
|
||||
// var prevSeekTime = 0.0;
|
||||
// await foreach (var timeframe in _vlcFrameExtractor.ExtractFrames(mediaInfo.Path, token))
|
||||
// {
|
||||
// Console.WriteLine($"Detect time: {timeframe.Time}");
|
||||
// try
|
||||
// {
|
||||
// var fName = _formState.GetTimeName(timeframe.Time);
|
||||
// var detections = await _aiDetector.Detect(fName, timeframe.Stream, token);
|
||||
//
|
||||
// var isValid = IsValidDetection(timeframe.Time, detections);
|
||||
// Console.WriteLine($"Detection time: {timeframe.Time}");
|
||||
//
|
||||
// var log = string.Join(Environment.NewLine, detections.Select(det =>
|
||||
// $"{_appConfig.AnnotationConfig.DetectionClassesDict[det.ClassNumber].Name}: " +
|
||||
// $"xy=({det.CenterX:F2},{det.CenterY:F2}), " +
|
||||
// $"size=({det.Width:F2}, {det.Height:F2}), " +
|
||||
// $"prob: {det.Probability:F1}%"));
|
||||
//
|
||||
// log = $"Detection time: {timeframe.Time}, Valid: {isValid}. {Environment.NewLine} {log}";
|
||||
// Dispatcher.Invoke(() => _autoDetectDialog.Log(log));
|
||||
//
|
||||
// if (timeframe.Time.TotalMilliseconds > prevSeekTime + 250)
|
||||
// {
|
||||
// Dispatcher.Invoke(() => SeekTo(timeframe.Time));
|
||||
// prevSeekTime = timeframe.Time.TotalMilliseconds;
|
||||
// if (!isValid) //Show frame anyway
|
||||
// {
|
||||
// Dispatcher.Invoke(() =>
|
||||
// {
|
||||
// Editor.RemoveAllAnns();
|
||||
// Editor.Background = new ImageBrush
|
||||
// {
|
||||
// ImageSource = timeframe.Stream.OpenImage()
|
||||
// };
|
||||
// });
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// if (!isValid)
|
||||
// continue;
|
||||
//
|
||||
// mediaInfo.HasAnnotations = true;
|
||||
// await ProcessDetection(timeframe, ".jpg", detections, token);
|
||||
// await timeframe.Stream.DisposeAsync();
|
||||
// }
|
||||
// catch (Exception ex)
|
||||
// {
|
||||
// _logger.LogError(ex, ex.Message);
|
||||
// await manualCancellationSource.CancelAsync();
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// private bool IsValidDetection(TimeSpan time, List<Detection> detections)
|
||||
// {
|
||||
// // No AI detection, forbid
|
||||
// if (detections.Count == 0)
|
||||
// return false;
|
||||
//
|
||||
// // Very first detection, allow
|
||||
// if (!_previousDetection.HasValue)
|
||||
// return true;
|
||||
//
|
||||
// var prev = _previousDetection.Value;
|
||||
//
|
||||
// // Time between detections is >= than Frame Recognition Seconds, allow
|
||||
// if (time >= prev.Time.Add(TimeSpan.FromSeconds(_appConfig.AIRecognitionConfig.FrameRecognitionSeconds)))
|
||||
// return true;
|
||||
//
|
||||
// // Detection is earlier than previous + FrameRecognitionSeconds.
|
||||
// // Look to the detections more in detail
|
||||
//
|
||||
// // More detected objects, allow
|
||||
// if (detections.Count > prev.Detections.Count)
|
||||
// return true;
|
||||
//
|
||||
// foreach (var det in detections)
|
||||
// {
|
||||
// var point = new Point(det.CenterX, det.CenterY);
|
||||
// var closestObject = prev.Detections
|
||||
// .Select(p => new
|
||||
// {
|
||||
// Point = p,
|
||||
// Distance = point.SqrDistance(new Point(p.CenterX, p.CenterY))
|
||||
// })
|
||||
// .OrderBy(x => x.Distance)
|
||||
// .First();
|
||||
//
|
||||
// // Closest object is farther than Tracking distance confidence, hence it's a different object, allow
|
||||
// if (closestObject.Distance > _appConfig.AIRecognitionConfig.TrackingDistanceConfidence)
|
||||
// return true;
|
||||
//
|
||||
// // Since closest object within distance confidence, then it is tracking of the same object. Then if recognition probability for the object > increase from previous
|
||||
// if (det.Probability >= closestObject.Point.Probability + _appConfig.AIRecognitionConfig.TrackingProbabilityIncrease)
|
||||
// return true;
|
||||
// }
|
||||
//
|
||||
// return false;
|
||||
// }
|
||||
|
||||
private async Task ProcessDetection(AnnotationImage annotationImage, CancellationToken token = default)
|
||||
{
|
||||
try
|
||||
{
|
||||
var fName = Path.GetFileNameWithoutExtension(mediaInfo.Path);
|
||||
var stream = new FileStream(mediaInfo.Path, FileMode.Open);
|
||||
var detections = await _aiDetector.Detect(fName, stream, token);
|
||||
await ProcessDetection((TimeSpan.FromMilliseconds(0), stream), Path.GetExtension(mediaInfo.Path), detections, token);
|
||||
if (detections.Count != 0)
|
||||
mediaInfo.HasAnnotations = true;
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
_logger.LogError(e, e.Message);
|
||||
await manualCancellationSource.CancelAsync();
|
||||
}
|
||||
}
|
||||
|
||||
private async Task DetectVideo(MediaFileInfo mediaInfo, CancellationTokenSource manualCancellationSource, CancellationToken token)
|
||||
{
|
||||
var prevSeekTime = 0.0;
|
||||
await foreach (var timeframe in _vlcFrameExtractor.ExtractFrames(mediaInfo.Path, token))
|
||||
{
|
||||
Console.WriteLine($"Detect time: {timeframe.Time}");
|
||||
try
|
||||
{
|
||||
var fName = _formState.GetTimeName(timeframe.Time);
|
||||
var detections = await _aiDetector.Detect(fName, timeframe.Stream, token);
|
||||
|
||||
var isValid = IsValidDetection(timeframe.Time, detections);
|
||||
Console.WriteLine($"Detection time: {timeframe.Time}");
|
||||
|
||||
var log = string.Join(Environment.NewLine, detections.Select(det =>
|
||||
$"{_appConfig.AnnotationConfig.DetectionClassesDict[det.ClassNumber].Name}: " +
|
||||
$"xy=({det.CenterX:F2},{det.CenterY:F2}), " +
|
||||
$"size=({det.Width:F2}, {det.Height:F2}), " +
|
||||
$"prob: {det.Probability:F1}%"));
|
||||
|
||||
log = $"Detection time: {timeframe.Time}, Valid: {isValid}. {Environment.NewLine} {log}";
|
||||
Dispatcher.Invoke(() => _autoDetectDialog.Log(log));
|
||||
|
||||
if (timeframe.Time.TotalMilliseconds > prevSeekTime + 250)
|
||||
{
|
||||
Dispatcher.Invoke(() => SeekTo(timeframe.Time));
|
||||
prevSeekTime = timeframe.Time.TotalMilliseconds;
|
||||
if (!isValid) //Show frame anyway
|
||||
{
|
||||
Dispatcher.Invoke(() =>
|
||||
{
|
||||
Editor.RemoveAllAnns();
|
||||
Editor.Background = new ImageBrush
|
||||
{
|
||||
ImageSource = timeframe.Stream.OpenImage()
|
||||
};
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (!isValid)
|
||||
continue;
|
||||
|
||||
mediaInfo.HasAnnotations = true;
|
||||
await ProcessDetection(timeframe, ".jpg", detections, token);
|
||||
await timeframe.Stream.DisposeAsync();
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, ex.Message);
|
||||
await manualCancellationSource.CancelAsync();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private bool IsValidDetection(TimeSpan time, List<Detection> detections)
|
||||
{
|
||||
// No AI detection, forbid
|
||||
if (detections.Count == 0)
|
||||
return false;
|
||||
|
||||
// Very first detection, allow
|
||||
if (!_previousDetection.HasValue)
|
||||
return true;
|
||||
|
||||
var prev = _previousDetection.Value;
|
||||
|
||||
// Time between detections is >= than Frame Recognition Seconds, allow
|
||||
if (time >= prev.Time.Add(TimeSpan.FromSeconds(_appConfig.AIRecognitionConfig.FrameRecognitionSeconds)))
|
||||
return true;
|
||||
|
||||
// Detection is earlier than previous + FrameRecognitionSeconds.
|
||||
// Look to the detections more in detail
|
||||
|
||||
// More detected objects, allow
|
||||
if (detections.Count > prev.Detections.Count)
|
||||
return true;
|
||||
|
||||
foreach (var det in detections)
|
||||
{
|
||||
var point = new Point(det.CenterX, det.CenterY);
|
||||
var closestObject = prev.Detections
|
||||
.Select(p => new
|
||||
{
|
||||
Point = p,
|
||||
Distance = point.SqrDistance(new Point(p.CenterX, p.CenterY))
|
||||
})
|
||||
.OrderBy(x => x.Distance)
|
||||
.First();
|
||||
|
||||
// Closest object is farther than Tracking distance confidence, hence it's a different object, allow
|
||||
if (closestObject.Distance > _appConfig.AIRecognitionConfig.TrackingDistanceConfidence)
|
||||
return true;
|
||||
|
||||
// Since closest object within distance confidence, then it is tracking of the same object. Then if recognition probability for the object > increase from previous
|
||||
if (det.Probability >= closestObject.Point.Probability + _appConfig.AIRecognitionConfig.TrackingProbabilityIncrease)
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
private async Task ProcessDetection((TimeSpan Time, Stream Stream) timeframe, string imageExtension, List<Detection> detections, CancellationToken token = default)
|
||||
{
|
||||
_previousDetection = (timeframe.Time, detections);
|
||||
await Dispatcher.Invoke(async () =>
|
||||
{
|
||||
try
|
||||
{
|
||||
var fName = _formState.GetTimeName(timeframe.Time);
|
||||
|
||||
var annotation = await _annotationService.SaveAnnotation(fName, imageExtension, detections, SourceEnum.AI, timeframe.Stream, token);
|
||||
var annotation = await _annotationService.SaveAnnotation(annotationImage, token);
|
||||
|
||||
Editor.Background = new ImageBrush { ImageSource = await annotation.ImagePath.OpenImage() };
|
||||
Editor.RemoveAllAnns();
|
||||
ShowAnnotations(annotation, true);
|
||||
AddAnnotation(annotation);
|
||||
|
||||
var log = string.Join(Environment.NewLine, detections.Select(det =>
|
||||
var log = string.Join(Environment.NewLine, annotation.Detections.Select(det =>
|
||||
$"{_appConfig.AnnotationConfig.DetectionClassesDict[det.ClassNumber].Name}: " +
|
||||
$"xy=({det.CenterX:F2},{det.CenterY:F2}), " +
|
||||
$"size=({det.Width:F2}, {det.Height:F2}), " +
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
using System.Windows;
|
||||
using System.Windows.Input;
|
||||
using Azaion.Annotator.DTO;
|
||||
using Azaion.Common;
|
||||
using Azaion.Common.DTO;
|
||||
using Azaion.Common.DTO.Config;
|
||||
using Azaion.Common.DTO.Queue;
|
||||
using Azaion.Common.Events;
|
||||
using Azaion.Common.Extensions;
|
||||
using Azaion.Common.Services;
|
||||
using LibVLCSharp.Shared;
|
||||
using MediatR;
|
||||
@@ -79,7 +79,7 @@ public class AnnotatorEventHandler(
|
||||
if (_keysControlEnumDict.TryGetValue(key, out var value))
|
||||
await ControlPlayback(value, cancellationToken);
|
||||
|
||||
if (key == Key.A)
|
||||
if (key == Key.R)
|
||||
mainWindow.AutoDetect(null!, null!);
|
||||
|
||||
#region Volume
|
||||
@@ -228,7 +228,7 @@ public class AnnotatorEventHandler(
|
||||
return;
|
||||
|
||||
var time = formState.BackgroundTime ?? TimeSpan.FromMilliseconds(mediaPlayer.Time);
|
||||
var fName = formState.GetTimeName(time);
|
||||
var fName = formState.VideoName.ToTimeName(time);
|
||||
|
||||
var currentDetections = mainWindow.Editor.CurrentDetections
|
||||
.Select(x => new Detection(fName, x.GetLabel(mainWindow.Editor.RenderSize, formState.BackgroundTime.HasValue ? mainWindow.Editor.RenderSize : formState.CurrentVideoSize)))
|
||||
@@ -267,7 +267,7 @@ public class AnnotatorEventHandler(
|
||||
File.Copy(formState.CurrentMedia.Path, imgPath, overwrite: true);
|
||||
NextMedia();
|
||||
}
|
||||
var annotation = await annotationService.SaveAnnotation(fName, imageExtension, currentDetections, SourceEnum.Manual, token: cancellationToken);
|
||||
var annotation = await annotationService.SaveAnnotation(formState.VideoName, time, imageExtension, currentDetections, SourceEnum.Manual, token: cancellationToken);
|
||||
mainWindow.AddAnnotation(annotation);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
using System.Windows;
|
||||
|
||||
namespace Azaion.Annotator.Extensions;
|
||||
|
||||
public static class PointExtensions
|
||||
{
|
||||
public static double SqrDistance(this Point p1, Point p2) =>
|
||||
(p2.X - p1.X) * (p2.X - p1.X) + (p2.Y - p1.Y) * (p2.Y - p1.Y);
|
||||
}
|
||||
@@ -1,130 +0,0 @@
|
||||
using System.Collections.Concurrent;
|
||||
using System.IO;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Runtime.InteropServices;
|
||||
using Azaion.Common.DTO.Config;
|
||||
using LibVLCSharp.Shared;
|
||||
using Microsoft.Extensions.Options;
|
||||
using SkiaSharp;
|
||||
|
||||
namespace Azaion.Annotator.Extensions;
|
||||
|
||||
public class VLCFrameExtractor(LibVLC libVLC, IOptions<AIRecognitionConfig> config)
|
||||
{
|
||||
private const uint RGBA_BYTES = 4;
|
||||
private const int PLAYBACK_RATE = 4;
|
||||
|
||||
private uint _pitch; // Number of bytes per "line", aligned to x32.
|
||||
private uint _lines; // Number of lines in the buffer, aligned to x32.
|
||||
private uint _width; // Thumbnail width
|
||||
private uint _height; // Thumbnail height
|
||||
|
||||
private MediaPlayer _mediaPlayer = null!;
|
||||
|
||||
private TimeSpan _lastFrameTimestamp;
|
||||
private long _lastFrame;
|
||||
|
||||
private static uint Align32(uint size)
|
||||
{
|
||||
if (size % 32 == 0)
|
||||
return size;
|
||||
return (size / 32 + 1) * 32;// Align on the next multiple of 32
|
||||
}
|
||||
|
||||
private static SKBitmap? _currentBitmap;
|
||||
private static readonly ConcurrentQueue<FrameInfo> FramesQueue = new();
|
||||
private static long _frameCounter;
|
||||
|
||||
public async IAsyncEnumerable<(TimeSpan Time, Stream Stream)> ExtractFrames(string mediaPath,
|
||||
[EnumeratorCancellation] CancellationToken manualCancellationToken = default)
|
||||
{
|
||||
var videoFinishedCancellationSource = new CancellationTokenSource();
|
||||
|
||||
_mediaPlayer = new MediaPlayer(libVLC);
|
||||
_mediaPlayer.Stopped += (s, e) => videoFinishedCancellationSource.CancelAfter(1);
|
||||
|
||||
using var media = new Media(libVLC, mediaPath);
|
||||
await media.Parse(cancellationToken: videoFinishedCancellationSource.Token);
|
||||
var videoTrack = media.Tracks.FirstOrDefault(x => x.Data.Video.Width != 0);
|
||||
_width = videoTrack.Data.Video.Width;
|
||||
_height = videoTrack.Data.Video.Height;
|
||||
|
||||
_pitch = Align32(_width * RGBA_BYTES);
|
||||
_lines = Align32(_height);
|
||||
_mediaPlayer.SetRate(PLAYBACK_RATE);
|
||||
|
||||
media.AddOption(":no-audio");
|
||||
_mediaPlayer.SetVideoFormat("RV32", _width, _height, _pitch);
|
||||
_mediaPlayer.SetVideoCallbacks(Lock, null, Display);
|
||||
|
||||
_mediaPlayer.Play(media);
|
||||
_frameCounter = 0;
|
||||
var surface = SKSurface.Create(new SKImageInfo((int) _width, (int) _height));
|
||||
var videoFinishedCT = videoFinishedCancellationSource.Token;
|
||||
|
||||
while ( !(FramesQueue.IsEmpty && videoFinishedCT.IsCancellationRequested || manualCancellationToken.IsCancellationRequested))
|
||||
{
|
||||
if (FramesQueue.TryDequeue(out var frameInfo))
|
||||
{
|
||||
if (frameInfo.Bitmap == null)
|
||||
continue;
|
||||
|
||||
surface.Canvas.DrawBitmap(frameInfo.Bitmap, 0, 0); // Effectively crops the original bitmap to get only the visible area
|
||||
|
||||
using var outputImage = surface.Snapshot();
|
||||
using var data = outputImage.Encode(SKEncodedImageFormat.Jpeg, 85);
|
||||
var ms = new MemoryStream();
|
||||
data.SaveTo(ms);
|
||||
|
||||
yield return (frameInfo.Time, ms);
|
||||
|
||||
frameInfo.Bitmap?.Dispose();
|
||||
}
|
||||
else
|
||||
{
|
||||
await Task.Delay(TimeSpan.FromSeconds(1), videoFinishedCT);
|
||||
}
|
||||
}
|
||||
FramesQueue.Clear(); //clear queue in case of manual stop
|
||||
_mediaPlayer.Stop();
|
||||
_mediaPlayer.Dispose();
|
||||
}
|
||||
|
||||
private IntPtr Lock(IntPtr opaque, IntPtr planes)
|
||||
{
|
||||
_currentBitmap = new SKBitmap(new SKImageInfo((int)(_pitch / RGBA_BYTES), (int)_lines, SKColorType.Bgra8888));
|
||||
Marshal.WriteIntPtr(planes, _currentBitmap.GetPixels());
|
||||
return IntPtr.Zero;
|
||||
}
|
||||
|
||||
private void Display(IntPtr opaque, IntPtr picture)
|
||||
{
|
||||
var playerTime = TimeSpan.FromMilliseconds(_mediaPlayer.Time);
|
||||
if (_lastFrameTimestamp != playerTime)
|
||||
{
|
||||
_lastFrame = _frameCounter;
|
||||
_lastFrameTimestamp = playerTime;
|
||||
}
|
||||
|
||||
if (_frameCounter > 20 && _frameCounter % config.Value.FramePeriodRecognition == 0)
|
||||
{
|
||||
var msToAdd = (_frameCounter - _lastFrame) * (_lastFrame == 0 ? 0 : _lastFrameTimestamp.TotalMilliseconds / _lastFrame);
|
||||
var time = _lastFrameTimestamp.Add(TimeSpan.FromMilliseconds(msToAdd));
|
||||
|
||||
FramesQueue.Enqueue(new FrameInfo(time, _currentBitmap));
|
||||
}
|
||||
else
|
||||
{
|
||||
_currentBitmap?.Dispose();
|
||||
}
|
||||
|
||||
_currentBitmap = null;
|
||||
_frameCounter++;
|
||||
}
|
||||
}
|
||||
|
||||
public class FrameInfo(TimeSpan time, SKBitmap? bitmap)
|
||||
{
|
||||
public TimeSpan Time { get; set; } = time;
|
||||
public SKBitmap? Bitmap { get; set; } = bitmap;
|
||||
}
|
||||
@@ -1,87 +0,0 @@
|
||||
using System.Diagnostics;
|
||||
using System.IO;
|
||||
using Azaion.Annotator.Extensions;
|
||||
using Azaion.Common.DTO;
|
||||
using Azaion.Common.DTO.Config;
|
||||
using Azaion.CommonSecurity.Services;
|
||||
using Compunet.YoloV8;
|
||||
using Microsoft.Extensions.Options;
|
||||
using SixLabors.ImageSharp;
|
||||
using SixLabors.ImageSharp.PixelFormats;
|
||||
using Detection = Azaion.Common.DTO.Detection;
|
||||
|
||||
namespace Azaion.Annotator;
|
||||
|
||||
public interface IAIDetector
|
||||
{
|
||||
Task<List<Detection>> Detect(string fName, Stream imageStream, CancellationToken cancellationToken = default);
|
||||
}
|
||||
|
||||
public class YOLODetector(IOptions<AIRecognitionConfig> recognitionConfig, IResourceLoader resourceLoader) : IAIDetector, IDisposable
|
||||
{
|
||||
private readonly AIRecognitionConfig _recognitionConfig = recognitionConfig.Value;
|
||||
private YoloPredictor? _predictor;
|
||||
private const string YOLO_MODEL = "azaion.onnx";
|
||||
|
||||
|
||||
public async Task<List<Detection>> Detect(string fName, Stream imageStream, CancellationToken cancellationToken)
|
||||
{
|
||||
if (_predictor == null)
|
||||
{
|
||||
await using var stream = await resourceLoader.Load(YOLO_MODEL, cancellationToken);
|
||||
_predictor = new YoloPredictor(stream.ToArray());
|
||||
}
|
||||
|
||||
imageStream.Seek(0, SeekOrigin.Begin);
|
||||
|
||||
using var image = Image.Load<Rgb24>(imageStream);
|
||||
var result = await _predictor.DetectAsync(image);
|
||||
var imageSize = new System.Windows.Size(image.Width, image.Height);
|
||||
var detections = result.Select(d =>
|
||||
{
|
||||
var label = new YoloLabel(new CanvasLabel(d.Name.Id, d.Bounds.X, d.Bounds.Y, d.Bounds.Width, d.Bounds.Height), imageSize, imageSize);
|
||||
return new Detection(fName, label, (double?)d.Confidence * 100);
|
||||
}).ToList();
|
||||
|
||||
return FilterOverlapping(detections);
|
||||
}
|
||||
|
||||
private List<Detection> FilterOverlapping(List<Detection> detections)
|
||||
{
|
||||
var k = _recognitionConfig.TrackingIntersectionThreshold;
|
||||
var filteredDetections = new List<Detection>();
|
||||
for (var i = 0; i < detections.Count; i++)
|
||||
{
|
||||
var detectionSelected = false;
|
||||
for (var j = i + 1; j < detections.Count; j++)
|
||||
{
|
||||
var intersect = detections[i].ToRectangle();
|
||||
intersect.Intersect(detections[j].ToRectangle());
|
||||
|
||||
var maxArea = Math.Max(detections[i].ToRectangle().Area(), detections[j].ToRectangle().Area());
|
||||
if (!(intersect.Area() > k * maxArea))
|
||||
continue;
|
||||
|
||||
if (detections[i].Probability > detections[j].Probability)
|
||||
{
|
||||
filteredDetections.Add(detections[i]);
|
||||
detections.RemoveAt(j);
|
||||
}
|
||||
else
|
||||
{
|
||||
filteredDetections.Add(detections[j]);
|
||||
detections.RemoveAt(i);
|
||||
}
|
||||
detectionSelected = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (!detectionSelected)
|
||||
filteredDetections.Add(detections[i]);
|
||||
}
|
||||
|
||||
return filteredDetections;
|
||||
}
|
||||
|
||||
public void Dispose() => _predictor?.Dispose();
|
||||
}
|
||||
@@ -1,10 +1,15 @@
|
||||
using MessagePack;
|
||||
|
||||
namespace Azaion.Common.DTO.Config;
|
||||
|
||||
[MessagePackObject]
|
||||
public class AIRecognitionConfig
|
||||
{
|
||||
public double FrameRecognitionSeconds { get; set; }
|
||||
public double TrackingDistanceConfidence { get; set; }
|
||||
public double TrackingProbabilityIncrease { get; set; }
|
||||
public double TrackingIntersectionThreshold { get; set; }
|
||||
public int FramePeriodRecognition { get; set; }
|
||||
[Key("FrameRecognitionSeconds")] public double FrameRecognitionSeconds { get; set; }
|
||||
|
||||
[Key("TrackingDistanceConfidence")] public double TrackingDistanceConfidence { get; set; }
|
||||
[Key("TrackingProbabilityIncrease")] public double TrackingProbabilityIncrease { get; set; }
|
||||
[Key("TrackingIntersectionThreshold")] public double TrackingIntersectionThreshold { get; set; }
|
||||
[Key("FramePeriodRecognition")] public int FramePeriodRecognition { get; set; }
|
||||
[Key("Data")] public byte[] Data { get; set; }
|
||||
}
|
||||
@@ -8,8 +8,6 @@ namespace Azaion.Common.DTO.Config;
|
||||
|
||||
public class AppConfig
|
||||
{
|
||||
public ApiConfig ApiConfig { get; set; } = null!;
|
||||
|
||||
public QueueConfig QueueConfig { get; set; } = null!;
|
||||
|
||||
public DirectoriesConfig DirectoriesConfig { get; set; } = null!;
|
||||
@@ -39,13 +37,6 @@ public class ConfigUpdater : IConfigUpdater
|
||||
|
||||
var appConfig = new AppConfig
|
||||
{
|
||||
ApiConfig = new ApiConfig
|
||||
{
|
||||
Url = SecurityConstants.DEFAULT_API_URL,
|
||||
RetryCount = SecurityConstants.DEFAULT_API_RETRY_COUNT,
|
||||
TimeoutSeconds = SecurityConstants.DEFAULT_API_TIMEOUT_SECONDS
|
||||
},
|
||||
|
||||
AnnotationConfig = new AnnotationConfig
|
||||
{
|
||||
AnnotationClasses = Constants.DefaultAnnotationClasses,
|
||||
|
||||
@@ -16,6 +16,4 @@ public class FormState
|
||||
public int CurrentVolume { get; set; } = 100;
|
||||
public ObservableCollection<AnnotationResult> AnnotationResults { get; set; } = [];
|
||||
public WindowEnum ActiveWindow { get; set; }
|
||||
|
||||
public string GetTimeName(TimeSpan? ts) => $"{VideoName}_{ts:hmmssf}";
|
||||
}
|
||||
+13
-11
@@ -1,18 +1,18 @@
|
||||
using System.Drawing;
|
||||
using System.Globalization;
|
||||
using System.IO;
|
||||
using MessagePack;
|
||||
using Newtonsoft.Json;
|
||||
using Size = System.Windows.Size;
|
||||
|
||||
namespace Azaion.Common.DTO;
|
||||
|
||||
[MessagePackObject]
|
||||
public abstract class Label
|
||||
{
|
||||
[JsonProperty(PropertyName = "cl")] public int ClassNumber { get; set; }
|
||||
[JsonProperty(PropertyName = "cl")][Key("c")] public int ClassNumber { get; set; }
|
||||
|
||||
protected Label()
|
||||
{
|
||||
}
|
||||
protected Label() { }
|
||||
|
||||
protected Label(int classNumber)
|
||||
{
|
||||
@@ -79,15 +79,16 @@ public class CanvasLabel : Label
|
||||
}
|
||||
}
|
||||
|
||||
[MessagePackObject]
|
||||
public class YoloLabel : Label
|
||||
{
|
||||
[JsonProperty(PropertyName = "x")] public double CenterX { get; set; }
|
||||
[JsonProperty(PropertyName = "x")][Key("x")] public double CenterX { get; set; }
|
||||
|
||||
[JsonProperty(PropertyName = "y")] public double CenterY { get; set; }
|
||||
[JsonProperty(PropertyName = "y")][Key("y")] public double CenterY { get; set; }
|
||||
|
||||
[JsonProperty(PropertyName = "w")] public double Width { get; set; }
|
||||
[JsonProperty(PropertyName = "w")][Key("w")] public double Width { get; set; }
|
||||
|
||||
[JsonProperty(PropertyName = "h")] public double Height { get; set; }
|
||||
[JsonProperty(PropertyName = "h")][Key("h")] public double Height { get; set; }
|
||||
|
||||
public YoloLabel()
|
||||
{
|
||||
@@ -184,12 +185,13 @@ public class YoloLabel : Label
|
||||
public override string ToString() => $"{ClassNumber} {CenterX:F5} {CenterY:F5} {Width:F5} {Height:F5}".Replace(',', '.');
|
||||
}
|
||||
|
||||
[MessagePackObject]
|
||||
public class Detection : YoloLabel
|
||||
{
|
||||
public string AnnotationName { get; set; } = null!;
|
||||
public double? Probability { get; set; }
|
||||
[IgnoreMember]public string AnnotationName { get; set; } = null!;
|
||||
[Key("p")] public double? Probability { get; set; }
|
||||
|
||||
//For db
|
||||
//For db & serialization
|
||||
public Detection(){}
|
||||
|
||||
public Detection(string annotationName, YoloLabel label, double? probability = null)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
namespace Azaion.Common.DTO;
|
||||
using Azaion.Common.Extensions;
|
||||
|
||||
namespace Azaion.Common.DTO;
|
||||
|
||||
public class MediaFileInfo
|
||||
{
|
||||
@@ -9,5 +11,5 @@ public class MediaFileInfo
|
||||
public bool HasAnnotations { get; set; }
|
||||
public MediaTypes MediaType { get; set; }
|
||||
|
||||
public string FName => System.IO.Path.GetFileNameWithoutExtension(Name).Replace(" ", "");
|
||||
public string FName => Name.ToFName();
|
||||
}
|
||||
@@ -7,15 +7,17 @@ using MessagePack;
|
||||
[MessagePackObject]
|
||||
public class AnnotationCreatedMessage
|
||||
{
|
||||
[Key(0)] public DateTime CreatedDate { get; set; }
|
||||
[Key(1)] public string Name { get; set; } = null!;
|
||||
[Key(2)] public string ImageExtension { get; set; } = null!;
|
||||
[Key(3)] public string Detections { get; set; } = null!;
|
||||
[Key(4)] public byte[] Image { get; set; } = null!;
|
||||
[Key(5)] public RoleEnum CreatedRole { get; set; }
|
||||
[Key(6)] public string CreatedEmail { get; set; } = null!;
|
||||
[Key(7)] public SourceEnum Source { get; set; }
|
||||
[Key(8)] public AnnotationStatus Status { get; set; }
|
||||
[Key(0)] public DateTime CreatedDate { get; set; }
|
||||
[Key(1)] public string Name { get; set; } = null!;
|
||||
[Key(2)] public string OriginalMediaName { get; set; } = null!;
|
||||
[Key(3)] public TimeSpan Time { get; set; }
|
||||
[Key(4)] public string ImageExtension { get; set; } = null!;
|
||||
[Key(5)] public string Detections { get; set; } = null!;
|
||||
[Key(6)] public byte[] Image { get; set; } = null!;
|
||||
[Key(7)] public RoleEnum CreatedRole { get; set; }
|
||||
[Key(8)] public string CreatedEmail { get; set; } = null!;
|
||||
[Key(9)] public SourceEnum Source { get; set; }
|
||||
[Key(10)] public AnnotationStatus Status { get; set; }
|
||||
}
|
||||
|
||||
[MessagePackObject]
|
||||
|
||||
@@ -3,9 +3,11 @@ using Azaion.Common.DTO;
|
||||
using Azaion.Common.DTO.Config;
|
||||
using Azaion.Common.DTO.Queue;
|
||||
using Azaion.CommonSecurity.DTO;
|
||||
using MessagePack;
|
||||
|
||||
namespace Azaion.Common.Database;
|
||||
|
||||
[MessagePackObject]
|
||||
public class Annotation
|
||||
{
|
||||
private static string _labelsDir = null!;
|
||||
@@ -19,53 +21,36 @@ public class Annotation
|
||||
_thumbDir = config.ThumbnailsDirectory;
|
||||
}
|
||||
|
||||
public string Name { get; set; } = null!;
|
||||
public string ImageExtension { get; set; } = null!;
|
||||
public DateTime CreatedDate { get; set; }
|
||||
public string CreatedEmail { get; set; } = null!;
|
||||
public RoleEnum CreatedRole { get; set; }
|
||||
public SourceEnum Source { get; set; }
|
||||
public AnnotationStatus AnnotationStatus { get; set; }
|
||||
[IgnoreMember]public string Name { get; set; } = null!;
|
||||
[IgnoreMember]public string OriginalMediaName { get; set; } = null!;
|
||||
[IgnoreMember]public TimeSpan Time { get; set; }
|
||||
[IgnoreMember]public string ImageExtension { get; set; } = null!;
|
||||
[IgnoreMember]public DateTime CreatedDate { get; set; }
|
||||
[IgnoreMember]public string CreatedEmail { get; set; } = null!;
|
||||
[IgnoreMember]public RoleEnum CreatedRole { get; set; }
|
||||
[IgnoreMember]public SourceEnum Source { get; set; }
|
||||
[IgnoreMember]public AnnotationStatus AnnotationStatus { get; set; }
|
||||
|
||||
public IEnumerable<Detection> Detections { get; set; } = null!;
|
||||
[Key("d")] public IEnumerable<Detection> Detections { get; set; } = null!;
|
||||
[Key("t")] public long Milliseconds { get; set; }
|
||||
|
||||
public double Lat { get; set; }
|
||||
public double Lon { get; set; }
|
||||
[Key("lat")]public double Lat { get; set; }
|
||||
[Key("lon")]public double Lon { get; set; }
|
||||
|
||||
#region Calculated
|
||||
public List<int> Classes => Detections.Select(x => x.ClassNumber).ToList();
|
||||
public string ImagePath => Path.Combine(_imagesDir, $"{Name}{ImageExtension}");
|
||||
public string LabelPath => Path.Combine(_labelsDir, $"{Name}.txt");
|
||||
public string ThumbPath => Path.Combine(_thumbDir, $"{Name}{Constants.THUMBNAIL_PREFIX}.jpg");
|
||||
public string OriginalMediaName => $"{Name[..^7]}";
|
||||
[IgnoreMember]public List<int> Classes => Detections.Select(x => x.ClassNumber).ToList();
|
||||
[IgnoreMember]public string ImagePath => Path.Combine(_imagesDir, $"{Name}{ImageExtension}");
|
||||
[IgnoreMember]public string LabelPath => Path.Combine(_labelsDir, $"{Name}.txt");
|
||||
[IgnoreMember]public string ThumbPath => Path.Combine(_thumbDir, $"{Name}{Constants.THUMBNAIL_PREFIX}.jpg");
|
||||
|
||||
private TimeSpan? _time;
|
||||
public TimeSpan Time
|
||||
{
|
||||
get
|
||||
{
|
||||
if (_time.HasValue)
|
||||
return _time.Value;
|
||||
|
||||
var timeStr = Name.Split("_").LastOrDefault();
|
||||
|
||||
//For some reason, TimeSpan.ParseExact doesn't work on every platform.
|
||||
if (!string.IsNullOrEmpty(timeStr) &&
|
||||
timeStr.Length == 6 &&
|
||||
int.TryParse(timeStr[..1], out var hours) &&
|
||||
int.TryParse(timeStr[1..3], out var minutes) &&
|
||||
int.TryParse(timeStr[3..5], out var seconds) &&
|
||||
int.TryParse(timeStr[5..], out var milliseconds))
|
||||
return new TimeSpan(0, hours, minutes, seconds, milliseconds * 100);
|
||||
|
||||
_time = TimeSpan.FromSeconds(0);
|
||||
return _time.Value;
|
||||
}
|
||||
}
|
||||
#endregion Calculated
|
||||
}
|
||||
|
||||
|
||||
[MessagePackObject]
|
||||
public class AnnotationImage : Annotation
|
||||
{
|
||||
[Key("i")] public byte[] Image { get; set; }
|
||||
}
|
||||
|
||||
public enum AnnotationStatus
|
||||
{
|
||||
|
||||
@@ -117,16 +117,19 @@ public static class AnnotationsDbSchemaHolder
|
||||
MappingSchema = new MappingSchema();
|
||||
var builder = new FluentMappingBuilder(MappingSchema);
|
||||
|
||||
builder.Entity<Annotation>()
|
||||
.HasTableName(Constants.ANNOTATIONS_TABLENAME)
|
||||
var annotationBuilder = builder.Entity<Annotation>();
|
||||
annotationBuilder.HasTableName(Constants.ANNOTATIONS_TABLENAME)
|
||||
.HasPrimaryKey(x => x.Name)
|
||||
.Ignore(x => x.Time)
|
||||
.Association(a => a.Detections, (a, d) => a.Name == d.AnnotationName)
|
||||
.Property(x => x.Time).HasDataType(DataType.Int64).HasConversion(ts => ts.Ticks, t => new TimeSpan(t));
|
||||
|
||||
annotationBuilder
|
||||
.Ignore(x => x.Milliseconds)
|
||||
.Ignore(x => x.Classes)
|
||||
.Ignore(x => x.Classes)
|
||||
.Ignore(x => x.ImagePath)
|
||||
.Ignore(x => x.LabelPath)
|
||||
.Ignore(x => x.ThumbPath)
|
||||
.Ignore(x => x.OriginalMediaName)
|
||||
.Association(a => a.Detections, (a, d) => a.Name == d.AnnotationName);
|
||||
.Ignore(x => x.ThumbPath);
|
||||
|
||||
builder.Entity<Detection>()
|
||||
.HasTableName(Constants.DETECTIONS_TABLENAME);
|
||||
|
||||
@@ -24,7 +24,7 @@ public class ParallelExt
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
};
|
||||
var threadsCount = (int)(Environment.ProcessorCount * parallelOptions.CpuUtilPercent / 100.0);
|
||||
var threadsCount = (int)Math.Round(Environment.ProcessorCount * parallelOptions.CpuUtilPercent / 100.0);
|
||||
|
||||
var processedCount = 0;
|
||||
var chunkSize = Math.Max(1, (int)(source.Count / (decimal)threadsCount));
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
using System.IO;
|
||||
|
||||
namespace Azaion.Common.Extensions;
|
||||
|
||||
public static class StringExtensions
|
||||
{
|
||||
public static string ToFName(this string path) =>
|
||||
Path.GetFileNameWithoutExtension(path).Replace(" ", "");
|
||||
|
||||
public static string ToTimeName(this string fName, TimeSpan? ts) =>
|
||||
$"{fName}_{ts:hmmssf}";
|
||||
}
|
||||
@@ -22,29 +22,31 @@ namespace Azaion.Common.Services;
|
||||
|
||||
public class AnnotationService : INotificationHandler<AnnotationsDeletedEvent>
|
||||
{
|
||||
private readonly AzaionApiClient _apiClient;
|
||||
private readonly IDbFactory _dbFactory;
|
||||
private readonly FailsafeAnnotationsProducer _producer;
|
||||
private readonly IGalleryService _galleryService;
|
||||
private readonly IMediator _mediator;
|
||||
private readonly IHardwareService _hardwareService;
|
||||
private readonly IAuthProvider _authProvider;
|
||||
private readonly QueueConfig _queueConfig;
|
||||
private Consumer _consumer = null!;
|
||||
|
||||
public AnnotationService(AzaionApiClient apiClient,
|
||||
public AnnotationService(
|
||||
IResourceLoader resourceLoader,
|
||||
IDbFactory dbFactory,
|
||||
FailsafeAnnotationsProducer producer,
|
||||
IOptions<QueueConfig> queueConfig,
|
||||
IGalleryService galleryService,
|
||||
IMediator mediator,
|
||||
IHardwareService hardwareService)
|
||||
IHardwareService hardwareService,
|
||||
IAuthProvider authProvider)
|
||||
{
|
||||
_apiClient = apiClient;
|
||||
_dbFactory = dbFactory;
|
||||
_producer = producer;
|
||||
_galleryService = galleryService;
|
||||
_mediator = mediator;
|
||||
_hardwareService = hardwareService;
|
||||
_authProvider = authProvider;
|
||||
_queueConfig = queueConfig.Value;
|
||||
|
||||
Task.Run(async () => await Init()).Wait();
|
||||
@@ -73,7 +75,8 @@ public class AnnotationService : INotificationHandler<AnnotationsDeletedEvent>
|
||||
|
||||
await SaveAnnotationInner(
|
||||
msg.CreatedDate,
|
||||
msg.Name,
|
||||
msg.OriginalMediaName,
|
||||
msg.Time,
|
||||
msg.ImageExtension,
|
||||
JsonConvert.DeserializeObject<List<Detection>>(msg.Detections) ?? [],
|
||||
msg.Source,
|
||||
@@ -98,36 +101,39 @@ public class AnnotationService : INotificationHandler<AnnotationsDeletedEvent>
|
||||
});
|
||||
}
|
||||
|
||||
//AI / Manual
|
||||
public async Task<Annotation> SaveAnnotation(string fName, string imageExtension, List<Detection> detections, SourceEnum source, Stream? stream = null, CancellationToken token = default) =>
|
||||
await SaveAnnotationInner(DateTime.UtcNow, fName, imageExtension, detections, source, stream, _apiClient.User.Role, _apiClient.User.Email, generateThumbnail: true, token);
|
||||
//AI
|
||||
public async Task<Annotation> SaveAnnotation(AnnotationImage a, CancellationToken cancellationToken = default)
|
||||
{
|
||||
a.Time = TimeSpan.FromMilliseconds(a.Milliseconds);
|
||||
a.Name = a.OriginalMediaName.ToTimeName(a.Time);
|
||||
return await SaveAnnotationInner(DateTime.Now, a.OriginalMediaName, a.Time, ".jpg", a.Detections.ToList(),
|
||||
a.Source, new MemoryStream(a.Image), a.CreatedRole, a.CreatedEmail, generateThumbnail: true, cancellationToken);
|
||||
}
|
||||
|
||||
//Manual
|
||||
public async Task<Annotation> SaveAnnotation(string originalMediaName, TimeSpan time, string imageExtension, List<Detection> detections, SourceEnum source, Stream? stream = null, CancellationToken token = default) =>
|
||||
await SaveAnnotationInner(DateTime.UtcNow, originalMediaName, time, imageExtension, detections, source, stream,
|
||||
_authProvider.CurrentUser.Role, _authProvider.CurrentUser.Email, generateThumbnail: true, token);
|
||||
|
||||
//Manual Validate existing
|
||||
public async Task ValidateAnnotation(Annotation annotation, CancellationToken token = default) =>
|
||||
await SaveAnnotationInner(DateTime.UtcNow, annotation.Name, annotation.ImageExtension, annotation.Detections.ToList(), SourceEnum.Manual, null, _apiClient.User.Role, _apiClient.User.Email,
|
||||
generateThumbnail: false, token);
|
||||
await SaveAnnotationInner(DateTime.UtcNow, annotation.OriginalMediaName, annotation.Time, annotation.ImageExtension, annotation.Detections.ToList(), SourceEnum.Manual, null,
|
||||
_authProvider.CurrentUser.Role, _authProvider.CurrentUser.Email, generateThumbnail: false, token);
|
||||
|
||||
// //Queue (only from operators)
|
||||
// public async Task Consume(AnnotationCreatedMessage message, CancellationToken cancellationToken = default)
|
||||
// {
|
||||
//
|
||||
// }
|
||||
|
||||
private async Task<Annotation> SaveAnnotationInner(DateTime createdDate, string fName, string imageExtension, List<Detection> detections, SourceEnum source, Stream? stream,
|
||||
private async Task<Annotation> SaveAnnotationInner(DateTime createdDate, string originalMediaName, TimeSpan time, string imageExtension, List<Detection> detections, SourceEnum source, Stream? stream,
|
||||
RoleEnum userRole,
|
||||
string createdEmail,
|
||||
bool generateThumbnail = false,
|
||||
CancellationToken token = default)
|
||||
{
|
||||
//Flow for roles:
|
||||
// Operator or (AI from any role) -> Created
|
||||
// Validator, Admin & Manual -> Validated
|
||||
|
||||
AnnotationStatus status;
|
||||
|
||||
var fName = originalMediaName.ToTimeName(time);
|
||||
var annotation = await _dbFactory.Run(async db =>
|
||||
{
|
||||
var ann = await db.Annotations.FirstOrDefaultAsync(x => x.Name == fName, token: token);
|
||||
// Manual Save from Validators -> Validated
|
||||
// otherwise Created
|
||||
status = userRole.IsValidator() && source == SourceEnum.Manual
|
||||
? AnnotationStatus.Validated
|
||||
: AnnotationStatus.Created;
|
||||
@@ -149,6 +155,8 @@ public class AnnotationService : INotificationHandler<AnnotationsDeletedEvent>
|
||||
{
|
||||
CreatedDate = createdDate,
|
||||
Name = fName,
|
||||
OriginalMediaName = originalMediaName,
|
||||
Time = time,
|
||||
ImageExtension = imageExtension,
|
||||
CreatedEmail = createdEmail,
|
||||
CreatedRole = userRole,
|
||||
|
||||
@@ -76,7 +76,7 @@ public class FailsafeAnnotationsProducer
|
||||
await _annotationConfirmProducer.Send(validatedMessages, CompressionType.Gzip);
|
||||
|
||||
await _dbFactory.Run(async db =>
|
||||
await db.AnnotationsQueue.DeleteAsync(aq => messagesChunk.Any(x => aq.Name == x.Name), token: cancellationToken));
|
||||
await db.AnnotationsQueue.DeleteAsync(aq => messagesChunk.Any(x => aq.Name == x.OriginalMediaName), token: cancellationToken));
|
||||
sent = true;
|
||||
}
|
||||
catch (Exception e)
|
||||
@@ -106,7 +106,8 @@ public class FailsafeAnnotationsProducer
|
||||
var annCreateMessage = new AnnotationCreatedMessage
|
||||
{
|
||||
Name = annotation.Name,
|
||||
|
||||
OriginalMediaName = annotation.OriginalMediaName,
|
||||
Time = annotation.Time,
|
||||
CreatedRole = annotation.CreatedRole,
|
||||
CreatedEmail = annotation.CreatedEmail,
|
||||
CreatedDate = annotation.CreatedDate,
|
||||
|
||||
@@ -8,6 +8,7 @@ using Azaion.Common.Database;
|
||||
using Azaion.Common.DTO;
|
||||
using Azaion.Common.DTO.Config;
|
||||
using Azaion.Common.DTO.Queue;
|
||||
using Azaion.Common.Extensions;
|
||||
using Azaion.CommonSecurity.DTO;
|
||||
using LinqToDB;
|
||||
using LinqToDB.Data;
|
||||
@@ -75,7 +76,6 @@ public class GalleryService(
|
||||
var missedAnnotations = new ConcurrentBag<Annotation>();
|
||||
try
|
||||
{
|
||||
|
||||
var prefixLen = Constants.THUMBNAIL_PREFIX.Length;
|
||||
|
||||
var thumbnails = ThumbnailsDirectory.GetFiles()
|
||||
@@ -105,9 +105,37 @@ public class GalleryService(
|
||||
return;
|
||||
|
||||
var detections = (await YoloLabel.ReadFromFile(labelName, cancellationToken)).Select(x => new Detection(fName, x)).ToList();
|
||||
|
||||
//get names and time
|
||||
var fileName = Path.GetFileNameWithoutExtension(file.Name);
|
||||
var strings = fileName.Split("_");
|
||||
var timeStr = strings.LastOrDefault();
|
||||
|
||||
string originalMediaName;
|
||||
TimeSpan time;
|
||||
|
||||
//For some reason, TimeSpan.ParseExact doesn't work on every platform.
|
||||
if (!string.IsNullOrEmpty(timeStr) &&
|
||||
timeStr.Length == 6 &&
|
||||
int.TryParse(timeStr[..1], out var hours) &&
|
||||
int.TryParse(timeStr[1..3], out var minutes) &&
|
||||
int.TryParse(timeStr[3..5], out var seconds) &&
|
||||
int.TryParse(timeStr[5..], out var milliseconds))
|
||||
{
|
||||
time = new TimeSpan(0, hours, minutes, seconds, milliseconds * 100);
|
||||
originalMediaName = fileName[..^7];
|
||||
}
|
||||
else
|
||||
{
|
||||
originalMediaName = fileName;
|
||||
time = TimeSpan.FromSeconds(0);
|
||||
}
|
||||
|
||||
var annotation = new Annotation
|
||||
{
|
||||
Name = fName,
|
||||
Time = time,
|
||||
OriginalMediaName = originalMediaName,
|
||||
Name = file.Name.ToFName(),
|
||||
ImageExtension = Path.GetExtension(file.Name),
|
||||
Detections = detections,
|
||||
CreatedDate = File.GetCreationTimeUtc(file.FullName),
|
||||
@@ -129,18 +157,22 @@ public class GalleryService(
|
||||
logger.LogError(e, $"Failed to generate thumbnail for {file.Name}! Error: {e.Message}");
|
||||
}
|
||||
},
|
||||
new ParallelOptions
|
||||
new ParallelOptions
|
||||
{
|
||||
ProgressFn = async num =>
|
||||
{
|
||||
ProgressFn = async num =>
|
||||
{
|
||||
Console.WriteLine($"Processed {num} item by Thread {Environment.CurrentManagedThreadId}");
|
||||
ProcessedThumbnailsPercentage = imagesCount == 0 ? 0 : Math.Min(100, num * 100 / (double)imagesCount);
|
||||
ThumbnailsUpdate?.Invoke(ProcessedThumbnailsPercentage);
|
||||
await Task.CompletedTask;
|
||||
},
|
||||
CpuUtilPercent = 100,
|
||||
ProgressUpdateInterval = 200
|
||||
});
|
||||
Console.WriteLine($"Processed {num} item by Thread {Environment.CurrentManagedThreadId}");
|
||||
ProcessedThumbnailsPercentage = imagesCount == 0 ? 0 : Math.Min(100, num * 100 / (double)imagesCount);
|
||||
ThumbnailsUpdate?.Invoke(ProcessedThumbnailsPercentage);
|
||||
await Task.CompletedTask;
|
||||
},
|
||||
CpuUtilPercent = 100,
|
||||
ProgressUpdateInterval = 200
|
||||
});
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
logger.LogError(e, $"Failed to refresh thumbnails! Error: {e.Message}");
|
||||
}
|
||||
finally
|
||||
{
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
using System.Text;
|
||||
using Azaion.Common.Database;
|
||||
using Azaion.Common.DTO.Config;
|
||||
using Azaion.CommonSecurity;
|
||||
using Azaion.CommonSecurity.DTO.Commands;
|
||||
using Azaion.CommonSecurity.Services;
|
||||
using MessagePack;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Options;
|
||||
using NetMQ;
|
||||
using NetMQ.Sockets;
|
||||
|
||||
namespace Azaion.Common.Services;
|
||||
|
||||
public interface IInferenceService
|
||||
{
|
||||
Task RunInference(string mediaPath, Func<AnnotationImage, CancellationToken, Task> processAnnotation, CancellationToken ct = default);
|
||||
}
|
||||
|
||||
public class PythonInferenceService(ILogger<PythonInferenceService> logger, IOptions<AIRecognitionConfig> aiConfigOptions) : IInferenceService
|
||||
{
|
||||
public async Task RunInference(string mediaPath, Func<AnnotationImage, CancellationToken, Task> processAnnotation, CancellationToken ct = default)
|
||||
{
|
||||
using var dealer = new DealerSocket();
|
||||
var clientId = Guid.NewGuid();
|
||||
dealer.Options.Identity = Encoding.UTF8.GetBytes(clientId.ToString("N"));
|
||||
dealer.Connect($"tcp://{SecurityConstants.ZMQ_HOST}:{SecurityConstants.ZMQ_PORT}");
|
||||
|
||||
var data = MessagePackSerializer.Serialize(aiConfigOptions.Value);
|
||||
dealer.SendFrame(MessagePackSerializer.Serialize(new RemoteCommand(CommandType.Inference, mediaPath, data)));
|
||||
|
||||
while (true)
|
||||
{
|
||||
byte[] bytes = [];
|
||||
try
|
||||
{
|
||||
var annotationStream = dealer.Get<AnnotationImage>(out bytes);
|
||||
if (annotationStream == null)
|
||||
{
|
||||
if (bytes.Length == 4 && Encoding.UTF8.GetString(bytes) == "DONE")
|
||||
break;
|
||||
continue;
|
||||
}
|
||||
|
||||
await processAnnotation(annotationStream, ct);
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
logger.LogError(e, e.Message);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
namespace Azaion.CommonSecurity.DTO;
|
||||
|
||||
public class ApiConfig
|
||||
{
|
||||
public string Url { get; set; } = null!;
|
||||
public int RetryCount {get;set;}
|
||||
public double TimeoutSeconds { get; set; }
|
||||
public string ResourcesFolder { get; set; } = "";
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
using MessagePack;
|
||||
|
||||
namespace Azaion.CommonSecurity.DTO.Commands;
|
||||
|
||||
[MessagePackObject]
|
||||
public class FileCommand
|
||||
{
|
||||
[Key("CommandType")]
|
||||
public CommandType CommandType { get; set; }
|
||||
|
||||
[Key("Filename")]
|
||||
public string Filename { get; set; }
|
||||
|
||||
[Key("Data")]
|
||||
public byte[] Data { get; set; }
|
||||
|
||||
}
|
||||
|
||||
public enum CommandType
|
||||
{
|
||||
None = 0,
|
||||
Inference = 1,
|
||||
Load = 2
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
using MessagePack;
|
||||
|
||||
namespace Azaion.CommonSecurity.DTO.Commands;
|
||||
|
||||
[MessagePackObject]
|
||||
public class RemoteCommand(CommandType commandType, string? filename = null, byte[]? data = null)
|
||||
{
|
||||
[Key("CommandType")]
|
||||
public CommandType CommandType { get; set; } = commandType;
|
||||
|
||||
[Key("Filename")]
|
||||
public string? Filename { get; set; } = filename;
|
||||
|
||||
[Key("Data")]
|
||||
public byte[]? Data { get; set; } = data;
|
||||
}
|
||||
|
||||
public enum CommandType
|
||||
{
|
||||
None = 0,
|
||||
Inference = 1,
|
||||
Load = 2,
|
||||
GetUser = 3
|
||||
}
|
||||
@@ -1,6 +0,0 @@
|
||||
namespace Azaion.CommonSecurity.DTO;
|
||||
|
||||
public class SecureAppConfig
|
||||
{
|
||||
public ApiConfig ApiConfig { get; set; } = null!;
|
||||
}
|
||||
@@ -1,21 +1,11 @@
|
||||
using System.Security.Claims;
|
||||
using MessagePack;
|
||||
|
||||
namespace Azaion.CommonSecurity.DTO;
|
||||
|
||||
[MessagePackObject]
|
||||
public class User
|
||||
{
|
||||
public Guid Id { get; set; }
|
||||
public string Email { get; set; }
|
||||
public RoleEnum Role { get; set; }
|
||||
|
||||
public User(IEnumerable<Claim> claims)
|
||||
{
|
||||
var claimDict = claims.ToDictionary(x => x.Type, x => x.Value);
|
||||
|
||||
Id = Guid.Parse(claimDict[SecurityConstants.CLAIM_NAME_ID]);
|
||||
Email = claimDict[SecurityConstants.CLAIM_EMAIL];
|
||||
if (!Enum.TryParse(claimDict[SecurityConstants.CLAIM_ROLE], out RoleEnum role))
|
||||
role = RoleEnum.None;
|
||||
Role = role;
|
||||
}
|
||||
[Key("i")]public string Id { get; set; }
|
||||
[Key("e")]public string Email { get; set; }
|
||||
[Key("r")]public RoleEnum Role { get; set; }
|
||||
}
|
||||
@@ -19,9 +19,8 @@ public class SecurityConstants
|
||||
#endregion ApiConfig
|
||||
|
||||
#region SocketClient
|
||||
public const string SOCKET_HOST = "127.0.0.1";
|
||||
public const int SOCKET_SEND_PORT = 5127;
|
||||
public const int SOCKET_RECEIVE_PORT = 5128;
|
||||
public const string ZMQ_HOST = "127.0.0.1";
|
||||
public const int ZMQ_PORT = 5127;
|
||||
|
||||
#endregion SocketClient
|
||||
}
|
||||
@@ -1,127 +0,0 @@
|
||||
using System.IdentityModel.Tokens.Jwt;
|
||||
using System.Net;
|
||||
using System.Net.Http.Headers;
|
||||
using System.Security;
|
||||
using System.Text;
|
||||
using Azaion.CommonSecurity.DTO;
|
||||
using Newtonsoft.Json;
|
||||
|
||||
namespace Azaion.CommonSecurity.Services;
|
||||
|
||||
public class AzaionApiClient(HttpClient httpClient) : IDisposable
|
||||
{
|
||||
const string JSON_MEDIA = "application/json";
|
||||
|
||||
private static ApiConfig _apiConfig = null!;
|
||||
|
||||
private string Email { get; set; } = null!;
|
||||
private SecureString Password { get; set; } = new();
|
||||
private string JwtToken { get; set; } = null!;
|
||||
public User User { get; set; } = null!;
|
||||
|
||||
public static AzaionApiClient Create(ApiCredentials credentials)
|
||||
{
|
||||
try
|
||||
{
|
||||
if (!File.Exists(SecurityConstants.CONFIG_PATH))
|
||||
throw new FileNotFoundException(SecurityConstants.CONFIG_PATH);
|
||||
var configStr = File.ReadAllText(SecurityConstants.CONFIG_PATH);
|
||||
_apiConfig = JsonConvert.DeserializeObject<SecureAppConfig>(configStr)!.ApiConfig;
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
Console.WriteLine(e);
|
||||
_apiConfig = new ApiConfig
|
||||
{
|
||||
Url = SecurityConstants.DEFAULT_API_URL,
|
||||
RetryCount = SecurityConstants.DEFAULT_API_RETRY_COUNT ,
|
||||
TimeoutSeconds = SecurityConstants.DEFAULT_API_TIMEOUT_SECONDS
|
||||
};
|
||||
}
|
||||
|
||||
var api = new AzaionApiClient(new HttpClient
|
||||
{
|
||||
BaseAddress = new Uri(_apiConfig.Url),
|
||||
Timeout = TimeSpan.FromSeconds(_apiConfig.TimeoutSeconds)
|
||||
});
|
||||
|
||||
api.EnterCredentials(credentials);
|
||||
return api;
|
||||
}
|
||||
|
||||
public void EnterCredentials(ApiCredentials credentials)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(credentials.Email) || string.IsNullOrWhiteSpace(credentials.Password))
|
||||
throw new Exception("Email or password is empty!");
|
||||
|
||||
Email = credentials.Email;
|
||||
Password = credentials.Password.ToSecureString();
|
||||
}
|
||||
|
||||
public async Task<Stream> GetResource(string fileName, string password, HardwareInfo hardware)
|
||||
{
|
||||
var response = await Send(httpClient, new HttpRequestMessage(HttpMethod.Post, $"/resources/get/{_apiConfig.ResourcesFolder}")
|
||||
{
|
||||
Content = new StringContent(JsonConvert.SerializeObject(new { fileName, password, hardware }), Encoding.UTF8, JSON_MEDIA)
|
||||
});
|
||||
return await response.Content.ReadAsStreamAsync();
|
||||
}
|
||||
|
||||
private async Task Authorize()
|
||||
{
|
||||
if (string.IsNullOrEmpty(Email) || Password.Length == 0)
|
||||
throw new Exception("Email or password is empty! Please do EnterCredentials first!");
|
||||
|
||||
var payload = new
|
||||
{
|
||||
email = Email,
|
||||
password = Password.ToRealString()
|
||||
};
|
||||
var response = await httpClient.PostAsync(
|
||||
"login",
|
||||
new StringContent(JsonConvert.SerializeObject(payload), Encoding.UTF8, JSON_MEDIA));
|
||||
|
||||
if (!response.IsSuccessStatusCode)
|
||||
throw new Exception($"EnterCredentials failed: {response.StatusCode}");
|
||||
|
||||
var responseData = await response.Content.ReadAsStringAsync();
|
||||
|
||||
var result = JsonConvert.DeserializeObject<LoginResponse>(responseData);
|
||||
|
||||
if (string.IsNullOrEmpty(result?.Token))
|
||||
throw new Exception("JWT Token not found in response");
|
||||
|
||||
var handler = new JwtSecurityTokenHandler();
|
||||
var token = handler.ReadJwtToken(result.Token);
|
||||
|
||||
User = new User(token.Claims);
|
||||
JwtToken = result.Token;
|
||||
}
|
||||
|
||||
private async Task<HttpResponseMessage> Send(HttpClient client, HttpRequestMessage request)
|
||||
{
|
||||
if (string.IsNullOrEmpty(JwtToken))
|
||||
await Authorize();
|
||||
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", JwtToken);
|
||||
var response = await client.SendAsync(request);
|
||||
|
||||
if (response.StatusCode == HttpStatusCode.Unauthorized)
|
||||
{
|
||||
await Authorize();
|
||||
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", JwtToken);
|
||||
response = await client.SendAsync(request);
|
||||
}
|
||||
|
||||
if (response.IsSuccessStatusCode)
|
||||
return response;
|
||||
|
||||
var result = await response.Content.ReadAsStringAsync();
|
||||
throw new Exception($"Failed: {response.StatusCode}! Result: {result}");
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
httpClient.Dispose();
|
||||
Password.Dispose();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
using System.Text;
|
||||
using Azaion.CommonSecurity.DTO;
|
||||
using Azaion.CommonSecurity.DTO.Commands;
|
||||
using MessagePack;
|
||||
using NetMQ;
|
||||
using NetMQ.Sockets;
|
||||
|
||||
namespace Azaion.CommonSecurity.Services;
|
||||
|
||||
public interface IResourceLoader
|
||||
{
|
||||
Task<MemoryStream> LoadFile(string fileName, CancellationToken ct = default);
|
||||
}
|
||||
|
||||
public interface IAuthProvider
|
||||
{
|
||||
User CurrentUser { get; }
|
||||
}
|
||||
|
||||
|
||||
public class PythonResourceLoader : IResourceLoader, IAuthProvider
|
||||
{
|
||||
private readonly DealerSocket _dealer = new();
|
||||
private readonly Guid _clientId = Guid.NewGuid();
|
||||
|
||||
public User CurrentUser { get; }
|
||||
|
||||
public PythonResourceLoader(ApiCredentials credentials)
|
||||
{
|
||||
//Run python by credentials
|
||||
_dealer.Options.Identity = Encoding.UTF8.GetBytes(_clientId.ToString("N"));
|
||||
_dealer.Connect($"tcp://{SecurityConstants.ZMQ_HOST}:{SecurityConstants.ZMQ_PORT}");
|
||||
|
||||
_dealer.SendFrame(MessagePackSerializer.Serialize(new RemoteCommand(CommandType.GetUser)));
|
||||
var user = _dealer.Get<User>(out _);
|
||||
if (user == null)
|
||||
throw new Exception("Can't get user from Auth provider");
|
||||
|
||||
CurrentUser = user;
|
||||
}
|
||||
|
||||
|
||||
public async Task<MemoryStream> LoadFile(string fileName, CancellationToken ct = default)
|
||||
{
|
||||
try
|
||||
{
|
||||
_dealer.SendFrame(MessagePackSerializer.Serialize(new RemoteCommand(CommandType.Load, fileName)));
|
||||
|
||||
if (!_dealer.TryReceiveFrameBytes(TimeSpan.FromMilliseconds(1000), out var bytes))
|
||||
throw new Exception($"Unable to receive {fileName}");
|
||||
|
||||
return await Task.FromResult(new MemoryStream(bytes));
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
throw new Exception($"Failed to load fil0e '{fileName}': {ex.Message}", ex);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
using Azaion.CommonSecurity.DTO;
|
||||
using Azaion.CommonSecurity.DTO.Commands;
|
||||
using MessagePack;
|
||||
using NetMQ;
|
||||
using NetMQ.Sockets;
|
||||
|
||||
namespace Azaion.CommonSecurity.Services;
|
||||
|
||||
public interface IResourceLoader
|
||||
{
|
||||
Task<MemoryStream> Load(string fileName, CancellationToken cancellationToken = default);
|
||||
}
|
||||
|
||||
public class PythonResourceLoader : IResourceLoader
|
||||
{
|
||||
private readonly PushSocket _pushSocket = new();
|
||||
private readonly PullSocket _pullSocket = new();
|
||||
|
||||
public PythonResourceLoader(ApiCredentials credentials)
|
||||
{
|
||||
//Run python by credentials
|
||||
_pushSocket.Connect($"tcp://{SecurityConstants.SOCKET_HOST}:{SecurityConstants.SOCKET_SEND_PORT}");
|
||||
_pullSocket.Connect($"tcp://{SecurityConstants.SOCKET_HOST}:{SecurityConstants.SOCKET_RECEIVE_PORT}");
|
||||
}
|
||||
|
||||
public async Task<MemoryStream> Load(string fileName, CancellationToken cancellationToken = default)
|
||||
{
|
||||
try
|
||||
{
|
||||
var b = MessagePackSerializer.Serialize(new FileCommand
|
||||
{
|
||||
CommandType = CommandType.Load,
|
||||
Filename = fileName
|
||||
});
|
||||
_pushSocket.SendFrame(b);
|
||||
|
||||
var bytes = _pullSocket.ReceiveFrameBytes(out bool more);
|
||||
return new MemoryStream(bytes);
|
||||
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
throw new Exception($"Failed to load fil0e '{fileName}': {ex.Message}", ex);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public class ResourceLoader(AzaionApiClient api, ApiCredentials credentials) : IResourceLoader
|
||||
{
|
||||
public async Task<MemoryStream> Load(string fileName, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var hardwareService = new HardwareService();
|
||||
var hardwareInfo = hardwareService.GetHardware();
|
||||
|
||||
var encryptedStream = Task.Run(() => api.GetResource(fileName, credentials.Password, hardwareInfo), cancellationToken).Result;
|
||||
|
||||
var key = Security.MakeEncryptionKey(credentials.Email, credentials.Password, hardwareInfo.Hash);
|
||||
var stream = new MemoryStream();
|
||||
await encryptedStream.DecryptTo(stream, key, cancellationToken);
|
||||
stream.Seek(0, SeekOrigin.Begin);
|
||||
return stream;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
using MessagePack;
|
||||
using NetMQ;
|
||||
using NetMQ.Sockets;
|
||||
|
||||
namespace Azaion.CommonSecurity;
|
||||
|
||||
public static class ZeroMqExtensions
|
||||
{
|
||||
public static T? Get<T>(this DealerSocket dealer, out byte[] message)
|
||||
{
|
||||
if (!dealer.TryReceiveFrameBytes(TimeSpan.FromMinutes(2), out var bytes))
|
||||
throw new Exception($"Unable to get {typeof(T).Name}");
|
||||
message = bytes;
|
||||
return MessagePackSerializer.Deserialize<T>(bytes);
|
||||
}
|
||||
}
|
||||
@@ -58,13 +58,12 @@ public class DatasetExplorerEventHandler(
|
||||
if (datasetExplorer.ThumbnailLoading)
|
||||
return;
|
||||
|
||||
var fName = Path.GetFileNameWithoutExtension(datasetExplorer.CurrentAnnotation!.Annotation.ImagePath);
|
||||
var extension = Path.GetExtension(fName);
|
||||
var a = datasetExplorer.CurrentAnnotation!.Annotation;
|
||||
|
||||
var detections = datasetExplorer.ExplorerEditor.CurrentDetections
|
||||
.Select(x => new Detection(fName, x.GetLabel(datasetExplorer.ExplorerEditor.RenderSize)))
|
||||
.Select(x => new Detection(a.Name, x.GetLabel(datasetExplorer.ExplorerEditor.RenderSize)))
|
||||
.ToList();
|
||||
await annotationService.SaveAnnotation(fName, extension, detections, SourceEnum.Manual, token: cancellationToken);
|
||||
await annotationService.SaveAnnotation(a.OriginalMediaName, a.Time, a.ImageExtension, detections, SourceEnum.Manual, token: cancellationToken);
|
||||
datasetExplorer.SwitchTab(toEditor: false);
|
||||
break;
|
||||
case PlaybackControlEnum.RemoveSelectedAnns:
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
using System.IO;
|
||||
using System.Reflection;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using System.Windows;
|
||||
using System.Windows.Threading;
|
||||
using Azaion.Annotator;
|
||||
@@ -13,7 +11,6 @@ using Azaion.Common.Events;
|
||||
using Azaion.Common.Extensions;
|
||||
using Azaion.Common.Services;
|
||||
using Azaion.CommonSecurity;
|
||||
using Azaion.CommonSecurity.DTO;
|
||||
using Azaion.CommonSecurity.Services;
|
||||
using Azaion.Dataset;
|
||||
using LibVLCSharp.Shared;
|
||||
@@ -23,7 +20,6 @@ using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Hosting;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Options;
|
||||
using Newtonsoft.Json;
|
||||
using Serilog;
|
||||
using KeyEventArgs = System.Windows.Input.KeyEventArgs;
|
||||
|
||||
@@ -36,8 +32,7 @@ public partial class App
|
||||
private IMediator _mediator = null!;
|
||||
private FormState _formState = null!;
|
||||
|
||||
private AzaionApiClient _apiClient = null!;
|
||||
private IResourceLoader _resourceLoader = null!;
|
||||
private PythonResourceLoader _resourceLoader = null!;
|
||||
private Stream _securedConfig = null!;
|
||||
|
||||
private void OnDispatcherUnhandledException(object sender, DispatcherUnhandledExceptionEventArgs e)
|
||||
@@ -64,9 +59,8 @@ public partial class App
|
||||
var login = new Login();
|
||||
login.CredentialsEntered += async (s, args) =>
|
||||
{
|
||||
_apiClient = AzaionApiClient.Create(args);
|
||||
_resourceLoader = new PythonResourceLoader(args);
|
||||
_securedConfig = await _resourceLoader.Load("secured-config.json");
|
||||
_securedConfig = await _resourceLoader.LoadFile("secured-config.json");
|
||||
|
||||
AppDomain.CurrentDomain.AssemblyResolve += (_, a) =>
|
||||
{
|
||||
@@ -75,7 +69,7 @@ public partial class App
|
||||
{
|
||||
try
|
||||
{
|
||||
var stream = _resourceLoader.Load($"{assemblyName}.dll").GetAwaiter().GetResult();
|
||||
var stream = _resourceLoader.LoadFile($"{assemblyName}.dll").GetAwaiter().GetResult();
|
||||
return Assembly.Load(stream.ToArray());
|
||||
}
|
||||
catch (Exception e)
|
||||
@@ -124,11 +118,11 @@ public partial class App
|
||||
services.AddSingleton<MainSuite>();
|
||||
services.AddSingleton<IHardwareService, HardwareService>();
|
||||
|
||||
services.AddSingleton(_apiClient);
|
||||
services.AddSingleton(_resourceLoader);
|
||||
services.AddSingleton<IResourceLoader>(_resourceLoader);
|
||||
services.AddSingleton<IAuthProvider>(_resourceLoader);
|
||||
services.AddSingleton<IInferenceService, PythonInferenceService>();
|
||||
|
||||
services.Configure<AppConfig>(context.Configuration);
|
||||
services.ConfigureSection<ApiConfig>(context.Configuration);
|
||||
services.ConfigureSection<QueueConfig>(context.Configuration);
|
||||
services.ConfigureSection<DirectoriesConfig>(context.Configuration);
|
||||
services.ConfigureSection<AnnotationConfig>(context.Configuration);
|
||||
@@ -139,7 +133,6 @@ public partial class App
|
||||
services.AddSingleton<Annotator.Annotator>();
|
||||
services.AddSingleton<DatasetExplorer>();
|
||||
services.AddSingleton<HelpWindow>();
|
||||
services.AddSingleton<IAIDetector, YOLODetector>();
|
||||
services.AddMediatR(c => c.RegisterServicesFromAssemblies(
|
||||
typeof(Annotator.Annotator).Assembly,
|
||||
typeof(DatasetExplorer).Assembly,
|
||||
@@ -152,10 +145,9 @@ public partial class App
|
||||
return new MediaPlayer(libVLC);
|
||||
});
|
||||
services.AddSingleton<AnnotatorEventHandler>();
|
||||
services.AddSingleton<VLCFrameExtractor>();
|
||||
services.AddSingleton<IDbFactory, DbFactory>();
|
||||
|
||||
services.AddSingleton<FailsafeAnnotationsProducer>();
|
||||
services.AddSingleton<FailsafeAnnotationsProducer>();
|
||||
|
||||
services.AddSingleton<AnnotationService>();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user