Module.state_dict() is wrong when using DataParallel

I have a module like this:

class Block(nn.Module):
    def __init__(self, net):
        super(Block, self).__init__()
        self.net = net
        self.net_copy = copy.deepcopy(net)

    def forward(self, x):
        self.net_copy.load_state_dict(self.net.state_dict())
        return self.net(x)

The net is an nn.Sequential() module. When I use Pytorch>=1.5 and use nn.DataParallel in multi-GPUs, It shows that net_copy.state_dict().keys() is different with net.state_dict().keys(). However, when I use Pytorch==1.4 or single-GPU, this problem doesn’t appear. How can I make sure that net and net_copy is exactly the same?

This is probably due to this PR: https://github.com/pytorch/pytorch/pull/33907

In v1.5, parameters on replicated models are no longer considered as leaves, as they shouldn’t be. If you really need to access those replicated parameters, you probably can get them from _former_parameters and manually add them into the stat_dict?

cc @ngimel please correct me if I am wrong. And any thoughts on whether we should make state_dict() consistent between v1.4 vs v1.5?

1 Like

In order to access the _former_parameters, we would need to access replica, right? Can you help me figure out how to access _former_parameters in OP’s example?
Or how to recreate state dict in some other manner?

Hey @aashaka

Below is the implementation of the DataParallel.forward method. It basically calls replicas[i].forward(inputs[i], ...). So during execution, the self variable in the forward function is the replica. Hence, you can use self._former_parameters to access the field in forward function.

1 Like

I managed to recreate the state_dict using code similar to state_dict. Thanks for your help.

I noticed that _former_parameters exists in 1.5.1 but not in 1.5.0. It seems tricky to get the parameters in 1.5.0 if we do not know the names of the parameters in advance (but still possible since we are setting attr). Any suggestions for this?

Hey @aashaka, yep, we added _former_parameters after v1.5 to fix the regression caused on https://github.com/pytorch/pytorch/pull/33907.

If this has become very inconvenient for you, I would suggest switch to DistributedDataParallel. There are more discussions here: https://github.com/pytorch/pytorch/issues/36268

Thanks a lot. I have one last question. Like the OP, I need to recreate the state dict every time in the forward pass. I see about 8x increase in training time when compared to original PyTorch DataParallel. Any ideas why this might be the case?

def create_state_dict_new(main_module):
    state_dict_data = OrderedDict()

    def state_dict_recursion(this_module, state_dict_data, prefix=''):
        if hasattr(this_module,"_former_parameters"):
            for name, param in this_module._former_parameters.items():
                if param is not None:
                    state_dict_data[prefix + name] = param
        for name, buf in this_module._buffers.items():
            if buf is not None:
                state_dict_data[prefix + name] = buf
        for name, module in this_module._modules.items():
            if module is not None:
                state_dict_recursion(module, state_dict_data, prefix + name + '.')
    state_dict_recursion(main_module._modules['model'], state_dict_data)
    return state_dict_data

class ModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super(ModelWrapper, self).__init__()
        self.model = model
    def forward(self, x):
        state_list = create_state_dict_new(self)
        return model(x)

model = torch.nn.DataParallel(ModelWrapper(model))

Could you please measure the time spent on the create_state_dict_new?

The forward function will be launched in each thread. If you have 4 GPUs, it means that there will be 4 threads executing that create_state_dict_new independently. However, due to Python GIL, the 4 threads cannot run the function concurrently, which would further exacerbate the delay.