Hi, I am currently implemented Trust Region Policy Optimization using the newest version of pytorch (0.2), which support higher order gradients.
I spend several hours to debug my code and find it’s mainly due to one flag I didn’t pass.
My code is kind of like this:
surr_loss = surr_loss_fn(policy)
surr_loss.backward()
kl_div = kl_div_func(policy)
kl_div.backward(create_graph=True)
grad = torch.cat([param.grad.view(-1) for param in policy.parameters()])
gradient_vector_product = torch.sum(grad * Variable(vector))
gradient_vector_product.backward()
I find this is not working, and the code will complains
RuntimeError: element 0 of variables tuple is volatile
on the last backward()
.
If I insert this before the last two lines of code
grad.volatile = False
then it will complain
RuntimeError: there are no graph nodes that require computing gradients
Finally, I find the way to have the code not to complain is to following at the second line of code:
surr_loss.backward(create_graph=True)
I am curious if this is a desired behavior of getting higher order gradient? Or in other words, did I do this right?
And if it is, what is the reason for setting create_graph
flag in the first backward call?