Distributed Data Parallel vs Data Parallel. Data loading too slow for Distributed setting in the first batch of every epoch

I am trying to train a video classification model. I wrote a custom video dataset which essentially reads pre-extracted video frames from SSD. I want to train on a cluster of GPU machines with 4 GPU per node.

While training on 1 machine with 4 GPUs, I have following observations under two settings

Case 1. DistributedDataParallel: with 4 threads for a machine (1 thread per GPU) the data loading time for the first batch of every epoch is a lot (~110 seconds)
Case 2. DataParallel: with 4 threads for the machine, the data loading time is significantly lower (for first batch of every epoch) than Case 1 (~1.5 seconds)

I still want to use DistributedDataParallel as I want to train on multiple machines. But the extra 110 seconds every epoch is too much. How should I improve Distributed setting?

Logs for reference.
Dataparallel 4 threads
Epoch: [0][0/7508] Time 13.270 (13.270) Data 1.521 (1.521) Loss 6.2721 (6.2721) Acc@1 0.000 (0.000) Acc@5 0.000 (0.000)
Epoch: [0][10/7508] Time 0.265 (1.459) Data 0.000 (0.138) Loss 17.9221 (17.1892) Acc@1 0.000 (0.284) Acc@5 0.000 (2.273)
Epoch: [0][20/7508] Time 0.265 (0.890) Data 0.000 (0.077) Loss 20.7100 (14.7189) Acc@1 0.000 (0.149) Acc@5 0.000 (1.786)

DistributedDataparallel 4 threads 1 thread each gpu
Epoch: [0][0/7508] Time 117.339 (117.339) Data 114.749 (114.749) Loss 6.3962 (6.3962) Acc@1 0.000 (0.000) Acc@5 0.000 (0.000)
Epoch: [0][0/7508] Time 117.070 (117.070) Data 110.291 (110.291) Loss 6.3759 (6.3759) Acc@1 0.000 (0.000) Acc@5 0.000 (0.000)
Epoch: [0][0/7508] Time 117.479 (117.479) Data 114.120 (114.120) Loss 6.3918 (6.3918) Acc@1 0.000 (0.000) Acc@5 0.000 (0.000)
Epoch: [0][0/7508] Time 116.495 (116.495) Data 112.885 (112.885) Loss 6.0654 (6.0654) Acc@1 0.000 (0.000) Acc@5 0.000 (0.000)
Epoch: [0][10/7508] Time 0.248 (10.814) Data 0.000 (10.262) Loss 13.6280 (14.8321) Acc@1 0.000 (0.000) Acc@5 0.000 (0.000)
Epoch: [0][10/7508] Time 0.248 (10.870) Data 0.000 (10.030) Loss 12.6716 (16.3162) Acc@1 12.500 (1.136) Acc@5 12.500 (2.273)
Epoch: [0][10/7508] Time 0.252 (10.904) Data 0.000 (10.375) Loss 6.9328 (14.4093) Acc@1 0.000 (1.136) Acc@5 25.000 (3.409)
Epoch: [0][10/7508] Time 0.251 (10.891) Data 0.000 (10.432) Loss 12.2168 (13.2482) Acc@1 0.000 (0.000) Acc@5 0.000 (0.000)
Epoch: [0][20/7508] Time 0.252 (5.813) Data 0.000 (5.260) Loss 6.3584 (13.0522) Acc@1 0.000 (0.595) Acc@5 0.000 (1.190)
Epoch: [0][20/7508] Time 0.254 (5.831) Data 0.000 (5.440) Loss 7.1645 (12.1273) Acc@1 0.000 (0.595) Acc@5 0.000 (1.786)
Epoch: [0][20/7508] Time 0.250 (5.825) Data 0.000 (5.470) Loss 6.9019 (12.8164) Acc@1 0.000 (0.595) Acc@5 0.000 (0.595)
Epoch: [0][20/7508] Time 0.252 (5.784) Data 0.000 (5.381) Loss 6.9181 (11.9140) Acc@1 0.000 (0.000) Acc@5 0.000 (0.000)

For training script I am using a modified version of https://github.com/pytorch/examples/blob/master/imagenet/main.py

Hey, I am having the exactly same issue. Was your problem solved?

@C_Ashraf Is the data loading time for your case really slow or the time to actually execute DistributedDataParallel? If you have a small self contained example demonstrating the problem, it would be easier to narrow down the issue.

My workflow is kind of complex and I do not have a self contained example. But, I will try to explain it in as much detail as possible.I have a very large dataset that I can not load into memory. So I wrote a custom dataset class

class BigDataset(torch.utils.data.Dataset):
    #def __init__(self, data_paths, target_paths):
    def __init__(self, data_paths):
        self.data_memmaps = [np.load(path, mmap_mode='r') for path in data_paths]
        #self.target_memmaps = [np.load(path, mmap_mode='r') for path in target_paths]
        self.start_indices = [0] * len(data_paths)
        self.data_count = 0
        for index, memmap in enumerate(self.data_memmaps):
            self.start_indices[index] = self.data_count
            self.data_count += memmap.shape[0]

    def __len__(self):
        return self.data_count

    def __getitem__(self, index):
        memmap_index = bisect(self.start_indices, index) - 1
        index_in_memmap = index - self.start_indices[memmap_index]
        data = self.data_memmaps[memmap_index][index_in_memmap]
        return index, torch.from_numpy(data)

