How to detect the Variable(s) in the computational graph that force(s) 'retain_graph=True'?


I am trying to train a model that requires ‘mini-batches of batches’ (something like columns of each row of a table ). Mini-batch is each row. Batch within a mini-batch are the column values for that row. However, the error should be back-propagated once per mini-batch (row). So, it needs a sum over the losses for each column within a row and then back-propagate.

For this design, I observe that PyTorch requires ‘retain_graph’ to be ‘True’ from mini-batch to mini-batch (row to row). This eventually leads to out-of-memory error for the GPU. How can I track the ‘Variables’ that are forcing ‘retain_graph’ to be ‘True’? How can I circumvent this issue?


You are reusing a non-leaf variable. Consider b in the following:

want_to_fail = True
a = torch.randn(5,5, requires_grad=True)
b = a**2

c = b.sum()

if not want_to_fail:
    b = a**2

d = b.sum()

There are more subtle ways to trick yourself into having a non-leaf variable, e.g. a = torch.randn(5, 5, requires_grad=True).cuda(). The .cuda() is computation as far as autograd is concerned, so the result a is non-leaf.

You can tell whether a given variable is a leaf or not by checking a.grad_fn. If it has a grad_fn that is not None, it requires grad and is not a leaf. These are precisely the variables that you cannot reuse after your first backprop.

Best regards