mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 12:56:35 +00:00
refactor augmentation to class, update classes.json, fix small bugs
This commit is contained in:
@@ -38,6 +38,7 @@ DEFAULT_CLASS_NUM = 80
|
||||
total_files_copied = 0
|
||||
|
||||
def form_dataset(from_date: datetime):
|
||||
|
||||
makedirs(today_dataset, exist_ok=True)
|
||||
images = []
|
||||
old_images = []
|
||||
@@ -67,7 +68,6 @@ def form_dataset(from_date: datetime):
|
||||
copy_annotations(images[:train_size], 'train')
|
||||
copy_annotations(images[train_size:train_size + valid_size], 'valid')
|
||||
copy_annotations(images[train_size + valid_size:], 'test')
|
||||
create_yaml()
|
||||
|
||||
|
||||
def copy_annotations(images, folder):
|
||||
@@ -174,21 +174,22 @@ def get_latest_model():
|
||||
|
||||
|
||||
def train_dataset(existing_date=None, from_scratch=False):
|
||||
latest_date, latest_model = get_latest_model()
|
||||
latest_date, latest_model = get_latest_model() if not from_scratch else None, None
|
||||
|
||||
if existing_date is not None:
|
||||
cur_folder = f'{prefix}{existing_date}'
|
||||
cur_dataset = path.join(datasets_dir, f'{prefix}{existing_date}')
|
||||
else:
|
||||
# form_dataset(latest_date)
|
||||
# create_yaml()
|
||||
if from_scratch:
|
||||
shutil.rmtree(today_dataset)
|
||||
form_dataset(latest_date)
|
||||
create_yaml()
|
||||
cur_folder = today_folder
|
||||
cur_dataset = today_dataset
|
||||
|
||||
model_name = latest_model if latest_model is not None and path.isfile(latest_model) and not from_scratch else 'yolo11m.yaml'
|
||||
print(f'Initial model: {model_name}')
|
||||
model = YOLO(model_name)
|
||||
model.info['author'] = 'LLC Azaion'
|
||||
|
||||
yaml = abspath(path.join(cur_dataset, 'data.yaml'))
|
||||
results = model.train(data=yaml,
|
||||
@@ -222,7 +223,7 @@ def validate(model_path):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# model_path = train_dataset(from_scratch=True)
|
||||
model_path = train_dataset(from_scratch=True)
|
||||
# validate(path.join('runs', 'detect', 'train7', 'weights', 'best.pt'))
|
||||
# form_data_sample(500)
|
||||
# convert2rknn()
|
||||
|
||||
Reference in New Issue
Block a user