My REINFORCE model is not learning

I’ve been reading through “Hands-On Machine Learning with Scikit-Learn & Tensorflow” and implementing the algorithms presented in the book using PyTorch instead of Tensorflow. (I’m a beginner and my PhD friends have recommended Pytorch to me!)

I have been working on this algorithm (link, go to “Policy Gradients”), but my implementation in Pytorch is simply not learning at all. Here’s the code. Could anyone point out to me what I’m doing wrong? My hunch is that I’m missing logarithms somewhere, but not too sure.

Hi,

I tried to apply same concept in pong game but by directly scaling the gradients way. It does not learn.
Please check this: Implementing reinforce using gradient scaling
my implementation

Thanks

pytorch lets you call .reinforce() directly on the output of the random variable. You need to sample the random variable using torch.multinomial rather than np.random.multinomial. You can see how this works at What is action.reinforce(r) doing actually? for example.

I actually got it working with .reinforce(), but then it’s calculating the gradient and updating the parameters every game right? The algorithm from the book plays the game 10 times, calculating the gradient for each, then calculates the mean of the gradients from those 10 games to update the parameters. Was wondering how I could do that with pytorch.

I guess I could just use .reinforce() but I thought trying to implement the algorithm from the book in pytorch would be good practice.

@yukw777 I’d recommend upgrading to master which does reinforce differently, and in my opinion more closelhy resembles how REINFORCE paper presents the algorithm.

For people looking around, this is the change @hughperkins is referring to:

And the doc on master branch reflects that: http://pytorch.org/docs/master/distributions.html

I definitely agree that this is a lot easier to understand as it is close to what’s described in the REINFORCE paper. Since the pytorch binary is still on 0.2.0 and doesn’t include this change, I guess we have to install pytorch from source for now.