Problem of inplace operation with InstanceNorm

Here is a problem I met when I implement GP-WGAN. It is something about inplace operation and second-order derivative. I simplify the problem code like below. (Pytorch version 1.1.0)

i = torch.randn((1,3,10,10), requires_grad=True)
main = nn.Sequential(
                    nn.Conv2d(3, 3, 3),
                    nn.InstanceNorm2d(3),
                    nn.ReLU(True),
                    nn.Conv2d(3, 3, 3),
                )
o = main(i)
gradients = torch.autograd.grad(outputs=o, inputs=i,
                                grad_outputs=torch.ones(o.size()),
                                retain_graph=True, create_graph=True, only_inputs=True)[0]
gradients.mean().backward()

The above code will cause

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

But if we use only first-order derivative of variable o or change the network layer part into

main = nn.Sequential(
                    nn.Conv2d(3, 3, 3),
                    nn.BatchNorm2d(3),
                    nn.ReLU(True),
                    nn.Conv2d(3, 3, 3),
                )

Or

main = nn.Sequential(
                    nn.Conv2d(3, 3, 3),
                    nn.ReLU(True),
                    nn.Conv2d(3, 3, 3),
                )

There will be no error. Of course, if change the inplace operation of ReLU to false or remove the last convolution layer, there will be no error either. Same case with LeakyReLU.
May I ask why using the InstanceNorm causes such error while no problem for BatchNorm?

1 Like