Files
ai-training/train.py
T
Alex Bezdieniezhnykh eba8b62db8 train.py fix
2024-06-19 03:47:57 +03:00

115 lines
3.8 KiB
Python

from os import path, replace, remove, listdir, makedirs
from os.path import abspath
import shutil
from datetime import datetime
from pathlib import Path
from ultralytics import YOLOv10
from constants import current_images_dir, current_labels_dir, annotation_classes, prefix, date_format
latest_model = f'models/{prefix}latest.pt'
today_folder = f'{prefix}{datetime.now():{date_format}}'
train_set = 70
valid_set = 20
test_set = 10
def form_dataset():
makedirs(path.join('datasets', today_folder), exist_ok=True)
images = listdir(current_images_dir)
train_size = int(len(images) * train_set / 100.0)
valid_size = int(len(images) * valid_set / 100.0)
move_annotations(images[:train_size], 'train')
move_annotations(images[train_size:train_size + valid_size], 'valid')
move_annotations(images[train_size + valid_size:], 'test')
create_yaml()
def move_annotations(images, folder):
today_dataset = path.join('datasets', today_folder)
destination_images = path.join(today_dataset, folder, 'images')
makedirs(destination_images, exist_ok=True)
destination_labels = path.join(today_dataset, folder, 'labels')
makedirs(destination_labels, exist_ok=True)
for image_name in images:
image_path = path.join(current_images_dir, image_name)
label_name = f'{Path(image_name).stem}.txt'
label_path = path.join(current_labels_dir, label_name)
if not check_label(label_path):
remove(image_path)
else:
replace(image_path, path.join(destination_images, image_name))
replace(label_path, path.join(destination_labels, label_name))
def check_label(label_path):
lines_edited = False
if not path.exists(label_path):
return False
with open(label_path, 'r') as f:
lines = f.readlines()
for line in lines:
for val in line.split(' ')[1:]:
if float(val) > 1:
lines.remove(line)
lines_edited = True
if len(lines) == 0:
return False
if not lines_edited:
return True
with open(label_path, 'w') as label_write:
label_write.writelines(lines)
label_write.close()
return True
def create_yaml():
lines = ['names:']
for c in annotation_classes:
lines.append(f'- {annotation_classes[c].name}')
lines.append(f'nc: {len(annotation_classes)}')
lines.append(f'test: test/images')
lines.append(f'train: train/images')
lines.append(f'val: valid/images')
lines.append('')
today_yaml = abspath(path.join('datasets', today_folder, 'data.yaml'))
with open(today_yaml, 'w', encoding='utf-8') as f:
f.writelines([f'{line}\n' for line in lines])
def revert_to_current(date):
def revert_dir(src_dir, dest_dir):
for file in listdir(src_dir):
s = path.join(src_dir, file)
d = path.join(dest_dir, file)
replace(s, d)
date_dataset = path.join('datasets', f'{prefix}{date}')
current_dataset = path.join('datasets', f'{prefix}current')
for subset in ['test', 'train', 'valid']:
revert_dir(path.join(date_dataset, subset, 'images'), path.join(current_dataset, 'images'))
revert_dir(path.join(date_dataset, subset, 'labels'), path.join(current_dataset, 'labels'))
shutil.rmtree(date_dataset)
if __name__ == '__main__':
# form_dataset()
# create_yaml()
m = latest_model or 'yolov10x.yaml'
print(f'Initial model: {m}')
model = YOLOv10(latest_model or 'yolov10x.yaml')
folder = f'{prefix}2024-06-18'
yaml = abspath(path.join('datasets', folder, 'data.yaml'))
results = model.train(data=yaml, epochs=100, batch=10, imgsz=640)
shutil.copy(f'{results.save_dir}/weights/best.pt', latest_model)
shutil.copytree(results.save_dir, f'models/{folder}')
shutil.rmtree('runs')
shutil.rmtree('models/zombobase-latest')