Dataloader reads custom Dataset as IterableDataset using newer pytorch version

Hi all,

I am working with a custom Dataset that is used with the pytorch Dataloader, but using shuffle gives an error that the Dataset is an IterableDataset
Using the original conda env with torch version 1.1.0 , the Dataloader allowed shuffling of the data, but using an updated conda version with pytorch 1.10.1 gives the error that the Dataset is of type IterableDataset.

class Dataset(object):
    def __init__(self, examples, fields):
        self.examples = examples
        self.fields = dict(fields)

    def collate_fn(self):
        def collate(batch):
            if len(self.fields) == 1:
                batch = [batch, ]
            else:
                batch = list(zip(*batch))

            tensors = []
            for field, data in zip(self.fields.values(), batch):
                tensor = field.process(data)
                if isinstance(tensor, collections.Sequence) and any(isinstance(t, torch.Tensor) for t in tensor):
                    tensors.extend(tensor)
                else:
                    tensors.append(tensor)

            if len(tensors) > 1:
                return tensors
            else:
                return tensors[0]

        return collate

    def __getitem__(self, i):
        example = self.examples[i]
        data = []
        for field_name, field in self.fields.items():
            data.append(field.preprocess(getattr(example, field_name)))

        if len(data) == 1:
            data = data[0]
        return data

    def __len__(self):
        return len(self.examples)

    def __getattr__(self, attr):
        if attr in self.fields:
            for x in self.examples:
                yield getattr(x, attr)

class PairedDataset(Dataset):
    def __init__(self, examples, fields):
        assert ('image' in fields)
        assert ('text' in fields)
        super(PairedDataset, self).__init__(examples, fields)
        self.image_field = self.fields['image']
        self.text_field = self.fields['text']

The error message using the newer pytorch version is:

ValueError: DataLoader with IterableDataset: expected unspecified shuffle option, but got shuffle=True

The original code is here: https://github.com/aimagelab/meshed-memory-transformer/blob/master/data/dataset.py, although I realise I can’t expect anyone to fully spit through all of it.
Hopefully anyone could tell me about possible origin of the error.
Thanks in advance,

It will be helpful if you can give a small, run-able code snippet where your Dataset gets created and passed into DataLoader.

I checked instances of your version of Dataset and PairedDataset in isinstance(ds, IterableDataset) and both return False.

Hi Nivek,

Thanks for your reply!
I made the most minimal solution that still uses the relevant class and gives the error.

from torch.utils.data import IterableDataset  

class Example(object):
    """Defines a single training or test example.
    Stores each column of the example as an attribute.
    """
    @classmethod
    def fromdict(cls, data):
        ex = cls(data)
        return ex

    def __init__(self, data):
        for key, val in data.items():
            super(Example, self).__setattr__(key, val)

    def __setattr__(self, key, value):
        raise AttributeError

    def __hash__(self):
        return hash(tuple(x for x in self.__dict__.values()))

    def __eq__(self, other):
        this = tuple(x for x in self.__dict__.values())
        other = tuple(x for x in other.__dict__.values())
        return this == other

    def __ne__(self, other):
        return not self.__eq__(other)



class Dataset(object):
    def __init__(self, examples):
        self.examples = examples

    def __getitem__(self, i):
        example = self.examples[i]
        data = []
        return data

    def __len__(self):
        return len(self.examples)

    def __getattr__(self, attr):
        for x in self.examples:
            yield getattr(x, attr)



def show_eg():
    examples = [Example.fromdict({'image':"veryimportantpath", 'text': "cute caption of a dog", 'img_id' : 55})]
    egdataset= Dataset(examples)
    print("Is it an IterableDataset?", isinstance(egdataset, IterableDataset))


if __name__ == '__main__':
    show_eg()

I further realized the old torch version can not run this code because it can not import IterableDataset .
I’m looking forward to hearing your thoughts.

I just tried running your minimal solution code using both PyTorch 1.10.1 and the nightly version, and both times I get the following, which seems correct.

Is it an IterableDataset? False

Process finished with exit code 0

Is it possible for you to create a fresh conda environment with Python 3.9, install the latest PyTorch release (1.10.1 or 1.10.2), and re-run your minimal solution code to see if the issue still exists?

If the issue still exists in the new environment, can you please run the following and paste the output?

wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py

Yes the output False is desired however with my env it outputs True.
My current env already uses python3.9 and has torch 1.10.1 .
I pasted the output of your file below.

Collecting environment information...
PyTorch version: 1.10.2
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.10.2
Libc version: glibc-2.27

Python version: 3.9.7 (default, Sep 16 2021, 13:09:58)  [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.0-99-generic-x86_64-with-glibc2.27
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: Quadro P2000
Nvidia driver version: 510.47.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.21.2
[pip3] torch==1.10.2
[pip3] torchaudio==0.10.2
[pip3] torchvision==0.11.3
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               10.2.89              hfd86e86_1  
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] mkl                       2021.4.0           h06a4308_640  
[conda] mkl-service               2.4.0            py39h7f8727e_0  
[conda] mkl_fft                   1.3.1            py39hd3c417c_0  
[conda] mkl_random                1.2.2            py39h51133e4_0  
[conda] numpy                     1.21.2           py39h20f2e39_0  
[conda] numpy-base                1.21.2           py39h79a1101_0  
[conda] pytorch                   1.10.2          py3.9_cuda10.2_cudnn7.6.5_0    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torch                     1.10.1                   pypi_0    pypi
[conda] torchaudio                0.10.2               py39_cu102    pytorch

It seems like your environment has two copies of PyTorch installed (1.10.2 with conda and 1.10.1 with pip) and that may be causing issues.

Is that a new environment that you just created? If not, can you create a new environment to see if the issue still persists?

I am unable to reproduce this error on my end, but I will keep an eye out for possible solution.

I uninstalled Pytorch from pip and created a new env where I only installed Pytorch 10.2 and the issue still arrises.
I also tried it on a gpu server, also by creating a new conda env with python3.9 and installing only Pytorch 10.2 , but again had the issue.
These where the only two installations I did:

conda create -n newenv python=3.9   
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch

Thanks for your help.

Hi Nivek,
I have been able to narrow down the problem.
The Dataset is only an IterableDatase when the __getattr__ contains the yield statement.
If I replace this with a return, the error disappears.
I know return is not equivalent to yield, but for now just hope that the __getattr__ function is not important for the code.
Thanks for guiding me into the right direction!