Dataloader re-initialize dataset after each iteration?

That’s an interesting use case. You are right, variable sized input won’t work with my first approach.
However, luckily Python provides some implementation for a shared dict.
Here is a small example using it:

from multiprocessing import Manager

import torch
from torch.utils.data import Dataset, DataLoader


class MyDataset(Dataset):
    def __init__(self, shared_dict, length):
        self.shared_dict = shared_dict
        self.length = length
        
    def __getitem__(self, index):
        if index not in self.shared_dict:
            print('Adding {} to shared_dict'.format(index))
            self.shared_dict[index] = torch.tensor(index)
        return self.shared_dict[index]
        
    def __len__(self):
        return self.length


# Init
manager = Manager()
shared_dict = manager.dict()
dataset = MyDataset(shared_dict, length=100)

loader = DataLoader(
    dataset,
    batch_size=10,
    num_workers=6,
    shuffle=True,
    pin_memory=True
)

# First loop will add data to the shared_dict
for x in loader:
    print(x)

# The second loop will just get the data
for x in loader:
    print(x)

Would that work for you? I guess you are using a custom collate function to create your batch?

1 Like