Dear all,
I am new to Pytorch. For my work, I am using IterableDataset for generating training data that consist of random numbers in a normal distribution. I read in the documentation that ChainDataset can be used for combining datasets generated from IterableDataset. I tried to code it, but it doesn’t work as I expected. The output from the DataLoader only consists of dataset1, although ChainDataset should combine dataset1 with the other datasets (dataset2, dataset3). Could someone help me with this further? Every help is appreciated. Below is a simplified version of my code. However, it has the same purpose.
Best Regards,
import torch
import torch.utils.data as data
class MyDataset(data.IterableDataset):
def __init__(self, EPD, HFOV, t1Min_min, t1Min_max, t1Range_min, t1Range_max):
super().__init__()
self.EPD = EPD
self.HFOV = HFOV
self.t1Min_min = t1Min_min
self.t1Min_max = t1Min_max
self.t1Range_min = t1Range_min
self.t1Range_max = t1Range_max
def __iter__(self):
while True:
yield self.sample_data()
def sample_data(self):
out = torch.zeros([1,4])
out[:,0] = self.EPD
out[:,1] = self.HFOV
out[:,2].uniform_(self.t1Min_min, self.t1Min_max)
out[:,3].uniform_(self.t1Range_min, self.t1Range_max)
return out
dataset1 = MyDataset(1, 2, 0., 0., 0.045, 1.045)
dataset2 = MyDataset(3, 4, 0., 0., 0.045, 1.045)
dataset3 = MyDataset(5, 6, 0., 0., 0.045, 1.045)
dataset_combine = data.ChainDataset([dataset1, dataset2, dataset3])
dataloader = torch.utils.data.DataLoader(dataset_combine, batch_size=6)
for i, data in enumerate(dataloader):
print(data)
if i >= 3:
break
IterableDatasets don’t end automatically, as they don’t use the __len__ method to determine the length of the data and in your particular code snippet you are using a while True loop, which won’t exit.
Instead you should break, if your stream doesn’t yield new data anymore or use any other condition.
Here is a small example:
import torch
from torch.utils.data import IterableDataset, ChainDataset, DataLoader
class MyDataset(IterableDataset):
def __init__(self, val, max_samples):
self.max_samples = max_samples
self.index = 0
self.val = val
def __iter__(self):
while self.index < self.max_samples:
yield self.sample_data()
self.index += 1
def sample_data(self):
out = torch.tensor([self.val])
return out
dataset1 = MyDataset(1, 5)
dataset2 = MyDataset(2, 5)
dataset3 = MyDataset(3, 5)
dataset_combine = ChainDataset([dataset1, dataset2, dataset3])
dataloader = DataLoader(dataset_combine, batch_size=10)
for i, data in enumerate(dataloader):
print(data)
> tensor([[1],
[1],
[1],
[1],
[1],
[2],
[2],
[2],
[2],
[2]])
tensor([[3],
[3],
[3],
[3],
[3]])
Note that you are also overwriting the data import with the data in your DataLoader loop, so I changed the imports.
Thanks @ptrblck, this really helped me. I have a follow-up question, though. When I put DataLoader(dataset_combine, batch_size=10, shuffle=True)
I got the error message:
ValueError: DataLoader with IterableDataset: expected unspecified shuffle option,
but got shuffle=True
Unfortunately, you cannot shuffle an IterableDataset, as this dataset assumes your data comes from e.g. a stream.
A map-style dataset can be shuffled, since we know the length of it before using it, and can simply shuffle the indices used to draw each sample.
If you know the length of the dataset and can use indices to load your data, I would recommend to use the standard Dataset instead of IterableDataset, as it seems to fit your use case better.
@ptrblck - What would happen if, however, we have a very large imaging dataset stored in something like an h5 file. Using Dataset in this case, with a code similar to the one below, would allow us to shuffle. However, the issue comes from loading the h5 dataset into memory, as in my case at least, this consists of a few hundred Gb worth of data, which sometimes might make it not practical. Is there a way around this in your expertise? Thanks!
My code for reference. It’s my understanding that doing things this way should prevent the whole h5 file(s) being loaded into memory, although it seems that I am wrong:
Related: I’m trying to Chain together local data files with WebDataset data, and I want the system to randomly draw from either dataset.
but I notice it only ever reads from the first element of the chain, which is in this case is the local dataset (torch.utils.data.ChainDataset((local_ds, web_ds))).
How can one make it randomly draw a data point from anywhere in this? I have .shuffle() in the WebDataset setup, so at least that part is randomly-ordered.