Multi GPU training with DistributedDataParallel fails after last epoch is done

I followed this tutorial to enable distributed training for my model (one machine with 2 GPU’s):
https://pytorch.org/tutorials/intermediate/ddp_tutorial.html

The high level structure of my train code is documented below :

def train(rank, config):
   # setup
   torch.distributed.init_process_group(
      backend='nccl', init_method='tcp://localhost:6666', 
      world_size=config['world_size'], rank=rank)

   # actual training code
   model = torch.nn.parallel.DistributedDataParellel(module=model, device_ids=[rank])
   dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, 
      shuffle=False, sampler=torch.utils.data.distributed.DistributedSampler(dataset))
   for epoch in range(config['epochs']):
      ...
   print(f'Done - rank {rank}')

   # cleanup
   torch.distributed. destroy_process_group()

if __name__ == '__main__':
   config = load_config():
   torch.multiprocessing.spawn(
      train, args=(config,), nprocs=config['world_size'])

In some cases it runs successfully, but on other cases (with the exact same input arguments) it prints a long failure message after all epochs are done (I can see the Done message on all processes) :

PyThreadState_Clear: warning: thread still has a frame
*** Error in 'opt/conda/bin/python': corrupted size vs. prev_size: 0x00007f38c40008f0 ***
###### Backtrace: ######
....
###### Memory map: ######
....

Exception: process 1 terminated with signal SIGABRT