Files
ai-training/train.py
T
Oleksandr Bezdieniezhnykh 142c6c4de8 Refactor constants management to use Pydantic BaseModel for configuration
- Replaced module-level path variables in constants.py with a structured Pydantic Config class.
- Updated all relevant modules (train.py, augmentation.py, exports.py, dataset-visualiser.py, manual_run.py) to access paths through the new config structure.
- Fixed bugs related to image processing and model saving.
- Enhanced test infrastructure to accommodate the new configuration approach.

This refactor improves code maintainability and clarity by centralizing configuration management.
2026-03-27 18:18:30 +02:00

179 lines
6.2 KiB
Python

import concurrent.futures
import glob
import os
import random
import shutil
import subprocess
from datetime import datetime
from os import path, replace, listdir, makedirs, scandir
from os.path import abspath
from pathlib import Path
from time import sleep
import yaml
from ultralytics import YOLO
import constants
from api_client import ApiCredentials, ApiClient
from cdn_manager import CDNCredentials, CDNManager
from dto.annotationClass import AnnotationClass
from inference.onnx_engine import OnnxEngine
from security import Security
from utils import Dotdict
from exports import export_tensorrt, upload_model, export_onnx
today_folder = f'{constants.prefix}{datetime.now():{constants.date_format}}'
train_set = 70
valid_set = 20
test_set = 10
old_images_percentage = 75
DEFAULT_CLASS_NUM = 80
total_files_copied = 0
def form_dataset():
today_dataset = path.join(constants.config.datasets_dir, today_folder)
shutil.rmtree(today_dataset, ignore_errors=True)
makedirs(today_dataset)
images = []
with scandir(constants.config.processed_images_dir) as imd:
for image_file in imd:
if not image_file.is_file():
continue
images.append(image_file)
print(f'Got {len(images)} images. Start shuffling...')
random.shuffle(images)
train_size = int(len(images) * train_set / 100.0)
valid_size = int(len(images) * valid_set / 100.0)
print(f'Start copying...')
copy_annotations(images[:train_size], 'train')
copy_annotations(images[train_size:train_size + valid_size], 'valid')
copy_annotations(images[train_size + valid_size:], 'test')
def copy_annotations(images, folder):
global total_files_copied
total_files_copied = 0
def copy_image(image):
global total_files_copied
total_files_copied += 1
label_name = f'{Path(image.path).stem}.txt'
label_path = path.join(constants.config.processed_labels_dir, label_name)
if check_label(label_path):
shutil.copy(image.path, path.join(destination_images, image.name))
shutil.copy(label_path, path.join(destination_labels, label_name))
else:
shutil.copy(image.path, path.join(constants.config.corrupted_images_dir, image.name))
shutil.copy(label_path, path.join(constants.config.corrupted_labels_dir, label_name))
print(f'Label {label_path} is corrupted! Copy with its image to the corrupted directory ({constants.config.corrupted_labels_dir})')
if total_files_copied % 1000 == 0:
print(f'{total_files_copied} copied...')
today_dataset = path.join(constants.config.datasets_dir, today_folder)
destination_images = path.join(today_dataset, folder, 'images')
makedirs(destination_images, exist_ok=True)
destination_labels = path.join(today_dataset, folder, 'labels')
makedirs(destination_labels, exist_ok=True)
makedirs(constants.config.corrupted_images_dir, exist_ok=True)
makedirs(constants.config.corrupted_labels_dir, exist_ok=True)
copied = 0
print(f'Copying annotations to {destination_images} and {destination_labels} folders:')
with concurrent.futures.ThreadPoolExecutor() as executor:
executor.map(copy_image, images)
print(f'Copied all {copied} annotations to {destination_images} and {destination_labels} folders')
def check_label(label_path):
if not path.exists(label_path):
return False
with open(label_path, 'r') as f:
lines = f.readlines()
for line in lines:
for val in line.split(' ')[1:]:
if float(val) > 1:
return False
return True
def create_yaml():
print('creating yaml...')
lines = ['names:']
annotation_classes = AnnotationClass.read_json()
for i in range(DEFAULT_CLASS_NUM):
if i in annotation_classes:
lines.append(f'- {annotation_classes[i].name}')
else:
lines.append(f'- Class-{i + 1}')
lines.append(f'nc: {DEFAULT_CLASS_NUM}')
lines.append(f'test: test/images')
lines.append(f'train: train/images')
lines.append(f'val: valid/images')
lines.append('')
today_dataset = path.join(constants.config.datasets_dir, today_folder)
today_yaml = abspath(path.join(today_dataset, 'data.yaml'))
with open(today_yaml, 'w', encoding='utf-8') as f:
f.writelines([f'{line}\n' for line in lines])
def resume_training(last_pt_path):
model = YOLO(last_pt_path)
results = model.train(data=yaml,
resume=True,
epochs=constants.config.training.epochs,
batch=constants.config.training.batch,
imgsz=constants.config.training.imgsz,
save_period=constants.config.training.save_period,
workers=constants.config.training.workers)
def train_dataset():
form_dataset()
create_yaml()
model = YOLO(constants.config.training.model)
today_dataset = path.join(constants.config.datasets_dir, today_folder)
results = model.train(data=abspath(path.join(today_dataset, 'data.yaml')),
epochs=constants.config.training.epochs,
batch=constants.config.training.batch,
imgsz=constants.config.training.imgsz,
save_period=constants.config.training.save_period,
workers=constants.config.training.workers)
model_dir = path.join(constants.config.models_dir, today_folder)
shutil.copytree(results.save_dir, model_dir)
for file in glob.glob(path.join(model_dir, 'weights', 'epoch*')): # remove unnecessary middle epochs
os.remove(file)
shutil.copy(path.join(model_dir, 'weights', 'best.pt'), constants.config.current_pt_model)
def export_current_model():
export_onnx(constants.config.current_pt_model)
api_client = ApiClient()
with open(constants.config.current_onnx_model, 'rb') as binary_file:
onnx_bytes = binary_file.read()
key = Security.get_model_encryption_key()
api_client.upload_big_small_resource(onnx_bytes, 'azaion.onnx', constants.MODELS_FOLDER, key)
if __name__ == '__main__':
train_dataset()
export_current_model()
print('success!')