fixed inference bugs

add DONE during inference, correct handling on C# side
This commit is contained in:
Alex Bezdieniezhnykh
2025-02-01 02:09:11 +02:00
parent e7afa96a0b
commit 739759628a
23 changed files with 324 additions and 95 deletions
+9
View File
@@ -0,0 +1,9 @@
namespace Azaion.CommonSecurity.DTO;
public class ApiConfig
{
public string Url { get; set; } = null!;
public int RetryCount {get;set;}
public double TimeoutSeconds { get; set; }
public string ResourcesFolder { get; set; } = "";
}
@@ -18,7 +18,9 @@ public class RemoteCommand(CommandType commandType, string? filename = null, byt
public enum CommandType
{
None = 0,
Inference = 1,
Load = 2,
GetUser = 3
GetUser = 10,
Load = 20,
Inference = 30,
StopInference = 40,
Exit = 100
}
@@ -0,0 +1,6 @@
namespace Azaion.CommonSecurity.DTO;
public class SecureAppConfig
{
public ApiConfig ApiConfig { get; set; } = null!;
}
+15
View File
@@ -1,3 +1,4 @@
using System.Security.Claims;
using MessagePack;
namespace Azaion.CommonSecurity.DTO;
@@ -8,4 +9,18 @@ public class User
[Key("i")]public string Id { get; set; }
[Key("e")]public string Email { get; set; }
[Key("r")]public RoleEnum Role { get; set; }
//For deserializing
public User(){}
public User(IEnumerable<Claim> claims)
{
var claimDict = claims.ToDictionary(x => x.Type, x => x.Value);
Id = claimDict[SecurityConstants.CLAIM_NAME_ID];
Email = claimDict[SecurityConstants.CLAIM_EMAIL];
if (!Enum.TryParse(claimDict[SecurityConstants.CLAIM_ROLE], out RoleEnum role))
role = RoleEnum.None;
Role = role;
}
}
@@ -0,0 +1,110 @@
using System.IdentityModel.Tokens.Jwt;
using System.Net;
using System.Net.Http.Headers;
using System.Security;
using System.Text;
using Azaion.CommonSecurity.DTO;
using Newtonsoft.Json;
namespace Azaion.CommonSecurity.Services;
public class AzaionApiClient(HttpClient httpClient) : IDisposable
{
const string JSON_MEDIA = "application/json";
private static ApiConfig _apiConfig = null!;
private string Email { get; set; } = null!;
private SecureString Password { get; set; } = new();
private string JwtToken { get; set; } = null!;
public User User { get; set; } = null!;
public static AzaionApiClient Create(ApiCredentials credentials, ApiConfig apiConfig)
{
_apiConfig = apiConfig;
var api = new AzaionApiClient(new HttpClient
{
BaseAddress = new Uri(_apiConfig.Url),
Timeout = TimeSpan.FromSeconds(_apiConfig.TimeoutSeconds)
});
api.EnterCredentials(credentials);
return api;
}
public void EnterCredentials(ApiCredentials credentials)
{
if (string.IsNullOrWhiteSpace(credentials.Email) || string.IsNullOrWhiteSpace(credentials.Password))
throw new Exception("Email or password is empty!");
Email = credentials.Email;
Password = credentials.Password.ToSecureString();
}
public async Task<Stream> GetResource(string fileName, string password, HardwareInfo hardware, CancellationToken cancellationToken = default)
{
var response = await Send(httpClient, new HttpRequestMessage(HttpMethod.Post, $"/resources/get/{_apiConfig.ResourcesFolder}")
{
Content = new StringContent(JsonConvert.SerializeObject(new { fileName, password, hardware }), Encoding.UTF8, JSON_MEDIA)
}, cancellationToken);
return await response.Content.ReadAsStreamAsync(cancellationToken);
}
private async Task Authorize()
{
if (string.IsNullOrEmpty(Email) || Password.Length == 0)
throw new Exception("Email or password is empty! Please do EnterCredentials first!");
var payload = new
{
email = Email,
password = Password.ToRealString()
};
var response = await httpClient.PostAsync(
"login",
new StringContent(JsonConvert.SerializeObject(payload), Encoding.UTF8, JSON_MEDIA));
if (!response.IsSuccessStatusCode)
throw new Exception($"EnterCredentials failed: {response.StatusCode}");
var responseData = await response.Content.ReadAsStringAsync();
var result = JsonConvert.DeserializeObject<LoginResponse>(responseData);
if (string.IsNullOrEmpty(result?.Token))
throw new Exception("JWT Token not found in response");
var handler = new JwtSecurityTokenHandler();
var token = handler.ReadJwtToken(result.Token);
User = new User(token.Claims);
JwtToken = result.Token;
}
private async Task<HttpResponseMessage> Send(HttpClient client, HttpRequestMessage request, CancellationToken cancellationToken = default)
{
if (string.IsNullOrEmpty(JwtToken))
await Authorize();
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", JwtToken);
var response = await client.SendAsync(request, cancellationToken);
if (response.StatusCode == HttpStatusCode.Unauthorized)
{
await Authorize();
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", JwtToken);
response = await client.SendAsync(request, cancellationToken);
}
if (response.IsSuccessStatusCode)
return response;
var result = await response.Content.ReadAsStringAsync(cancellationToken);
throw new Exception($"Failed: {response.StatusCode}! Result: {result}");
}
public void Dispose()
{
httpClient.Dispose();
Password.Dispose();
}
}
@@ -1,4 +1,5 @@
using System.Text;
using System.Diagnostics;
using System.Text;
using Azaion.CommonSecurity.DTO;
using Azaion.CommonSecurity.DTO.Commands;
using MessagePack;
@@ -9,7 +10,7 @@ namespace Azaion.CommonSecurity.Services;
public interface IResourceLoader
{
Task<MemoryStream> LoadFile(string fileName, CancellationToken ct = default);
MemoryStream LoadFileFromPython(string fileName);
}
public interface IAuthProvider
@@ -20,27 +21,64 @@ public interface IAuthProvider
public class PythonResourceLoader : IResourceLoader, IAuthProvider
{
private readonly ApiCredentials _credentials;
private readonly AzaionApiClient _api;
private readonly DealerSocket _dealer = new();
private readonly Guid _clientId = Guid.NewGuid();
public User CurrentUser { get; }
public PythonResourceLoader(ApiCredentials credentials)
public PythonResourceLoader(ApiConfig apiConfig, ApiCredentials credentials, AzaionApiClient api)
{
//Run python by credentials
_credentials = credentials;
_api = api;
//StartPython(apiConfig, credentials);
_dealer.Options.Identity = Encoding.UTF8.GetBytes(_clientId.ToString("N"));
_dealer.Connect($"tcp://{SecurityConstants.ZMQ_HOST}:{SecurityConstants.ZMQ_PORT}");
_dealer.SendFrame(MessagePackSerializer.Serialize(new RemoteCommand(CommandType.GetUser)));
var user = _dealer.Get<User>(out _);
var user = _dealer.Get<User>();
if (user == null)
throw new Exception("Can't get user from Auth provider");
CurrentUser = user;
}
private void StartPython( ApiConfig apiConfig, ApiCredentials credentials)
{
//var inferenceExe = LoadPythonFile().GetAwaiter().GetResult();
string outputProcess = "";
string errorProcess = "";
public async Task<MemoryStream> LoadFile(string fileName, CancellationToken ct = default)
var path = "azaion-inference.exe";
var arguments = $"-e {credentials.Email} -p {credentials.Password} -f {apiConfig.ResourcesFolder}";
using var process = new Process();
process.StartInfo.FileName = path;
process.StartInfo.Arguments = arguments;
process.StartInfo.UseShellExecute = false;
process.StartInfo.RedirectStandardOutput = true;
process.StartInfo.RedirectStandardError = true;
//process.StartInfo.CreateNoWindow = true;
process.OutputDataReceived += (sender, e) => { if (e.Data != null) Console.WriteLine(e.Data); };
process.ErrorDataReceived += (sender, e) => { if (e.Data != null) Console.WriteLine(e.Data); };
process.Start();
}
public async Task LoadFileFromApi(CancellationToken cancellationToken = default)
{
var hardwareService = new HardwareService();
var hardwareInfo = hardwareService.GetHardware();
var encryptedStream = await _api.GetResource("azaion_inference.exe", _credentials.Password, hardwareInfo, cancellationToken);
var key = Security.MakeEncryptionKey(_credentials.Email, _credentials.Password, hardwareInfo.Hash);
var stream = new MemoryStream();
await encryptedStream.DecryptTo(stream, key, cancellationToken);
stream.Seek(0, SeekOrigin.Begin);
}
public MemoryStream LoadFileFromPython(string fileName)
{
try
{
@@ -49,7 +87,7 @@ public class PythonResourceLoader : IResourceLoader, IAuthProvider
if (!_dealer.TryReceiveFrameBytes(TimeSpan.FromMilliseconds(1000), out var bytes))
throw new Exception($"Unable to receive {fileName}");
return await Task.FromResult(new MemoryStream(bytes));
return new MemoryStream(bytes);
}
catch (Exception ex)
{
+3 -2
View File
@@ -6,11 +6,12 @@ namespace Azaion.CommonSecurity;
public static class ZeroMqExtensions
{
public static T? Get<T>(this DealerSocket dealer, out byte[] message)
public static T? Get<T>(this DealerSocket dealer, Func<byte[], bool>? shouldInterceptFn = null) where T : class
{
if (!dealer.TryReceiveFrameBytes(TimeSpan.FromMinutes(2), out var bytes))
throw new Exception($"Unable to get {typeof(T).Name}");
message = bytes;
if (shouldInterceptFn != null && shouldInterceptFn(bytes))
return null;
return MessagePackSerializer.Deserialize<T>(bytes);
}
}