Why does requires_grad = True lead to missing keys in the state_dict only when using DataParallel?

Hi everyone! I have a problem understanding what happens to state_dicts, when using DataParallel. Once I apply DataParallel to a model with requires_grad = True, then the state_dict().keys() are empty.
My code has these two classes:

import torch
import torch.nn as nn
import torchvision
from torchvision import models

class LTE(torch.nn.Module):
    def __init__(self, requires_grad=True):
        super(LTE, self).__init__()
        vgg_pretrained_features = models.vgg19(pretrained=True).features
        self.slice1 = torch.nn.Sequential()

        for x in range(2):
            self.slice1.add_module(str(x), vgg_pretrained_features[x]) 
        if not requires_grad:
            for param in self.slice1.parameters():
                param.requires_grad = requires_grad
class TTSR(nn.Module):    
    def __init__(self):
        super(TTSR, self).__init__()
        self.LTE      = LTE(requires_grad=True) 
        self.LTE_copy = LTE(requires_grad=False)
    def forward(self, sr=None):
        print("LTE "+str(torch.cuda.current_device())+"    "+str(self.LTE.state_dict().keys()))
        print("LTE_copy "+str(torch.cuda.current_device())+"    "+str(self.LTE_copy.state_dict().keys()))

If I now initialize a TTSR model and run the forward function:

device = torch.device('cuda')
_model = TTSR().to(device)
a = _model(sr=1)

Then, this is my output:

LTE 0    odict_keys(['slice1.0.weight', 'slice1.0.bias'])
LTE_copy 0    odict_keys(['slice1.0.weight', 'slice1.0.bias'])

This is fine so far! But if I use DataParallel now (with two GPUs),

model = nn.DataParallel(_model, list(range(2)))
a = model(sr=1)

my output is as follows:

LTE 0    odict_keys([ ])
LTE_copy 0    odict_keys(['slice1.0.weight', 'slice1.0.bias'])
LTE 1    odict_keys([ ])
LTE_copy 1    odict_keys([ ])
  • Why does requires_grad = True lead to those Keys [‘slice1.0.weight’, ‘slice1.0.bias’] missing in the state_dict? This only happens once I send that model to DataParallel.
  • Did I understand it correct, that DataParallel always has one main GPU where the model is stored in (DataParallel imbalanced memory usage) and therefore the state_dicts of LTE1 & LTE_copy1 are empty?

I suspect that you don’t see the keys because of this line which is called from DataParallel.forward. Basically you should not try to access replicas on different GPUs directly, because they are handled in different way than regular modules. Also, we recommend to use DistributedDataParallel, instead of this class, to do multi-GPU training, even if there is only a single node. See: Use nn.parallel.DistributedDataParallel instead of multiprocessing or nn.DataParallel and Distributed Data Parallel.

1 Like