Hello there,
I want to load image sequences of a fixed length into batches of the same size (for example sequence length = batch size = 7).
There are multiple directories each with images from a sequence with varying number of images. The sequences from different directories are not related to each other.
With my current code, I can process several subdirectories, but if there are not enough images in one directory to fill a batch, the remaining images are taken from the next directory. I would like to avoid this.
Instead, a batch should be discarded if there are not enough images in the current directory and instead the batch should only be filled with images from the next directory. This way, I want to avoid mixing unrelated image sequences in the same batch. If a directory does not have enough images to create even a single batch, it should be skipped completely.
So for example with a sequence length of 7:
- directory A has 15 images ā 2 batches each with 7 images are created; the rest are ignored
- directory B has 10 images ā 1 batch with 7 images is created; the rest are ignored
- direcoty C has 3 images ā directory is skipped entirely
Iām still learning, but I think this can be done with a costum batch sampler?
Unfortunately, I have some problems with this.
Maybe someone can help me find a solution.
This is my current code:
class MainDataset(Dataset):
def __init__(self, img_dir, use_folder_name=False):
self.gt_images = self._load_main_dataset(img_dir)
self.dataset_len = len(self.gt_images)
self.use_folder_name = use_folder_name
def __len__(self):
return self.dataset_len
def __getitem__(self, idx):
img_dir = self.gt_images[idx]
img_name = self._get_name(img_dir)
gt = self._load_img(img_dir)
# Skip non-image files
if gt is None:
return None
gt = torch.from_numpy(gt).permute(2, 0, 1)
return gt, img_name
def _get_name(self, img_dir):
if self.use_folder_name:
return img_dir.split(os.sep)[-2]
else:
return img_dir.split(os.sep)[-1].split('.')[0]
def _load_main_dataset(self, img_dir):
if not (os.path.isdir(img_dir)):
return [img_dir]
gt_images = []
for root, dirs, files in os.walk(img_dir):
for file in files:
if not is_valid_file(file):
continue
gt_images.append(os.path.join(root, file))
gt_images.sort()
return gt_images
def _load_img(self, img_path):
gt_image = io.imread(img_path)
gt_image_bd = getBitDepth(gt_image)
gt_image = np.array(gt_image).astype(np.float32) / ((2 ** (gt_image_bd / 3)) - 1)
return gt_image
def is_valid_file(file_name: str):
valid_image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.tif']
for ext in valid_image_extensions:
if file_name.lower().endswith(ext):
return True
return False
sequence_data_store = MainDataset(img_dir=sdr_img_dir, use_folder_name=True)
sequence_loader = DataLoader(sequence_data_store, num_workers=0, pin_memory=False)