diff --git a/manual_run.py b/manual_run.py index b3292f1..c1b56d8 100644 --- a/manual_run.py +++ b/manual_run.py @@ -1,8 +1,10 @@ +import glob +import os import shutil from os import path +import constants import train -from constants import models_dir, prefix from augmentation import Augmentator # Augmentator().augment_annotations() @@ -10,11 +12,12 @@ from augmentation import Augmentator # train.resume_training('/azaion/dev/ai-training/runs/detect/train12/weights/last.pt') result_dir = '/azaion/dev/ai-training/runs/detect/train12' -model_dir = path.join(models_dir, f'{prefix}2025-05-18') -shutil.copytree(result_dir, model_dir, dirs_exist_ok=True) +model_dir = path.join(constants.models_dir, f'{constants.prefix}2025-05-18') -model_path = path.join(models_dir, f'{prefix[:-1]}.pt') -shutil.copy(path.join(model_dir, 'weights', 'best.pt'), model_path) +shutil.copytree(result_dir, model_dir, dirs_exist_ok=True) +for file in glob.glob(path.join(model_dir, 'weights', 'epoch*')): + os.remove(file) +shutil.copy(path.join(model_dir, 'weights', 'best.pt'), constants.CURRENT_PT_MODEL) train.export_current_model() print('success!') \ No newline at end of file diff --git a/train.py b/train.py index 1850973..00efae3 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,5 @@ import concurrent.futures +import glob import os import random import shutil @@ -146,18 +147,21 @@ def resume_training(last_pt_path): def train_dataset(): form_dataset() create_yaml() - model_name = 'yolo11m.yaml' - model = YOLO(model_name) + model = YOLO('yolo11m.yaml') results = model.train(data=abspath(path.join(today_dataset, 'data.yaml')), - epochs=120, - batch=11, - imgsz=1280, - save_period=1, - workers=24) + epochs=120, # Empirically set for good performance and relatively not so long training + # (360k of annotations on 1 RTX4090 takes 11.5 days of training :( ) + batch=11, # reflects current GPU memory, 24Gb (batch 11 gets ~22Gb, batch 12 fails on 24.2Gb) + imgsz=1280, # 1280p is a tradeoff between quality and speed + save_period=1, # for resuming in case of power outages / other issues + workers=24) # loading data workers. Bound to cpus count model_dir = path.join(models_dir, today_folder) + shutil.copytree(results.save_dir, model_dir) + for file in glob.glob(path.join(model_dir, 'weights', 'epoch*')): # remove unnecessary middle epochs + os.remove(file) shutil.copy(path.join(model_dir, 'weights', 'best.pt'), constants.CURRENT_PT_MODEL)