Model's parameters update during DDP training

I’m using DDP to train Neural Architecture Search networks which contained a controller and a model network. During training, my controller predictss a model’s architecture that maximize reward. the call looks like this.

# both model and controller are torch.nn.DistributedDataParallel
arch = controller.forward(conditions)
model.module.set_arch(arch) # modified model internal architecture.
output = model.forward(input)...

However, in DDP docs I noticed the following:

… warning::
You should never try to change your model’s parameters after wrapping
up your model with DistributedDataParallel. In other words, when
wrapping up your model with DistributedDataParallel, the constructor of
DistributedDataParallel will register the additional gradient
reduction functions on all the parameters of the model itself at the
time of construction. If you change the model’s parameters after
the DistributedDataParallel construction, this is not supported and
unexpected behaviors can happen, since some parameters’ gradient
reduction functions might not get called.

So I’m just wondering what is the correct way to do this? or if NAS is not suitable with DDP.

model.module.set_arch(arch) # modified model internal architecture.

  1. By doing the above, are you removing parameters from the model or adding new parameters into the model? If yes, then it won’t work with DDP, as DDP creates communication buckets at construction time using the parameters returned by model.parameters() field. Hence, if the model.parameters() returns a different set of parameters, DDP won’t adapt to it.
    • To make it work, you can create a new DDP instance using the modified model whenever the model gets updated. But all DDP processes need to do the same at the same time using the same model.
  2. If it just changes the value of those parameters, it should be fine.
3 Likes

can you clarify the different between modifying and replacing?

def __init__(self):
    self._arch =  torch.variable(<shape> , required_grad=True)
def set_arch(self, arch):
    self._arch = arch # is this modifying or replacing? 

I believe this is replacing. You can use self._arch.copy_(arch) to override the value. See the code below.

import torch

x = torch.zeros(2, 2)
y = torch.ones(2, 2)
print("x storage: ", x.data_ptr())
print("y storage: ", y.data_ptr())
x = y
print("x storage: ", x.data_ptr())
z = torch.zeros(2, 2) + 2
print("z storage: ", z.data_ptr())
x.copy_(z)
print("x storage: ", x.data_ptr()) 
print(x)

outputs are:

x storage:  94191491020800                                                                                                                                           y storage:  94191523992320                                                                                                                                           
x storage:  94191523992320                                                                                                                                        
z storage:  94191523994816                                                                                                                                     
x storage:  94191523992320                                                                                                                                    
tensor([[2., 2.],                                                                                                                                               
        [2., 2.]])   

This might be it. if DDP wrapper kept a ptr to my arch settings, then it will not see the new value since it with a different pointer.
So does that mean that DDP.module params is a stale copy of our model??

So does that mean that DDP.module params is a stale copy of our model??

I believe so. As DDP remembers the variables at construction time:

And there might be more than that. DDP might not be able to read that value at all. Because DDP registers a backward hook on each parameter, and relying on that hook to notify DDP when and what to read. Those hooks are installed at DDP construction time as well. If you create a new variable and assign it to self._arch, that hook might be lost.

cc @albanD is the above statement on variable hook correct?

Hi,

Yes I think this note explicitly warns you against doing this. You should not change the Parameters.

As a side note, you should never call the forward() of your module directly but call module(input).

what if i modified my forward function such that

> forward(input, arch):
>      self._arch = arch

will this works?
Also does DDP keeps DDP.module value up-to-date?

IIUC, that will still remove DDP autograd hooks on self._arch.

Question, do you need the backward pass to compute the gradients for self._arch? If not, you can explicitly setting self._arch.requires_grad = False before passing the model to DDP ctor to tell DDP to ignore self._arch. Then, the above assignment would work.

Thank you. My model is now performing as expected :slight_smile:

1 Like