BatchSampler mistakes batch size as channel number

import torch
from import SequentialSampler, BatchSampler, DataLoader

def data_loader(dataset, indices_subset, batch_size=5, drop_last=False):
    pin_memory = torch.cuda.is_available()
    return DataLoader(dataset, sampler=BatchSampler(SequentialSampler(indices_subset), 
                                  batch_size=batch_size, drop_last=drop_last), pin_memory=pin_memory)

I use the batch sampler for my data loader and it yields each batched data with size of (1, 5, 256, 256, 256) if I run the iteractor and check the data size like:

for i, (X, _) in enumerate(dataloader):

I however would like it to be size to be (5, 1, 256, 256, 256) as channel number should be 1 and the batch size should be 5 instead.
I can permute the data after calling emumerate to get right size but I don’t want to modify the standardized module of SupervisedTrainer (imported from monai.engines, which in fact is based on pytorch ignite) for the downstream processing which I intend to use for my project. How should I customize the dataloader definition so that iterator will provide batched input data of the right size to a 3D classification model by itself and I don’t need to modify any imported packages which I use in the downstream? I thank anyone who would like to help out in advance!

Set the batch_size of the DataLoader (not the custom sampler) to None and it should work.

I have done that before. If I set the batch_size to None, the data shape becomes (5, 256, 256, 256) which is not wrong as it is exactly the size of the batched data, but then it creates a problem for my downstream using SupervisedTrainer and I will get a RuntimeError complaining about the incompatibility with the weight size that input with 1 channel is expected but it gets 5 channels, which again is due to the error of mistaking batch size as channel number. Somehow the pytorch-ignite based engine recreates each batch to have the size of (1, 5, 256, 256, 256) through the internal run generator.

The error seems to be raised as the channel dimension is missing. What is the shape of each sample when you iterate the dataset (not the DataLoader) as I guess the channel dim might already be missing there. In this case you could simply unsqueeze it in the Dataset.__getitem__ or inside the DataLoader loop.

Here is how the dataset is implemented. The stored dataset is in .h5 format and each image is stored as a flattened array (1*16777216) which will be reshaped to the original 3D shape (i.e. (256, 256, 256) ) using the stored ‘DataShape’ attribute associated with the image data.

import h5py
from import Dataset

class HDF5Dataset(Dataset):
"""Dataset to load data from the dataset .h5 files
:param project_config: project configuration contains all necessary project-level information
:type project_config: dict
:param transform: transformation to apply to the images
:type transform: list
:param target_transform: transforms to add to each target/label
:type target_transform: list
:param max_size: maximum size of data to draw from
:type max_size: int or None
def __init__(self, project_config, transform=[], target_transform=[], max_size=None):
    dir_path = os.path.join(project_config['OutputPath'], 'Dataset')
    file_name = project_config['DatasetName']
    super(HDF5Dataset, self).__init__()
    self.dir_path = dir_path
    self.inputs = self._load_h5_file_with_data(file_name, 'data')
    self.data_shape = tuple(self.inputs.attrs['DataShape'])
    self.labels = self._load_h5_file_with_data(file_name, 'labels')
    self.transform = Compose([self._from_numpy, self._reshape] + transform)
    self.target_transform = Compose([self._from_numpy, self._to_torch_int] + target_transform)
    self.max_size = max_size

def __getitem__(self, index):
    inputs = self.transform(self.inputs[index])
    labels = self.target_transform(self.labels[index])
    return inputs, labels

def __len__(self):
    return self.max_size if self.max_size else self.inputs.shape[0]

def _load_h5_file_with_data(self, file_name, data_name):
    path = os.path.join(self.dir_path, file_name)
    file = h5py.File(path)
    data = file[data_name]
    return data

def _reshape(self, tensor):
    # Reshape the flattened tensor in .h5 dataset file according to the specified data shape
    return torch.reshape(tensor, tuple([tensor.shape[0]]+list(self.data_shape)))

def _from_numpy(self, tensor):
    return torch.from_numpy(tensor).float()

def _to_torch_int(self, tensor):
    # transformation for labels
    return torch.reshape(tensor, (-1,)).type(torch.ByteTensor)

What modifications would you suggest for the __getitem__ method in my posted code? Now the returned ‘inputs’ of the __getitem__ method provides me with a batch instead of one single image.

I tested inputs = torch.unsqueeze(inputs, 1) after the line of inputs = self.transform(self.inputs[index]), now it is working! Thanks a lot for your kind help!