I’m using DDP to train Neural Architecture Search networks which contained a controller and a model network. During training, my controller predictss a model’s architecture that maximize reward. the call looks like this.
# both model and controller are torch.nn.DistributedDataParallel
arch = controller.forward(conditions)
model.module.set_arch(arch) # modified model internal architecture.
output = model.forward(input)...
However, in DDP docs I noticed the following:
… warning::
You should never try to change your model’s parameters after wrapping
up your model with DistributedDataParallel. In other words, when
wrapping up your model with DistributedDataParallel, the constructor of
DistributedDataParallel will register the additional gradient
reduction functions on all the parameters of the model itself at the
time of construction. If you change the model’s parameters after
the DistributedDataParallel construction, this is not supported and
unexpected behaviors can happen, since some parameters’ gradient
reduction functions might not get called.
So I’m just wondering what is the correct way to do this? or if NAS is not suitable with DDP.
model.module.set_arch(arch) # modified model internal architecture.
By doing the above, are you removing parameters from the model or adding new parameters into the model? If yes, then it won’t work with DDP, as DDP creates communication buckets at construction time using the parameters returned by model.parameters() field. Hence, if the model.parameters() returns a different set of parameters, DDP won’t adapt to it.
To make it work, you can create a new DDP instance using the modified model whenever the model gets updated. But all DDP processes need to do the same at the same time using the same model.
If it just changes the value of those parameters, it should be fine.
I believe this is replacing. You can use self._arch.copy_(arch) to override the value. See the code below.
import torch
x = torch.zeros(2, 2)
y = torch.ones(2, 2)
print("x storage: ", x.data_ptr())
print("y storage: ", y.data_ptr())
x = y
print("x storage: ", x.data_ptr())
z = torch.zeros(2, 2) + 2
print("z storage: ", z.data_ptr())
x.copy_(z)
print("x storage: ", x.data_ptr())
print(x)
outputs are:
x storage: 94191491020800 y storage: 94191523992320
x storage: 94191523992320
z storage: 94191523994816
x storage: 94191523992320
tensor([[2., 2.],
[2., 2.]])
This might be it. if DDP wrapper kept a ptr to my arch settings, then it will not see the new value since it with a different pointer.
So does that mean that DDP.module params is a stale copy of our model??
So does that mean that DDP.module params is a stale copy of our model??
I believe so. As DDP remembers the variables at construction time:
And there might be more than that. DDP might not be able to read that value at all. Because DDP registers a backward hook on each parameter, and relying on that hook to notify DDP when and what to read. Those hooks are installed at DDP construction time as well. If you create a new variable and assign it to self._arch, that hook might be lost.
cc @albanD is the above statement on variable hook correct?
IIUC, that will still remove DDP autograd hooks on self._arch.
Question, do you need the backward pass to compute the gradients for self._arch? If not, you can explicitly setting self._arch.requires_grad = False before passing the model to DDP ctor to tell DDP to ignore self._arch. Then, the above assignment would work.