Dataloader with Numpy much slower when num_workers > 0

Hello everyone,

I have been working on a project where the data and features are stored in Numpy arrays, and I found that the DataLoader was quite slow when the num_workers > 0, so I decided to recreate the issue with a dummy example:

import numpy as np
from torch.utils.data import DataLoader


class NumpyDataset:
    def __init__(self, size: int):
        self.data = np.random.rand(size, 2)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        return {
            "feature": self.data[i][0],
            "target": self.data[i][1]
        }

SIZE = #TBD
BATCH_SIZE = #TBD
NUM_WORKERS = #TBD

ds = NumpyDataset(size=SIZE)
dl = torch.utils.data.DataLoader(dataset=ds, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
for _ in dl:
    pass

After running some benchmarks for different combinations of parameters, these were the results:

size batch_size num_workers total_time (s)
1e+05 32 0 0.041524
1e+05 32 8 0.387215
1e+05 64 0 0.0182004
1e+05 64 8 0.260331
1e+05 128 0 0.0164798
1e+05 128 8 0.184617
1e+06 32 0 2.30033
1e+06 32 8 28.3145
1e+06 64 0 1.73181
1e+06 64 8 14.0961
1e+06 128 0 1.51957
1e+06 128 8 8.15612
1e+07 32 0 22.3278
1e+07 32 8 281.27
1e+07 64 0 15.7327
1e+07 64 8 151.014
1e+07 128 0 14.3264
1e+07 128 8 75.8562

From these results I could see that:

  • num_workers = 0 is around 1 order of magnitude faster than num_workers = 8
  • The difference appears to reduce with bigger batch sizes

Does anyone know why this might be happening?
Is it recommender to run single thread operations when dealing with NumPy?

Thanks you!

I don’t think this effect is necessarily depending on the usage of numpy but might be the expected overhead from using multiple processes to only index an already preloaded dataset.
Multiple workers are beneficial especially if you are lazily loading and processing the data, i.e. if a single sample is loaded and transformed in each __getitem__ call. In this case each worker will create a full batch in the background while the main thread is busy with the actual model training.
In your example you have already preloaded the dataset in the __init__ such that each worker will only index the data sample from its copy and create the batch, which could yield an overall slowdown due to the added overhead.

2 Likes

Hey @ptrblck,
I have a similar issue despite loading and processing the data in __getitem__ call instead of __init__. Is there any other possible reason why?

Assuming you are lazily loading the samples in the __getitem__ I would guess you might see a bad perf if the actual loading is the bottleneck and thus blocks the other code parts (which could be the case if you are using e.g. a slow network drive).
Could you profile the data loader alone and see how the speed is for different number of workers?

I used torch profiler and here what I got:

Dataloaders time (microsecs) with various num workers
0: 1,484,033
1: 13,361,716
2: 14,867,360
3: 14,596,160
4: 15,243,254
5: 14,412,214
20: 15,291,521

I want to understand the multiprocessing happening here further. Are the operations in __init__ done in the main thread and only the operations inside __getitem__ happen in each of the multiple workers? Hence, in the XavierB’s case the overhead of copying the loaded data (that happened in the __init__) to each worker is the bottleneck. Is it correct that __init__ should contain as minimum operation as possible?

Also, after the workers load the data, will they put the loaded the data in sort of “queue” (though maybe queue is not the right term here) to be fed to GPU during training? If yes, then I suspect there will be an optimum num_workers because if we keep increasing that number, we will just create a long queue that is not necessary (and perhaps consumes the memory)?

Hi @ptrblck , I am having similar problem when I am loading numpy files in Dataset. Dataloading is too slow that GPU is waiting for the data thus resulting into no GPU utilisation.

Here is sample dataset :

class CNNRegressionDataset(Dataset):
    def __init__(self, dataset_path,cfg):
        super(CNNRegressionDataset, self).__init__()
        self.target_dir = os.path.join(dataset_path, 'y')
        self.input_dir = os.path.join(dataset_path, 'x')

        self.cfg = cfg

    def __len__(self):
        return len(glob.glob(os.path.join(self.input_dir, "*.npy"), recursive=True))

    def __getitem__(self, idx):
        input_file = glob.glob(os.path.join(self.input_dir, "*.npy"), recursive=True)[idx]
        target_file = os.path.join(self.target_dir, os.path.basename(input_file))

        x = np.load(input_file)
        if self.cfg.DATASETS.DEPTH_ONLY:
            x = x[:,:,-1]
        if x.shape[0] != self.cfg.DATASETS.DESIRED_SIZE[0] or x.shape[1] != self.cfg.DATASETS.DESIRED_SIZE[1]:
            x = padding(x, self.cfg.DATASETS.DESIRED_SIZE, num_channels=self.cfg.DATASETS.NUM_CHANNELS)
        x = x.reshape(x.shape[0],x.shape[1],self.cfg.DATASETS.NUM_CHANNELS)
        y = torch.tensor(np.load(target_file))
        # changing mm to meter
        # if np.max(x)>10:
        #     x = x/1000
        #     y = y/torch.tensor(1000000)
        # # changing meter to centimeter 
        # else:
        #     x = x*100

        if self.cfg.DATASETS.NORMALISE_DEPTH:
            x = normalize_image(x,input_file)
        x = torch.tensor(x)
        x = x.permute(2, 0, 1)
       

Dataset Initialisers :

train_dataset = CNNRegressionDataset(dataset_path=dataset_path,cfg=cfg)
# train_dataset = datasets.FakeData(50000, (1,224, 224), 1, transforms.ToTensor())

Here is dataloader :


train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=128, shuffle=(train_sampler is None),
        num_workers=0, pin_memory=True, sampler=train_sampler)

Using CNNRegressionDataset, GPU util is 0 while FakeData, my GPU util is 100

For general data loading advice you can check this post.

1 Like