Precomputing part of a loss function

I’m trying to use pytorch to calculate a complex function but still require the gradient of the output as a function of the inputs. For example:

a=torch.tensor([1,2,3], dtype=torch.float32, requires_grad=True)
b=torch.tensor([3,2,1], dtype=torch.float32, requires_grad=True)
v=torch.tensor([0.2], dtype=torch.float32, requires_grad=True)

# here precalc represents some (fairly expensive) sequence of operations
precalc = a.dot(b)+a*b+a*a+b*b

def calc(precalc, v):
    z=torch.randn(3,1000)
    batch=v*precalc.matmul(z)
    return torch.relu(batch).mean()

So the precalc tensor is a fixed function of a and b (but expensive to compute).

When I call calc for the first time as x=calc(precalc, v) and then x.backward(), I get the correct gradients for a,b and v. However, if I call calc a second time, i.e. x2=calc(precalc, v) and then x2.backward(), I get the pytorch error that I’m going backward through the graph a second time and I should retain the graph.

I would ideally like to free the graph as precalc is expensive to compute but small in memory and the calc itself is easier to compute but will use a lot of memory.

Is there a way to achieve this?

Thanks.