Autograd.grad throws runtime error in DistributedDataParallel

To train a WGAN-GP in DistributedDataParallel,I met several errors in coding.

fake_g_pred=self.model['D'](outputs)
gen_loss=self.criteron['adv'](fake_g_pred,True)
loss_g.backward()
self.optimizer['G'].step()
self.lr_scheduler['G'].step()

#hat grad penalty
epsilon=torch.rand(images.size(0),1,1,1).to(self.device).expand(images.size())
hat=(outputs.mul(1-epsilon)+images.mul(epsilon))
hat=torch.autograd.Variable(hat,requires_grad=True)
dis_hat_loss=self.model['D'](hat)
grad=torch.autograd.grad(
                outputs=dis_hat_loss,inputs=hat,
                grad_outputs=torch.ones_like(dis_hat_loss),
                retain_graph=True,create_graph=True,only_inputs=True
            )[0]
grad_penalty=((grad.norm(2,dim=1)-1)**2).mean()
grad_penalty.backward()

While the GANis able to run one iter (only one batch in dataloader),the second iter reports an error.

RuntimeError: Expected to have finished reduction in the prior 
iteration before starting a new one.This error indicates that your module 
has parameters that were not used in producing loss. You can enable 
unused parameter detection by (1) passing the keyword argument 
`find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`; 
(2) making sure all `forward` function outputs participate in calculating loss. 
If you already have done the above two steps, then the distributed data parallel 
module wasn't able to locate the output tensors in the return 
value of your module's `forward` function. 
Please include the loss function and the structure 
of the return value of `forward` of your module
 when reporting this issue (e.g. list, dict, iterable).

It throws an error at line

fake_g_pred=self.model['D'](outputs)

Seems that’s an intrinsic bug in DDP model,Can anyone tell me how to debug it or any alternatives to
train a WGAN with Gradient Penalty in distributeddataparallel way?

Hi,

I am afraid this is expected as DDP only works with .backward() and does not support autograd.grad.
You will have to use .backward() for it to work properly.