Mp.spawn and DataLoader performances issue

I wonder whether this is known and if there is a possible explanation of the following difference in performances. I compare runs of two similar scripts:

a) run over batch of a dataloader in child process spawned with mp.spawn (one process is spawned)

b) run over batch of a dataloader as a main function and execute the script with torch.distributed.launch

As a result, I can observe a slowdown by ~3x using mp.spawn.

python check_mp_dist.py /tmp/cifar10/ -j 4 -b 512
Epoch: [0][ 0/97] Time  8.532 ( 8.532)
Epoch: [0][30/97] Time  0.001 ( 0.297)
Epoch: [0][60/97] Time  0.105 ( 0.166)
Epoch: [0][90/97] Time  0.000 ( 0.120)    
Epoch: [1][ 0/97] Time  7.419 ( 7.419)  
Epoch: [1][30/97] Time  0.000 ( 0.264)   
Epoch: [1][60/97] Time  0.088 ( 0.148)    
Epoch: [1][90/97] Time  0.000 ( 0.108)   
Epoch: [2][ 0/97] Time  7.369 ( 7.369)    
Epoch: [2][30/97] Time  0.001 ( 0.263)    
Epoch: [2][60/97] Time  0.037 ( 0.147)    
Epoch: [2][90/97] Time  0.000 ( 0.108)    
Execution time: 32.965460194973275

vs

> python -m torch.distributed.launch --nproc_per_node=1 --use_env  check_dist_launch.py /tmp/cifar10/ -j 4 -b 512
Epoch: [0][ 0/97] Time  1.398 ( 1.398)    
Epoch: [0][30/97] Time  0.001 ( 0.069)    
Epoch: [0][60/97] Time  0.113 ( 0.050)    
Epoch: [0][90/97] Time  0.000 ( 0.043)    
Epoch: [1][ 0/97] Time  0.212 ( 0.212)    
Epoch: [1][30/97] Time  0.035 ( 0.035)    
Epoch: [1][60/97] Time  0.037 ( 0.032)    
Epoch: [1][90/97] Time  0.060 ( 0.032)    
Epoch: [2][ 0/97] Time  0.207 ( 0.207)   
Epoch: [2][30/97] Time  0.018 ( 0.034)    
Epoch: [2][60/97] Time  0.103 ( 0.033)    
Epoch: [2][90/97] Time  0.006 ( 0.031)    
Execution time: 10.863449607044458

Code launched with torch.distributed.launch

import argparse
import time

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision import datasets
from torchvision.transforms import Compose, ToTensor, Normalize, Pad, RandomCrop, RandomHorizontalFlip


def get_train_loader(path, **kwargs):
    train_transform = Compose(
        [
            Pad(4),
            RandomCrop(32, fill=128),
            RandomHorizontalFlip(),
            ToTensor(),
            Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ]
    )
        
    train_dataset = datasets.CIFAR10(root=path, train=True, download=False, transform=train_transform)

    train_sampler = DistributedSampler(train_dataset)
    train_loader = DataLoader(
        train_dataset,
        sampler=train_sampler,
        pin_memory=True,
        drop_last=True,
        **kwargs
    )
        
    return train_loader


parser = argparse.ArgumentParser(description='Check MP training')
parser.add_argument('data', metavar='DIR',
                    help='path to dataset')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    metavar='N',
                    help='mini-batch size (default: 256), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')
parser.add_argument('-p', '--print-freq', default=30, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--world-size', default=1, type=int,
                    help='number of nodes for distributed training')
parser.add_argument('--node_rank', default=0, type=int,
                    help='node rank for distributed training')


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print(' '.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def training(index, args):
    rank = index + args.num_procs_per_node * args.node_rank
    dist.init_process_group(
        backend="nccl", 
        init_method="tcp://localhost:2233",
        world_size=args.world_size, 
        rank=rank
    )
    
    torch.cuda.set_device(index)
    
    train_loader = get_train_loader(args.data, batch_size=args.batch_size, num_workers=args.workers)

    for epoch in range(3):
        
        batch_time = AverageMeter('Time', ':6.3f')
        data_time = AverageMeter('Data', ':6.3f')
        progress = ProgressMeter(
            len(train_loader),
            [batch_time, data_time, ],
            prefix="Epoch: [{}]".format(epoch)
        )

        end = time.time()

        for i, (images, target) in enumerate(train_loader):
            # measure data loading time
            data_time.update(time.time() - end)
            
            images = images.to("cuda", non_blocking=True)
            target = target.to("cuda", non_blocking=True)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

    
    dist.destroy_process_group()


def main():
    args = parser.parse_args()
    num_procs_per_node = torch.cuda.device_count()
    args.num_procs_per_node = num_procs_per_node

    import os    
    
    local_rank = int(os.environ["LOCAL_RANK"])
    
    t_start = time.perf_counter()

    training(local_rank, args)

    print("Execution time: {}".format(time.perf_counter() - t_start))


if __name__ == '__main__':
    main()

Code using mp.spawn

import argparse
import time

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision import datasets
from torchvision.transforms import Compose, ToTensor, Normalize, Pad, RandomCrop, RandomHorizontalFlip


def get_train_loader(path, **kwargs):
    train_transform = Compose(
        [
            Pad(4),
            RandomCrop(32, fill=128),
            RandomHorizontalFlip(),
            ToTensor(),
            Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ]
    )
        
    train_dataset = datasets.CIFAR10(root=path, train=True, download=False, transform=train_transform)

    train_sampler = DistributedSampler(train_dataset)
    train_loader = DataLoader(
        train_dataset,
        sampler=train_sampler,
        pin_memory=True,
        drop_last=True,
        **kwargs
    )
        
    return train_loader


parser = argparse.ArgumentParser(description='Check MP training')
parser.add_argument('data', metavar='DIR',
                    help='path to dataset')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    metavar='N',
                    help='mini-batch size (default: 256), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')
parser.add_argument('-p', '--print-freq', default=30, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--world-size', default=1, type=int,
                    help='number of nodes for distributed training')
parser.add_argument('--node_rank', default=0, type=int,
                    help='node rank for distributed training')


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print(' '.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def training(index, args):
    rank = index + args.num_procs_per_node * args.node_rank
    dist.init_process_group(
        backend="nccl", 
        init_method="tcp://localhost:2233",
        world_size=args.world_size, 
        rank=rank
    )
    
    torch.cuda.set_device(index)
    
    train_loader = get_train_loader(args.data, batch_size=args.batch_size, num_workers=args.workers)

    for epoch in range(3):
        
        batch_time = AverageMeter('Time', ':6.3f')
        data_time = AverageMeter('Data', ':6.3f')
        progress = ProgressMeter(
            len(train_loader),
            [batch_time, data_time, ],
            prefix="Epoch: [{}]".format(epoch)
        )

        end = time.time()

        for i, (images, target) in enumerate(train_loader):
            # measure data loading time
            data_time.update(time.time() - end)
            
            images = images.to("cuda", non_blocking=True)
            target = target.to("cuda", non_blocking=True)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

    
    dist.destroy_process_group()


def main():
    mp.set_start_method("spawn")

    args = parser.parse_args()
    num_procs_per_node = torch.cuda.device_count()
    args.num_procs_per_node = num_procs_per_node

    t_start = time.perf_counter()

    mp.start_processes(training, args=(args, ), nprocs=args.num_procs_per_node, start_method="spawn")

    print("Execution time: {}".format(time.perf_counter() - t_start))

if __name__ == '__main__':
    main()

PS: This issue may be related to Multiprocessing slows dataloaders at the each beginning epoch