Getting a eError: Default process group has not been initialized, please make sure to call init_process_group

Hello,

I’ve been trying to move a model from a single GPU to a machine I’ve rented with four GPUs. I used the DistributedDataParralel command and I’m getting the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-30-d1064068ed08> in <module>
      6 exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
      7 combined_model = combined_model.cuda()
----> 8 combined_model = DDP(combined_model)

~/miniconda3/lib/python3.7/site-packages/torch/nn/parallel/distributed.py in __init__(self, module, device_ids, output_device, dim, broadcast_buffers, process_group, bucket_cap_mb, find_unused_parameters, check_reduction)
    271 
    272         if process_group is None:
--> 273             self.process_group = _get_default_group()
    274         else:
    275             self.process_group = process_group

~/miniconda3/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py in _get_default_group()
    266     """
    267     if not is_initialized():
--> 268         raise RuntimeError("Default process group has not been initialized, "
    269                            "please make sure to call init_process_group.")
    270     return _default_pg

RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.

Here is the code I’m using to try and parallelize the model:

from torch.nn.parallel import DistributedDataParallel as DDP
torch.manual_seed(101)
combined_model = Image_Embedd(embedding_size=train_categorical_embedding_sizes)
criterion = torch.nn.NLLLoss().cuda()
optimizer = torch.optim.Adam(combined_model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, 'min', patience = 4, verbose = True, min_lr = .00000001)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
combined_model = combined_model.cuda()
combined_model = DDP(combined_model)

Has anyone ran into this error before? I’m on a ubuntu environment with jupyter notebook.

You would have to setup your environment as described here.

5 Likes

This is a minimum example.

@ptrblck I have trained my model with DistributedDataParallel and have saved the model like this: torch.save(model.module.state_dict(), ‘TS_Model.pth’).
I want to load my model in another machine that has not GPU device, but I get various errors.

1- When I load the model like this:
model = Network()
state_dict = torch.load('TS_model.pth, map_location=torch.device(‘cpu’)).module.state_dict()
model.load_state_dict(state_dict)
model.eval()
I get this error: “Default process group has not been initialized, please make sure to call init_process_group.”

2- When I initialize the environment just like training process and then load the model, I get this error: “Distributed package doesn’t have NCCL built in”

I can run this code on my machine totally fine, but I cannot load it in another machine.
How can I overcome this issue?

The code snippets do not quite match.
Assuming you’ve stored the state_dict alone (as would be the recommended way) via:

torch.save(model.module.state_dict(), 'TS_Model.pth')

then torch.load('TS_Model.pth') should only return the state_dict (mapped to the CPU).
However, it seems you are trying to load the “entire” model via:

state_dict = torch.load('TS_model.pth', map_location=torch.device('cpu')).module.state_dict()

as you are calling .module.state_dict() on the output of torch.load.
This would mean that TS_model.pth (note the lowercase m in model) would contain the entire DDP model stored as torch.save(model, path) and would thus try to reinitialize it with the DDP setup and will fail on the other machine.

When I try
state_dict = torch.load('TS_model.pth', map_location=torch.device('cpu'))
I get this error : "Missing key(s) in state_dict: “branch1.backbone.conv1.0.weight”,“branch1.backbone.conv1.1.weight” and so on to the end of the model structure (i.e branch1.backbone.conv2 and … ).

This would mean that the keys of the stored state_dict and the current state_dict of the model do not match, which is caused by a change in the model architecture (e.g. if you are re-wrapping the model into another parent model etc.).
Check which keys are mismatches and where the additional or missing attributes are coming from.

Eventually I loaded the model using:

state_dict = torch.load('TSmodel.pth',map_location='cpu')
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    # load params
    model.load_state_dict(new_state_dict)
    model.eval()

But when I feed the input to the model I receive this error:

output = model(img[None, :], step=1)

NVTX functions not installed. Are you sure you have a CUDA build?

Did you add nvtx markers to to forward pass somewhere? If so, you would need to remove them if you want to use a pure CPU build for the model execution.

Honestly I do not know whether these markers have been added or not. How can I remove them from the environment? (Just to remind I want to load my model in another machine that has not GPU device)
BTW I have torch 1.13.0.dev20220614+cpu on this machine.

I don’t see how the state_dict loading would add them so are you even able to run the plain CPU model (without trying to load the pretrained state_dict) on the CPU using some random input?
If not, where does this model come from as PyTorch would not automatically add GPU-specific code to your model.

So I tried random input, but I have the same error.
Do I have to change the model saving code?

When I want to define my model I have to use SynchBatchNorm for norm_layer

model = Network(num_classes=11, criterion=nn.CrossEntropyLoss, pretrained_model=None, norm_layer=SyncBatchNorm)

which SyncBatchNorm uses cuda

You might want to replace it with a plain nn.BatchNormXd layer as SyncBatchNorm expects to run in a DDP setup with a single process per GPU.

Yes, eventually worked :grinning:!!

Many thanks for your help and time :wave:.

HI! I have the same problem but I don’t know why. I didn’t calling .module.state_dict() on the output of torch.load but the same error still happened.
The save code:

if (opts.local_rank == 0):
        if (opts.checkpoint_epochs != 0 and epoch % opts.checkpoint_epochs == 0) or epoch == opts.n_epochs - 1:
            print('Saving model and state...')
            torch.save(
                {
                    'model': get_inner_model(model).state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'rng_state': torch.get_rng_state(),
                    'cuda_rng_state': torch.cuda.get_rng_state_all(),
                    'baseline': baseline.state_dict()
                },
                os.path.join(opts.save_dir, 'epoch-{}.pt'.format(epoch))
            )
def get_inner_model(model):
    return model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model

The load code:

load_data = torch_load_cpu(model_filename)
model.load_state_dict({**model.state_dict(), **load_data.get('model', {})})
def torch_load_cpu(load_path):
    return torch.load(load_path, map_location=lambda storage, loc: storage)  # Load on CPU

Then I see this post,so I try

torch.load(load_path, map_location=torch.device('cpu'))

But I still get this error:
“RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.”

Your get_inner_model looks alright, but I don’t know what e.g. baseline is and if this might be raising the error.
Could you post a minimal and executable code snippet in case you get stuck, please?

Thanks for your reply! But the project is a little bit big, I couldn’t get a small snippet for you. :sob: However, you said that the problem maybe is related to the baseline part. So I checked the baseline part and maybe you said is right! The code related to load model is down below.

baseline = RolloutBaseline(model, problem, opts)
class RolloutBaseline(Baseline):
    def state_dict(self):
        return {
            'model': self.model,
            'dataset': self.dataset,
            'epoch': self.epoch
        }

    def load_state_dict(self, state_dict):
        # We make it such that it works whether model was saved as data parallel or not
        load_model = copy.deepcopy(self.model)
        get_inner_model(load_model).load_state_dict(get_inner_model(state_dict['model']).state_dict())
        self._update_model(load_model, state_dict['epoch'], state_dict['dataset'])

baseline = WarmupBaseline(baseline, opts.bl_warmup_epochs, warmup_exp_beta=opts.exp_beta)
class WarmupBaseline(Baseline):
    def state_dict(self):
        # Checkpointing within warmup stage makes no sense, only save inner baseline
        return self.baseline.state_dict()

So when I save the baseline as baseline.state_dict(), I save model but not model.module! Therefore,when I try to load the whole document(with the saved baseline in it), it will try to load the entire model. Is this true? :pleading_face:

If it is true, is there any solution to only load the model part without the baseline part? ( except add model.module to the baseline save part because the training process takes a long time :sob:)

This sounds plausible, yes.

But also it seems you are storing the entire model:

    def state_dict(self):
        return {
            'model': self.model,
            'dataset': self.dataset,
            'epoch': self.epoch
        }

instead of its state_dict.

Why wouldn’t you want to save the baseline’s state_dict, too?
I guess something like:

    def state_dict(self):
        return {
            'model': get_inner_model(self.model).state_dict(),
            'dataset': self.dataset,
            'epoch': self.epoch
        }

might work.

Thanks for your reply!! If I save the baseline’s state_dict, I have to train the model again and the training process takes a long time :sob: