How to load a dataset from batches

I have a processed ImageNet data that has been put into (a standard) ten ‘binary’ batches {train_data_batch_1, train_data_batch_2, …, train_data_batch_10}.

What is the best way, and how, to unpack these ‘binary’ patches and to use them later in a PyTorch dataset class? Is there a Dataset Class that accepts batches as inputs?

If you would like to stick to the batch size, you could just lazily load the data in your Dataset and use a batch size of 1 in your DataLoader.
Here is a small example:

class MyDataset(Dataset):
    def __init__(self, paths):
        self.paths = paths
        
    def __getitem__(self, index):
        x = torch.load(self.paths[index])
        return x
    
    def __len__(self):
        return len(self.paths)
    

batch_size = 10
data = torch.randn(batch_size, 3, 24, 24)
path = './tmp_data.pth'
torch.save(data, path)

dataset = MyDataset([path])
loader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=True,
    num_workers=2
)

x = next(iter(loader))
x.squeeze_()
print(x.shape)
> torch.Size([10, 3, 24, 24])

However, if you would like to use another batch size, this issue would be a bit more complicated.
Let me know, if that’s the case or if the current approach would work or you.

Thanks a bunch!

I will give the method you suggested a go and get back to you.

NB. I think the batches of the data I am using can be treated in a similar way to the Cifar10 batches of torch vision datasets (using pickle command to read each instance, as shown in Cifar10_pytorch_dataset). But that’s even a bit complicated, so I thought of a dataset class that we could pass the batches and the metafile to get things done on the fly.