Challenges porting TensorFlow reinforcement learning sample to PyTorch

Hi - I’m brand new to PyTorch, and have been attempting to port a simple reinforcement learning sample from TensorFlow to PyTorch to help get me up to speed with PyTorch. I haven’t been able to get my PyTorch version to come close to the performance of the TensorFlow network, and I’m struggling to understand why.

I believe my PyTorch network is effectively equivalent to the TensorFlow network, but I suspect there is just something slightly off.

Here’s the TensorFlow version:

import gym
import numpy as np
import random
import tensorflow as tf

env = gym.make("FrozenLake-v0")

tf.reset_default_graph()

inputs1 = tf.placeholder(shape=[1,16], dtype=tf.float32)
W = tf.Variable(tf.random_uniform([16,4], 0, 0.01))
Qout = tf.matmul(inputs1, W)
predict = tf.argmax(Qout, 1)
nextQ = tf.placeholder(shape=[1,4], dtype=tf.float32)

loss = tf.reduce_sum(tf.square(nextQ - Qout))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
updateModel = optimizer.minimize(loss)

y = 0.99
e = 0.1
num_episodes = 2000
jList = []
rList = []

init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    for i in range(num_episodes):
        # Reset environment and get first new observation
        s = env.reset()
        rAll = 0
        d = False
        j = 0
        
        # Q-network
        while j < 99:
            j += 1
            
            # Choose action greedily (with chance of random action) from the Q-network
            a,allQ = sess.run([predict, Qout], feed_dict={inputs1:np.identity(16)[s:s+1]})
            if np.random.rand(1) < e:
                a[0] = env.action_space.sample()
                
            # Get new state and reward from environment
            s1,r,d,_ = env.step(a[0])
            
            # Obtain the Q' values by feeding new state through our network
            Q1 = sess.run(Qout, feed_dict={inputs1:np.identity(16)[s1:s1+1]})
            
            # Obtain maxQ' and set target value for chosen action
            maxQ1 = np.max(Q1)
            targetQ = allQ
            targetQ[0, a[0]] = r + y*maxQ1
            
            # Train network using target and predicted Q values
            _,W1 = sess.run([updateModel,W], feed_dict={inputs1:np.identity(16)[s:s+1], nextQ:targetQ})
            
            rAll += r
            s = s1
            if d == True:
                # Reduce change of random action as we train the model
                e = 1./((i/50) + 10)
                break
                
        jList.append(j)
        rList.append(rAll)

print("Percent of successful episodes: %s%%" % str((sum(rList) / num_episodes) * 100))

And here’s my PyTorch version:

from __future__ import print_function
import gym
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

env = gym.make("FrozenLake-v0")

def predict(x):
    values, indices = torch.max(x, 1)
    return indices

class DQN(nn.Module):    
    def __init__(self):
        super(DQN, self).__init__()        
        self.fc1 = nn.Linear(16, 4, bias=False)
            
    def forward(self, x):
        x = self.fc1(x)
        return x

y = 0.99
e = 0.1
num_episodes = 2000
num_q_iterations = 99
jList = []
rList = []

dqn = DQN()
criterion = nn.MSELoss()
optimizer = optim.SGD(dqn.parameters(), lr=0.1)

for i in range(num_episodes):
    # Reset environment and get first new observation
    s = env.reset()
    rAll = 0
    d = False
    j = 0

    # Q-network
    while j < num_q_iterations:
        j += 1
        
        # Choose action greedily (with chance of random action) from the Q-network    
        input = Variable(torch.from_numpy(np.identity(16)[s:s+1]).type(torch.FloatTensor))
        allQ = dqn.forward(input)
        a = predict(allQ).data[0] 
        if np.random.rand(1) < e:
            a = env.action_space.sample()
                
        # Get new state and reward from environment
        s1,r,d,_ = env.step(a)
            
        # Obtain the Q' values by feeding new state through our network        
        next_state_input = Variable(torch.from_numpy(np.identity(16)[s1:s1+1]).type(torch.FloatTensor))
        Q1 = dqn.forward(next_state_input)
            
        # Obtain maxQ' and set target value for chosen action
        maxQ1 = np.max(Q1.data.numpy())
        targetQ = allQ.data.numpy()
        targetQ[0, a] = r + y*maxQ1

        # Train network using target and predicted Q values        
        first_state_input = Variable(torch.from_numpy(np.identity(16)[s:s+1]).type(torch.FloatTensor))        
        target = Variable(torch.from_numpy(targetQ).type(torch.FloatTensor))        
        
        # Backprop & update weights
        dqn.zero_grad()
        optimizer.zero_grad()
        output = dqn(first_state_input)
        loss = criterion(output, target)    
        loss.backward()
        optimizer.step()
        
        rAll += r
        s = s1
        if d == True:
            # Reduce change of random action as we train the model
            e = 1./((i/50) + 10)
            break
                
    jList.append(j)
    rList.append(rAll)
    
