mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 22:36:36 +00:00
Refactor constants management to use Pydantic BaseModel for configuration
- Replaced module-level path variables in constants.py with a structured Pydantic Config class. - Updated all relevant modules (train.py, augmentation.py, exports.py, dataset-visualiser.py, manual_run.py) to access paths through the new config structure. - Fixed bugs related to image processing and model saving. - Enhanced test infrastructure to accommodate the new configuration approach. This refactor improves code maintainability and clarity by centralizing configuration management.
This commit is contained in:
@@ -16,11 +16,6 @@ from ultralytics import YOLO
|
||||
import constants
|
||||
from api_client import ApiCredentials, ApiClient
|
||||
from cdn_manager import CDNCredentials, CDNManager
|
||||
from constants import (processed_images_dir,
|
||||
processed_labels_dir,
|
||||
prefix, date_format,
|
||||
datasets_dir, models_dir,
|
||||
corrupted_images_dir, corrupted_labels_dir, sample_dir)
|
||||
from dto.annotationClass import AnnotationClass
|
||||
from inference.onnx_engine import OnnxEngine
|
||||
|
||||
@@ -28,8 +23,7 @@ from security import Security
|
||||
from utils import Dotdict
|
||||
from exports import export_tensorrt, upload_model, export_onnx
|
||||
|
||||
today_folder = f'{prefix}{datetime.now():{date_format}}'
|
||||
today_dataset = path.join(datasets_dir, today_folder)
|
||||
today_folder = f'{constants.prefix}{datetime.now():{constants.date_format}}'
|
||||
train_set = 70
|
||||
valid_set = 20
|
||||
test_set = 10
|
||||
@@ -40,10 +34,11 @@ total_files_copied = 0
|
||||
|
||||
|
||||
def form_dataset():
|
||||
today_dataset = path.join(constants.config.datasets_dir, today_folder)
|
||||
shutil.rmtree(today_dataset, ignore_errors=True)
|
||||
makedirs(today_dataset)
|
||||
images = []
|
||||
with scandir(processed_images_dir) as imd:
|
||||
with scandir(constants.config.processed_images_dir) as imd:
|
||||
for image_file in imd:
|
||||
if not image_file.is_file():
|
||||
continue
|
||||
@@ -69,26 +64,27 @@ def copy_annotations(images, folder):
|
||||
global total_files_copied
|
||||
total_files_copied += 1
|
||||
label_name = f'{Path(image.path).stem}.txt'
|
||||
label_path = path.join(processed_labels_dir, label_name)
|
||||
label_path = path.join(constants.config.processed_labels_dir, label_name)
|
||||
if check_label(label_path):
|
||||
shutil.copy(image.path, path.join(destination_images, image.name))
|
||||
shutil.copy(label_path, path.join(destination_labels, label_name))
|
||||
else:
|
||||
shutil.copy(image.path, path.join(corrupted_images_dir, image.name))
|
||||
shutil.copy(label_path, path.join(corrupted_labels_dir, label_name))
|
||||
print(f'Label {label_path} is corrupted! Copy with its image to the corrupted directory ({corrupted_labels_dir})')
|
||||
shutil.copy(image.path, path.join(constants.config.corrupted_images_dir, image.name))
|
||||
shutil.copy(label_path, path.join(constants.config.corrupted_labels_dir, label_name))
|
||||
print(f'Label {label_path} is corrupted! Copy with its image to the corrupted directory ({constants.config.corrupted_labels_dir})')
|
||||
|
||||
if total_files_copied % 1000 == 0:
|
||||
print(f'{total_files_copied} copied...')
|
||||
|
||||
today_dataset = path.join(constants.config.datasets_dir, today_folder)
|
||||
destination_images = path.join(today_dataset, folder, 'images')
|
||||
makedirs(destination_images, exist_ok=True)
|
||||
|
||||
destination_labels = path.join(today_dataset, folder, 'labels')
|
||||
makedirs(destination_labels, exist_ok=True)
|
||||
|
||||
makedirs(corrupted_images_dir, exist_ok=True)
|
||||
makedirs(corrupted_labels_dir, exist_ok=True)
|
||||
makedirs(constants.config.corrupted_images_dir, exist_ok=True)
|
||||
makedirs(constants.config.corrupted_labels_dir, exist_ok=True)
|
||||
|
||||
copied = 0
|
||||
print(f'Copying annotations to {destination_images} and {destination_labels} folders:')
|
||||
@@ -127,6 +123,7 @@ def create_yaml():
|
||||
lines.append(f'val: valid/images')
|
||||
lines.append('')
|
||||
|
||||
today_dataset = path.join(constants.config.datasets_dir, today_folder)
|
||||
today_yaml = abspath(path.join(today_dataset, 'data.yaml'))
|
||||
with open(today_yaml, 'w', encoding='utf-8') as f:
|
||||
f.writelines([f'{line}\n' for line in lines])
|
||||
@@ -137,38 +134,38 @@ def resume_training(last_pt_path):
|
||||
model = YOLO(last_pt_path)
|
||||
results = model.train(data=yaml,
|
||||
resume=True,
|
||||
epochs=120,
|
||||
batch=11,
|
||||
imgsz=1280,
|
||||
save_period=1,
|
||||
workers=24)
|
||||
epochs=constants.config.training.epochs,
|
||||
batch=constants.config.training.batch,
|
||||
imgsz=constants.config.training.imgsz,
|
||||
save_period=constants.config.training.save_period,
|
||||
workers=constants.config.training.workers)
|
||||
|
||||
|
||||
def train_dataset():
|
||||
form_dataset()
|
||||
create_yaml()
|
||||
model = YOLO('yolo11m.yaml')
|
||||
model = YOLO(constants.config.training.model)
|
||||
|
||||
today_dataset = path.join(constants.config.datasets_dir, today_folder)
|
||||
results = model.train(data=abspath(path.join(today_dataset, 'data.yaml')),
|
||||
epochs=120, # Empirically set for good performance and relatively not so long training
|
||||
# (360k of annotations on 1 RTX4090 takes 11.5 days of training :( )
|
||||
batch=11, # reflects current GPU memory, 24Gb (batch 11 gets ~22Gb, batch 12 fails on 24.2Gb)
|
||||
imgsz=1280, # 1280p is a tradeoff between quality and speed
|
||||
save_period=1, # for resuming in case of power outages / other issues
|
||||
workers=24) # loading data workers. Bound to cpus count
|
||||
epochs=constants.config.training.epochs,
|
||||
batch=constants.config.training.batch,
|
||||
imgsz=constants.config.training.imgsz,
|
||||
save_period=constants.config.training.save_period,
|
||||
workers=constants.config.training.workers)
|
||||
|
||||
model_dir = path.join(models_dir, today_folder)
|
||||
model_dir = path.join(constants.config.models_dir, today_folder)
|
||||
|
||||
shutil.copytree(results.save_dir, model_dir)
|
||||
for file in glob.glob(path.join(model_dir, 'weights', 'epoch*')): # remove unnecessary middle epochs
|
||||
os.remove(file)
|
||||
shutil.copy(path.join(model_dir, 'weights', 'best.pt'), constants.CURRENT_PT_MODEL)
|
||||
shutil.copy(path.join(model_dir, 'weights', 'best.pt'), constants.config.current_pt_model)
|
||||
|
||||
|
||||
def export_current_model():
|
||||
export_onnx(constants.CURRENT_PT_MODEL)
|
||||
export_onnx(constants.config.current_pt_model)
|
||||
api_client = ApiClient()
|
||||
with open(constants.CURRENT_ONNX_MODEL, 'rb') as binary_file:
|
||||
with open(constants.config.current_onnx_model, 'rb') as binary_file:
|
||||
onnx_bytes = binary_file.read()
|
||||
|
||||
key = Security.get_model_encryption_key()
|
||||
|
||||
Reference in New Issue
Block a user