Saving state dict with optimizer state sharding

I am trying to use OSS to train a large model on a single machine with 4 GPUs and am running into issues when I issue the optimizer.consolidate_state_dict() call. The distributed data parallel job hangs and then finally dies after a long time. The final error log is below:

[E ProcessGroupNCCL.cpp:781] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=95188, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1807405 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:781] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=95187, OpType=ALLREDUCE, Timeout(ms)=1800000) ran for 1808126 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:781] [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=95187, OpType=ALLREDUCE, Timeout(ms)=1800000) ran for 1808124 milliseconds before timing out.

Traceback (most recent call last):
  File "train.py", line 143, in <module>
    main()
  File "base.py", line 456, in _consolidate_state_dict
    self.opt.consolidate_state_dict()
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/optim/zero_redundancy_optimizer.py", line 537, in consolidate_state_dict
    local_state_dict = _broadcast_object(
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/optim/zero_redundancy_optimizer.py", line 106, in _broadcast_object
    data_recv_tensor = torch.empty([int(length_tensor.item())], dtype=torch.uint8, device=device)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate more than 1EB memory.

It doesn’t quite make sense that it is trying to allocate 1 exabyte of memory!
This issue is with pytorch 1.13. Any clue what the issue might be?