import os import time import cv2 from pathlib import Path import nvidia.dali as dali import nvidia.dali.fn as fn from constants import ( data_images_dir, data_labels_dir, processed_images_dir, processed_labels_dir ) NUM_AUGMENTATIONS = 7 class DataLoader: def __init__(self, batch_size=32): self.batch_size = batch_size os.makedirs(processed_images_dir, exist_ok=True) os.makedirs(processed_labels_dir, exist_ok=True) def _read_labels(self, labels_path): with open(labels_path, 'r') as f: rows = f.readlines() arr = [] for row in rows: str_coordinates = row.split(' ') class_num = str_coordinates.pop(0) coordinates = [float(n.replace(',', '.')) for n in str_coordinates] coordinates.append(class_num) arr.append(coordinates) return arr def _get_image_label_pairs(self): processed_images = set(f.name for f in os.scandir(processed_images_dir)) pairs = [] for image_file in os.scandir(data_images_dir): if image_file.is_file() and image_file.name not in processed_images: image_path = os.path.join(data_images_dir, image_file.name) labels_path = os.path.join(data_labels_dir, f'{Path(image_path).stem}.txt') if os.path.exists(labels_path): pairs.append((image_path, labels_path)) return pairs def create_dali_pipeline(self, file_paths): @dali.pipeline_def(batch_size=self.batch_size, num_threads=32, device_id=0) def augmentation_pipeline(): # Read images jpegs, _ = fn.file_reader(file_root=data_images_dir, file_list=file_paths, random_shuffle=False) # Decode images images = fn.decoders.image(jpegs, device='mixed') # Random augmentations with GPU acceleration augmented_images = [] for _ in range(NUM_AUGMENTATIONS): aug_image = fn.random_resized_crop( images, device='gpu', min_scale=0.8, max_scale=1.0 ) # Apply multiple random augmentations aug_image = fn.flip(aug_image, horizontal=fn.random.coin_flip()) aug_image = fn.brightness_contrast( aug_image, brightness=fn.random.uniform(range=(-0.05, 0.05)), contrast=fn.random.uniform(range=(-0.05, 0.05)) ) aug_image = fn.rotate( aug_image, angle=fn.random.uniform(range=(-25, 25)), fill_value=0 ) # Add noise and color jittering aug_image = fn.noise.gaussian(aug_image, mean=0, stddev=fn.random.uniform(range=(0, 0.1))) aug_image = fn.hsv( aug_image, hue=fn.random.uniform(range=(-8, 8)), saturation=fn.random.uniform(range=(-8, 8)), value=fn.random.uniform(range=(-8, 8)) ) augmented_images.append(aug_image) # Also include original image augmented_images.append(images) return tuple(augmented_images) return augmentation_pipeline() def process_batch(self): image_label_pairs = self._get_image_label_pairs() # Create file list for DALI file_list_path = os.path.join(processed_images_dir, 'file_list.txt') with open(file_list_path, 'w') as f: for img_path, _ in image_label_pairs: f.write(f'{img_path}\n') # Create DALI pipeline pipeline = self.create_dali_pipeline(file_list_path) pipeline.build() # Process images for batch_idx in range(0, len(image_label_pairs), self.batch_size): batch_pairs = image_label_pairs[batch_idx:batch_idx + self.batch_size] pipeline.run() # Get augmented images for img_idx, (orig_img_path, orig_labels_path) in enumerate(batch_pairs): # Read original labels orig_labels = self._read_labels(orig_labels_path) # Write original image and labels self._write_image_and_labels( pipeline.output[NUM_AUGMENTATIONS][img_idx], orig_img_path, orig_labels, is_original=True ) # Write augmented images for aug_idx in range(NUM_AUGMENTATIONS): self._write_image_and_labels( pipeline.output[aug_idx][img_idx], orig_img_path, orig_labels, aug_idx=aug_idx ) def _write_image_and_labels(self, image, orig_img_path, labels, is_original=False, aug_idx=None): path = Path(orig_img_path) if is_original: img_name = path.name label_name = f'{path.stem}.txt' else: img_name = f'{path.stem}_{aug_idx + 1}{path.suffix}' label_name = f'{path.stem}_{aug_idx + 1}.txt' # Write image img_path = os.path.join(processed_images_dir, img_name) cv2.imencode('.jpg', image.asnumpy())[1].tofile(img_path) # Write labels label_path = os.path.join(processed_labels_dir, label_name) with open(label_path, 'w') as f: lines = [f'{ann[4]} {round(ann[0], 5)} {round(ann[1], 5)} {round(ann[2], 5)} {round(ann[3], 5)}\n' for ann in labels] f.writelines(lines) def main(): while True: loader = DataLoader() loader.process_batch() print('All processed, waiting for 2 minutes...') time.sleep(120) if __name__ == '__main__': main()