Model parallelism on 2 gpu's and how to load the model state dictionary to cpu

This is a simple model based on model parallelism which runs on gpu’s 0 and 1. How to save the model after training and load it back so that I can test my model on cpu.

class Net(nn.Module):
    def __init__(self, gpu0, gpu1):
        super(Net, self).__init__()
        self.gpu0 = gpu0
        self.gpu1 = gpu1
        self.conv = nn.Sequential( nn.Conv2d(1, 32, 3, 1), nn.ReLU(), nn.Conv2d(32, 64, 3, 1), nn.ReLU(),
                                 nn.MaxPool2d(2), nn.Dropout(0.25), nn.Flatten(1),
                                 ).cuda(gpu0)
        self.feat = nn.Sequential(nn.Linear(9216, 128), nn.BatchNorm1d(128),
                                  nn.ReLU(), nn.Dropout2d(0.5), nn.Linear(128, 10)
                                 ).cuda(gpu1)
        
        

    def forward(self, x):
        x = self.conv(x).cuda(self.gpu1)
        x = self.feat(x)
        output = F.log_softmax(x, dim=1)
        return output

You could push all parameters and buffers back to the CPU and store the state_dict:

# after training
model.cpu()
torch.save(model.state_dict(), 'filename.pth')

In another script you would be able to create the model instance, load the state_dict, and perform the inference.
However, unfortunately you have defined the cuda() calls explicitly in your model, so that creating the model instance would always try to push the model to the GPU(s).
I would generally recommend to try to write device-agnostic code via to() and pass the device into the __init__ instead of the GPU id.

@ptrblck Model is saved as DDP module and I get this error

 if (gpu0 == 0 and epoch == 4):
            model.cpu()
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer' : optimizer.state_dict(),
            }, epoch)
class Net(nn.Module):
    def __init__(self, gpu0, gpu1):
        super(Net, self).__init__()
        if gpu0 != "cpu":
            self.gpu0 = "cuda:"+str(gpu0)
            self.gpu1 = "cuda:"+str(gpu1)
        else:
            self.gpu0 = "cpu"
            self.gpu1 = "cpu"
        self.conv = nn.Sequential( nn.Conv2d(1, 32, 3, 1), nn.ReLU(), nn.Conv2d(32, 64, 3, 1), nn.ReLU(),
                                 nn.MaxPool2d(2), nn.Dropout(0.25), nn.Flatten(1),
                                 ).to(self.gpu0)
        self.feat = nn.Sequential(nn.Linear(9216, 128), nn.BatchNorm1d(128),
                                  nn.ReLU(), nn.Dropout2d(0.5), nn.Linear(128, 10)
                                 ).to(self.gpu1)
        
  def forward(self, x):
        x = self.conv(x).to(self.gpu1)
        x = self.feat(x)
        output = F.log_softmax(x, dim=1)
        return output

model = Net("cpu", "cpu")
PATH = '../model_4.pth'
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['state_dict'])

Error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-3-ec88a5c31813> in <module>
      2 PATH = '../model_4.pth'
      3 checkpoint = torch.load(PATH)
----> 4 model.load_state_dict(checkpoint['state_dict'])
      5 #optimizer.load_state_dict(checkpoint['optimizer'])

~/.conda/envs/praveen_tf/lib/python3.6/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1043         if len(error_msgs) > 0:
   1044             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1045                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1046         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1047 

RuntimeError: Error(s) in loading state_dict for Net:
	Missing key(s) in state_dict: "conv.0.weight", "conv.0.bias", "conv.2.weight", "conv.2.bias", "feat.0.weight", "feat.0.bias", "feat.1.weight", "feat.1.bias", "feat.1.running_mean", "feat.1.running_var", "feat.4.weight", "feat.4.bias". 
	Unexpected key(s) in state_dict: "module.conv.0.weight", "module.conv.0.bias", "module.conv.2.weight", "module.conv.2.bias", "module.feat.0.weight", "module.feat.0.bias", "module.feat.1.weight", "module.feat.1.bias", "module.feat.1.running_mean", "module.feat.1.running_var", "module.feat.1.num_batches_tracked", "module.feat.4.weight", "module.feat.4.bias". 
PATH = '../model_4.pth'
checkpoint = torch.load(PATH)
state_dict = checkpoint['state_dict']
for k, v in state_dict.items():
    print(k, v.get_device())

Output:

module.conv.0.weight -1
module.conv.0.bias -1
module.conv.2.weight -1
module.conv.2.bias -1
module.feat.0.weight -1
module.feat.0.bias -1
module.feat.1.weight -1
module.feat.1.bias -1
module.feat.1.running_mean -1
module.feat.1.running_var -1
module.feat.1.num_batches_tracked -1
module.feat.4.weight -1
module.feat.4.bias -1
model = Net("cpu", "cpu")
model_dict = model.state_dict()
for k, v in model_dict.items():
    print(k, v.get_device())

Output:

conv.0.weight -1
conv.0.bias -1
conv.2.weight -1
conv.2.bias -1
feat.0.weight -1
feat.0.bias -1
feat.1.weight -1
feat.1.bias -1
feat.1.running_mean -1
feat.1.running_var -1
feat.1.num_batches_tracked -1
feat.4.weight -1
feat.4.bias -1

Solved the problem by doing this @ptrblck Thank you for all the help

state_dict = checkpoint['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v