mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 10:36:35 +00:00
70abaded05
simplify learning, improve readme
95 lines
3.1 KiB
Python
95 lines
3.1 KiB
Python
from os import path, replace, remove, listdir, makedirs
|
|
from os.path import abspath
|
|
import shutil
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from ultralytics import YOLOv10
|
|
from constants import current_images_dir, current_labels_dir, annotation_classes, prefix
|
|
|
|
|
|
def get_yaml_path(date_str):
|
|
date_dataset = path.join('datasets', f'{prefix}{date_str}')
|
|
date_yaml_path = abspath(path.join(date_dataset, 'data.yaml'))
|
|
return date_dataset, date_yaml_path
|
|
|
|
|
|
today_dataset, today_yaml = get_yaml_path(f'{datetime.now():%Y-%m-%d}')
|
|
train_set = 70
|
|
valid_set = 20
|
|
test_set = 10
|
|
|
|
|
|
def form_dataset():
|
|
makedirs(today_dataset, exist_ok=True)
|
|
images = listdir(current_images_dir)
|
|
|
|
train_size = int(len(images) * train_set / 100.0)
|
|
valid_size = int(len(images) * valid_set / 100.0)
|
|
|
|
move_annotations(images[:train_size], 'train')
|
|
move_annotations(images[train_size:train_size + valid_size], 'valid')
|
|
move_annotations(images[train_size + valid_size:], 'test')
|
|
|
|
create_yaml()
|
|
|
|
|
|
def move_annotations(images, folder):
|
|
destination_images = path.join(today_dataset, folder, 'images')
|
|
makedirs(destination_images, exist_ok=True)
|
|
destination_labels = path.join(today_dataset, folder, 'labels')
|
|
makedirs(destination_labels, exist_ok=True)
|
|
for image_name in images:
|
|
image_path = path.join(current_images_dir, image_name)
|
|
label_name = f'{Path(image_name).stem}.txt'
|
|
label_path = path.join(current_labels_dir, label_name)
|
|
if not path.exists(label_path):
|
|
remove(image_path)
|
|
else:
|
|
replace(image_path, path.join(destination_images, image_name))
|
|
replace(label_path, path.join(destination_labels, label_name))
|
|
|
|
|
|
def create_yaml():
|
|
lines = ['names:']
|
|
for c in annotation_classes:
|
|
lines.append(f'- {annotation_classes[c].name}')
|
|
lines.append(f'nc: {len(annotation_classes)}')
|
|
|
|
lines.append(f'test: test/images')
|
|
lines.append(f'train: train/images')
|
|
lines.append(f'val: valid/images')
|
|
lines.append('')
|
|
|
|
with open(today_yaml, 'w', encoding='utf-8') as f:
|
|
f.writelines([f'{line}\n' for line in lines])
|
|
|
|
|
|
def revert_to_current(date):
|
|
def revert_dir(src_dir, dest_dir):
|
|
for file in listdir(src_dir):
|
|
s = path.join(src_dir, file)
|
|
d = path.join(dest_dir, file)
|
|
replace(s, d)
|
|
date_dataset = path.join('datasets', f'{prefix}{date}')
|
|
current_dataset = path.join('datasets', f'{prefix}current')
|
|
for subset in ['test', 'train', 'valid']:
|
|
revert_dir(path.join(date_dataset, subset, 'images'), path.join(current_dataset, 'images'))
|
|
revert_dir(path.join(date_dataset, subset, 'labels'), path.join(current_dataset, 'labels'))
|
|
shutil.rmtree(date_dataset)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# form_dataset()
|
|
# create_yaml()
|
|
|
|
model = YOLOv10('datasets/zombobase-latest.pt' or 'yolov10x.yaml')
|
|
_, yaml = get_yaml_path('2024-06-09')
|
|
results = model.train(data=yaml, epochs=2, batch=10, imgsz=640)
|
|
|
|
print(results)
|
|
res_model = path.join(results['save_dir'], '/weights/best.pt')
|
|
print(res_model)
|
|
shutil.copy(res_model, 'datasets/zombobase-latest2.pt')
|
|
|
|
pass
|