def trust_region_loss(model, ref_model, distribution, ref_distribution, loss, threshold):
# Compute gradients from original loss
g = [param.grad.clone() for param in model.parameters()]
# KL divergence k ← ∇θ0∙DKL[π(∙|s_i; θ_a) || π(∙|s_i; θ)]
kl = distribution * (distribution.log() - ref_distribution.log())
# Compute gradients from (negative) KL loss (increases KL divergence)
k = [param.grad.clone() for param in model.parameters()]
# Compute dot products of gradients
k_dot_g = sum(torch.sum(k_p * g_p) for k_p, g_p in zip(k, g))
k_dot_k = sum(torch.sum(k_p ** 2) for k_p in k)
# Compute trust region update
trust_factor = k_dot_k.data > 0 and (k_dot_g - threshold) / k_dot_k or Variable(torch.zeros(1))
trust_update = [g_p - trust_factor.expand_as(k_p) * k_p for g_p, k_p in zip(g, k)]
trust_loss = 0
for param, trust_update_p in zip(model.parameters(), trust_update):
trust_loss += (param * trust_update_p).sum()
However, the eventual .backward() call on this loss results in RuntimeError: calling backward on a volatile variable, which I assume is due to trying to use gradients in this way. Adding trust_loss.volatile = False allows the backwards pass to happen, but the network doesn’t seem to improve (could be a matter of hyperparameter tuning here). Is this the correct approach?
Support for higher order differentiation was just added, but it probably won’t work yet for many architectures because support is still currently in the process of being added for several individual operators in these pulls #1507 & #1423 and a few others (e.g. conv is still incompatible with higher order differentiation and currently being modified to allow double differentiation).
I don’t think it’s a tuning issue. The trpo update just won’t be correct until you’re able to take the gradient of the gradients in g and k. Without create_graph=True, g[p].grad_fn and k[p].grad_fn will just return None instead of keeping track of the gradient with respect to the gradient.
I’m a little confused, because Chainer has no support at all for second-order derivatives – so translating a codebase from Chainer shouldn’t have to rely on PyTorch’s incipient grad of grads support, right?
So it looks to me like in the Chainer code, the variables in k and g are treated as constant and are not backpropagated through; one way to do that in the PyTorch code would be to use param.grad.detach() rather than param.grad.clone().
Nice spot! That said, when producing the grads using .detach() or .clone(), the Variables (e.g. k) are volatile, despite the losses (g and kl) being non-volatile, so it seems like something needs to be done with the grads.
Sticking trust_loss.volatile = False at the end produces RuntimeError: there are no graph nodes that require computing gradients, which seems like further evidence that something needs to be done earlier on in the function.
OK, I think I might have figured it out – each param.grad is either not volatile (if the backward pass that created it was made entirely of double-backprop-ready functions) or volatile (if the backward pass that created it included some functions that aren’t double-backprop-ready). What you want is actually for k and g to be neither volatile=False (because that would lead PyTorch to try to backpropagate through them) nor volatile=True (because that flag is greedy and would lead the whole network to stop building a graph entirely). Instead you should set the variables in both k and g to be volatile=False, requires_grad=False. You can either do this by explicitly setting the flags or by creating k and g as:
g = [Variable(param.grad.data) for param in model.parameters()]
since volatile=False, requires_grad=False is the default for new Variables.
The latter would exactly replicate what’s going on in Chainer, since Chainer’s param.grad is a tensor rather than a variable and Chainer autowraps tensors into Variables.
I’m confused as to why the kl divergence (model_avg || current_model) is calculated as kl = current_distribution * (current_distribution.log() - avg_distribution.log()). It seems to be the correct way in that it yields better performance (on pong at least), but i don’t understand why the kl is calculated in the opposite direction of that which the notation (and chainer implementation) describes. Is it because k (the gradient of the kl) is later transposed in the notation before being multiplied times g and calculating the kl divergence in the opposite direction can shortcut/skip/equate the transpose step?
I think I’ll try testing both versions (directions of kl) on RoboschoolReacher-v0. Performance on that environment is supposed to change the most drastically from TRPO.
mean(1) needs to be replaced with sum(1); <–(performance improves significantly)
Also, I tried the KL in both directions on CartPole-v0 (discrete control env for which trpo helps a lot).
KL(avg || most_recent) updated more monotonically towards max reward than KL(most_recent || avg).
Once at max reward, KL(most_recent || avg) hovers closer to max_reward but has collapses that are more catastrophic than KL(avg || most_recent)
@ethancaballero probably my bad and I put down the KL backwards, but yes it is supposed to be KL(average_distribution||current_distribution). I can’t make it public yet, but I’m working on a standalone implementation of ACER that I will open source in early June. I’ve added you as a collaborator, so that you can use my whole codebase rather than having to port this snippet elsewhere.
Hi @AjayTalati, thanks for the links! I’ll try to have a look at these as another point of reference when I can.
My ACER implementation is open sourced as part of an entry to the Malmo Challenge, but we never attempted to use the trust region updates for that. I also have a standalone implementation (for the Gym’s CartPole instead of the Malmo Challenge) which I will open source in the next week or so. My ACER implementation without trust region updates can get good scores, but is quite unstable. Haven’t managed to spend time debugging the trust region updates yet, but @ethancaballero has had a look too. I’m looking forward to opening it up to the PyTorch community!