mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 09:06:35 +00:00
fix tensor rt engine
This commit is contained in:
committed by
Alex Bezdieniezhnykh
parent
5b89a21b36
commit
06a23525a6
@@ -16,13 +16,14 @@ from azaion_api import ApiCredentials, Api
|
||||
from cdn_manager import CDNCredentials, CDNManager
|
||||
from constants import (processed_images_dir,
|
||||
processed_labels_dir,
|
||||
annotation_classes,
|
||||
prefix, date_format,
|
||||
datasets_dir, models_dir,
|
||||
corrupted_images_dir, corrupted_labels_dir, sample_dir)
|
||||
from exports.export import form_data_sample
|
||||
from dto.annotationClass import AnnotationClass
|
||||
|
||||
from security import Security
|
||||
from utils import Dotdict
|
||||
from exports import export_tensorrt, upload_model
|
||||
|
||||
today_folder = f'{prefix}{datetime.now():{date_format}}'
|
||||
today_dataset = path.join(datasets_dir, today_folder)
|
||||
@@ -120,6 +121,7 @@ def check_label(label_path):
|
||||
def create_yaml():
|
||||
print('creating yaml...')
|
||||
lines = ['names:']
|
||||
annotation_classes = AnnotationClass.read_json()
|
||||
for i in range(DEFAULT_CLASS_NUM):
|
||||
if i in annotation_classes:
|
||||
lines.append(f'- {annotation_classes[i].name}')
|
||||
@@ -217,36 +219,15 @@ def validate(model_path):
|
||||
pass
|
||||
|
||||
|
||||
def upload_model(model_path: str, size_small_in_kb: int=3):
|
||||
# model = YOLO(model_path)
|
||||
# model.export(format="onnx", imgsz=1280, nms=True, batch=4)
|
||||
onnx_model = path.dirname(model_path) + Path(model_path).stem + '.onnx'
|
||||
|
||||
with open(onnx_model, 'rb') as f_in:
|
||||
onnx_bytes = f_in.read()
|
||||
|
||||
key = Security.get_model_encryption_key()
|
||||
onnx_encrypted = Security.encrypt_to(onnx_bytes, key)
|
||||
|
||||
part1_size = min(size_small_in_kb * 1024, int(0.9 * len(onnx_encrypted)))
|
||||
onnx_part_small = onnx_encrypted[:part1_size] # slice bytes for part1
|
||||
onnx_part_big = onnx_encrypted[part1_size:]
|
||||
|
||||
with open(constants.CONFIG_FILE, "r") as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
d_config = Dotdict(config_dict)
|
||||
cdn_c = Dotdict(d_config.cdn)
|
||||
api_c = Dotdict(d_config.api)
|
||||
cdn_manager = CDNManager(CDNCredentials(cdn_c.host, cdn_c.access_key, cdn_c.secret_key))
|
||||
cdn_manager.upload(cdn_c.bucket, 'azaion.onnx.big', onnx_part_big)
|
||||
|
||||
api = Api(ApiCredentials(api_c.url, api_c.user, api_c.pw, api_c.folder))
|
||||
api.upload_file('azaion.onnx.small', onnx_part_small)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# model_path = train_dataset(from_scratch=True)
|
||||
# validate(path.join('runs', 'detect', 'train7', 'weights', 'best.pt'))
|
||||
# form_data_sample(500)
|
||||
# convert2rknn()
|
||||
model_path = 'azaion.pt'
|
||||
export_tensorrt(model_path)
|
||||
engine_model_path = f'{Path(model_path).stem}.engine'
|
||||
upload_model(engine_model_path, engine_model_path)
|
||||
|
||||
upload_model('azaion-2024-10-26.onnx')
|
||||
onnx_model_path = f'{Path(model_path).stem}.onnx'
|
||||
upload_model(onnx_model_path, onnx_model_path)
|
||||
|
||||
Reference in New Issue
Block a user