Model parallelism on 2 gpu's and how to load the model state dictionary to cpu

This is a simple model based on model parallelism which runs on gpu’s 0 and 1. How to save the model after training and load it back so that I can test my model on cpu.

class Net(nn.Module):
    def __init__(self, gpu0, gpu1):
        super(Net, self).__init__()
        self.gpu0 = gpu0
        self.gpu1 = gpu1
        self.conv = nn.Sequential( nn.Conv2d(1, 32, 3, 1), nn.ReLU(), nn.Conv2d(32, 64, 3, 1), nn.ReLU(),
                                 nn.MaxPool2d(2), nn.Dropout(0.25), nn.Flatten(1),
                                 ).cuda(gpu0)
        self.feat = nn.Sequential(nn.Linear(9216, 128), nn.BatchNorm1d(128),
                                  nn.ReLU(), nn.Dropout2d(0.5), nn.Linear(128, 10)
                                 ).cuda(gpu1)
        
        

    def forward(self, x):
        x = self.conv(x).cuda(self.gpu1)
        x = self.feat(x)
        output = F.log_softmax(x, dim=1)
        return output

You could push all parameters and buffers back to the CPU and store the state_dict:

# after training
model.cpu()
torch.save(model.state_dict(), 'filename.pth')

In another script you would be able to create the model instance, load the state_dict, and perform the inference.
However, unfortunately you have defined the cuda() calls explicitly in your model, so that creating the model instance would always try to push the model to the GPU(s).
I would generally recommend to try to write device-agnostic code via to() and pass the device into the __init__ instead of the GPU id.

@ptrblck Model is saved as DDP module and I get this error

 if (gpu0 == 0 and epoch == 4):
            model.cpu()
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer' : optimizer.state_dict(),
            }, epoch)
class Net(nn.Module):
    def __init__(self, gpu0, gpu1):
        super(Net, self).__init__()
        if gpu0 != "cpu":
            self.gpu0 = "cuda:"+str(gpu0)
            self.gpu1 = "cuda:"+str(gpu1)
        else:
            self.gpu0 = "cpu"
            self.gpu1 = "cpu"
        self.conv = nn.Sequential( nn.Conv2d(1, 32, 3, 1), nn.ReLU(), nn.Conv2d(32, 64, 3, 1), nn.ReLU(),
                                 nn.MaxPool2d(2), nn.Dropout(0.25), nn.Flatten(1),
                                 ).to(self.gpu0)
        self.feat = nn.Sequential(nn.Linear(9216, 128), nn.BatchNorm1d(128),
                                  nn.ReLU(), nn.Dropout2d(0.5), nn.Linear(128, 10)
                                 ).to(self.gpu1)
        
  def forward(self, x):
        x = self.conv(x).to(self.gpu1)
        x = self.feat(x)
        output = F.log_softmax(x, dim=1)
        return output

model = Net("cpu", "cpu")
PATH = '../model_4.pth'
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['state_dict'])

Error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-3-ec88a5c31813> in <module>
      2 PATH = '../model_4.pth'
      3 checkpoint = torch.load(PATH)
----> 4 model.load_state_dict(checkpoint['state_dict'])
      5 #optimizer.load_state_dict(checkpoint['optimizer'])

~/.conda/envs/praveen_tf/lib/python3.6/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1043         if len(error_msgs) > 0:
   1044             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1045                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1046         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1047 

RuntimeError: Error(s) in loading state_dict for Net:
	Missing key(s) in state_dict: "conv.0.weight", "conv.0.bias", "conv.2.weight", "conv.2.bias", "feat.0.weight", "feat.0.bias", "feat.1.weight", "feat.1.bias", "feat.1.running_mean", "feat.1.running_var", "feat.4.weight", "feat.4.bias". 
	Unexpected key(s) in state_dict: "module.conv.0.weight", "module.conv.0.bias", "module.conv.2.weight", "module.conv.2.bias", "module.feat.0.weight", "module.feat.0.bias", "module.feat.1.weight", "module.feat.1.bias", "module.feat.1.running_mean", "module.feat.1.running_var", "module.feat.1.num_batches_tracked", "module.feat.4.weight", "module.feat.4.bias". 
PATH = '../model_4.pth'
checkpoint = torch.load(PATH)
state_dict = checkpoint['state_dict']
for k, v in state_dict.items():
    print(k, v.get_device())

Output:

module.conv.0.weight -1
module.conv.0.bias -1
module.conv.2.weight -1
module.conv.2.bias -1
module.feat.0.weight -1
module.feat.0.bias -1
module.feat.1.weight -1
module.feat.1.bias -1
module.feat.1.running_mean -1
module.feat.1.running_var -1
module.feat.1.num_batches_tracked -1
module.feat.4.weight -1
module.feat.4.bias -1
model = Net("cpu", "cpu")
model_dict = model.state_dict()
for k, v in model_dict.items():
    print(k, v.get_device())

Output:

conv.0.weight -1
conv.0.bias -1
conv.2.weight -1
conv.2.bias -1
feat.0.weight -1
feat.0.bias -1
feat.1.weight -1
feat.1.bias -1
feat.1.running_mean -1
feat.1.running_var -1
feat.1.num_batches_tracked -1
feat.4.weight -1
feat.4.bias -1

Solved the problem by doing this @ptrblck Thank you for all the help

state_dict = checkpoint['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v

can you provide an example of a device agnostic model declaration. I am running into a similar issue but I am not sure how to execute what you are saying.

Generally, don’t use device-specific calls such as model.cuda() unless you require the code to run on a GPU. Device-agnostic code would allow the user to specify the device or would check e.g. if a GPU is available:

device = "cuda" if torch.cuda.is_available() else "cpu"

and then move the model and inputs to this device:

model.to(device)
data = data.to(device)

This allows user with different setups to properly execute your script.

Ok but if you have a big model that requires more than 1 gpu is there a way to still have device agnostic code ?

It depends what kind of use cases you would want to allow. E.g. I could imagine multiple GPUs could be used if the needed number of devices is detected and the code could otherwise fall back to CPU-only execution.