Strange behavior of inplace operations and backwards in WGAN-GP

Recently I’ve worked on implementing a WGAN-GP ( by myself.
For those of you who are not familiar with WGAN-GP, it uses a gradient penalty to enforce the Lipschitz-1 constraint on the WGAN’s discriminators function instead of weight clipping.

This gradient penalty term is a function of the gradient of the discriminator’s output on some special kind of input with respect to that input. Thus, we are essentially taking a second derivative, for which we use the function autograd.grad.

My original problem was, that even though I pass the argument retain_graph=True to grad, when I call backward on the loss function I encounter the error ‘Trying to backward through the graph a second time…’.
From debugging, I learned that the error was caused by my model containing a residual layer, that preformed inplace ReLU. I removed the inplace operation and my code works, but I’m still not sure what exactly is the error that is happening and it really annoys me.

I have constructed a toy example that to my understating suffers from the same issue, can anyone please help me understand what exactly is the error occurring?

import torch
from torch import nn
from torchviz import make_dot

def double_backprop(inputs, net):
    y = net(x).mean()
    grad,  = torch.autograd.grad(y, x, create_graph=True, retain_graph=True)
    return grad.pow(2).mean() + y

class TestNet(nn.Module):
    A network for testing double backprop
    def __init__(self):
        super(TestNet, self).__init__()

    def forward(self, input):
        output = input.transpose(1, 2)
        output = nn.Conv1d(4, 100, 1)(output)
        # If I remove either the second ReLU layer, or the inplace argument, this works. 
        output = nn.ReLU(True)(output)
        output = nn.ReLU()(output)
        output = output.view(-1, 500)
        output = nn.Linear(500, 1)(output)
        return output

model = TestNet()
x = torch.randn((64,50,4),requires_grad=True)

out = double_backprop(x, model)
# make_dot(out)

OMG. This is probably another instance of I finally found a second person that sees this problem!!!

Yeah it definitely seems like the same thing!

How do you approach do debug such problems? I would love to be a able to understand more about such issues by myself. As you might understand from my code, I tried graphing the computational graph using graphviz and debugging, but it wasn’t really helpful since the graph does not show which buffers are still in memory from my understanding. Any tips?

I usually just use gdb and print statements. :slight_smile: They may not be the most productive tools, but they usually do the job. But in this case the graph is complicated (double backward + inplace operations on views), so I debugged for an afternoon and put it down for more urgent tasks. I definitely still plan to keep trying.

About the in-place operation on views. We have a (somewhat complicated) logic on how the graph should be updated in this case. If you are interested, here is the original proposal, and here is the patch, which has evolved beyond the proposal a bit.