mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 08:46:36 +00:00
142c6c4de8
- Replaced module-level path variables in constants.py with a structured Pydantic Config class. - Updated all relevant modules (train.py, augmentation.py, exports.py, dataset-visualiser.py, manual_run.py) to access paths through the new config structure. - Fixed bugs related to image processing and model saving. - Enhanced test infrastructure to accommodate the new configuration approach. This refactor improves code maintainability and clarity by centralizing configuration management.
179 lines
6.2 KiB
Python
179 lines
6.2 KiB
Python
import concurrent.futures
|
|
import glob
|
|
import os
|
|
import random
|
|
import shutil
|
|
import subprocess
|
|
from datetime import datetime
|
|
from os import path, replace, listdir, makedirs, scandir
|
|
from os.path import abspath
|
|
from pathlib import Path
|
|
from time import sleep
|
|
|
|
import yaml
|
|
from ultralytics import YOLO
|
|
|
|
import constants
|
|
from api_client import ApiCredentials, ApiClient
|
|
from cdn_manager import CDNCredentials, CDNManager
|
|
from dto.annotationClass import AnnotationClass
|
|
from inference.onnx_engine import OnnxEngine
|
|
|
|
from security import Security
|
|
from utils import Dotdict
|
|
from exports import export_tensorrt, upload_model, export_onnx
|
|
|
|
today_folder = f'{constants.prefix}{datetime.now():{constants.date_format}}'
|
|
train_set = 70
|
|
valid_set = 20
|
|
test_set = 10
|
|
old_images_percentage = 75
|
|
|
|
DEFAULT_CLASS_NUM = 80
|
|
total_files_copied = 0
|
|
|
|
|
|
def form_dataset():
|
|
today_dataset = path.join(constants.config.datasets_dir, today_folder)
|
|
shutil.rmtree(today_dataset, ignore_errors=True)
|
|
makedirs(today_dataset)
|
|
images = []
|
|
with scandir(constants.config.processed_images_dir) as imd:
|
|
for image_file in imd:
|
|
if not image_file.is_file():
|
|
continue
|
|
images.append(image_file)
|
|
|
|
print(f'Got {len(images)} images. Start shuffling...')
|
|
random.shuffle(images)
|
|
|
|
train_size = int(len(images) * train_set / 100.0)
|
|
valid_size = int(len(images) * valid_set / 100.0)
|
|
|
|
print(f'Start copying...')
|
|
copy_annotations(images[:train_size], 'train')
|
|
copy_annotations(images[train_size:train_size + valid_size], 'valid')
|
|
copy_annotations(images[train_size + valid_size:], 'test')
|
|
|
|
|
|
def copy_annotations(images, folder):
|
|
global total_files_copied
|
|
total_files_copied = 0
|
|
|
|
def copy_image(image):
|
|
global total_files_copied
|
|
total_files_copied += 1
|
|
label_name = f'{Path(image.path).stem}.txt'
|
|
label_path = path.join(constants.config.processed_labels_dir, label_name)
|
|
if check_label(label_path):
|
|
shutil.copy(image.path, path.join(destination_images, image.name))
|
|
shutil.copy(label_path, path.join(destination_labels, label_name))
|
|
else:
|
|
shutil.copy(image.path, path.join(constants.config.corrupted_images_dir, image.name))
|
|
shutil.copy(label_path, path.join(constants.config.corrupted_labels_dir, label_name))
|
|
print(f'Label {label_path} is corrupted! Copy with its image to the corrupted directory ({constants.config.corrupted_labels_dir})')
|
|
|
|
if total_files_copied % 1000 == 0:
|
|
print(f'{total_files_copied} copied...')
|
|
|
|
today_dataset = path.join(constants.config.datasets_dir, today_folder)
|
|
destination_images = path.join(today_dataset, folder, 'images')
|
|
makedirs(destination_images, exist_ok=True)
|
|
|
|
destination_labels = path.join(today_dataset, folder, 'labels')
|
|
makedirs(destination_labels, exist_ok=True)
|
|
|
|
makedirs(constants.config.corrupted_images_dir, exist_ok=True)
|
|
makedirs(constants.config.corrupted_labels_dir, exist_ok=True)
|
|
|
|
copied = 0
|
|
print(f'Copying annotations to {destination_images} and {destination_labels} folders:')
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
executor.map(copy_image, images)
|
|
|
|
|
|
print(f'Copied all {copied} annotations to {destination_images} and {destination_labels} folders')
|
|
|
|
|
|
def check_label(label_path):
|
|
if not path.exists(label_path):
|
|
return False
|
|
with open(label_path, 'r') as f:
|
|
lines = f.readlines()
|
|
for line in lines:
|
|
for val in line.split(' ')[1:]:
|
|
if float(val) > 1:
|
|
return False
|
|
return True
|
|
|
|
|
|
def create_yaml():
|
|
print('creating yaml...')
|
|
lines = ['names:']
|
|
annotation_classes = AnnotationClass.read_json()
|
|
for i in range(DEFAULT_CLASS_NUM):
|
|
if i in annotation_classes:
|
|
lines.append(f'- {annotation_classes[i].name}')
|
|
else:
|
|
lines.append(f'- Class-{i + 1}')
|
|
lines.append(f'nc: {DEFAULT_CLASS_NUM}')
|
|
|
|
lines.append(f'test: test/images')
|
|
lines.append(f'train: train/images')
|
|
lines.append(f'val: valid/images')
|
|
lines.append('')
|
|
|
|
today_dataset = path.join(constants.config.datasets_dir, today_folder)
|
|
today_yaml = abspath(path.join(today_dataset, 'data.yaml'))
|
|
with open(today_yaml, 'w', encoding='utf-8') as f:
|
|
f.writelines([f'{line}\n' for line in lines])
|
|
|
|
|
|
|
|
def resume_training(last_pt_path):
|
|
model = YOLO(last_pt_path)
|
|
results = model.train(data=yaml,
|
|
resume=True,
|
|
epochs=constants.config.training.epochs,
|
|
batch=constants.config.training.batch,
|
|
imgsz=constants.config.training.imgsz,
|
|
save_period=constants.config.training.save_period,
|
|
workers=constants.config.training.workers)
|
|
|
|
|
|
def train_dataset():
|
|
form_dataset()
|
|
create_yaml()
|
|
model = YOLO(constants.config.training.model)
|
|
|
|
today_dataset = path.join(constants.config.datasets_dir, today_folder)
|
|
results = model.train(data=abspath(path.join(today_dataset, 'data.yaml')),
|
|
epochs=constants.config.training.epochs,
|
|
batch=constants.config.training.batch,
|
|
imgsz=constants.config.training.imgsz,
|
|
save_period=constants.config.training.save_period,
|
|
workers=constants.config.training.workers)
|
|
|
|
model_dir = path.join(constants.config.models_dir, today_folder)
|
|
|
|
shutil.copytree(results.save_dir, model_dir)
|
|
for file in glob.glob(path.join(model_dir, 'weights', 'epoch*')): # remove unnecessary middle epochs
|
|
os.remove(file)
|
|
shutil.copy(path.join(model_dir, 'weights', 'best.pt'), constants.config.current_pt_model)
|
|
|
|
|
|
def export_current_model():
|
|
export_onnx(constants.config.current_pt_model)
|
|
api_client = ApiClient()
|
|
with open(constants.config.current_onnx_model, 'rb') as binary_file:
|
|
onnx_bytes = binary_file.read()
|
|
|
|
key = Security.get_model_encryption_key()
|
|
api_client.upload_big_small_resource(onnx_bytes, 'azaion.onnx', constants.MODELS_FOLDER, key)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
train_dataset()
|
|
export_current_model()
|
|
print('success!')
|