That’s an interesting use case. You are right, variable sized input won’t work with my first approach.
However, luckily Python provides some implementation for a shared dict.
Here is a small example using it:
from multiprocessing import Manager
import torch
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self, shared_dict, length):
self.shared_dict = shared_dict
self.length = length
def __getitem__(self, index):
if index not in self.shared_dict:
print('Adding {} to shared_dict'.format(index))
self.shared_dict[index] = torch.tensor(index)
return self.shared_dict[index]
def __len__(self):
return self.length
# Init
manager = Manager()
shared_dict = manager.dict()
dataset = MyDataset(shared_dict, length=100)
loader = DataLoader(
dataset,
batch_size=10,
num_workers=6,
shuffle=True,
pin_memory=True
)
# First loop will add data to the shared_dict
for x in loader:
print(x)
# The second loop will just get the data
for x in loader:
print(x)
Would that work for you? I guess you are using a custom collate function to create your batch?