I am working with a custom Dataset that is used with the pytorch Dataloader, but using shuffle gives an error that the Dataset is an IterableDataset
Using the original conda env with torch version 1.1.0 , the Dataloader allowed shuffling of the data, but using an updated conda version with pytorch 1.10.1 gives the error that the Dataset is of type IterableDataset.
class Dataset(object):
def __init__(self, examples, fields):
self.examples = examples
self.fields = dict(fields)
def collate_fn(self):
def collate(batch):
if len(self.fields) == 1:
batch = [batch, ]
else:
batch = list(zip(*batch))
tensors = []
for field, data in zip(self.fields.values(), batch):
tensor = field.process(data)
if isinstance(tensor, collections.Sequence) and any(isinstance(t, torch.Tensor) for t in tensor):
tensors.extend(tensor)
else:
tensors.append(tensor)
if len(tensors) > 1:
return tensors
else:
return tensors[0]
return collate
def __getitem__(self, i):
example = self.examples[i]
data = []
for field_name, field in self.fields.items():
data.append(field.preprocess(getattr(example, field_name)))
if len(data) == 1:
data = data[0]
return data
def __len__(self):
return len(self.examples)
def __getattr__(self, attr):
if attr in self.fields:
for x in self.examples:
yield getattr(x, attr)
class PairedDataset(Dataset):
def __init__(self, examples, fields):
assert ('image' in fields)
assert ('text' in fields)
super(PairedDataset, self).__init__(examples, fields)
self.image_field = self.fields['image']
self.text_field = self.fields['text']
The error message using the newer pytorch version is:
ValueError: DataLoader with IterableDataset: expected unspecified shuffle option, but got shuffle=True
Thanks for your reply!
I made the most minimal solution that still uses the relevant class and gives the error.
from torch.utils.data import IterableDataset
class Example(object):
"""Defines a single training or test example.
Stores each column of the example as an attribute.
"""
@classmethod
def fromdict(cls, data):
ex = cls(data)
return ex
def __init__(self, data):
for key, val in data.items():
super(Example, self).__setattr__(key, val)
def __setattr__(self, key, value):
raise AttributeError
def __hash__(self):
return hash(tuple(x for x in self.__dict__.values()))
def __eq__(self, other):
this = tuple(x for x in self.__dict__.values())
other = tuple(x for x in other.__dict__.values())
return this == other
def __ne__(self, other):
return not self.__eq__(other)
class Dataset(object):
def __init__(self, examples):
self.examples = examples
def __getitem__(self, i):
example = self.examples[i]
data = []
return data
def __len__(self):
return len(self.examples)
def __getattr__(self, attr):
for x in self.examples:
yield getattr(x, attr)
def show_eg():
examples = [Example.fromdict({'image':"veryimportantpath", 'text': "cute caption of a dog", 'img_id' : 55})]
egdataset= Dataset(examples)
print("Is it an IterableDataset?", isinstance(egdataset, IterableDataset))
if __name__ == '__main__':
show_eg()
I further realized the old torch version can not run this code because it can not import IterableDataset .
I’m looking forward to hearing your thoughts.
I just tried running your minimal solution code using both PyTorch 1.10.1 and the nightly version, and both times I get the following, which seems correct.
Is it an IterableDataset? False
Process finished with exit code 0
Is it possible for you to create a fresh conda environment with Python 3.9, install the latest PyTorch release (1.10.1 or 1.10.2), and re-run your minimal solution code to see if the issue still exists?
If the issue still exists in the new environment, can you please run the following and paste the output?
wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py
Yes the output False is desired however with my env it outputs True.
My current env already uses python3.9 and has torch 1.10.1 .
I pasted the output of your file below.
Collecting environment information...
PyTorch version: 1.10.2
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.10.2
Libc version: glibc-2.27
Python version: 3.9.7 (default, Sep 16 2021, 13:09:58) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.0-99-generic-x86_64-with-glibc2.27
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: Quadro P2000
Nvidia driver version: 510.47.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.21.2
[pip3] torch==1.10.2
[pip3] torchaudio==0.10.2
[pip3] torchvision==0.11.3
[conda] blas 1.0 mkl
[conda] cudatoolkit 10.2.89 hfd86e86_1
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py39h7f8727e_0
[conda] mkl_fft 1.3.1 py39hd3c417c_0
[conda] mkl_random 1.2.2 py39h51133e4_0
[conda] numpy 1.21.2 py39h20f2e39_0
[conda] numpy-base 1.21.2 py39h79a1101_0
[conda] pytorch 1.10.2 py3.9_cuda10.2_cudnn7.6.5_0 pytorch
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torch 1.10.1 pypi_0 pypi
[conda] torchaudio 0.10.2 py39_cu102 pytorch
I uninstalled Pytorch from pip and created a new env where I only installed Pytorch 10.2 and the issue still arrises.
I also tried it on a gpu server, also by creating a new conda env with python3.9 and installing only Pytorch 10.2 , but again had the issue.
These where the only two installations I did:
Hi Nivek,
I have been able to narrow down the problem.
The Dataset is only an IterableDatase when the __getattr__ contains the yield statement.
If I replace this with a return, the error disappears.
I know return is not equivalent to yield, but for now just hope that the __getattr__ function is not important for the code.
Thanks for guiding me into the right direction!