How autograd execute backward functions in an iterative program

Hi! I have a question about using autograd in an iterative program.
Here I have a simulation program consists of three custom modules and two built-in nn modules. The custom modules describe some physical process and the nn modules serve as controller. In the forward pass, the modules are executed for 100 steps from initial state to a final state. The states computed in the custom modules depend on the number of iteration t. The loss is defined as the difference between the final state and a target. And I would like to optimize the weights of the nn module.

  1. I am wondering how will autograd execute the backward functions of this program? I assume that it will execute all backward functions 100 times, i.e., from step 100 to step 1, however in my test it seems only execute once which confuses me.
  2. How will the overwritten tensor such as ret be recorded? Do I need to keep a different version of it for each step?

It will be appreciated if someone can help explain about the underlying mechanism.

Thanks!

steps = 100
for t in range(steps):
    out = self.custom_module_1(t, input1)
    out1 = self.nn1(out)
    out2= self.nn2(out1)
    out3  = self.custom_module_2(t, out2)
    ret = self.custom_module_3(t, out3)

loss_tensor = self.compute_loss(ret, target)
loss_tensor.backward()

=================== Update ======================
I found that it was because the torch is unware of the iterative scheme so that the backward only executes once. And I change the code a bit below, using a double-buffering to tell torch about the iterative scheme. Now I can get the expected backward graph, however the graph is extremely large and the backward process is slow. I am wondering if there are any best practices for differentiating iterative program?

for t in range(100):
    out = self.custom_module_1(t, state0)
    out1 = self.nn1(out)
    out2= self.nn2(out1)
    out3  = self.custom_module_2(t, out2)
    state1 = self.custom_module_3(t, out3)
    state0 = state1

loss_tensor = self.compute_loss(state1)
loss_tensor.backward()

According to your code, only one backward call is being made outside the loop.

You could shift the backward call inside the loop if you want multiple backward passes.

PS - Also, don’t forget to use optimizer.zero_grad() in that case to clear out the grad of parameters else the gradients from multiple backward calls will get accumulated leading to wrong updates via optimizer.step() that should also be in the loop ideally.

Thanks for the reply! I would like to have only one backward pass here. And I have made an updated version (please refer to the updates under the original question) which makes the computation graph connected between iterations. However the backward pass becomes very slow due to the large graph probably. I am wondering if there are any best practices for differentiating iterative program?