InstanceNorm resulting in "Trying to backward through the graph a second time" when drop-in replacing BatchNorm?

I’m working with a WGAN where the model is initially using BatchNorm2Ds. Using the BatchNorm2Ds, my code runs without a problem. However, when I try to drop-in replace the BatchNorm2Ds with InstanceNorm2Ds, then I end up with the Trying to backward through the graph a second time error.

Having worked with WGANs, I’ve seen this error before. However, I don’t believe it should occur when replacing BatchNorm2Ds with InstanceNorm2Ds, and applying no other changes. Attempting to debug the issue, I added retain_graph=True to all my backward()s and to the WGAN’s torch.autograd.grad(), but still the error occurs.

Does anyone happen to have an explanation as to why this should occur for InstanceNorm when it doesn’t for BatchNorm? And how I might be able to resolve this? Thank you!

To update, I had added retain_graph=True to all backwards except the final one (which should never need to retrain the graph. This is the backward used in the gradient penalty after the torch.autograd.grad. Notably, this is also the backward where the error occurs. However, adding retain_graph=True to this backward fixes the problem. This seems strange to me, since it means the same backward which is trying to use the freed graph is also the backward which is doing the undesired freeing. Does anyone know why this might be happening? Again, it is only occurring when using InstanceNorm to replace BatchNorm. This suggests to me that InstanceNorm is doing something strange when used in a second order gradient.

Hi, I also got into such a situation. By retain_graph=True when calling backward it fixed this problem. Thanks.