Files
ai-training/train.py
T
Alex Bezdieniezhnykh 2fa864018f upload model to cdn and api
switch to yolov11
2025-03-03 23:36:10 +02:00

259 lines
9.0 KiB
Python

import io
import os
import random
import shutil
import subprocess
from datetime import datetime
from os import path, replace, listdir, makedirs, scandir
from os.path import abspath
from pathlib import Path
from utils import Dotdict
import yaml
from ultralytics import YOLO
import constants
from azaion_api import ApiCredentials, Api
from cdn_manager import CDNCredentials, CDNManager
from security import Security
from constants import (processed_images_dir,
processed_labels_dir,
annotation_classes,
prefix, date_format,
datasets_dir, models_dir,
corrupted_images_dir, corrupted_labels_dir, sample_dir)
today_folder = f'{prefix}{datetime.now():{date_format}}'
today_dataset = path.join(datasets_dir, today_folder)
train_set = 70
valid_set = 20
test_set = 10
old_images_percentage = 75
DEFAULT_CLASS_NUM = 80
def form_dataset(from_date: datetime):
makedirs(today_dataset, exist_ok=True)
images = []
old_images = []
with scandir(processed_images_dir) as imd:
for image_file in imd:
if not image_file.is_file():
continue
mod_time = datetime.fromtimestamp(image_file.stat().st_mtime).replace(hour=0, minute=0, second=0, microsecond=0)
if from_date is None:
images.append(image_file)
elif mod_time > from_date:
images.append(image_file)
else: # gather old images as well in order to avoid overfitting on the only new data.
old_images.append(image_file)
random.shuffle(old_images)
old_images_size = int(len(old_images) * old_images_percentage / 100.0)
print(f'Got {len(images)} new images and {old_images_size} of old images (to prevent overfitting). Shuffling them...')
images.extend(old_images[:old_images_size])
random.shuffle(images)
train_size = int(len(images) * train_set / 100.0)
valid_size = int(len(images) * valid_set / 100.0)
copy_annotations(images[:train_size], 'train')
copy_annotations(images[train_size:train_size + valid_size], 'valid')
copy_annotations(images[train_size + valid_size:], 'test')
create_yaml()
def copy_annotations(images, folder):
destination_images = path.join(today_dataset, folder, 'images')
makedirs(destination_images, exist_ok=True)
destination_labels = path.join(today_dataset, folder, 'labels')
makedirs(destination_labels, exist_ok=True)
makedirs(corrupted_images_dir, exist_ok=True)
makedirs(corrupted_labels_dir, exist_ok=True)
copied = 0
print(f'Copying annotations to {destination_images} and {destination_labels} folders:')
for image in images:
label_name = f'{Path(image.path).stem}.txt'
label_path = path.join(processed_labels_dir, label_name)
if check_label(label_path):
shutil.copy(image.path, path.join(destination_images, image.name))
shutil.copy(label_path, path.join(destination_labels, label_name))
else:
shutil.copy(image.path, path.join(corrupted_images_dir, image.name))
shutil.copy(label_path, path.join(corrupted_labels_dir, label_name))
print(f'Label {label_path} is corrupted! Copy with its image to the corrupted directory ({corrupted_labels_dir})')
copied = copied + 1
if copied % 1000 == 0:
print(f'{copied} copied...')
print(f'Copied all {copied} annotations to {destination_images} and {destination_labels} folders')
def check_label(label_path):
if not path.exists(label_path):
return False
with open(label_path, 'r') as f:
lines = f.readlines()
for line in lines:
for val in line.split(' ')[1:]:
if float(val) > 1:
return False
return True
def create_yaml():
print('creating yaml...')
lines = ['names:']
for c in annotation_classes:
lines.append(f'- {annotation_classes[c].name}')
classes_count = len(annotation_classes)
for c in range(DEFAULT_CLASS_NUM - classes_count):
lines.append(f'- Class-{c + classes_count + 1}')
lines.append(f'nc: {DEFAULT_CLASS_NUM}')
lines.append(f'test: test/images')
lines.append(f'train: train/images')
lines.append(f'val: valid/images')
lines.append('')
today_yaml = abspath(path.join(today_dataset, 'data.yaml'))
with open(today_yaml, 'w', encoding='utf-8') as f:
f.writelines([f'{line}\n' for line in lines])
def revert_to_processed_data(date):
def revert_dir(src_dir, dest_dir):
for file in listdir(src_dir):
s = path.join(src_dir, file)
d = path.join(dest_dir, file)
replace(s, d)
date_dataset = path.join(datasets_dir, f'{prefix}{date}')
makedirs(processed_images_dir, exist_ok=True)
makedirs(processed_labels_dir, exist_ok=True)
for subset in ['test', 'train', 'valid']:
revert_dir(path.join(date_dataset, subset, 'images'), processed_images_dir)
revert_dir(path.join(date_dataset, subset, 'labels'), processed_labels_dir)
shutil.rmtree(date_dataset)
def get_latest_model():
def convert(d: str):
dir_date = datetime.strptime(d.replace(prefix, ''), '%Y-%m-%d')
dir_model_path = path.join(models_dir, d, 'weights', 'best.pt')
return {'date': dir_date, 'path': dir_model_path}
dates = [convert(d) for d in next(os.walk(models_dir))[1]]
sorted_dates = list(sorted(dates, key=lambda x: x['date']))
if len(sorted_dates) == 0:
return None, None
last_model = sorted_dates[-1]
return last_model['date'], last_model['path']
def train_dataset(existing_date=None, from_scratch=False):
latest_date, latest_model = get_latest_model()
if existing_date is not None:
cur_folder = f'{prefix}{existing_date}'
cur_dataset = path.join(datasets_dir, f'{prefix}{existing_date}')
else:
form_dataset(latest_date)
cur_folder = today_folder
cur_dataset = today_dataset
model_name = latest_model if latest_model is not None and path.isfile(latest_model) and not from_scratch else 'yolo11m.yaml'
print(f'Initial model: {model_name}')
model = YOLO(model_name)
yaml = abspath(path.join(cur_dataset, 'data.yaml'))
results = model.train(data=yaml,
epochs=120,
batch=14,
imgsz=1280,
save_period=1,
workers=24)
model_dir = path.join(models_dir, cur_folder)
shutil.copytree(results.save_dir, model_dir)
model_path = path.join(models_dir, f'{prefix[:-1]}.pt')
shutil.copy(path.join(model_dir, 'weights', 'best.pt'), model_path)
shutil.rmtree('runs')
return model_path
def convert2rknn():
subprocess.call(['bash', 'convert.sh'], cwd="./orangepi5")
latest_date, latest_model = get_latest_model()
model = YOLO(latest_model)
model.export(format="onnx")
pass
def form_data_sample(size=300):
images = []
with scandir(processed_images_dir) as imd:
for image_file in imd:
if not image_file.is_file():
continue
images.append(image_file)
print('shuffling images')
random.shuffle(images)
images = images[:size]
shutil.rmtree(sample_dir, ignore_errors=True)
makedirs(sample_dir, exist_ok=True)
lines = []
for image in images:
shutil.copy(image.path, path.join(sample_dir, image.name))
lines.append(f'./{image.name}')
with open(path.join(sample_dir, 'azaion_subset.txt'), 'w', encoding='utf-8') as f:
f.writelines([f'{line}\n' for line in lines])
def validate(model_path):
model = YOLO(model_path)
metrics = model.val()
pass
def upload_model(model_path: str):
# model = YOLO(model_path)
# model.export(format="onnx", imgsz=1280, nms=True, batch=4)
onnx_model = path.dirname(model_path) + Path(model_path).stem + '.onnx'
with open(onnx_model, 'rb') as f_in:
onnx_bytes = f_in.read()
key = Security.get_model_encryption_key()
onnx_encrypted = Security.encrypt_to(onnx_bytes, key)
part1_size = min(10 * 1024, int(0.9 * len(onnx_encrypted)))
onnx_part_small = onnx_encrypted[:part1_size] # slice bytes for part1
onnx_part_big = onnx_encrypted[part1_size:]
with open(constants.CONFIG_FILE, "r") as f:
config_dict = yaml.safe_load(f)
d_config = Dotdict(config_dict)
cdn_c = Dotdict(d_config.cdn)
api_c = Dotdict(d_config.api)
cdn_manager = CDNManager(CDNCredentials(cdn_c.host, cdn_c.access_key, cdn_c.secret_key))
cdn_manager.upload(cdn_c.bucket, 'azaion.onnx.big', onnx_part_big)
api = Api(ApiCredentials(api_c.url, api_c.user, api_c.pw, api_c.folder))
api.upload_file('azaion.onnx.small', onnx_part_small)
if __name__ == '__main__':
# model_path = train_dataset('2024-10-26', from_scratch=True)
# validate(path.join('runs', 'detect', 'train7', 'weights', 'best.pt'))
# form_data_sample(500)
# convert2rknn()
model_path = 'azaion.pt'
upload_model(model_path)