fixed inference bugs

add DONE during inference, correct handling on C# side
This commit is contained in:
Alex Bezdieniezhnykh
2025-02-01 02:09:11 +02:00
parent e7afa96a0b
commit 739759628a
23 changed files with 324 additions and 95 deletions
+1 -1
View File
@@ -50,7 +50,7 @@ This is crucial for the build because build needs Python.h header and other file
pip install ultralytics pip install ultralytics
pip uninstall -y opencv-python pip uninstall -y opencv-python
pip install opencv-python cython msgpack cryptography rstream pika zmq pyjwt pip install opencv-python cython msgpack cryptography rstream pika zmq pyjwt pyinstaller tensorboard
``` ```
In case of fbgemm.dll error (Windows specific): In case of fbgemm.dll error (Windows specific):
+22
View File
@@ -0,0 +1,22 @@
pyinstaller --onefile ^
--collect-all jwt ^
--collect-all requests ^
--collect-all psutil ^
--collect-all cryptography ^
--collect-all msgpack ^
--collect-all expecttest ^
--collect-all torch ^
--collect-all ultralytics ^
--collect-all zmq ^
--hidden-import user ^
--hidden-import security ^
--hidden-import secure_model ^
--hidden-import api_client ^
--hidden-import hardware_service ^
--hidden-import constants ^
--hidden-import annotation ^
--hidden-import remote_command ^
--hidden-import ai_config ^
--hidden-import inference ^
--hidden-import remote_command_handler ^
start.py
-17
View File
@@ -1,17 +0,0 @@
import main
from main import ParsedArguments
def start_server():
args = ParsedArguments('admin@azaion.com', 'Az@1on1000Odm$n', 'stage', True)
processor = main.CommandProcessor(args)
try:
processor.start()
except Exception as e:
processor.stop()
def on_annotation(self, cmd, annotation):
print('on_annotation hit!')
if __name__ == "__main__":
start_server()
+2
View File
@@ -7,11 +7,13 @@ cdef class Inference:
cdef object on_annotation cdef object on_annotation
cdef Annotation _previous_annotation cdef Annotation _previous_annotation
cdef AIRecognitionConfig ai_config cdef AIRecognitionConfig ai_config
cdef bint stop_signal
cdef bint is_video(self, str filepath) cdef bint is_video(self, str filepath)
cdef run_inference(self, RemoteCommand cmd, int batch_size=?) cdef run_inference(self, RemoteCommand cmd, int batch_size=?)
cdef _process_video(self, RemoteCommand cmd, int batch_size) cdef _process_video(self, RemoteCommand cmd, int batch_size)
cdef _process_image(self, RemoteCommand cmd) cdef _process_image(self, RemoteCommand cmd)
cdef stop(self)
cdef frame_to_annotation(self, long time, frame, boxes: object) cdef frame_to_annotation(self, long time, frame, boxes: object)
cdef bint is_valid_annotation(self, Annotation annotation) cdef bint is_valid_annotation(self, Annotation annotation)
+12 -7
View File
@@ -1,5 +1,3 @@
import ai_config
import msgpack
from ultralytics import YOLO from ultralytics import YOLO
import mimetypes import mimetypes
import cv2 import cv2
@@ -13,6 +11,7 @@ cdef class Inference:
def __init__(self, model_bytes, on_annotation): def __init__(self, model_bytes, on_annotation):
loader = SecureModelLoader() loader = SecureModelLoader()
model_path = loader.load_model(model_bytes) model_path = loader.load_model(model_bytes)
self.stop_signal = False
self.model = YOLO(<str>model_path) self.model = YOLO(<str>model_path)
self.on_annotation = on_annotation self.on_annotation = on_annotation
@@ -22,19 +21,20 @@ cdef class Inference:
cdef run_inference(self, RemoteCommand cmd, int batch_size=8): cdef run_inference(self, RemoteCommand cmd, int batch_size=8):
print('run inference..') print('run inference..')
self.stop_signal = False
if self.is_video(cmd.filename): if self.is_video(cmd.filename):
return self._process_video(cmd, batch_size) self._process_video(cmd, batch_size)
else: else:
return self._process_image(cmd) self._process_image(cmd)
cdef _process_video(self, RemoteCommand cmd, int batch_size): cdef _process_video(self, RemoteCommand cmd, int batch_size):
frame_count = 0 frame_count = 0
batch_frame = [] batch_frame = []
self._previous_annotation = None
v_input = cv2.VideoCapture(<str>cmd.filename) v_input = cv2.VideoCapture(<str>cmd.filename)
self.ai_config = AIRecognitionConfig.from_msgpack(cmd.data) self.ai_config = AIRecognitionConfig.from_msgpack(cmd.data)
while v_input.isOpened(): while v_input.isOpened() and not self.stop_signal:
ret, frame = v_input.read() ret, frame = v_input.read()
ms = v_input.get(cv2.CAP_PROP_POS_MSEC) ms = v_input.get(cv2.CAP_PROP_POS_MSEC)
if not ret or frame is None: if not ret or frame is None:
@@ -51,7 +51,9 @@ cdef class Inference:
for frame, res in zip(batch_frame, results): for frame, res in zip(batch_frame, results):
annotation = self.frame_to_annotation(int(frame[1]), frame[0], res.boxes) annotation = self.frame_to_annotation(int(frame[1]), frame[0], res.boxes)
if self.is_valid_annotation(<Annotation>annotation): is_valid = self.is_valid_annotation(<Annotation>annotation)
print(f'Is valid annotation: {is_valid}')
if is_valid:
self._previous_annotation = annotation self._previous_annotation = annotation
self.on_annotation(cmd, annotation) self.on_annotation(cmd, annotation)
batch_frame.clear() batch_frame.clear()
@@ -64,6 +66,9 @@ cdef class Inference:
annotation = self.frame_to_annotation(0, frame, res[0].boxes) annotation = self.frame_to_annotation(0, frame, res[0].boxes)
self.on_annotation(cmd, annotation) self.on_annotation(cmd, annotation)
cdef stop(self):
self.stop_signal = True
cdef frame_to_annotation(self, long time, frame, boxes: Boxes): cdef frame_to_annotation(self, long time, frame, boxes: Boxes):
detections = [] detections = []
for box in boxes: for box in boxes:
+16 -34
View File
@@ -1,6 +1,7 @@
import traceback import traceback
from queue import Queue from queue import Queue
cimport constants cimport constants
import msgpack
from api_client cimport ApiClient from api_client cimport ApiClient
from annotation cimport Annotation from annotation cimport Annotation
@@ -8,23 +9,21 @@ from inference cimport Inference
from remote_command cimport RemoteCommand, CommandType from remote_command cimport RemoteCommand, CommandType
from remote_command_handler cimport RemoteCommandHandler from remote_command_handler cimport RemoteCommandHandler
from user cimport User from user cimport User
import argparse
cdef class ParsedArguments: cdef class ParsedArguments:
cdef str email, password, folder; cdef str email, password, folder;
cdef bint persist_token
def __init__(self, str email, str password, str folder, bint persist_token): def __init__(self, str email, str password, str folder):
self.email = email self.email = email
self.password = password self.password = password
self.folder = folder self.folder = folder
self.persist_token = persist_token
cdef class CommandProcessor: cdef class CommandProcessor:
cdef ApiClient api_client cdef ApiClient api_client
cdef RemoteCommandHandler remote_handler cdef RemoteCommandHandler remote_handler
cdef object command_queue cdef object command_queue
cdef bint running cdef bint running
cdef Inference inference
def __init__(self, args: ParsedArguments): def __init__(self, args: ParsedArguments):
self.api_client = ApiClient(args.email, args.password, args.folder) self.api_client = ApiClient(args.email, args.password, args.folder)
@@ -32,25 +31,31 @@ cdef class CommandProcessor:
self.command_queue = Queue(maxsize=constants.QUEUE_MAXSIZE) self.command_queue = Queue(maxsize=constants.QUEUE_MAXSIZE)
self.remote_handler.start() self.remote_handler.start()
self.running = True self.running = True
model = self.api_client.load_ai_model()
self.inference = Inference(model, self.on_annotation)
def start(self): def start(self):
while self.running: while self.running:
try: try:
command = self.command_queue.get() command = self.command_queue.get()
model = self.api_client.load_ai_model() self.inference.run_inference(command)
Inference(model, self.on_annotation).run_inference(command) self.remote_handler.send(command.client_id, <bytes>'DONE'.encode('utf-8'))
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
cdef on_command(self, RemoteCommand command): cdef on_command(self, RemoteCommand command):
try: try:
if command.command_type == CommandType.INFERENCE: if command.command_type == CommandType.GET_USER:
self.command_queue.put(command) self.get_user(command, self.api_client.get_user())
elif command.command_type == CommandType.LOAD: elif command.command_type == CommandType.LOAD:
response = self.api_client.load_bytes(command.filename) response = self.api_client.load_bytes(command.filename)
self.remote_handler.send(command.client_id, response) self.remote_handler.send(command.client_id, response)
elif command.command_type == CommandType.GET_USER: elif command.command_type == CommandType.INFERENCE:
self.get_user(command, self.api_client.get_user()) self.command_queue.put(command)
elif command.command_type == CommandType.STOP_INFERENCE:
self.inference.stop()
elif command.command_type == CommandType.EXIT:
self.stop()
else: else:
pass pass
except Exception as e: except Exception as e:
@@ -64,28 +69,5 @@ cdef class CommandProcessor:
self.remote_handler.send(cmd.client_id, data) self.remote_handler.send(cmd.client_id, data)
def stop(self): def stop(self):
self.remote_handler.stop()
self.running = False self.running = False
def parse_arguments():
parser = argparse.ArgumentParser(description="Command Processor")
parser.add_argument("-e", "--email", type=str, default="", help="Email")
parser.add_argument("-p", "--pw", type=str, default="", help="Password")
parser.add_argument("-f", "--folder", type=str, default="", help="Folder to API inner folder to download file from")
parser.add_argument("-t", "--persist_token", type=bool, default=True, help="True for persisting token from API")
cdef args = parser.parse_args()
cdef str email = args.email
cdef str password = args.pw
cdef str folder = args.folder
cdef bint persist_token = args.persist_token
return ParsedArguments(email, password, folder, persist_token)
if __name__ == '__main__':
args = parse_arguments()
processor = CommandProcessor(args)
try:
processor.start()
except KeyboardInterrupt:
processor.stop()
+5 -3
View File
@@ -1,7 +1,9 @@
cdef enum CommandType: cdef enum CommandType:
INFERENCE = 1 GET_USER = 10
LOAD = 2 LOAD = 20
GET_USER = 3 INFERENCE = 30
STOP_INFERENCE = 40
EXIT = 100
cdef class RemoteCommand: cdef class RemoteCommand:
cdef public bytes client_id cdef public bytes client_id
+5 -3
View File
@@ -8,9 +8,11 @@ cdef class RemoteCommand:
def __str__(self): def __str__(self):
command_type_names = { command_type_names = {
1: "INFERENCE", 10: "GET_USER",
2: "LOAD", 20: "LOAD",
3: "GET_USER" 30: "INFERENCE",
40: "STOP INFERENCE",
100: "EXIT"
} }
data_str = f'. Data: {len(self.data)} bytes' if self.data else '' data_str = f'. Data: {len(self.data)} bytes' if self.data else ''
return f'{command_type_names[self.command_type]}: {self.filename}{data_str}' return f'{command_type_names[self.command_type]}: {self.filename}{data_str}'
+1 -1
View File
@@ -12,4 +12,4 @@ cdef class RemoteCommandHandler:
cdef _proxy_loop(self) cdef _proxy_loop(self)
cdef _worker_loop(self) cdef _worker_loop(self)
cdef send(self, bytes client_id, bytes data) cdef send(self, bytes client_id, bytes data)
cdef close(self) cdef stop(self)
+4 -1
View File
@@ -1,8 +1,10 @@
import time
import zmq import zmq
from threading import Thread, Event from threading import Thread, Event
from remote_command cimport RemoteCommand from remote_command cimport RemoteCommand
cimport constants cimport constants
cdef class RemoteCommandHandler: cdef class RemoteCommandHandler:
def __init__(self, object on_command): def __init__(self, object on_command):
self._on_command = on_command self._on_command = on_command
@@ -59,8 +61,9 @@ cdef class RemoteCommandHandler:
socket.send_multipart([client_id, data]) socket.send_multipart([client_id, data])
print(f'{len(data)} bytes was sent to client {client_id}') print(f'{len(data)} bytes was sent to client {client_id}')
cdef close(self): cdef stop(self):
self._shutdown_event.set() self._shutdown_event.set()
time.sleep(0.5)
self._router.close() self._router.close()
self._dealer.close() self._dealer.close()
self._context.term() self._context.term()
+22
View File
@@ -0,0 +1,22 @@
import argparse
from main import ParsedArguments, CommandProcessor
def parse_arguments():
parser = argparse.ArgumentParser(description="Command Processor")
parser.add_argument("-e", "--email", type=str, default="", help="Email")
parser.add_argument("-p", "--pw", type=str, default="", help="Password")
parser.add_argument("-f", "--folder", type=str, default="", help="Folder to API inner folder to download file from")
args = parser.parse_args()
return ParsedArguments(args.email, args.pw, args.folder)
def start(args: ParsedArguments):
processor = CommandProcessor(args)
try:
processor.start()
except KeyboardInterrupt:
processor.stop()
if __name__ == '__main__':
start(parse_arguments())
+1 -1
View File
@@ -1 +1 @@
eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1laWQiOiJkOTBhMzZjYS1lMjM3LTRmYmQtOWM3Yy0xMjcwNDBhYzg1NTYiLCJ1bmlxdWVfbmFtZSI6ImFkbWluQGF6YWlvbi5jb20iLCJyb2xlIjoiQXBpQWRtaW4iLCJuYmYiOjE3MzgxNjM2MzYsImV4cCI6MTczODE3ODAzNiwiaWF0IjoxNzM4MTYzNjM2LCJpc3MiOiJBemFpb25BcGkiLCJhdWQiOiJBbm5vdGF0b3JzL09yYW5nZVBpL0FkbWlucyJ9.7VVws5mwGqx--sGopOuZE9iu3dzt1UdVPXeje2KZTYk eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1laWQiOiJkOTBhMzZjYS1lMjM3LTRmYmQtOWM3Yy0xMjcwNDBhYzg1NTYiLCJ1bmlxdWVfbmFtZSI6ImFkbWluQGF6YWlvbi5jb20iLCJyb2xlIjoiQXBpQWRtaW4iLCJuYmYiOjE3MzgzNjUwMjksImV4cCI6MTczODM3OTQyOSwiaWF0IjoxNzM4MzY1MDI5LCJpc3MiOiJBemFpb25BcGkiLCJhdWQiOiJBbm5vdGF0b3JzL09yYW5nZVBpL0FkbWlucyJ9.5teWb-gnhRngV337u_0OyUQ-o2-plN7shrvvKUsckPw
+4 -2
View File
@@ -200,7 +200,9 @@ public partial class Annotator
DgAnnotations.MouseDoubleClick += (sender, args) => DgAnnotations.MouseDoubleClick += (sender, args) =>
{ {
var dgRow = ItemsControl.ContainerFromElement((DataGrid)sender, (args.OriginalSource as DependencyObject)!) as DataGridRow; var dgRow = ItemsControl.ContainerFromElement((DataGrid)sender, (args.OriginalSource as DependencyObject)!) as DataGridRow;
if (dgRow != null)
OpenAnnotationResult((AnnotationResult)dgRow!.Item); OpenAnnotationResult((AnnotationResult)dgRow!.Item);
}; };
DgAnnotations.KeyUp += async (sender, args) => DgAnnotations.KeyUp += async (sender, args) =>
@@ -531,13 +533,13 @@ public partial class Annotator
LvFiles.SelectedIndex += 1; LvFiles.SelectedIndex += 1;
return (MediaFileInfo)LvFiles.SelectedItem; return (MediaFileInfo)LvFiles.SelectedItem;
}); });
LvFiles.Items.Refresh(); Dispatcher.Invoke(() => LvFiles.Items.Refresh());
} }
Dispatcher.Invoke(() => Dispatcher.Invoke(() =>
{ {
_mediaPlayer.Stop();
LvFiles.Items.Refresh(); LvFiles.Items.Refresh();
IsInferenceNow = false; IsInferenceNow = false;
FollowAI = false;
}); });
}, token); }, token);
} }
+2
View File
@@ -8,6 +8,8 @@ namespace Azaion.Common.DTO.Config;
public class AppConfig public class AppConfig
{ {
public ApiConfig ApiConfig { get; set; } = null!;
public QueueConfig QueueConfig { get; set; } = null!; public QueueConfig QueueConfig { get; set; } = null!;
public DirectoriesConfig DirectoriesConfig { get; set; } = null!; public DirectoriesConfig DirectoriesConfig { get; set; } = null!;
+1 -6
View File
@@ -31,16 +31,11 @@ public class PythonInferenceService(ILogger<PythonInferenceService> logger, IOpt
while (true) while (true)
{ {
byte[] bytes = [];
try try
{ {
var annotationStream = dealer.Get<AnnotationImage>(out bytes); var annotationStream = dealer.Get<AnnotationImage>(bytes => bytes.Length == 4 && Encoding.UTF8.GetString(bytes) == "DONE");
if (annotationStream == null) if (annotationStream == null)
{
if (bytes.Length == 4 && Encoding.UTF8.GetString(bytes) == "DONE")
break; break;
continue;
}
await processAnnotation(annotationStream); await processAnnotation(annotationStream);
} }
+9
View File
@@ -0,0 +1,9 @@
namespace Azaion.CommonSecurity.DTO;
public class ApiConfig
{
public string Url { get; set; } = null!;
public int RetryCount {get;set;}
public double TimeoutSeconds { get; set; }
public string ResourcesFolder { get; set; } = "";
}
@@ -18,7 +18,9 @@ public class RemoteCommand(CommandType commandType, string? filename = null, byt
public enum CommandType public enum CommandType
{ {
None = 0, None = 0,
Inference = 1, GetUser = 10,
Load = 2, Load = 20,
GetUser = 3 Inference = 30,
StopInference = 40,
Exit = 100
} }
@@ -0,0 +1,6 @@
namespace Azaion.CommonSecurity.DTO;
public class SecureAppConfig
{
public ApiConfig ApiConfig { get; set; } = null!;
}
+15
View File
@@ -1,3 +1,4 @@
using System.Security.Claims;
using MessagePack; using MessagePack;
namespace Azaion.CommonSecurity.DTO; namespace Azaion.CommonSecurity.DTO;
@@ -8,4 +9,18 @@ public class User
[Key("i")]public string Id { get; set; } [Key("i")]public string Id { get; set; }
[Key("e")]public string Email { get; set; } [Key("e")]public string Email { get; set; }
[Key("r")]public RoleEnum Role { get; set; } [Key("r")]public RoleEnum Role { get; set; }
//For deserializing
public User(){}
public User(IEnumerable<Claim> claims)
{
var claimDict = claims.ToDictionary(x => x.Type, x => x.Value);
Id = claimDict[SecurityConstants.CLAIM_NAME_ID];
Email = claimDict[SecurityConstants.CLAIM_EMAIL];
if (!Enum.TryParse(claimDict[SecurityConstants.CLAIM_ROLE], out RoleEnum role))
role = RoleEnum.None;
Role = role;
}
} }
@@ -0,0 +1,110 @@
using System.IdentityModel.Tokens.Jwt;
using System.Net;
using System.Net.Http.Headers;
using System.Security;
using System.Text;
using Azaion.CommonSecurity.DTO;
using Newtonsoft.Json;
namespace Azaion.CommonSecurity.Services;
public class AzaionApiClient(HttpClient httpClient) : IDisposable
{
const string JSON_MEDIA = "application/json";
private static ApiConfig _apiConfig = null!;
private string Email { get; set; } = null!;
private SecureString Password { get; set; } = new();
private string JwtToken { get; set; } = null!;
public User User { get; set; } = null!;
public static AzaionApiClient Create(ApiCredentials credentials, ApiConfig apiConfig)
{
_apiConfig = apiConfig;
var api = new AzaionApiClient(new HttpClient
{
BaseAddress = new Uri(_apiConfig.Url),
Timeout = TimeSpan.FromSeconds(_apiConfig.TimeoutSeconds)
});
api.EnterCredentials(credentials);
return api;
}
public void EnterCredentials(ApiCredentials credentials)
{
if (string.IsNullOrWhiteSpace(credentials.Email) || string.IsNullOrWhiteSpace(credentials.Password))
throw new Exception("Email or password is empty!");
Email = credentials.Email;
Password = credentials.Password.ToSecureString();
}
public async Task<Stream> GetResource(string fileName, string password, HardwareInfo hardware, CancellationToken cancellationToken = default)
{
var response = await Send(httpClient, new HttpRequestMessage(HttpMethod.Post, $"/resources/get/{_apiConfig.ResourcesFolder}")
{
Content = new StringContent(JsonConvert.SerializeObject(new { fileName, password, hardware }), Encoding.UTF8, JSON_MEDIA)
}, cancellationToken);
return await response.Content.ReadAsStreamAsync(cancellationToken);
}
private async Task Authorize()
{
if (string.IsNullOrEmpty(Email) || Password.Length == 0)
throw new Exception("Email or password is empty! Please do EnterCredentials first!");
var payload = new
{
email = Email,
password = Password.ToRealString()
};
var response = await httpClient.PostAsync(
"login",
new StringContent(JsonConvert.SerializeObject(payload), Encoding.UTF8, JSON_MEDIA));
if (!response.IsSuccessStatusCode)
throw new Exception($"EnterCredentials failed: {response.StatusCode}");
var responseData = await response.Content.ReadAsStringAsync();
var result = JsonConvert.DeserializeObject<LoginResponse>(responseData);
if (string.IsNullOrEmpty(result?.Token))
throw new Exception("JWT Token not found in response");
var handler = new JwtSecurityTokenHandler();
var token = handler.ReadJwtToken(result.Token);
User = new User(token.Claims);
JwtToken = result.Token;
}
private async Task<HttpResponseMessage> Send(HttpClient client, HttpRequestMessage request, CancellationToken cancellationToken = default)
{
if (string.IsNullOrEmpty(JwtToken))
await Authorize();
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", JwtToken);
var response = await client.SendAsync(request, cancellationToken);
if (response.StatusCode == HttpStatusCode.Unauthorized)
{
await Authorize();
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", JwtToken);
response = await client.SendAsync(request, cancellationToken);
}
if (response.IsSuccessStatusCode)
return response;
var result = await response.Content.ReadAsStringAsync(cancellationToken);
throw new Exception($"Failed: {response.StatusCode}! Result: {result}");
}
public void Dispose()
{
httpClient.Dispose();
Password.Dispose();
}
}
@@ -1,4 +1,5 @@
using System.Text; using System.Diagnostics;
using System.Text;
using Azaion.CommonSecurity.DTO; using Azaion.CommonSecurity.DTO;
using Azaion.CommonSecurity.DTO.Commands; using Azaion.CommonSecurity.DTO.Commands;
using MessagePack; using MessagePack;
@@ -9,7 +10,7 @@ namespace Azaion.CommonSecurity.Services;
public interface IResourceLoader public interface IResourceLoader
{ {
Task<MemoryStream> LoadFile(string fileName, CancellationToken ct = default); MemoryStream LoadFileFromPython(string fileName);
} }
public interface IAuthProvider public interface IAuthProvider
@@ -20,27 +21,64 @@ public interface IAuthProvider
public class PythonResourceLoader : IResourceLoader, IAuthProvider public class PythonResourceLoader : IResourceLoader, IAuthProvider
{ {
private readonly ApiCredentials _credentials;
private readonly AzaionApiClient _api;
private readonly DealerSocket _dealer = new(); private readonly DealerSocket _dealer = new();
private readonly Guid _clientId = Guid.NewGuid(); private readonly Guid _clientId = Guid.NewGuid();
public User CurrentUser { get; } public User CurrentUser { get; }
public PythonResourceLoader(ApiCredentials credentials) public PythonResourceLoader(ApiConfig apiConfig, ApiCredentials credentials, AzaionApiClient api)
{ {
//Run python by credentials _credentials = credentials;
_api = api;
//StartPython(apiConfig, credentials);
_dealer.Options.Identity = Encoding.UTF8.GetBytes(_clientId.ToString("N")); _dealer.Options.Identity = Encoding.UTF8.GetBytes(_clientId.ToString("N"));
_dealer.Connect($"tcp://{SecurityConstants.ZMQ_HOST}:{SecurityConstants.ZMQ_PORT}"); _dealer.Connect($"tcp://{SecurityConstants.ZMQ_HOST}:{SecurityConstants.ZMQ_PORT}");
_dealer.SendFrame(MessagePackSerializer.Serialize(new RemoteCommand(CommandType.GetUser))); _dealer.SendFrame(MessagePackSerializer.Serialize(new RemoteCommand(CommandType.GetUser)));
var user = _dealer.Get<User>(out _); var user = _dealer.Get<User>();
if (user == null) if (user == null)
throw new Exception("Can't get user from Auth provider"); throw new Exception("Can't get user from Auth provider");
CurrentUser = user; CurrentUser = user;
} }
private void StartPython( ApiConfig apiConfig, ApiCredentials credentials)
{
//var inferenceExe = LoadPythonFile().GetAwaiter().GetResult();
string outputProcess = "";
string errorProcess = "";
public async Task<MemoryStream> LoadFile(string fileName, CancellationToken ct = default) var path = "azaion-inference.exe";
var arguments = $"-e {credentials.Email} -p {credentials.Password} -f {apiConfig.ResourcesFolder}";
using var process = new Process();
process.StartInfo.FileName = path;
process.StartInfo.Arguments = arguments;
process.StartInfo.UseShellExecute = false;
process.StartInfo.RedirectStandardOutput = true;
process.StartInfo.RedirectStandardError = true;
//process.StartInfo.CreateNoWindow = true;
process.OutputDataReceived += (sender, e) => { if (e.Data != null) Console.WriteLine(e.Data); };
process.ErrorDataReceived += (sender, e) => { if (e.Data != null) Console.WriteLine(e.Data); };
process.Start();
}
public async Task LoadFileFromApi(CancellationToken cancellationToken = default)
{
var hardwareService = new HardwareService();
var hardwareInfo = hardwareService.GetHardware();
var encryptedStream = await _api.GetResource("azaion_inference.exe", _credentials.Password, hardwareInfo, cancellationToken);
var key = Security.MakeEncryptionKey(_credentials.Email, _credentials.Password, hardwareInfo.Hash);
var stream = new MemoryStream();
await encryptedStream.DecryptTo(stream, key, cancellationToken);
stream.Seek(0, SeekOrigin.Begin);
}
public MemoryStream LoadFileFromPython(string fileName)
{ {
try try
{ {
@@ -49,7 +87,7 @@ public class PythonResourceLoader : IResourceLoader, IAuthProvider
if (!_dealer.TryReceiveFrameBytes(TimeSpan.FromMilliseconds(1000), out var bytes)) if (!_dealer.TryReceiveFrameBytes(TimeSpan.FromMilliseconds(1000), out var bytes))
throw new Exception($"Unable to receive {fileName}"); throw new Exception($"Unable to receive {fileName}");
return await Task.FromResult(new MemoryStream(bytes)); return new MemoryStream(bytes);
} }
catch (Exception ex) catch (Exception ex)
{ {
+3 -2
View File
@@ -6,11 +6,12 @@ namespace Azaion.CommonSecurity;
public static class ZeroMqExtensions public static class ZeroMqExtensions
{ {
public static T? Get<T>(this DealerSocket dealer, out byte[] message) public static T? Get<T>(this DealerSocket dealer, Func<byte[], bool>? shouldInterceptFn = null) where T : class
{ {
if (!dealer.TryReceiveFrameBytes(TimeSpan.FromMinutes(2), out var bytes)) if (!dealer.TryReceiveFrameBytes(TimeSpan.FromMinutes(2), out var bytes))
throw new Exception($"Unable to get {typeof(T).Name}"); throw new Exception($"Unable to get {typeof(T).Name}");
message = bytes; if (shouldInterceptFn != null && shouldInterceptFn(bytes))
return null;
return MessagePackSerializer.Deserialize<T>(bytes); return MessagePackSerializer.Deserialize<T>(bytes);
} }
} }
+31 -5
View File
@@ -11,6 +11,7 @@ using Azaion.Common.Events;
using Azaion.Common.Extensions; using Azaion.Common.Extensions;
using Azaion.Common.Services; using Azaion.Common.Services;
using Azaion.CommonSecurity; using Azaion.CommonSecurity;
using Azaion.CommonSecurity.DTO;
using Azaion.CommonSecurity.Services; using Azaion.CommonSecurity.Services;
using Azaion.Dataset; using Azaion.Dataset;
using LibVLCSharp.Shared; using LibVLCSharp.Shared;
@@ -20,6 +21,7 @@ using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
using Newtonsoft.Json;
using Serilog; using Serilog;
using KeyEventArgs = System.Windows.Input.KeyEventArgs; using KeyEventArgs = System.Windows.Input.KeyEventArgs;
@@ -53,14 +55,38 @@ public partial class App
"Azaion.Dataset" "Azaion.Dataset"
]; ];
private ApiConfig ReadConfig()
{
try
{
if (!File.Exists(SecurityConstants.CONFIG_PATH))
throw new FileNotFoundException(SecurityConstants.CONFIG_PATH);
var configStr = File.ReadAllText(SecurityConstants.CONFIG_PATH);
return JsonConvert.DeserializeObject<SecureAppConfig>(configStr)!.ApiConfig;
}
catch (Exception e)
{
Console.WriteLine(e);
return new ApiConfig
{
Url = SecurityConstants.DEFAULT_API_URL,
RetryCount = SecurityConstants.DEFAULT_API_RETRY_COUNT ,
TimeoutSeconds = SecurityConstants.DEFAULT_API_TIMEOUT_SECONDS
};
}
}
private void StartLogin() private void StartLogin()
{ {
new ConfigUpdater().CheckConfig(); new ConfigUpdater().CheckConfig();
var login = new Login(); var login = new Login();
login.CredentialsEntered += async (s, args) => login.CredentialsEntered += (_, credentials) =>
{ {
_resourceLoader = new PythonResourceLoader(args); var apiConfig = ReadConfig();
_securedConfig = await _resourceLoader.LoadFile("secured-config.json"); var api = AzaionApiClient.Create(credentials, apiConfig);
_resourceLoader = new PythonResourceLoader(apiConfig, credentials, api);
_securedConfig = _resourceLoader.LoadFileFromPython("secured-config.json");
AppDomain.CurrentDomain.AssemblyResolve += (_, a) => AppDomain.CurrentDomain.AssemblyResolve += (_, a) =>
{ {
@@ -69,7 +95,7 @@ public partial class App
{ {
try try
{ {
var stream = _resourceLoader.LoadFile($"{assemblyName}.dll").GetAwaiter().GetResult(); var stream = _resourceLoader.LoadFileFromPython($"{assemblyName}.dll");
return Assembly.Load(stream.ToArray()); return Assembly.Load(stream.ToArray());
} }
catch (Exception e) catch (Exception e)
@@ -88,7 +114,7 @@ public partial class App
}; };
StartMain(); StartMain();
await _host.StartAsync(); _host.Start();
EventManager.RegisterClassHandler(typeof(UIElement), UIElement.KeyDownEvent, new RoutedEventHandler(GlobalClick)); EventManager.RegisterClassHandler(typeof(UIElement), UIElement.KeyDownEvent, new RoutedEventHandler(GlobalClick));
_host.Services.GetRequiredService<MainSuite>().Show(); _host.Services.GetRequiredService<MainSuite>().Show();
}; };