How to update variables until all batches finished

My optimization problem is as follows:

a = torch.rand(2, 2, requires_grad=True)
b = torch.rand(1, requires_grad=True)
# f is a very time-consuming function which I don't
# want to do it every iteration.
c = f(b)
for i in range(N_epoch):
    for j in range(N_iteration):
        # loop over all batches
        L = loss(data[j], a, c)
        # udpate a
        a -= lr_a * a.grad
# update b and calculate c
b -= lr_b * b.grad
c = f(b)

The function to calculate intermediate variable c is very time-consuming. I want to accumulate the gradient of b obtained from each data batch and update it at the end of every epoch. However, this will cost error because computational graph is destroyed when backward is called. Could anyone give me some advice?

@ ptrblck Could you please give me some advice? Thanks in advance.

I’m not sure how your use case works exactly, but

to keep the computation graph (and the intermediate tensors) alive, you could use backward(retain_graph=True).

OK, thanks for your advice. The real forward model is complex and the data is big, so I’m afraid this will cuase memory problem. I will try it anyway.