Hi
Assume you have a custom autograd.Function class that during backward calls autograd.grad.
How can this call to backward know if in the ‘main code’ the option create_graph was set on?
One option could be to always set on create_graph in the backward function but that has obviously unwanted consequences, especially in term of memory usage.
Here is the solution i came up with, not super elegant i must admit.
import torch
x = torch.tensor([1,2.0,3.0], requires_grad=True)
CREATE_GRAPH=False
class set_create_graph:
def __init__(self, value):
self.value = value
def __enter__(self, *args):
global CREATE_GRAPH
self.prev = CREATE_GRAPH
CREATE_GRAPH=self.value
def __exit__(self, *args):
global CREATE_GRAPH
CREATE_GRAPH=self.prev
class f(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
with torch.set_grad_enabled(True):
y = x.tanh()
ctx.save_for_backward(x,y)
return y.detach().requires_grad_(y.requires_grad)
@staticmethod
def backward(ctx, gd):
with torch.set_grad_enabled(True):
x, y = ctx.saved_tensors
return torch.autograd.grad(y, x, gd, True, CREATE_GRAPH)
s = f.apply(x)
with set_create_graph(True):
g = torch.autograd.grad(s, x, torch.ones_like(s, requires_grad=True), True, True)
torch.autograd.grad(g, x, torch.ones(3), True)
Would be nice to have a way to replace this global variable CREATE_GRAPH by something that would not require me to put things in a specific context.
Thanks for the help!