Next, I read the locations of all the files (my data is separated over multiple files)

    data_paths = [os.path.join(file_path, f'data/feature{index}.npy')
                  for index in range(2)]

    dataset = BigDataset(data_paths)

Since this dataset has both the train and validation data, I need to split it. Thus, I generate train and val indices and use the following code for train and val dataloader

    if args.distributed:
        #train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        train_sampler = torch.utils.data.distributed.DistributedSampler(torch.utils.data.Subset(dataset, train_indices))
        val_sampler = torch.utils.data.distributed.DistributedSampler(torch.utils.data.Subset(dataset, val_indices))
    else:
        train_sampler = SubsetRandomSampler(train_indices)
        val_sampler = SubsetRandomSampler(val_indices)

    train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,
                                               num_workers=args.workers, sampler=train_sampler,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,
                                             num_workers=args.worker, sampler=val_sampler,
                                             pin_memory=True)

I am pretty sure that the dataloader is causing the issue

Epoch: [0][ 0/51]       Time 220.202 (220.202)  Data 205.658 (205.658)  Loss 39.61639 (39.61639)        Accuracy   0.00 (  0.00)
Epoch: [0][ 0/51]       Time 220.181 (220.181)  Data 205.639 (205.639)  Loss 43.61139 (43.61139)        Accuracy   0.00 (  0.00)
Epoch: [0][ 0/51]       Time 220.229 (220.229)  Data 205.687 (205.687)  Loss 35.34707 (35.34707)        Accuracy   0.00 (  0.00)
Epoch: [0][ 0/51]       Time 220.228 (220.228)  Data 205.683 (205.683)  Loss 56.56057 (56.56057)        Accuracy   0.00 (  0.00)

Epoch: [0][ 1/51]       Time  0.917 (110.549)   Data  0.000 (102.820)   Loss 20.94585 (32.27862)        Accuracy   0.00 (  0.00)
Epoch: [0][ 1/51]       Time  0.917 (110.560)   Data  0.000 (102.829)   Loss 63.88563 (51.75101)        Accuracy   0.00 (  0.00)
Epoch: [0][ 1/51]       Time  0.917 (110.573)   Data  0.000 (102.844)   Loss 23.30010 (29.32359)        Accuracy   0.00 (  0.00)
Epoch: [0][ 1/51]       Time  0.917 (110.572)   Data  0.000 (102.842)   Loss 33.03528 (44.79793)        Accuracy   0.00 (  0.00)

I followed the same procedure described [here].(https://github.com/pytorch/examples/blob/master/imagenet/main.py)

Is the way I am loading my data (not directly in the memory) causing this issue?

I am able to reproduce the same behavior using imagenet example with tiny_image_dataset. Using a batch size of 256 in two gpus I get

While using a batch size of 512 in two gpus, it gives

Also, I assume this could be due to dataloader memory leak. If I use my entire dataset (120GB), I see out of memory (oom) kill before any batch is trained. I looked at pytorch discussion forum and looks like it is a very open issue. Any help solving this issue will be appreciated. Thanks.

@VitalyFedyunin I was wondering if you could help out here since this seems like a dataloader issue?

Hi! There is no (known) leak, but more like problem with misunderstanding how memory works in python+forking world, we are aware of this issue and planning to fix it this year (or sooner).

Some work around discussed here: https://github.com/pytorch/pytorch/issues/13246#issuecomment-612396143

I don’t know if you solved this problem, but I have noticed this happening to me as well. I could actually tell something was wrong just by comparing the time between various length datasets, reducing the size of the same dataset, and reducing/increasing the number of workers.

Basically, using COCO dataset wrapped into a torch.DataSet, I noticed that by changing the total number of images the time was not reduced; also changing the batch size was not affecting the initial batch loading time, and, definitely, increasing the number of workers was dramatically raising this loading time.

Finally, I understood that the cause of this was the way of spawning the processes at the beginning, which I replicated by looking at the sample code of PyTorch Imagenet. In this, “torch.multiprocessing.Spawn” is used, but it should be avoided. The best practice, and might be unique to date, is to rely on the official launch module, which relies on subprocess.Popen function. I do not know the precise difference and this should be further investigated.
Hope this helps someone else spare some time either in training or in looking for a solution.

Data loading depends highly on the parameters like num_workers, pin_memory and the batch_size that is able to fit into your GPU. Feel free to change num_workers in the fraction of 2 and notice the speed of iterations/epochs. Using the maximum number of num_workers will also cause large overheads and slow down your data loading.

1 Like

For a similar issue, using args.workers=8 worked for me. My data load times went down from 3 seconds to ~0 seconds. I’m using a modified version of the pytorch imagenet example, with 3 GPUs on the same server. The num_workers passed to the DataLoader is computed as int((args.workers + gpus_per_node - 1)/gpus_per_node), which is int((8 + 3 - 1)/3), or 3 DataLoader workers per GPU.