My optimization problem is as follows:
a = torch.rand(2, 2, requires_grad=True)
b = torch.rand(1, requires_grad=True)
# f is a very time-consuming function which I don't
# want to do it every iteration.
c = f(b)
for i in range(N_epoch):
for j in range(N_iteration):
# loop over all batches
L = loss(data[j], a, c)
L.backward()
# udpate a
a -= lr_a * a.grad
# update b and calculate c
b -= lr_b * b.grad
c = f(b)
The function to calculate intermediate variable c
is very time-consuming. I want to accumulate the gradient of b
obtained from each data batch and update it at the end of every epoch. However, this will cost error because computational graph is destroyed when backward
is called. Could anyone give me some advice?