mirror of
https://github.com/azaion/annotations.git
synced 2026-04-22 22:26:31 +00:00
fixed inference bugs
add DONE during inference, correct handling on C# side
This commit is contained in:
+1
-1
@@ -50,7 +50,7 @@ This is crucial for the build because build needs Python.h header and other file
|
|||||||
pip install ultralytics
|
pip install ultralytics
|
||||||
|
|
||||||
pip uninstall -y opencv-python
|
pip uninstall -y opencv-python
|
||||||
pip install opencv-python cython msgpack cryptography rstream pika zmq pyjwt
|
pip install opencv-python cython msgpack cryptography rstream pika zmq pyjwt pyinstaller tensorboard
|
||||||
```
|
```
|
||||||
In case of fbgemm.dll error (Windows specific):
|
In case of fbgemm.dll error (Windows specific):
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,22 @@
|
|||||||
|
pyinstaller --onefile ^
|
||||||
|
--collect-all jwt ^
|
||||||
|
--collect-all requests ^
|
||||||
|
--collect-all psutil ^
|
||||||
|
--collect-all cryptography ^
|
||||||
|
--collect-all msgpack ^
|
||||||
|
--collect-all expecttest ^
|
||||||
|
--collect-all torch ^
|
||||||
|
--collect-all ultralytics ^
|
||||||
|
--collect-all zmq ^
|
||||||
|
--hidden-import user ^
|
||||||
|
--hidden-import security ^
|
||||||
|
--hidden-import secure_model ^
|
||||||
|
--hidden-import api_client ^
|
||||||
|
--hidden-import hardware_service ^
|
||||||
|
--hidden-import constants ^
|
||||||
|
--hidden-import annotation ^
|
||||||
|
--hidden-import remote_command ^
|
||||||
|
--hidden-import ai_config ^
|
||||||
|
--hidden-import inference ^
|
||||||
|
--hidden-import remote_command_handler ^
|
||||||
|
start.py
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
import main
|
|
||||||
from main import ParsedArguments
|
|
||||||
|
|
||||||
|
|
||||||
def start_server():
|
|
||||||
args = ParsedArguments('admin@azaion.com', 'Az@1on1000Odm$n', 'stage', True)
|
|
||||||
processor = main.CommandProcessor(args)
|
|
||||||
try:
|
|
||||||
processor.start()
|
|
||||||
except Exception as e:
|
|
||||||
processor.stop()
|
|
||||||
|
|
||||||
def on_annotation(self, cmd, annotation):
|
|
||||||
print('on_annotation hit!')
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
start_server()
|
|
||||||
@@ -7,11 +7,13 @@ cdef class Inference:
|
|||||||
cdef object on_annotation
|
cdef object on_annotation
|
||||||
cdef Annotation _previous_annotation
|
cdef Annotation _previous_annotation
|
||||||
cdef AIRecognitionConfig ai_config
|
cdef AIRecognitionConfig ai_config
|
||||||
|
cdef bint stop_signal
|
||||||
|
|
||||||
cdef bint is_video(self, str filepath)
|
cdef bint is_video(self, str filepath)
|
||||||
cdef run_inference(self, RemoteCommand cmd, int batch_size=?)
|
cdef run_inference(self, RemoteCommand cmd, int batch_size=?)
|
||||||
cdef _process_video(self, RemoteCommand cmd, int batch_size)
|
cdef _process_video(self, RemoteCommand cmd, int batch_size)
|
||||||
cdef _process_image(self, RemoteCommand cmd)
|
cdef _process_image(self, RemoteCommand cmd)
|
||||||
|
cdef stop(self)
|
||||||
|
|
||||||
cdef frame_to_annotation(self, long time, frame, boxes: object)
|
cdef frame_to_annotation(self, long time, frame, boxes: object)
|
||||||
cdef bint is_valid_annotation(self, Annotation annotation)
|
cdef bint is_valid_annotation(self, Annotation annotation)
|
||||||
|
|||||||
+12
-7
@@ -1,5 +1,3 @@
|
|||||||
import ai_config
|
|
||||||
import msgpack
|
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import cv2
|
import cv2
|
||||||
@@ -13,6 +11,7 @@ cdef class Inference:
|
|||||||
def __init__(self, model_bytes, on_annotation):
|
def __init__(self, model_bytes, on_annotation):
|
||||||
loader = SecureModelLoader()
|
loader = SecureModelLoader()
|
||||||
model_path = loader.load_model(model_bytes)
|
model_path = loader.load_model(model_bytes)
|
||||||
|
self.stop_signal = False
|
||||||
self.model = YOLO(<str>model_path)
|
self.model = YOLO(<str>model_path)
|
||||||
self.on_annotation = on_annotation
|
self.on_annotation = on_annotation
|
||||||
|
|
||||||
@@ -22,19 +21,20 @@ cdef class Inference:
|
|||||||
|
|
||||||
cdef run_inference(self, RemoteCommand cmd, int batch_size=8):
|
cdef run_inference(self, RemoteCommand cmd, int batch_size=8):
|
||||||
print('run inference..')
|
print('run inference..')
|
||||||
|
self.stop_signal = False
|
||||||
if self.is_video(cmd.filename):
|
if self.is_video(cmd.filename):
|
||||||
return self._process_video(cmd, batch_size)
|
self._process_video(cmd, batch_size)
|
||||||
else:
|
else:
|
||||||
return self._process_image(cmd)
|
self._process_image(cmd)
|
||||||
|
|
||||||
cdef _process_video(self, RemoteCommand cmd, int batch_size):
|
cdef _process_video(self, RemoteCommand cmd, int batch_size):
|
||||||
frame_count = 0
|
frame_count = 0
|
||||||
batch_frame = []
|
batch_frame = []
|
||||||
|
self._previous_annotation = None
|
||||||
v_input = cv2.VideoCapture(<str>cmd.filename)
|
v_input = cv2.VideoCapture(<str>cmd.filename)
|
||||||
self.ai_config = AIRecognitionConfig.from_msgpack(cmd.data)
|
self.ai_config = AIRecognitionConfig.from_msgpack(cmd.data)
|
||||||
|
|
||||||
while v_input.isOpened():
|
while v_input.isOpened() and not self.stop_signal:
|
||||||
ret, frame = v_input.read()
|
ret, frame = v_input.read()
|
||||||
ms = v_input.get(cv2.CAP_PROP_POS_MSEC)
|
ms = v_input.get(cv2.CAP_PROP_POS_MSEC)
|
||||||
if not ret or frame is None:
|
if not ret or frame is None:
|
||||||
@@ -51,7 +51,9 @@ cdef class Inference:
|
|||||||
for frame, res in zip(batch_frame, results):
|
for frame, res in zip(batch_frame, results):
|
||||||
annotation = self.frame_to_annotation(int(frame[1]), frame[0], res.boxes)
|
annotation = self.frame_to_annotation(int(frame[1]), frame[0], res.boxes)
|
||||||
|
|
||||||
if self.is_valid_annotation(<Annotation>annotation):
|
is_valid = self.is_valid_annotation(<Annotation>annotation)
|
||||||
|
print(f'Is valid annotation: {is_valid}')
|
||||||
|
if is_valid:
|
||||||
self._previous_annotation = annotation
|
self._previous_annotation = annotation
|
||||||
self.on_annotation(cmd, annotation)
|
self.on_annotation(cmd, annotation)
|
||||||
batch_frame.clear()
|
batch_frame.clear()
|
||||||
@@ -64,6 +66,9 @@ cdef class Inference:
|
|||||||
annotation = self.frame_to_annotation(0, frame, res[0].boxes)
|
annotation = self.frame_to_annotation(0, frame, res[0].boxes)
|
||||||
self.on_annotation(cmd, annotation)
|
self.on_annotation(cmd, annotation)
|
||||||
|
|
||||||
|
cdef stop(self):
|
||||||
|
self.stop_signal = True
|
||||||
|
|
||||||
cdef frame_to_annotation(self, long time, frame, boxes: Boxes):
|
cdef frame_to_annotation(self, long time, frame, boxes: Boxes):
|
||||||
detections = []
|
detections = []
|
||||||
for box in boxes:
|
for box in boxes:
|
||||||
|
|||||||
+16
-34
@@ -1,6 +1,7 @@
|
|||||||
import traceback
|
import traceback
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
cimport constants
|
cimport constants
|
||||||
|
import msgpack
|
||||||
|
|
||||||
from api_client cimport ApiClient
|
from api_client cimport ApiClient
|
||||||
from annotation cimport Annotation
|
from annotation cimport Annotation
|
||||||
@@ -8,23 +9,21 @@ from inference cimport Inference
|
|||||||
from remote_command cimport RemoteCommand, CommandType
|
from remote_command cimport RemoteCommand, CommandType
|
||||||
from remote_command_handler cimport RemoteCommandHandler
|
from remote_command_handler cimport RemoteCommandHandler
|
||||||
from user cimport User
|
from user cimport User
|
||||||
import argparse
|
|
||||||
|
|
||||||
cdef class ParsedArguments:
|
cdef class ParsedArguments:
|
||||||
cdef str email, password, folder;
|
cdef str email, password, folder;
|
||||||
cdef bint persist_token
|
|
||||||
|
|
||||||
def __init__(self, str email, str password, str folder, bint persist_token):
|
def __init__(self, str email, str password, str folder):
|
||||||
self.email = email
|
self.email = email
|
||||||
self.password = password
|
self.password = password
|
||||||
self.folder = folder
|
self.folder = folder
|
||||||
self.persist_token = persist_token
|
|
||||||
|
|
||||||
cdef class CommandProcessor:
|
cdef class CommandProcessor:
|
||||||
cdef ApiClient api_client
|
cdef ApiClient api_client
|
||||||
cdef RemoteCommandHandler remote_handler
|
cdef RemoteCommandHandler remote_handler
|
||||||
cdef object command_queue
|
cdef object command_queue
|
||||||
cdef bint running
|
cdef bint running
|
||||||
|
cdef Inference inference
|
||||||
|
|
||||||
def __init__(self, args: ParsedArguments):
|
def __init__(self, args: ParsedArguments):
|
||||||
self.api_client = ApiClient(args.email, args.password, args.folder)
|
self.api_client = ApiClient(args.email, args.password, args.folder)
|
||||||
@@ -32,25 +31,31 @@ cdef class CommandProcessor:
|
|||||||
self.command_queue = Queue(maxsize=constants.QUEUE_MAXSIZE)
|
self.command_queue = Queue(maxsize=constants.QUEUE_MAXSIZE)
|
||||||
self.remote_handler.start()
|
self.remote_handler.start()
|
||||||
self.running = True
|
self.running = True
|
||||||
|
model = self.api_client.load_ai_model()
|
||||||
|
self.inference = Inference(model, self.on_annotation)
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
while self.running:
|
while self.running:
|
||||||
try:
|
try:
|
||||||
command = self.command_queue.get()
|
command = self.command_queue.get()
|
||||||
model = self.api_client.load_ai_model()
|
self.inference.run_inference(command)
|
||||||
Inference(model, self.on_annotation).run_inference(command)
|
self.remote_handler.send(command.client_id, <bytes>'DONE'.encode('utf-8'))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
cdef on_command(self, RemoteCommand command):
|
cdef on_command(self, RemoteCommand command):
|
||||||
try:
|
try:
|
||||||
if command.command_type == CommandType.INFERENCE:
|
if command.command_type == CommandType.GET_USER:
|
||||||
self.command_queue.put(command)
|
self.get_user(command, self.api_client.get_user())
|
||||||
elif command.command_type == CommandType.LOAD:
|
elif command.command_type == CommandType.LOAD:
|
||||||
response = self.api_client.load_bytes(command.filename)
|
response = self.api_client.load_bytes(command.filename)
|
||||||
self.remote_handler.send(command.client_id, response)
|
self.remote_handler.send(command.client_id, response)
|
||||||
elif command.command_type == CommandType.GET_USER:
|
elif command.command_type == CommandType.INFERENCE:
|
||||||
self.get_user(command, self.api_client.get_user())
|
self.command_queue.put(command)
|
||||||
|
elif command.command_type == CommandType.STOP_INFERENCE:
|
||||||
|
self.inference.stop()
|
||||||
|
elif command.command_type == CommandType.EXIT:
|
||||||
|
self.stop()
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -64,28 +69,5 @@ cdef class CommandProcessor:
|
|||||||
self.remote_handler.send(cmd.client_id, data)
|
self.remote_handler.send(cmd.client_id, data)
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
|
self.remote_handler.stop()
|
||||||
self.running = False
|
self.running = False
|
||||||
|
|
||||||
def parse_arguments():
|
|
||||||
parser = argparse.ArgumentParser(description="Command Processor")
|
|
||||||
parser.add_argument("-e", "--email", type=str, default="", help="Email")
|
|
||||||
parser.add_argument("-p", "--pw", type=str, default="", help="Password")
|
|
||||||
parser.add_argument("-f", "--folder", type=str, default="", help="Folder to API inner folder to download file from")
|
|
||||||
parser.add_argument("-t", "--persist_token", type=bool, default=True, help="True for persisting token from API")
|
|
||||||
cdef args = parser.parse_args()
|
|
||||||
|
|
||||||
cdef str email = args.email
|
|
||||||
cdef str password = args.pw
|
|
||||||
cdef str folder = args.folder
|
|
||||||
cdef bint persist_token = args.persist_token
|
|
||||||
|
|
||||||
return ParsedArguments(email, password, folder, persist_token)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
args = parse_arguments()
|
|
||||||
processor = CommandProcessor(args)
|
|
||||||
try:
|
|
||||||
processor.start()
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
processor.stop()
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
cdef enum CommandType:
|
cdef enum CommandType:
|
||||||
INFERENCE = 1
|
GET_USER = 10
|
||||||
LOAD = 2
|
LOAD = 20
|
||||||
GET_USER = 3
|
INFERENCE = 30
|
||||||
|
STOP_INFERENCE = 40
|
||||||
|
EXIT = 100
|
||||||
|
|
||||||
cdef class RemoteCommand:
|
cdef class RemoteCommand:
|
||||||
cdef public bytes client_id
|
cdef public bytes client_id
|
||||||
|
|||||||
@@ -8,9 +8,11 @@ cdef class RemoteCommand:
|
|||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
command_type_names = {
|
command_type_names = {
|
||||||
1: "INFERENCE",
|
10: "GET_USER",
|
||||||
2: "LOAD",
|
20: "LOAD",
|
||||||
3: "GET_USER"
|
30: "INFERENCE",
|
||||||
|
40: "STOP INFERENCE",
|
||||||
|
100: "EXIT"
|
||||||
}
|
}
|
||||||
data_str = f'. Data: {len(self.data)} bytes' if self.data else ''
|
data_str = f'. Data: {len(self.data)} bytes' if self.data else ''
|
||||||
return f'{command_type_names[self.command_type]}: {self.filename}{data_str}'
|
return f'{command_type_names[self.command_type]}: {self.filename}{data_str}'
|
||||||
|
|||||||
@@ -12,4 +12,4 @@ cdef class RemoteCommandHandler:
|
|||||||
cdef _proxy_loop(self)
|
cdef _proxy_loop(self)
|
||||||
cdef _worker_loop(self)
|
cdef _worker_loop(self)
|
||||||
cdef send(self, bytes client_id, bytes data)
|
cdef send(self, bytes client_id, bytes data)
|
||||||
cdef close(self)
|
cdef stop(self)
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
|
import time
|
||||||
import zmq
|
import zmq
|
||||||
from threading import Thread, Event
|
from threading import Thread, Event
|
||||||
from remote_command cimport RemoteCommand
|
from remote_command cimport RemoteCommand
|
||||||
cimport constants
|
cimport constants
|
||||||
|
|
||||||
|
|
||||||
cdef class RemoteCommandHandler:
|
cdef class RemoteCommandHandler:
|
||||||
def __init__(self, object on_command):
|
def __init__(self, object on_command):
|
||||||
self._on_command = on_command
|
self._on_command = on_command
|
||||||
@@ -59,8 +61,9 @@ cdef class RemoteCommandHandler:
|
|||||||
socket.send_multipart([client_id, data])
|
socket.send_multipart([client_id, data])
|
||||||
print(f'{len(data)} bytes was sent to client {client_id}')
|
print(f'{len(data)} bytes was sent to client {client_id}')
|
||||||
|
|
||||||
cdef close(self):
|
cdef stop(self):
|
||||||
self._shutdown_event.set()
|
self._shutdown_event.set()
|
||||||
|
time.sleep(0.5)
|
||||||
self._router.close()
|
self._router.close()
|
||||||
self._dealer.close()
|
self._dealer.close()
|
||||||
self._context.term()
|
self._context.term()
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
import argparse
|
||||||
|
from main import ParsedArguments, CommandProcessor
|
||||||
|
|
||||||
|
|
||||||
|
def parse_arguments():
|
||||||
|
parser = argparse.ArgumentParser(description="Command Processor")
|
||||||
|
parser.add_argument("-e", "--email", type=str, default="", help="Email")
|
||||||
|
parser.add_argument("-p", "--pw", type=str, default="", help="Password")
|
||||||
|
parser.add_argument("-f", "--folder", type=str, default="", help="Folder to API inner folder to download file from")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
return ParsedArguments(args.email, args.pw, args.folder)
|
||||||
|
|
||||||
|
def start(args: ParsedArguments):
|
||||||
|
processor = CommandProcessor(args)
|
||||||
|
try:
|
||||||
|
processor.start()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
processor.stop()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
start(parse_arguments())
|
||||||
+1
-1
@@ -1 +1 @@
|
|||||||
eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1laWQiOiJkOTBhMzZjYS1lMjM3LTRmYmQtOWM3Yy0xMjcwNDBhYzg1NTYiLCJ1bmlxdWVfbmFtZSI6ImFkbWluQGF6YWlvbi5jb20iLCJyb2xlIjoiQXBpQWRtaW4iLCJuYmYiOjE3MzgxNjM2MzYsImV4cCI6MTczODE3ODAzNiwiaWF0IjoxNzM4MTYzNjM2LCJpc3MiOiJBemFpb25BcGkiLCJhdWQiOiJBbm5vdGF0b3JzL09yYW5nZVBpL0FkbWlucyJ9.7VVws5mwGqx--sGopOuZE9iu3dzt1UdVPXeje2KZTYk
|
eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1laWQiOiJkOTBhMzZjYS1lMjM3LTRmYmQtOWM3Yy0xMjcwNDBhYzg1NTYiLCJ1bmlxdWVfbmFtZSI6ImFkbWluQGF6YWlvbi5jb20iLCJyb2xlIjoiQXBpQWRtaW4iLCJuYmYiOjE3MzgzNjUwMjksImV4cCI6MTczODM3OTQyOSwiaWF0IjoxNzM4MzY1MDI5LCJpc3MiOiJBemFpb25BcGkiLCJhdWQiOiJBbm5vdGF0b3JzL09yYW5nZVBpL0FkbWlucyJ9.5teWb-gnhRngV337u_0OyUQ-o2-plN7shrvvKUsckPw
|
||||||
@@ -200,7 +200,9 @@ public partial class Annotator
|
|||||||
DgAnnotations.MouseDoubleClick += (sender, args) =>
|
DgAnnotations.MouseDoubleClick += (sender, args) =>
|
||||||
{
|
{
|
||||||
var dgRow = ItemsControl.ContainerFromElement((DataGrid)sender, (args.OriginalSource as DependencyObject)!) as DataGridRow;
|
var dgRow = ItemsControl.ContainerFromElement((DataGrid)sender, (args.OriginalSource as DependencyObject)!) as DataGridRow;
|
||||||
|
if (dgRow != null)
|
||||||
OpenAnnotationResult((AnnotationResult)dgRow!.Item);
|
OpenAnnotationResult((AnnotationResult)dgRow!.Item);
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
DgAnnotations.KeyUp += async (sender, args) =>
|
DgAnnotations.KeyUp += async (sender, args) =>
|
||||||
@@ -531,13 +533,13 @@ public partial class Annotator
|
|||||||
LvFiles.SelectedIndex += 1;
|
LvFiles.SelectedIndex += 1;
|
||||||
return (MediaFileInfo)LvFiles.SelectedItem;
|
return (MediaFileInfo)LvFiles.SelectedItem;
|
||||||
});
|
});
|
||||||
LvFiles.Items.Refresh();
|
Dispatcher.Invoke(() => LvFiles.Items.Refresh());
|
||||||
}
|
}
|
||||||
Dispatcher.Invoke(() =>
|
Dispatcher.Invoke(() =>
|
||||||
{
|
{
|
||||||
_mediaPlayer.Stop();
|
|
||||||
LvFiles.Items.Refresh();
|
LvFiles.Items.Refresh();
|
||||||
IsInferenceNow = false;
|
IsInferenceNow = false;
|
||||||
|
FollowAI = false;
|
||||||
});
|
});
|
||||||
}, token);
|
}, token);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ namespace Azaion.Common.DTO.Config;
|
|||||||
|
|
||||||
public class AppConfig
|
public class AppConfig
|
||||||
{
|
{
|
||||||
|
public ApiConfig ApiConfig { get; set; } = null!;
|
||||||
|
|
||||||
public QueueConfig QueueConfig { get; set; } = null!;
|
public QueueConfig QueueConfig { get; set; } = null!;
|
||||||
|
|
||||||
public DirectoriesConfig DirectoriesConfig { get; set; } = null!;
|
public DirectoriesConfig DirectoriesConfig { get; set; } = null!;
|
||||||
|
|||||||
@@ -31,16 +31,11 @@ public class PythonInferenceService(ILogger<PythonInferenceService> logger, IOpt
|
|||||||
|
|
||||||
while (true)
|
while (true)
|
||||||
{
|
{
|
||||||
byte[] bytes = [];
|
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
var annotationStream = dealer.Get<AnnotationImage>(out bytes);
|
var annotationStream = dealer.Get<AnnotationImage>(bytes => bytes.Length == 4 && Encoding.UTF8.GetString(bytes) == "DONE");
|
||||||
if (annotationStream == null)
|
if (annotationStream == null)
|
||||||
{
|
|
||||||
if (bytes.Length == 4 && Encoding.UTF8.GetString(bytes) == "DONE")
|
|
||||||
break;
|
break;
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
await processAnnotation(annotationStream);
|
await processAnnotation(annotationStream);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,9 @@
|
|||||||
|
namespace Azaion.CommonSecurity.DTO;
|
||||||
|
|
||||||
|
public class ApiConfig
|
||||||
|
{
|
||||||
|
public string Url { get; set; } = null!;
|
||||||
|
public int RetryCount {get;set;}
|
||||||
|
public double TimeoutSeconds { get; set; }
|
||||||
|
public string ResourcesFolder { get; set; } = "";
|
||||||
|
}
|
||||||
@@ -18,7 +18,9 @@ public class RemoteCommand(CommandType commandType, string? filename = null, byt
|
|||||||
public enum CommandType
|
public enum CommandType
|
||||||
{
|
{
|
||||||
None = 0,
|
None = 0,
|
||||||
Inference = 1,
|
GetUser = 10,
|
||||||
Load = 2,
|
Load = 20,
|
||||||
GetUser = 3
|
Inference = 30,
|
||||||
|
StopInference = 40,
|
||||||
|
Exit = 100
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
namespace Azaion.CommonSecurity.DTO;
|
||||||
|
|
||||||
|
public class SecureAppConfig
|
||||||
|
{
|
||||||
|
public ApiConfig ApiConfig { get; set; } = null!;
|
||||||
|
}
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
using System.Security.Claims;
|
||||||
using MessagePack;
|
using MessagePack;
|
||||||
|
|
||||||
namespace Azaion.CommonSecurity.DTO;
|
namespace Azaion.CommonSecurity.DTO;
|
||||||
@@ -8,4 +9,18 @@ public class User
|
|||||||
[Key("i")]public string Id { get; set; }
|
[Key("i")]public string Id { get; set; }
|
||||||
[Key("e")]public string Email { get; set; }
|
[Key("e")]public string Email { get; set; }
|
||||||
[Key("r")]public RoleEnum Role { get; set; }
|
[Key("r")]public RoleEnum Role { get; set; }
|
||||||
|
|
||||||
|
//For deserializing
|
||||||
|
public User(){}
|
||||||
|
|
||||||
|
public User(IEnumerable<Claim> claims)
|
||||||
|
{
|
||||||
|
var claimDict = claims.ToDictionary(x => x.Type, x => x.Value);
|
||||||
|
|
||||||
|
Id = claimDict[SecurityConstants.CLAIM_NAME_ID];
|
||||||
|
Email = claimDict[SecurityConstants.CLAIM_EMAIL];
|
||||||
|
if (!Enum.TryParse(claimDict[SecurityConstants.CLAIM_ROLE], out RoleEnum role))
|
||||||
|
role = RoleEnum.None;
|
||||||
|
Role = role;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,110 @@
|
|||||||
|
using System.IdentityModel.Tokens.Jwt;
|
||||||
|
using System.Net;
|
||||||
|
using System.Net.Http.Headers;
|
||||||
|
using System.Security;
|
||||||
|
using System.Text;
|
||||||
|
using Azaion.CommonSecurity.DTO;
|
||||||
|
using Newtonsoft.Json;
|
||||||
|
|
||||||
|
namespace Azaion.CommonSecurity.Services;
|
||||||
|
|
||||||
|
public class AzaionApiClient(HttpClient httpClient) : IDisposable
|
||||||
|
{
|
||||||
|
const string JSON_MEDIA = "application/json";
|
||||||
|
|
||||||
|
private static ApiConfig _apiConfig = null!;
|
||||||
|
|
||||||
|
private string Email { get; set; } = null!;
|
||||||
|
private SecureString Password { get; set; } = new();
|
||||||
|
private string JwtToken { get; set; } = null!;
|
||||||
|
public User User { get; set; } = null!;
|
||||||
|
|
||||||
|
public static AzaionApiClient Create(ApiCredentials credentials, ApiConfig apiConfig)
|
||||||
|
{
|
||||||
|
_apiConfig = apiConfig;
|
||||||
|
var api = new AzaionApiClient(new HttpClient
|
||||||
|
{
|
||||||
|
BaseAddress = new Uri(_apiConfig.Url),
|
||||||
|
Timeout = TimeSpan.FromSeconds(_apiConfig.TimeoutSeconds)
|
||||||
|
});
|
||||||
|
|
||||||
|
api.EnterCredentials(credentials);
|
||||||
|
return api;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void EnterCredentials(ApiCredentials credentials)
|
||||||
|
{
|
||||||
|
if (string.IsNullOrWhiteSpace(credentials.Email) || string.IsNullOrWhiteSpace(credentials.Password))
|
||||||
|
throw new Exception("Email or password is empty!");
|
||||||
|
|
||||||
|
Email = credentials.Email;
|
||||||
|
Password = credentials.Password.ToSecureString();
|
||||||
|
}
|
||||||
|
|
||||||
|
public async Task<Stream> GetResource(string fileName, string password, HardwareInfo hardware, CancellationToken cancellationToken = default)
|
||||||
|
{
|
||||||
|
var response = await Send(httpClient, new HttpRequestMessage(HttpMethod.Post, $"/resources/get/{_apiConfig.ResourcesFolder}")
|
||||||
|
{
|
||||||
|
Content = new StringContent(JsonConvert.SerializeObject(new { fileName, password, hardware }), Encoding.UTF8, JSON_MEDIA)
|
||||||
|
}, cancellationToken);
|
||||||
|
return await response.Content.ReadAsStreamAsync(cancellationToken);
|
||||||
|
}
|
||||||
|
|
||||||
|
private async Task Authorize()
|
||||||
|
{
|
||||||
|
if (string.IsNullOrEmpty(Email) || Password.Length == 0)
|
||||||
|
throw new Exception("Email or password is empty! Please do EnterCredentials first!");
|
||||||
|
|
||||||
|
var payload = new
|
||||||
|
{
|
||||||
|
email = Email,
|
||||||
|
password = Password.ToRealString()
|
||||||
|
};
|
||||||
|
var response = await httpClient.PostAsync(
|
||||||
|
"login",
|
||||||
|
new StringContent(JsonConvert.SerializeObject(payload), Encoding.UTF8, JSON_MEDIA));
|
||||||
|
|
||||||
|
if (!response.IsSuccessStatusCode)
|
||||||
|
throw new Exception($"EnterCredentials failed: {response.StatusCode}");
|
||||||
|
|
||||||
|
var responseData = await response.Content.ReadAsStringAsync();
|
||||||
|
|
||||||
|
var result = JsonConvert.DeserializeObject<LoginResponse>(responseData);
|
||||||
|
|
||||||
|
if (string.IsNullOrEmpty(result?.Token))
|
||||||
|
throw new Exception("JWT Token not found in response");
|
||||||
|
|
||||||
|
var handler = new JwtSecurityTokenHandler();
|
||||||
|
var token = handler.ReadJwtToken(result.Token);
|
||||||
|
|
||||||
|
User = new User(token.Claims);
|
||||||
|
JwtToken = result.Token;
|
||||||
|
}
|
||||||
|
|
||||||
|
private async Task<HttpResponseMessage> Send(HttpClient client, HttpRequestMessage request, CancellationToken cancellationToken = default)
|
||||||
|
{
|
||||||
|
if (string.IsNullOrEmpty(JwtToken))
|
||||||
|
await Authorize();
|
||||||
|
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", JwtToken);
|
||||||
|
var response = await client.SendAsync(request, cancellationToken);
|
||||||
|
|
||||||
|
if (response.StatusCode == HttpStatusCode.Unauthorized)
|
||||||
|
{
|
||||||
|
await Authorize();
|
||||||
|
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", JwtToken);
|
||||||
|
response = await client.SendAsync(request, cancellationToken);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (response.IsSuccessStatusCode)
|
||||||
|
return response;
|
||||||
|
|
||||||
|
var result = await response.Content.ReadAsStringAsync(cancellationToken);
|
||||||
|
throw new Exception($"Failed: {response.StatusCode}! Result: {result}");
|
||||||
|
}
|
||||||
|
|
||||||
|
public void Dispose()
|
||||||
|
{
|
||||||
|
httpClient.Dispose();
|
||||||
|
Password.Dispose();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
using System.Text;
|
using System.Diagnostics;
|
||||||
|
using System.Text;
|
||||||
using Azaion.CommonSecurity.DTO;
|
using Azaion.CommonSecurity.DTO;
|
||||||
using Azaion.CommonSecurity.DTO.Commands;
|
using Azaion.CommonSecurity.DTO.Commands;
|
||||||
using MessagePack;
|
using MessagePack;
|
||||||
@@ -9,7 +10,7 @@ namespace Azaion.CommonSecurity.Services;
|
|||||||
|
|
||||||
public interface IResourceLoader
|
public interface IResourceLoader
|
||||||
{
|
{
|
||||||
Task<MemoryStream> LoadFile(string fileName, CancellationToken ct = default);
|
MemoryStream LoadFileFromPython(string fileName);
|
||||||
}
|
}
|
||||||
|
|
||||||
public interface IAuthProvider
|
public interface IAuthProvider
|
||||||
@@ -20,27 +21,64 @@ public interface IAuthProvider
|
|||||||
|
|
||||||
public class PythonResourceLoader : IResourceLoader, IAuthProvider
|
public class PythonResourceLoader : IResourceLoader, IAuthProvider
|
||||||
{
|
{
|
||||||
|
private readonly ApiCredentials _credentials;
|
||||||
|
private readonly AzaionApiClient _api;
|
||||||
private readonly DealerSocket _dealer = new();
|
private readonly DealerSocket _dealer = new();
|
||||||
private readonly Guid _clientId = Guid.NewGuid();
|
private readonly Guid _clientId = Guid.NewGuid();
|
||||||
|
|
||||||
public User CurrentUser { get; }
|
public User CurrentUser { get; }
|
||||||
|
|
||||||
public PythonResourceLoader(ApiCredentials credentials)
|
public PythonResourceLoader(ApiConfig apiConfig, ApiCredentials credentials, AzaionApiClient api)
|
||||||
{
|
{
|
||||||
//Run python by credentials
|
_credentials = credentials;
|
||||||
|
_api = api;
|
||||||
|
//StartPython(apiConfig, credentials);
|
||||||
_dealer.Options.Identity = Encoding.UTF8.GetBytes(_clientId.ToString("N"));
|
_dealer.Options.Identity = Encoding.UTF8.GetBytes(_clientId.ToString("N"));
|
||||||
_dealer.Connect($"tcp://{SecurityConstants.ZMQ_HOST}:{SecurityConstants.ZMQ_PORT}");
|
_dealer.Connect($"tcp://{SecurityConstants.ZMQ_HOST}:{SecurityConstants.ZMQ_PORT}");
|
||||||
|
|
||||||
_dealer.SendFrame(MessagePackSerializer.Serialize(new RemoteCommand(CommandType.GetUser)));
|
_dealer.SendFrame(MessagePackSerializer.Serialize(new RemoteCommand(CommandType.GetUser)));
|
||||||
var user = _dealer.Get<User>(out _);
|
var user = _dealer.Get<User>();
|
||||||
if (user == null)
|
if (user == null)
|
||||||
throw new Exception("Can't get user from Auth provider");
|
throw new Exception("Can't get user from Auth provider");
|
||||||
|
|
||||||
CurrentUser = user;
|
CurrentUser = user;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void StartPython( ApiConfig apiConfig, ApiCredentials credentials)
|
||||||
|
{
|
||||||
|
//var inferenceExe = LoadPythonFile().GetAwaiter().GetResult();
|
||||||
|
string outputProcess = "";
|
||||||
|
string errorProcess = "";
|
||||||
|
|
||||||
public async Task<MemoryStream> LoadFile(string fileName, CancellationToken ct = default)
|
var path = "azaion-inference.exe";
|
||||||
|
var arguments = $"-e {credentials.Email} -p {credentials.Password} -f {apiConfig.ResourcesFolder}";
|
||||||
|
using var process = new Process();
|
||||||
|
process.StartInfo.FileName = path;
|
||||||
|
process.StartInfo.Arguments = arguments;
|
||||||
|
process.StartInfo.UseShellExecute = false;
|
||||||
|
process.StartInfo.RedirectStandardOutput = true;
|
||||||
|
process.StartInfo.RedirectStandardError = true;
|
||||||
|
//process.StartInfo.CreateNoWindow = true;
|
||||||
|
process.OutputDataReceived += (sender, e) => { if (e.Data != null) Console.WriteLine(e.Data); };
|
||||||
|
process.ErrorDataReceived += (sender, e) => { if (e.Data != null) Console.WriteLine(e.Data); };
|
||||||
|
process.Start();
|
||||||
|
}
|
||||||
|
|
||||||
|
public async Task LoadFileFromApi(CancellationToken cancellationToken = default)
|
||||||
|
{
|
||||||
|
var hardwareService = new HardwareService();
|
||||||
|
var hardwareInfo = hardwareService.GetHardware();
|
||||||
|
|
||||||
|
var encryptedStream = await _api.GetResource("azaion_inference.exe", _credentials.Password, hardwareInfo, cancellationToken);
|
||||||
|
|
||||||
|
var key = Security.MakeEncryptionKey(_credentials.Email, _credentials.Password, hardwareInfo.Hash);
|
||||||
|
var stream = new MemoryStream();
|
||||||
|
await encryptedStream.DecryptTo(stream, key, cancellationToken);
|
||||||
|
stream.Seek(0, SeekOrigin.Begin);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public MemoryStream LoadFileFromPython(string fileName)
|
||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
@@ -49,7 +87,7 @@ public class PythonResourceLoader : IResourceLoader, IAuthProvider
|
|||||||
if (!_dealer.TryReceiveFrameBytes(TimeSpan.FromMilliseconds(1000), out var bytes))
|
if (!_dealer.TryReceiveFrameBytes(TimeSpan.FromMilliseconds(1000), out var bytes))
|
||||||
throw new Exception($"Unable to receive {fileName}");
|
throw new Exception($"Unable to receive {fileName}");
|
||||||
|
|
||||||
return await Task.FromResult(new MemoryStream(bytes));
|
return new MemoryStream(bytes);
|
||||||
}
|
}
|
||||||
catch (Exception ex)
|
catch (Exception ex)
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -6,11 +6,12 @@ namespace Azaion.CommonSecurity;
|
|||||||
|
|
||||||
public static class ZeroMqExtensions
|
public static class ZeroMqExtensions
|
||||||
{
|
{
|
||||||
public static T? Get<T>(this DealerSocket dealer, out byte[] message)
|
public static T? Get<T>(this DealerSocket dealer, Func<byte[], bool>? shouldInterceptFn = null) where T : class
|
||||||
{
|
{
|
||||||
if (!dealer.TryReceiveFrameBytes(TimeSpan.FromMinutes(2), out var bytes))
|
if (!dealer.TryReceiveFrameBytes(TimeSpan.FromMinutes(2), out var bytes))
|
||||||
throw new Exception($"Unable to get {typeof(T).Name}");
|
throw new Exception($"Unable to get {typeof(T).Name}");
|
||||||
message = bytes;
|
if (shouldInterceptFn != null && shouldInterceptFn(bytes))
|
||||||
|
return null;
|
||||||
return MessagePackSerializer.Deserialize<T>(bytes);
|
return MessagePackSerializer.Deserialize<T>(bytes);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -11,6 +11,7 @@ using Azaion.Common.Events;
|
|||||||
using Azaion.Common.Extensions;
|
using Azaion.Common.Extensions;
|
||||||
using Azaion.Common.Services;
|
using Azaion.Common.Services;
|
||||||
using Azaion.CommonSecurity;
|
using Azaion.CommonSecurity;
|
||||||
|
using Azaion.CommonSecurity.DTO;
|
||||||
using Azaion.CommonSecurity.Services;
|
using Azaion.CommonSecurity.Services;
|
||||||
using Azaion.Dataset;
|
using Azaion.Dataset;
|
||||||
using LibVLCSharp.Shared;
|
using LibVLCSharp.Shared;
|
||||||
@@ -20,6 +21,7 @@ using Microsoft.Extensions.DependencyInjection;
|
|||||||
using Microsoft.Extensions.Hosting;
|
using Microsoft.Extensions.Hosting;
|
||||||
using Microsoft.Extensions.Logging;
|
using Microsoft.Extensions.Logging;
|
||||||
using Microsoft.Extensions.Options;
|
using Microsoft.Extensions.Options;
|
||||||
|
using Newtonsoft.Json;
|
||||||
using Serilog;
|
using Serilog;
|
||||||
using KeyEventArgs = System.Windows.Input.KeyEventArgs;
|
using KeyEventArgs = System.Windows.Input.KeyEventArgs;
|
||||||
|
|
||||||
@@ -53,14 +55,38 @@ public partial class App
|
|||||||
"Azaion.Dataset"
|
"Azaion.Dataset"
|
||||||
];
|
];
|
||||||
|
|
||||||
|
private ApiConfig ReadConfig()
|
||||||
|
{
|
||||||
|
try
|
||||||
|
{
|
||||||
|
if (!File.Exists(SecurityConstants.CONFIG_PATH))
|
||||||
|
throw new FileNotFoundException(SecurityConstants.CONFIG_PATH);
|
||||||
|
var configStr = File.ReadAllText(SecurityConstants.CONFIG_PATH);
|
||||||
|
return JsonConvert.DeserializeObject<SecureAppConfig>(configStr)!.ApiConfig;
|
||||||
|
}
|
||||||
|
catch (Exception e)
|
||||||
|
{
|
||||||
|
Console.WriteLine(e);
|
||||||
|
return new ApiConfig
|
||||||
|
{
|
||||||
|
Url = SecurityConstants.DEFAULT_API_URL,
|
||||||
|
RetryCount = SecurityConstants.DEFAULT_API_RETRY_COUNT ,
|
||||||
|
TimeoutSeconds = SecurityConstants.DEFAULT_API_TIMEOUT_SECONDS
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private void StartLogin()
|
private void StartLogin()
|
||||||
{
|
{
|
||||||
new ConfigUpdater().CheckConfig();
|
new ConfigUpdater().CheckConfig();
|
||||||
var login = new Login();
|
var login = new Login();
|
||||||
login.CredentialsEntered += async (s, args) =>
|
login.CredentialsEntered += (_, credentials) =>
|
||||||
{
|
{
|
||||||
_resourceLoader = new PythonResourceLoader(args);
|
var apiConfig = ReadConfig();
|
||||||
_securedConfig = await _resourceLoader.LoadFile("secured-config.json");
|
var api = AzaionApiClient.Create(credentials, apiConfig);
|
||||||
|
|
||||||
|
_resourceLoader = new PythonResourceLoader(apiConfig, credentials, api);
|
||||||
|
_securedConfig = _resourceLoader.LoadFileFromPython("secured-config.json");
|
||||||
|
|
||||||
AppDomain.CurrentDomain.AssemblyResolve += (_, a) =>
|
AppDomain.CurrentDomain.AssemblyResolve += (_, a) =>
|
||||||
{
|
{
|
||||||
@@ -69,7 +95,7 @@ public partial class App
|
|||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
var stream = _resourceLoader.LoadFile($"{assemblyName}.dll").GetAwaiter().GetResult();
|
var stream = _resourceLoader.LoadFileFromPython($"{assemblyName}.dll");
|
||||||
return Assembly.Load(stream.ToArray());
|
return Assembly.Load(stream.ToArray());
|
||||||
}
|
}
|
||||||
catch (Exception e)
|
catch (Exception e)
|
||||||
@@ -88,7 +114,7 @@ public partial class App
|
|||||||
};
|
};
|
||||||
|
|
||||||
StartMain();
|
StartMain();
|
||||||
await _host.StartAsync();
|
_host.Start();
|
||||||
EventManager.RegisterClassHandler(typeof(UIElement), UIElement.KeyDownEvent, new RoutedEventHandler(GlobalClick));
|
EventManager.RegisterClassHandler(typeof(UIElement), UIElement.KeyDownEvent, new RoutedEventHandler(GlobalClick));
|
||||||
_host.Services.GetRequiredService<MainSuite>().Show();
|
_host.Services.GetRequiredService<MainSuite>().Show();
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user