Dataloader re-initialize dataset after each iteration?

Hi
I write a dataset class, which has a dictionary called image_pool. Each time the getitem function is called, I will first check whether the image exists in the pool. If not, load from the disk and save it into the pool. However, i find that, in the second iteration the dictionary becomes empty and so on in all later iterations. What is the reason?

1 Like

If you use multiple workers an inplace manipulation of your Dataset won’t be saved as far as I know, since each worker will get a copy of your Dataset.
We had a similar discussion about caching some data and I’ve created a small example using shared arrays here.
Maybe this could also help in your case.

Hi, thank you so much for you reply. However, in my project, the images are of different sizes which can not be saved as arrays with same shape. Do you have any suggestions?

That’s an interesting use case. You are right, variable sized input won’t work with my first approach.
However, luckily Python provides some implementation for a shared dict.
Here is a small example using it:

from multiprocessing import Manager

import torch
from torch.utils.data import Dataset, DataLoader


class MyDataset(Dataset):
    def __init__(self, shared_dict, length):
        self.shared_dict = shared_dict
        self.length = length
        
    def __getitem__(self, index):
        if index not in self.shared_dict:
            print('Adding {} to shared_dict'.format(index))
            self.shared_dict[index] = torch.tensor(index)
        return self.shared_dict[index]
        
    def __len__(self):
        return self.length


# Init
manager = Manager()
shared_dict = manager.dict()
dataset = MyDataset(shared_dict, length=100)

loader = DataLoader(
    dataset,
    batch_size=10,
    num_workers=6,
    shuffle=True,
    pin_memory=True
)

# First loop will add data to the shared_dict
for x in loader:
    print(x)

# The second loop will just get the data
for x in loader:
    print(x)

Would that work for you? I guess you are using a custom collate function to create your batch?

1 Like

Hi, thank you very much for your kind reply. It seems the dictionary can not store PIL image, but can sore numpy array and Tensor. I prefer saving PIL because the inbuilt image transformations mainly work for PIL image. A stupid solution is transform it to numpy array when saving it, and transform it back to PIL when loading it. Any better idea?

Are you sure the dict throws this error?
It seems to work on my machine.
I guess the DataLoader’s default_collate method throws an error like

TypeError: batch must contain tensors, numbers, dicts or lists

If you really want to return PIL.Images, you could write your own collate_fn.
Usually you would apply the transformation in __getitem__ and store the processed sample.
Wouldn’t that work

it works now. thank you so much

hi, the problem is like this. When I manually call getitem function, it works fine. But when I use dataloader, it breaks down.
Bascially, I save the PIL images in the shared dict. Every time I call getitem,i take one from the dictionary and apply transformation (random scale, crop, ToTensor) to it and return . what is the problem?

Traceback (most recent call last):
  File "/home/zhangchi/pycharm-2017.3.4/helpers/pydev/pydev_run_in_console.py", line 53, in run_file
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/home/zhangchi/pycharm-2017.3.4/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/home/zhangchi/Dropbox/iccv/data/train/dataset_mask_train_memory.py", line 211, in <module>
    for data in trainloader:
  File "/home/zhangchi/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 267, in __next__
    return self._process_next_batch(batch)
  File "/home/zhangchi/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 301, in _process_next_batch
    raise batch.exc_type(batch.exc_msg)
multiprocessing.managers.RemoteError: 
---------------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/zhangchi/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 55, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/zhangchi/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 55, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/zhangchi/Dropbox/iccv/data/train/dataset_mask_train_memory.py", line 129, in __getitem__
    support_rgb,support_mask=self.load_pil(support_name,sample_class)
  File "/home/zhangchi/Dropbox/iccv/data/train/dataset_mask_train_memory.py", line 97, in load_pil
    self.shared_dict[(name, sample_class)] = rgb, mask
  File "<string>", line 2, in __setitem__
  File "/home/zhangchi/anaconda3/lib/python3.6/multiprocessing/managers.py", line 772, in _callmethod
    raise convert_to_error(kind, result)
multiprocessing.managers.RemoteError: 
---------------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/zhangchi/anaconda3/lib/python3.6/multiprocessing/managers.py", line 228, in serve_client
    request = recv()
  File "/home/zhangchi/anaconda3/lib/python3.6/multiprocessing/connection.py", line 251, in recv
    return _ForkingPickler.loads(buf.getbuffer())
AttributeError: Can't get attribute 'JpegImageFile' on <module 'PIL.JpegImagePlugin' from '/home/zhangchi/anaconda3/lib/python3.6/site-packages/PIL/JpegImagePlugin.py'>
---------------------------------------------------------------------------
---------------------------------------------------------------------------

I modify your example a little bit to show the problem.

from multiprocessing import Manager
import torchvision
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image

class MyDataset(Dataset):
    def __init__(self, shared_dict, length):
        self.shared_dict = shared_dict
        self.length = length
        self.scale=torchvision.transforms.Resize([321,321])
        self.ToTensor=torchvision.transforms.ToTensor()
    def __getitem__(self, index):
        if index not in self.shared_dict:
            print('Adding {} to shared_dict'.format(index))
            self.shared_dict[index] = Image.open('test_img.png')#torch.FloatTensor([index])

        return self.ToTensor(self.scale(self.shared_dict[index]))
    def __len__(self):
        return self.length


# Init
manager = Manager()
shared_dict = manager.dict()
dataset = MyDataset(shared_dict, length=100)





loader = DataLoader(
    dataset,
    batch_size=2,
    num_workers=6,
    shuffle=True,
    pin_memory=True
)

# First loop will add data to the shared_dict
for x in loader:
    
    pass

# The second loop will just get the data
for x in loader:
    pass

Would it work if you apply both transformations on the image before saving it to the dict, such that the tensor will be stored instead of the image?

if tensor is saved in the dict, it works well, but the transformations are random (crop, scale). it does not make sense to save the transformed ones. :frowning:

it’s weird that if num_worker is 1, it is ok

Thanks for debugging.
I think the easiest way would be to store the data as a numpy array (or tensor) and convert it back to a PIL.Image for the random transformations.

Maybe. I will try. Thank you for your help.

Hi @ptrblck , I have the same issue. I was thinking if it would be good to do the following:
use functools.lru_cache( ) on __getitem__ and iterate over the dataloader once so that all getitem are cached. Now increase the num_workers from 0 to some value, and expect the cached data is shared across the processes.

Will this work ?

I would see a few problematic points with this approach, which you might have already thought about:

  • usually you are transforming the data in __getitem__, which would yield slightly different results in each call for the same sample. This would be a cache miss and you would most likely never hit a cached sample. Are you thus only using static transformations? If so, why wouldn’t preloading the entire dataset into a single tensor work?
  • “so that all getitem are cached” → this would mean that you are able to preload the entire dataset (which is not possible if the dataset is too large for your host RAM or you just don’t want to waste it to store the data). If so, how large is the data?
  • “cached data is shared” → I don’t know how the lru_cache works internally, but to share data between multiple processes you would need to use shared object types to be able to use it for IPC.
1 Like

Oh ! I see. The ram size isn’t sufficient to load the whole data, I missed that. Thanks anyway.

functools.lru_cache() actually will not work because of the issue mentioned in this thread.