My optimization problem is as follows:

```
a = torch.rand(2, 2, requires_grad=True)
b = torch.rand(1, requires_grad=True)
# f is a very time-consuming function which I don't
# want to do it every iteration.
c = f(b)
for i in range(N_epoch):
for j in range(N_iteration):
# loop over all batches
L = loss(data[j], a, c)
L.backward()
# udpate a
a -= lr_a * a.grad
# update b and calculate c
b -= lr_b * b.grad
c = f(b)
```

The function to calculate intermediate variable `c`

is very time-consuming. I want to accumulate the gradient of `b`

obtained from each data batch and update it at the end of every epoch. However, this will cost error because computational graph is destroyed when `backward`

is called. Could anyone give me some advice?