Episodic Policy Gradient in Pytorch

I am trying to translate Andrej Karpathy’s pong example to Pytorch. My problem comes down to the loss. I cannot backpropagate at each timestep because I need to wait until the end of an episode to be able to compute de discounted returns for each step. I need to be able to compute the gradient with respect to my model’s parameters for the loss function of every step.

Here’s what I do :

My model outputs a sigmoid, so it basically chooses between two actions. I sample an action from my policy network prob_a = model(x) and the log-likelihood for that step is defined as LL = torch.log(y - a) where I set y = 1 if the action taken was a = 1 and y = 0 if the action taken was a = 0. I accumulate those log-likelihoods (which are autograd.Variable) for every step in LL_list. When the episode ends, I also compute the discounted return for every step in G_list. I then try to compute the loss over the whole episode in this way :

LLs = torch.stack(LL_list)
Gs= Variable(torch.FloatTensor(G_list))

loss = torch.mean(LLs * Gs)
loss.backward()

Episodes last for about a thousand steps (so that is the size of LLs and Gs).

What happens :

When I run that, the execution gets stuck forever in loss.backward().

Question :

I feel like I might be missing something quite important here (maybe Pytorch actually builds an immense graph and takes forever to backpropagate through it). Does anyone have an idea about what might be going on?

I guess in a more general setting, my question would be : How to compute the loss that involves Variables obtained through several propagations through the model ?

I found a workaround. Apparently torch.stack() on my list of Variables made backpropagation ridiculously long. By computing my loss by iterating over my list of Variables instead of converting it in a single tensor seem to have resolved the issue. However, I stil don’t understand why torch.stack() would cause such a slow down.

For me I use torch.cat to concatenate a list of torch tensors/variables (i.e a list of log prob or actions).

I’ve been able to do batch policy gradient this way very fast (faster than iterating through the list and summing up the loss).