Hi!
Setting: My dataset wraps a large pandas dataframe. The dataset is replicated for each dataloader worker, filling my ram, limiting the workers i can use. So I thought to split it via worker_init_fn.
Problem: When I initialize the dataset the way I do, I get index errors.
Sample code showing how I do things:
#%%
import torch
import numpy as np
from time import time, sleep
from torch.utils.data import DataLoader, Dataset
# %%
class Data(Dataset):
def __init__(self) -> None:
super().__init__()
self.indices = np.arange(1000)
def __getitem__(self, index):
return self.indices[index]
def __len__(self):
return len(self.indices)
@staticmethod
def worker_init_fn(worker_id):
worker_info = torch.utils.data.get_worker_info()
print(f"setting up worker {worker_id}")
dataset = worker_info.dataset
all_ids = dataset.indices
subset = np.array_split(all_ids, worker_info.num_workers)[worker_id]
print(f"worker {worker_id} subset size: {len(subset)} / {len(all_ids)}\n")
dataset.indices = subset
loader = DataLoader(Data(), num_workers=4, worker_init_fn=Data.worker_init_fn)
# %%
t0 = time()
s = 0
for i in loader:
print(i)
s += i
sleep(0.01)
print(time()-t0)
> IndexError: index 250 is out of bounds for axis 0 with size 250
What am I doing wrong?