remove unnecessary middle epochs

This commit is contained in:
Alex Bezdieniezhnykh
2025-05-31 19:02:57 +03:00
parent 80c2433141
commit 44c9e87bd4
2 changed files with 19 additions and 12 deletions
+8 -5
View File
@@ -1,8 +1,10 @@
import glob
import os
import shutil import shutil
from os import path from os import path
import constants
import train import train
from constants import models_dir, prefix
from augmentation import Augmentator from augmentation import Augmentator
# Augmentator().augment_annotations() # Augmentator().augment_annotations()
@@ -10,11 +12,12 @@ from augmentation import Augmentator
# train.resume_training('/azaion/dev/ai-training/runs/detect/train12/weights/last.pt') # train.resume_training('/azaion/dev/ai-training/runs/detect/train12/weights/last.pt')
result_dir = '/azaion/dev/ai-training/runs/detect/train12' result_dir = '/azaion/dev/ai-training/runs/detect/train12'
model_dir = path.join(models_dir, f'{prefix}2025-05-18') model_dir = path.join(constants.models_dir, f'{constants.prefix}2025-05-18')
shutil.copytree(result_dir, model_dir, dirs_exist_ok=True)
model_path = path.join(models_dir, f'{prefix[:-1]}.pt') shutil.copytree(result_dir, model_dir, dirs_exist_ok=True)
shutil.copy(path.join(model_dir, 'weights', 'best.pt'), model_path) for file in glob.glob(path.join(model_dir, 'weights', 'epoch*')):
os.remove(file)
shutil.copy(path.join(model_dir, 'weights', 'best.pt'), constants.CURRENT_PT_MODEL)
train.export_current_model() train.export_current_model()
print('success!') print('success!')
+11 -7
View File
@@ -1,4 +1,5 @@
import concurrent.futures import concurrent.futures
import glob
import os import os
import random import random
import shutil import shutil
@@ -146,18 +147,21 @@ def resume_training(last_pt_path):
def train_dataset(): def train_dataset():
form_dataset() form_dataset()
create_yaml() create_yaml()
model_name = 'yolo11m.yaml' model = YOLO('yolo11m.yaml')
model = YOLO(model_name)
results = model.train(data=abspath(path.join(today_dataset, 'data.yaml')), results = model.train(data=abspath(path.join(today_dataset, 'data.yaml')),
epochs=120, epochs=120, # Empirically set for good performance and relatively not so long training
batch=11, # (360k of annotations on 1 RTX4090 takes 11.5 days of training :( )
imgsz=1280, batch=11, # reflects current GPU memory, 24Gb (batch 11 gets ~22Gb, batch 12 fails on 24.2Gb)
save_period=1, imgsz=1280, # 1280p is a tradeoff between quality and speed
workers=24) save_period=1, # for resuming in case of power outages / other issues
workers=24) # loading data workers. Bound to cpus count
model_dir = path.join(models_dir, today_folder) model_dir = path.join(models_dir, today_folder)
shutil.copytree(results.save_dir, model_dir) shutil.copytree(results.save_dir, model_dir)
for file in glob.glob(path.join(model_dir, 'weights', 'epoch*')): # remove unnecessary middle epochs
os.remove(file)
shutil.copy(path.join(model_dir, 'weights', 'best.pt'), constants.CURRENT_PT_MODEL) shutil.copy(path.join(model_dir, 'weights', 'best.pt'), constants.CURRENT_PT_MODEL)