I have a function that solves a fairly simple optimization problem in the following way:
def my_func(input):
error = 1
z = torch.autograd.Variable(some tensor, requires_grad=True)
opt = torch.optim.SGD([z], lr=0.01)
while error > 1E-4:
error = some calculation of input and z
with torch.no_grad():
opt.zero_grad()
error.backward()
opt.step()
If I call this function, everything works fine and I don’t get any errors.
However, I need to call this function in forward() of an nn.Module that I wrote. As soon as I call the my_func from within forward(), I receive the following error:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
Note: I do not need to compute gradients w.r.t. the output of my_func.