Modify DataLoader to not mix files from different directories in batch

Hello there,

I want to load image sequences of a fixed length into batches of the same size (for example sequence length = batch size = 7).

There are multiple directories each with images from a sequence with varying number of images. The sequences from different directories are not related to each other.

With my current code, I can process several subdirectories, but if there are not enough images in one directory to fill a batch, the remaining images are taken from the next directory. I would like to avoid this.

Instead, a batch should be discarded if there are not enough images in the current directory and instead the batch should only be filled with images from the next directory. This way, I want to avoid mixing unrelated image sequences in the same batch. If a directory does not have enough images to create even a single batch, it should be skipped completely.

So for example with a sequence length of 7:

  • directory A has 15 images ā†’ 2 batches each with 7 images are created; the rest are ignored
  • directory B has 10 images ā†’ 1 batch with 7 images is created; the rest are ignored
  • direcoty C has 3 images ā†’ directory is skipped entirely

Iā€™m still learning, but I think this can be done with a costum batch sampler?
Unfortunately, I have some problems with this.
Maybe someone can help me find a solution.

This is my current code:

class MainDataset(Dataset):

def __init__(self, img_dir, use_folder_name=False):
    self.gt_images = self._load_main_dataset(img_dir)
    self.dataset_len = len(self.gt_images)
    self.use_folder_name = use_folder_name

def __len__(self):
    return self.dataset_len

def __getitem__(self, idx):
    img_dir = self.gt_images[idx]
    img_name = self._get_name(img_dir)

    gt = self._load_img(img_dir)

    # Skip non-image files
    if gt is None:
        return None

    gt = torch.from_numpy(gt).permute(2, 0, 1)

    return gt, img_name

def _get_name(self, img_dir):
    if self.use_folder_name:
        return img_dir.split(os.sep)[-2]
    else:
        return img_dir.split(os.sep)[-1].split('.')[0] 

def _load_main_dataset(self, img_dir):
    if not (os.path.isdir(img_dir)):
        return [img_dir]

    gt_images = []
    for root, dirs, files in os.walk(img_dir):
        for file in files:
            if not is_valid_file(file):
                continue 
            gt_images.append(os.path.join(root, file))

    gt_images.sort()

    return gt_images

def _load_img(self, img_path):

    gt_image = io.imread(img_path)
    gt_image_bd = getBitDepth(gt_image)
    gt_image = np.array(gt_image).astype(np.float32) / ((2 ** (gt_image_bd / 3)) - 1)

    return gt_image

def is_valid_file(file_name: str):

valid_image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.tif']

for ext in valid_image_extensions: 
    if file_name.lower().endswith(ext):
        return True

return False

sequence_data_store = MainDataset(img_dir=sdr_img_dir, use_folder_name=True)

sequence_loader = DataLoader(sequence_data_store, num_workers=0, pin_memory=False)

I fixed the problem by creating a costum batch sampler which creates batches with files only from the same directory. Then in the main program, I simply discard every batch that does not have the same size as the given batch size.

class CustomBatchSampler(Sampler):
    def __init__(self, data_source, batch_size):
        self.data_source = data_source
        self.batch_size = batch_size
        self.num_samples = len(data_source)
        self.path_to_indices = self._get_path_to_indices()

    def _get_path_to_indices(self):
        path_to_indices = {}
        for i, img_path in enumerate(self.data_source.gt_images):
            img_dir = self.data_source._get_name(img_path)
            if img_dir not in path_to_indices:
                path_to_indices[img_dir] = []
            path_to_indices[img_dir].append(i)
        return path_to_indices

    def __iter__(self):
        for paths_indices in self.path_to_indices.values():
            for i in range(0, len(paths_indices), self.batch_size):
                yield paths_indices[i:i + self.batch_size]

    def __len__(self):
        return sum(len(indices) // self.batch_size for indices in self.path_to_indices.values())

Then in the main program

sequence_data_store = MainDatasetSequence(img_dir=sdr_img_dir)  directories
sequence_batch_sampler = CustomBatchSampler(sequence_data_store, batch_size)  
sequence_loader = DataLoader(sequence_data_store, batch_sampler=sequence_batch_sampler)