Using ChainDataset to combine IterableDataset

Dear all,
I am new to Pytorch. For my work, I am using IterableDataset for generating training data that consist of random numbers in a normal distribution. I read in the documentation that ChainDataset can be used for combining datasets generated from IterableDataset. I tried to code it, but it doesn’t work as I expected. The output from the DataLoader only consists of dataset1, although ChainDataset should combine dataset1 with the other datasets (dataset2, dataset3). Could someone help me with this further? Every help is appreciated. Below is a simplified version of my code. However, it has the same purpose.

Best Regards,

import torch
import torch.utils.data as data

class MyDataset(data.IterableDataset):
    def __init__(self, EPD, HFOV, t1Min_min, t1Min_max, t1Range_min, t1Range_max):
        super().__init__()
        self.EPD = EPD
        self.HFOV = HFOV
        
        self.t1Min_min = t1Min_min
        self.t1Min_max = t1Min_max        
        self.t1Range_min = t1Range_min
        self.t1Range_max = t1Range_max       
        

    def __iter__(self):
        while True:
            yield self.sample_data()
    
    def sample_data(self):
        out = torch.zeros([1,4])
        out[:,0] = self.EPD
        out[:,1] = self.HFOV
        out[:,2].uniform_(self.t1Min_min, self.t1Min_max)
        out[:,3].uniform_(self.t1Range_min, self.t1Range_max)

        return out

dataset1 = MyDataset(1, 2, 0., 0., 0.045, 1.045)
dataset2 = MyDataset(3, 4, 0., 0., 0.045, 1.045)
dataset3 = MyDataset(5, 6, 0., 0., 0.045, 1.045)
dataset_combine = data.ChainDataset([dataset1, dataset2, dataset3])

dataloader = torch.utils.data.DataLoader(dataset_combine, batch_size=6)



for i, data in enumerate(dataloader):
    print(data)
    if i >= 3:
        break
1 Like

IterableDatasets don’t end automatically, as they don’t use the __len__ method to determine the length of the data and in your particular code snippet you are using a while True loop, which won’t exit.

Instead you should break, if your stream doesn’t yield new data anymore or use any other condition.

Here is a small example:

import torch
from torch.utils.data import IterableDataset, ChainDataset, DataLoader

class MyDataset(IterableDataset):
    def __init__(self, val, max_samples):
        self.max_samples = max_samples
        self.index = 0
        self.val = val

    def __iter__(self):
        while self.index < self.max_samples:
            yield self.sample_data()
            self.index += 1
    
    def sample_data(self):
        out = torch.tensor([self.val])
        return out

dataset1 = MyDataset(1, 5)
dataset2 = MyDataset(2, 5)
dataset3 = MyDataset(3, 5)
dataset_combine = ChainDataset([dataset1, dataset2, dataset3])

dataloader = DataLoader(dataset_combine, batch_size=10)

for i, data in enumerate(dataloader):
    print(data)

> tensor([[1],
        [1],
        [1],
        [1],
        [1],
        [2],
        [2],
        [2],
        [2],
        [2]])
tensor([[3],
        [3],
        [3],
        [3],
        [3]])

Note that you are also overwriting the data import with the data in your DataLoader loop, so I changed the imports.

1 Like

Thanks @ptrblck, this really helped me. I have a follow-up question, though. When I put DataLoader(dataset_combine, batch_size=10, shuffle=True)
I got the error message:

ValueError: DataLoader with IterableDataset: expected unspecified shuffle option, 
but got shuffle=True

Is there a way to shuffle the combined datasets?

1 Like

Unfortunately, you cannot shuffle an IterableDataset, as this dataset assumes your data comes from e.g. a stream.
A map-style dataset can be shuffled, since we know the length of it before using it, and can simply shuffle the indices used to draw each sample.
If you know the length of the dataset and can use indices to load your data, I would recommend to use the standard Dataset instead of IterableDataset, as it seems to fit your use case better.

1 Like

@ptrblck - What would happen if, however, we have a very large imaging dataset stored in something like an h5 file. Using Dataset in this case, with a code similar to the one below, would allow us to shuffle. However, the issue comes from loading the h5 dataset into memory, as in my case at least, this consists of a few hundred Gb worth of data, which sometimes might make it not practical. Is there a way around this in your expertise? Thanks!

My code for reference. It’s my understanding that doing things this way should prevent the whole h5 file(s) being loaded into memory, although it seems that I am wrong:

class DataMapper(data.Dataset):
        def __init__(self, X, y):
        self.X = X
        self.y = y
        
    def __getitem__(self, index):
        X_volume = torch.from_numpy(self.X[index])
        y_volume = torch.from_numpy(self.y[index])

        return X_volume, y_volume

    def __len__(self):
        return len(self.y)


def get_datasets(data_parameters):
    
    key_X = 'input'
    key_y = 'target'

    X_train_data = h5py.File(os.path.join(data_parameters["data_folder_name"], data_parameters["input_data_train"]), 'r')
    y_train_data = h5py.File(os.path.join(data_parameters["data_folder_name"], data_parameters["target_data_train"]), 'r')
    
    X_validation_data = h5py.File(os.path.join(data_parameters["data_folder_name"], data_parameters["input_data_validation"]), 'r')
    y_validation_data = h5py.File(os.path.join(data_parameters["data_folder_name"], data_parameters["target_data_validation"]), 'r')

       
    return (
        DataMapper( X_train_data[key_X][()], y_train_data[key_y][()] ),
        DataMapper( X_validation_data[key_X][()], y_validation_data[key_y][()] ),
    )

@rasbt has shared some examples of using hdf5 data in this post, which also seems to support shuffling, so you might want to take a look at it.

Related: I’m trying to Chain together local data files with WebDataset data, and I want the system to randomly draw from either dataset.
but I notice it only ever reads from the first element of the chain, which is in this case is the local dataset (torch.utils.data.ChainDataset((local_ds, web_ds))).
How can one make it randomly draw a data point from anywhere in this? I have .shuffle() in the WebDataset setup, so at least that part is randomly-ordered.