How to apply DDP in an algorithm which uses model parameters outside the forward function

Hi, I find it quite difficult to use DDP to train a model with an additional loss function outside the forward function.


Training Procedure

The model (M, based on ProxylessNAS) has two sets of parameters,

  • neural network weights W,
  • architecture parameters (operator weights) A,

The steps to update A are,

  1. randomly sample a sub-network Msub, with parameters Wsub, based on the probability matrix A
  2. loss1 = Msub(data)
  3. loss2 is directly calculated from A, e.g. Latency(A) = 3xA01 + 2xA02
  4. loss = f(loss1, loss2); loss.backward()
  5. update A

The Problem

I set find_unused_parameters=True and it raises an error during backward propagation.

RuntimeError: Expected to mark a variable ready only once. This error is caused by use of a module parameter outside the `forward` function. The return value of the `forward` function is inspected by the distributed data parallel wrapper to figure out if any of the module's parameters went unused. If this is the case, it knows they won't receive gradients in a backward pass. If any of those parameters are then used outside `forward`, this error condition is triggered. You can disable unused parameter detection by passing the keyword argument `find_unused_parameters=False` to `torch.nn.parallel.DistributedDataParallel`.

The problem is since only a part of the model (Msub) is used in each iteration, DDP won’t get gradients of those parameters not belong to Wsub. If I set find_unused_parameters=False, it will crash in next forward pass.

Does anyone have any idea to solve the problem?

In this case, I’m assuming that A is a part of the network M, although during backwards, since according to DDP, A went unused, it is marked as ready for reduction, but then we attempt to re-mark A’s parameters when it gets grads in the backwards pass.

Is it possible to create a wrapper network that wraps steps 2 and 3 into a single module and returns both loss1 and loss2/the values needed to compute it? That seems like it would avoid the double mark issue.

Also, the version of PyTorch you are using and a script to repro the issue would also be valuable.

1 Like

In this case, I’m assuming that A is a part of the network M, although during backwards, since according to DDP, A went unused, it is marked as ready for reduction, but then we attempt to re-mark A’s parameters when it gets grads in the backwards pass.

Yes, this is the cause. Thanks for the explanation.

Is it possible to create a wrapper network that wraps steps 2 and 3 into a single module and returns both loss1 and loss2/the values needed to compute it? That seems like it would avoid the double mark issue.

That’s a good idea. Although it means a lot of work :rofl:.

I’m using pytorch 1.4. The code is a little bit messy, I’ll try to make a demo if above suggestion not works.