torch.multiprocessing.spawn.ProcessRaisedException: -- Process 0 terminated with the following error:

I am using multiple GPUs on same system to train a network. I have followed all steps mentioned in pytorch documentation. While validation, it give an error regarding -- Process 0

Step 1:

import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

Step 2:

# ------ Setting up the distributed environment -------
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()

Step 3:
In training loop

 train_loss = model(images, targets)["loss"].to(rank)

Step 4:

def main(rank, args):
    print(args)

    print(f"Running Distributed ResNet on rank {rank}.")
    setup(rank, args.world_size)
    torch.manual_seed(0)
    torch.cuda.set_device(rank)

Step 5:

model = recognition.__dict__[args.arch](pretrained=args.pretrained, vocab=vocab).to(rank)
model = DDP(model, device_ids=[rank])

Step 6:

parser.add_argument('--world_size', type=int, default=world_size, help='total number of processes')
parser.add_argument("--local_rank", type=int,
                                help="Local rank. Necessary for using the torch.distributed.launch utility.")
    
args = parser.parse_args()
mp.spawn(train_func, args=(args,), nprocs=args.world_size, join=True)

Step 7:

if __name__ == "__main__":
    # Johnson
    n_gpus = torch.cuda.device_count()
    run_train_model(main, n_gpus)

Traceback:

Validation loss decreased inf --> 0.421184: saving state...                                                                              
Epoch 1/2 - Validation loss: 0.421184 (Exact: 75.37% | Partial: 75.37%)
Traceback (most recent call last):                                                                                                       
  File "/media/cvpr/CM_22/doctr/references/recognition/pytorch_single_lang.py", line 504, in <module>
    run_train_model(main, n_gpus)
  File "/media/cvpr/CM_22/doctr/references/recognition/pytorch_single_lang.py", line 498, in run_train_model
    mp.spawn(train_func, args=(args,), nprocs=args.world_size, join=True)
  File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 240, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/media/cvpr/CM_22/doctr/references/recognition/pytorch_single_lang.py", line 410, in main
    val_loss, exact_match, partial_match = evaluate(model, val_loader, batch_transforms, val_metric, amp=args.amp)
  File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/media/cvpr/CM_22/doctr/references/recognition/pytorch_single_lang.py", line 191, in evaluate
    cleanup()
  File "/media/cvpr/CM_22/doctr/references/recognition/pytorch_single_lang.py", line 50, in cleanup
    dist.destroy_process_group()
  File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 797, in destroy_process_group
    assert pg is not None
AssertionError