I have a function that solves a fairly simple optimization problem in the following way:
def my_func(input): error = 1 z = torch.autograd.Variable(some tensor, requires_grad=True) opt = torch.optim.SGD([z], lr=0.01) while error > 1E-4: error = some calculation of input and z with torch.no_grad(): opt.zero_grad() error.backward() opt.step()
If I call this function, everything works fine and I don’t get any errors.
However, I need to call this function in forward() of an nn.Module that I wrote. As soon as I call the my_func from within forward(), I receive the following error:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
Note: I do not need to compute gradients w.r.t. the output of my_func.