print("Percent of successful episodes: %s%%" % str((sum(rList) / num_episodes) * 100))

The TensorFlow network regularly hits ~45%, while the PyTorch one struggles to get past 5%. After extensive debugging, I suspect the problem might lie in one of these three areas:

  • The loss function. I suspect that PyTorch’s MSELoss does not work the same way as tf.reduce_sum(tf.square(nextQ - Qout)). As far as I can tell, it’s doing what I think is the same thing, but I’m not 100% sure.
  • The way I’m updating the weights of my network. I’m not 100% sure that I’m using zero_grad() and the PyTorch SGD optimizer properly.
  • I’m also not entirely sure if PyTorch’s SGB optimizer works all that similarly to TensorFlow’s GradientDescentOptimizer.
  • Weight initialization of my Linear layer - I’ve tried nn.init.uniform(self.fc1.weight, 0, 0.01) to try to match tf.random_uniform([16,4], 0, 0.01), but that does not improve performance of my PyTorch network. It looks like Linear might initialize weights by sampling from a uniform distribution anyway, so I’m only somewhat convinced this is a problem area.

Those are the areas I’m actively investigating, but I’m hoping there’s something blatantly obvious about how PyTorch works that I’m just missing. Would love any thoughts!

Thanks!

oh man, it’s hard to tell what’s different.

I wanted to add some quick comments:

dqn.forward(next_state_input) # wrong
dqn(next_state_input)         # correct

It doesn’t affect things in your particular code sample though.

Also, have a look at our DQN training for any gotchas you might find: http://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html#training-loop

Thank you for taking a look, it’s much appreciated!

Definitely will dig into the tutorial. Going to play around with different optimizers & loss functions. I suspect that’s where the problem lies.

I’ve also read somewhere that there is something you need to do in PyTorch if you have batches with single training examples (as I do). I’m trying to track that down and see if that is (1) something real, and not something I’m imagining I read and (2) actually helpful in this situation.

Thanks!

It looks like the fundamental mistake I was making was failing to unsqueeze() the mini-batch I was training the network with. In this example, each mini-batch only had 1 example in it; adding a fake batch dimension (per http://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html) seems to do the trick!

    # Train network using target and predicted Q values
    final_input_tensor = torch.from_numpy(np.identity(16)[s:s+1]).type(torch.FloatTensor)
    final_input = Variable(final_input_tensor)
    final_input = final_input.unsqueeze(0)
    target_tensor = torch.from_numpy(targetQ).type(torch.FloatTensor)        
    target = Variable(target_tensor)     
  
    # Backprop
    optimizer.zero_grad()
    output = dqn(final_input)
    criterion = nn.MSELoss()
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()

I have a similar problem where I try to get pytorch to work on Pong-v0, while tensorflow can get past 20, pytorch merely gets past -20 which means it learns nothing! The same pytorch code runs pretty well on another machine with an older version of pytorch and it does not work on a newer version of pytorch, so I think I should downgrade my pytorch version? After I downgrade, the problem solved. So I’d suggest to stay with the original pytorch version 0.1.x and wait for 0.2.x after you are sure it will work.

It’s nice that your solution works, but I would not suggest to downgrade to version 0.1.x.
You are missing a lot of nice features. See Pytorch Release Notes

One thing you should try when porting your code, is to enable warnings highlighting incompatible code.
Quote from the release notes:

Here is a code snippet that you can add to the top of your scripts.
Adding this code will generate warnings highlighting incompatible code.

Fix your code to no longer generate warnings.

insert this to the top of your scripts (usually main.py)

import sys, warnings, traceback, torch
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
sys.stderr.write(warnings.formatwarning(message, category, filename, lineno, line))
traceback.print_stack(sys._getframe(2))
warnings.showwarning = warn_with_traceback; warnings.simplefilter('always', UserWarning);
torch.utils.backcompat.broadcast_warning.enabled = True
torch.utils.backcompat.keepdim_warning.enabled = True

Once all warnings disappear, you can remove the code snippet.

1 Like

well, actually I was wrong… after all even down grade the pytorch version does not work… the same code works on another machine with old pytorch version and non-anaconda python 2.7, but didn’t work on old pytorch with anaconda python 2.7. But I just can’t find the reason why it didn’t work…

optimizer = optim.Adam(dqn.parameters(), lr=0.01)

I change from SGD to Adam and lr to 0.01.

the result is “Percent of successful episodes: 37.1% !!”