class Block(nn.Module):
def __init__(self, net):
super(Block, self).__init__()
self.net = net
self.net_copy = copy.deepcopy(net)
def forward(self, x):
self.net_copy.load_state_dict(self.net.state_dict())
return self.net(x)
The net is an nn.Sequential() module. When I use Pytorch>=1.5 and use nn.DataParallel in multi-GPUs, It shows that net_copy.state_dict().keys() is different with net.state_dict().keys(). However, when I use Pytorch==1.4 or single-GPU, this problem doesn’t appear. How can I make sure that net and net_copy is exactly the same?
In v1.5, parameters on replicated models are no longer considered as leaves, as they shouldn’t be. If you really need to access those replicated parameters, you probably can get them from _former_parameters and manually add them into the stat_dict?
cc @ngimel please correct me if I am wrong. And any thoughts on whether we should make state_dict() consistent between v1.4 vs v1.5?
In order to access the _former_parameters, we would need to access replica, right? Can you help me figure out how to access _former_parameters in OP’s example?
Or how to recreate state dict in some other manner?
Below is the implementation of the DataParallel.forward method. It basically calls replicas[i].forward(inputs[i], ...). So during execution, the self variable in the forward function is the replica. Hence, you can use self._former_parameters to access the field in forward function.
I managed to recreate the state_dict using code similar to state_dict. Thanks for your help.
I noticed that _former_parameters exists in 1.5.1 but not in 1.5.0. It seems tricky to get the parameters in 1.5.0 if we do not know the names of the parameters in advance (but still possible since we are setting attr). Any suggestions for this?
Thanks a lot. I have one last question. Like the OP, I need to recreate the state dict every time in the forward pass. I see about 8x increase in training time when compared to original PyTorch DataParallel. Any ideas why this might be the case?
def create_state_dict_new(main_module):
state_dict_data = OrderedDict()
def state_dict_recursion(this_module, state_dict_data, prefix=''):
if hasattr(this_module,"_former_parameters"):
for name, param in this_module._former_parameters.items():
if param is not None:
state_dict_data[prefix + name] = param
for name, buf in this_module._buffers.items():
if buf is not None:
state_dict_data[prefix + name] = buf
for name, module in this_module._modules.items():
if module is not None:
state_dict_recursion(module, state_dict_data, prefix + name + '.')
state_dict_recursion(main_module._modules['model'], state_dict_data)
return state_dict_data
class ModelWrapper(torch.nn.Module):
def __init__(self, model):
super(ModelWrapper, self).__init__()
self.model = model
def forward(self, x):
state_list = create_state_dict_new(self)
return model(x)
model = torch.nn.DataParallel(ModelWrapper(model))
Could you please measure the time spent on the create_state_dict_new?
The forward function will be launched in each thread. If you have 4 GPUs, it means that there will be 4 threads executing that create_state_dict_new independently. However, due to Python GIL, the 4 threads cannot run the function concurrently, which would further exacerbate the delay.