Hi! I have a question about using autograd in an iterative program.
Here I have a simulation program consists of three custom modules and two built-in nn modules. The custom modules describe some physical process and the nn modules serve as controller. In the forward pass, the modules are executed for 100 steps from initial state to a final state. The states computed in the custom modules depend on the number of iteration
t. The loss is defined as the difference between the final state and a target. And I would like to optimize the weights of the nn module.
- I am wondering how will autograd execute the backward functions of this program? I assume that it will execute all backward functions 100 times, i.e., from step 100 to step 1, however in my test it seems only execute once which confuses me.
- How will the overwritten tensor such as
retbe recorded? Do I need to keep a different version of it for each step?
It will be appreciated if someone can help explain about the underlying mechanism.
steps = 100 for t in range(steps): out = self.custom_module_1(t, input1) out1 = self.nn1(out) out2= self.nn2(out1) out3 = self.custom_module_2(t, out2) ret = self.custom_module_3(t, out3) loss_tensor = self.compute_loss(ret, target) loss_tensor.backward()
=================== Update ======================
I found that it was because the torch is unware of the iterative scheme so that the backward only executes once. And I change the code a bit below, using a double-buffering to tell torch about the iterative scheme. Now I can get the expected backward graph, however the graph is extremely large and the backward process is slow. I am wondering if there are any best practices for differentiating iterative program?
for t in range(100): out = self.custom_module_1(t, state0) out1 = self.nn1(out) out2= self.nn2(out1) out3 = self.custom_module_2(t, out2) state1 = self.custom_module_3(t, out3) state0 = state1 loss_tensor = self.compute_loss(state1) loss_tensor.backward()