I’m trying to use pytorch to calculate a complex function but still require the gradient of the output as a function of the inputs. For example:
a=torch.tensor([1,2,3], dtype=torch.float32, requires_grad=True)
b=torch.tensor([3,2,1], dtype=torch.float32, requires_grad=True)
v=torch.tensor([0.2], dtype=torch.float32, requires_grad=True)
# here precalc represents some (fairly expensive) sequence of operations
precalc = a.dot(b)+a*b+a*a+b*b
def calc(precalc, v):
z=torch.randn(3,1000)
batch=v*precalc.matmul(z)
return torch.relu(batch).mean()
So the precalc tensor is a fixed function of a and b (but expensive to compute).
When I call calc for the first time as x=calc(precalc, v)
and then x.backward()
, I get the correct gradients for a,b and v. However, if I call calc a second time, i.e. x2=calc(precalc, v)
and then x2.backward()
, I get the pytorch error that I’m going backward through the graph a second time and I should retain the graph.
I would ideally like to free the graph as precalc is expensive to compute but small in memory and the calc itself is easier to compute but will use a lot of memory.
Is there a way to achieve this?
Thanks.