add checkpoints and config system

convert from bbox oriented and pascal xml
fixes
This commit is contained in:
Alex Bezdieniezhnykh
2024-06-18 21:32:15 +03:00
parent b7b8b8fd27
commit 66987f4d95
6 changed files with 182 additions and 59 deletions
+61 -26
View File
@@ -1,13 +1,18 @@
import os.path
import time
from datetime import datetime
from pathlib import Path
import albumentations as A
import cv2
from constants import current_images_dir, current_labels_dir, annotation_classes
import numpy as np
from dateutil.relativedelta import relativedelta
from config import Config
from constants import current_images_dir, current_labels_dir, annotation_classes, prefix, date_format, \
current_dataset_dir
from dto.imageLabel import ImageLabel
labels_dir = 'labels'
images_dir = 'images'
config = Config()
def image_processing(img_ann: ImageLabel) -> [ImageLabel]:
@@ -22,7 +27,7 @@ def image_processing(img_ann: ImageLabel) -> [ImageLabel]:
A.RandomBrightnessContrast(always_apply=True)],
bbox_params=A.BboxParams(format='yolo')),
A.Compose([A.ShiftScaleRotate(scale_limit=0.2, always_apply=True),
A.VerticalFlip(always_apply=True),],
A.VerticalFlip(always_apply=True), ],
bbox_params=A.BboxParams(format='yolo')),
A.Compose([A.ShiftScaleRotate(scale_limit=0.2, always_apply=True)],
bbox_params=A.BboxParams(format='yolo')),
@@ -36,7 +41,7 @@ def image_processing(img_ann: ImageLabel) -> [ImageLabel]:
try:
res = transform(image=img_ann.image, bboxes=img_ann.labels)
path = Path(img_ann.image_path)
name = f'{path.stem}_{i+1}'
name = f'{path.stem}_{i + 1}'
img = ImageLabel(
image=res['image'],
labels=res['bboxes'],
@@ -56,11 +61,12 @@ def write_result(img_ann: ImageLabel, show_image=False):
if show_image:
img_ann.visualize(annotation_classes)
cv2.imwrite(img_ann.image_path, img_ann.image)
cv2.imencode('.jpg', img_ann.image)[1].tofile(img_ann.image_path)
print(f'{img_ann.image_path} written')
with open(img_ann.labels_path, 'w') as f:
lines = [f'{ann[4]} {round(ann[0], 5)} {round(ann[1], 5)} {round(ann[2], 5)} {round(ann[3], 5)}\n' for ann in img_ann.labels]
lines = [f'{ann[4]} {round(ann[0], 5)} {round(ann[1], 5)} {round(ann[2], 5)} {round(ann[3], 5)}\n' for ann in
img_ann.labels]
f.writelines(lines)
f.close()
print(f'{img_ann.labels_path} written')
@@ -89,40 +95,69 @@ def process_image(img_ann):
image_path=os.path.join(current_images_dir, Path(img_ann.image_path).name),
labels_path=os.path.join(current_labels_dir, Path(img_ann.labels_path).name)
))
os.remove(img_ann.image_path)
os.remove(img_ann.labels_path)
# os.remove(img_ann.image_path)
# os.remove(img_ann.labels_path)
def get_checkpoint():
if config.checkpoint is not None:
return config.checkpoint
dates = []
for directory in os.listdir('models'):
try:
dates.append(datetime.strptime(directory[len(prefix):], date_format))
except:
continue
if len(dates) == 0:
return datetime.now() - relativedelta(years=1)
else:
return max(dates)
def main():
last_date = checkpoint = get_checkpoint()
while True:
images = os.listdir(images_dir)
if len(images) == 0:
time.sleep(5)
continue
images = []
with os.scandir(config.images_dir) as imd:
for image_file in imd:
if not image_file.is_file():
continue
mod_time = datetime.fromtimestamp(image_file.stat().st_mtime)
if mod_time > checkpoint:
images.append(image_file)
last_date = max(last_date, mod_time)
for image in images:
for image_file in images:
try:
image_path = os.path.join(images_dir, image)
labels_path = os.path.join(labels_dir, f'{Path(image_path).stem}.txt')
image_path = os.path.join(config.images_dir, image_file.name)
labels_path = os.path.join(config.labels_dir, f'{Path(image_path).stem}.txt')
image = cv2.imdecode(np.fromfile(image_path, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
process_image(ImageLabel(
image_path=image_path,
image=cv2.imread(image_path),
image=image,
labels_path=labels_path,
labels=read_labels(labels_path)
))
except Exception as e:
print(f'Error appeared {e}')
if last_date != checkpoint:
checkpoint = config.checkpoint = last_date
config.write()
time.sleep(5)
try:
os.remove(image_path)
except OSError:
pass
try:
os.remove(labels_path)
except OSError:
pass
def check_labels():
for label in os.listdir(os.path.join(current_dataset_dir, 'labels')):
with open(os.path.join(current_dataset_dir, 'labels', label), 'r') as f:
lines = f.readlines()
for line in lines:
list_c = line.split(' ')[1:]
for l in list_c:
if float(l) > 1:
print('Error!')
if __name__ == '__main__':
main()
check_labels()
# main()