Hi! I have a question about using autograd in an iterative program.
Here I have a simulation program consists of three custom modules and two built-in nn modules. The custom modules describe some physical process and the nn modules serve as controller. In the forward pass, the modules are executed for 100 steps from initial state to a final state. The states computed in the custom modules depend on the number of iteration t
. The loss is defined as the difference between the final state and a target. And I would like to optimize the weights of the nn module.
- I am wondering how will autograd execute the backward functions of this program? I assume that it will execute all backward functions 100 times, i.e., from step 100 to step 1, however in my test it seems only execute once which confuses me.
- How will the overwritten tensor such as
ret
be recorded? Do I need to keep a different version of it for each step?
It will be appreciated if someone can help explain about the underlying mechanism.
Thanks!
steps = 100
for t in range(steps):
out = self.custom_module_1(t, input1)
out1 = self.nn1(out)
out2= self.nn2(out1)
out3 = self.custom_module_2(t, out2)
ret = self.custom_module_3(t, out3)
loss_tensor = self.compute_loss(ret, target)
loss_tensor.backward()
=================== Update ======================
I found that it was because the torch is unware of the iterative scheme so that the backward only executes once. And I change the code a bit below, using a double-buffering to tell torch about the iterative scheme. Now I can get the expected backward graph, however the graph is extremely large and the backward process is slow. I am wondering if there are any best practices for differentiating iterative program?
for t in range(100):
out = self.custom_module_1(t, state0)
out1 = self.nn1(out)
out2= self.nn2(out1)
out3 = self.custom_module_2(t, out2)
state1 = self.custom_module_3(t, out3)
state0 = state1
loss_tensor = self.compute_loss(state1)
loss_tensor.backward()