Why my DistributedDataParallel is slower than DataParallel if my Dataset is not loaded fully in memory

I am using a custom dataset and used custom data.Dataset class for loading it.

class MyDataset(data.Dataset):

    def __init__(self, datasets, transform=None, target_transform=None):
        self.datasets = datasets
        self.transform = transform

    def __len__(self):
        return len(self.datasets)

    def __getitem__(self, index):
        image = Image.open(os.path.join(self.datasets[index][0]))
        if self.transform:
            image = self.transform(image)
        return image, torch.tensor(self.datasets[index][1], dtype=torch.long)

When I started training on my 4 GPU machine, unlike the mentioned in Pytorch documentation, I found that DistributedDataParallel is slower than DataParallel!

I reviewed my code carefully and tried different configurations and batch sizes (especially the DataLoader num_workers) to see what makes DistributedDataParallel runs faster than DataParallel as expected, but nothing worked.

The only change I did that made DistributedDataParallel faster is loading the whole dataset into memory during initialization!

class Inmemory_Dataset(data.Dataset):

    def __init__(self, datasets, transform=None, target_transform=None):
        self.datasets = datasets
        transform = transform
        image_list = []
        target_list = []
        for i, data in enumerate(datasets):
            image = Image.open(os.path.join(data[0]))
            if transform:
                image = transform(image)
            image_list.append(image.numpy())
            target_list.append(data[1])
        self.images = torch.tensor(image_list)
        self.targets = torch.tensor(target_list, dtype=torch.long)

    def __len__(self):
        return len(self.datasets)

    def __getitem__(self, index):
        return self.images[index], self.targets[index]

After this change, DistributedDataParallel became 30% faster. but I do not think this is how it should be. Because what if my dataset does not fit into memory?

Below I highlight the main parts where I setup the use for both DataParallel and DistributedDataParallel. Notice that the overall effictive batch size is the same in both cases.

DataParallel:

    batch_size = 100
    if torch.cuda.device_count() > 1:
        print("Using DataParallel...")
        model = nn.DataParallel(model)
        batch_size = batch_size * torch.cuda.device_count()

DistributedDataParallel:

def train(gpu, args):
    # print(args)
    rank = args.nr * args.gpus + gpu
    dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank)
batch_size = 100
model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
train_sampler = torch.utils.data.distributed.DistributedSampler(training_dataset,
                                                                    num_replicas=args.world_size,
                                                                    rank=rank)
training_dataloader = torch.utils.data.DataLoader(
        dataset=training_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=1,
        sampler=train_sampler,
        pin_memory=True)
1 Like

Hey @ammary-mo, how did you measure the delay? Since DataParallel and DistributedDataParallel are only involved in the forward and backward passes, could you please try using elapsed_time to measure data loading, forward and backward delay breakdowns? See the following discussion. It’s possible that if multiple DDP processes try to read from the same file, contentions might lead to data loading perf regression. If that’s the case, the solution would be implementing a more performant data loder.

cc @VitalyFedyunin @glaringlee for DataLoader and DataSampler.

1 Like

Hey @mrshenli

Thanks for tuning in, I understand why you are asking to measure the forward pass for both DP and DDP. But if you think about it from the end-user perspective, I care more about the overall training performance (which involves both model sync and data batches loading). If DDP provides better time only in the forward pass but (somehow) wastes the time it saved in data loading, then, as overall, DP will be a better option! Do not you agree?!

Is it possible, that if your data is small enough to entirely fit into the memory, the DDP setup overhead is just increasing time on the task without any performance improvement? In other words: GPU utilization is small enough, you just can’t see the gains of using multiple GPUs

1 Like

Good question @Alexey_Demyanchuk.

But my answer is no. My data is large to fit in GPU, that is why I started by loading it from disk. However, when I found that DDP is slower than DP, as I mentioned in the question, I started comparing both with different configurations to see what works. Eventually, the only change that made DDP faster is when I reduced my data size and loaded it into memory. I hope this clarifies the situation.

My suggestion would be to profile the pipeline. Is GPU utilization near 100% throughout training? If it is not, you could have some sort of preprocessing bottleneck (I/O or CPU bound). Does it make sense in your case?

Not sure why to check the GPU utilization? GPU utilization depends on multiple factors, like batch size and even image size.

Anyhow, I checked GPU utilization and it is low in all cases. but this is not due to data-loading bottleneck rather than because I am using small model and small data (MNIST-like data set) just to compare the performance.

@mrshenli, it seems you are right. It looks like a dataloading issue to me.

In DDP, with only two workers (num_workers=2), it is clear that data loading time is a bottleneck as one worker/batch is constantly taking more time to load than the other:
image

To solution is to increase the num_workers to the limit that hides the dataloading delay as in this post.

But it seems that this solution works well only with DataParallel, but not with DistributedDataParallel.

  • DataParallel num_workers=16
    It is working fine because only the first batch takes 3 seconds and all consecutive batches takes almost no time
    image

  • DistributedDataParallel num_workers=16
    The first batch is taking 13 seconds to load, which is too much! and as I add more workers it takes even more.
    image

My explanation is that the data sampler is adding more overhead in DistributedDataParallel which is not the case in DataParallel.

Below again is my data sampler code to check if there is any issue with it or potential enhancement:

batch_size = 100
model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
train_sampler = torch.utils.data.distributed.DistributedSampler(training_dataset,
                                                                    num_replicas=args.world_size,
                                                                    rank=rank)
training_dataloader = torch.utils.data.DataLoader(
        dataset=training_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=16,
        sampler=train_sampler,
        pin_memory=True)

I hope @VitalyFedyunin and @glaringlee can get back to us for advice about DataLoader and DataSampler.

1 Like

Hello, Do you solve these problems? I have the same problems.

Can you open a new issue with the dataloader tag and a script of your issue. This issue is 2 years old now so there have been many changes.