Let custom autograd Function know if create_graph is enabled

Assume you have a custom autograd.Function class that during backward calls autograd.grad.
How can this call to backward know if in the ‘main code’ the option create_graph was set on?
One option could be to always set on create_graph in the backward function but that has obviously unwanted consequences, especially in term of memory usage.
Here is the solution i came up with, not super elegant i must admit.

import torch
x = torch.tensor([1,2.0,3.0], requires_grad=True)

class set_create_graph:
    def __init__(self, value):
        self.value = value
    def __enter__(self, *args):
        global CREATE_GRAPH
        self.prev = CREATE_GRAPH
    def __exit__(self, *args):
        global CREATE_GRAPH
class f(torch.autograd.Function):
    def forward(ctx, x):
        with torch.set_grad_enabled(True):
            y = x.tanh()
        return y.detach().requires_grad_(y.requires_grad)
    def backward(ctx, gd):
        with torch.set_grad_enabled(True):
            x, y = ctx.saved_tensors
            return torch.autograd.grad(y, x, gd, True, CREATE_GRAPH)
s = f.apply(x)
with set_create_graph(True):
    g = torch.autograd.grad(s, x, torch.ones_like(s, requires_grad=True), True, True)
torch.autograd.grad(g, x, torch.ones(3), True)

Would be nice to have a way to replace this global variable CREATE_GRAPH by something that would not require me to put things in a specific context.
Thanks for the help!


You can check if torch.is_grad_enabled() in the backward and if gd.requires_grad.
That will tell if you something wants the gradients to be computed for your function. Namely if grad mode is enabled and the input requires_grad then you should create the graph. Otherwise, it is not needed.

1 Like