torch.distributed.checkpoint CUDA OOM with broadcast_from_rank0

When using broadcast_from_rank0, do all ranks except for rank0 have an empty state_dict? If yes, can you repost to Issues · pytorch/pytorch · GitHub? This sounds more like an issue/bug.