Hi, I am currently implemented Trust Region Policy Optimization using the newest version of pytorch (0.2), which support higher order gradients.
I spend several hours to debug my code and find it’s mainly due to one flag I didn’t pass.
My code is kind of like this:
surr_loss = surr_loss_fn(policy) surr_loss.backward() kl_div = kl_div_func(policy) kl_div.backward(create_graph=True) grad = torch.cat([param.grad.view(-1) for param in policy.parameters()]) gradient_vector_product = torch.sum(grad * Variable(vector)) gradient_vector_product.backward()
I find this is not working, and the code will complains
RuntimeError: element 0 of variables tuple is volatile
on the last
If I insert this before the last two lines of code
grad.volatile = False
then it will complain
RuntimeError: there are no graph nodes that require computing gradients
Finally, I find the way to have the code not to complain is to following at the second line of code:
I am curious if this is a desired behavior of getting higher order gradient? Or in other words, did I do this right?
And if it is, what is the reason for setting
create_graph flag in the first backward call?