Distributed Data Parallel repeated pickling

I’m trying to change from plain DP to DDP. My code works, but is very slow.
I narrowed down the problem to a weird behaviour: the dataset, which I pass as an argument in mp.spawn, is being pickled and unpickled every time I create new iterator from loader.

Why is this the case? Surely all arguments should be pickled and unpickled once when processes are created.

Rough code:

def train(rank, dataset):
    setup_ddp(rank)
    model = create_model()
    loader = create_loader(dataset)

    train_iter = iter(loader)

    for i in range(NUM_ITERS):
        try:
            batch, targets = next(train_iter)
        except StopIteration:
            train_iter = iter(loader)
            # The dataset is being pickled and unpickled every time next is called on new iterator
            batch, targets = next(train_iter)

        step(model, batch, targets)

    cleanup_ddp(rank)


dataset = create_dataset()
mp.spawn(train, args=(dataset,), nprocs=WORLD_SIZE)

I figured out that the problem was in loader instead of DDP. I had four workers for each loader, which were created repeatedly each time a new iter was created, with 8 processes it meant that 32 workers were created each time.
Reducing the number of workers solved the problem.