When using worker_init_fn, dataset length breaks

Hi!
Setting: My dataset wraps a large pandas dataframe. The dataset is replicated for each dataloader worker, filling my ram, limiting the workers i can use. So I thought to split it via worker_init_fn.

Problem: When I initialize the dataset the way I do, I get index errors.

Sample code showing how I do things:

#%%
import torch
import numpy as np
from time import time, sleep
from torch.utils.data import DataLoader, Dataset
# %%

class Data(Dataset):
    def __init__(self) -> None:
        super().__init__()
        self.indices = np.arange(1000)

    def __getitem__(self, index):
        return self.indices[index]
    
    def __len__(self):
        return len(self.indices)
    
    @staticmethod
    def worker_init_fn(worker_id):
        worker_info = torch.utils.data.get_worker_info()
        print(f"setting up worker {worker_id}")
        dataset = worker_info.dataset
        all_ids = dataset.indices
        subset = np.array_split(all_ids, worker_info.num_workers)[worker_id]
        print(f"worker {worker_id} subset size: {len(subset)} / {len(all_ids)}\n")
        dataset.indices = subset
    
loader = DataLoader(Data(), num_workers=4, worker_init_fn=Data.worker_init_fn)
# %%
t0 = time()
s = 0
for i in loader:
    print(i)
    s += i
    sleep(0.01)
print(time()-t0)

> IndexError: index 250 is out of bounds for axis 0 with size 250

What am I doing wrong?

I don’t think you’re supposed to change the length of the dataset at that point. For indexed datasets, the indices that each worker gets will be global and drawn from the sampler (which by default does not know about the split).
I think the conceptually easiest thing is to use for splitting data is to use an IterableDatasset which will have each worker return data until it is exhausted, meaning that you sidestep the length. The link above has an example along the lines of your use case.

Best regards

Thomas

1 Like

Thanks, that avoids the problem.