RNN and Adam: slower convergence than Keras

I’m training a Simple RNN on this dataset: https://ufile.io/gf7xo. I put the link so you can try my code on your machine. I use Adam as optimizer. I tried to build the same model (same weight initialization also) both on Pytorch and Keras (TF as backend) but, unfortunately, Pytorch’s convergence is always slower than Keras’. If you plot the loss along the epochs, you will also see that Pytorch’s Adam is a bit unstable with this learning rate while Keras is not and it is not a negligible problem. These are the results from some trials after 200 epochs:

Pytorch:
3.9312e-04
9.4073e-04
4.9248e-04
3.9022e-04

Keras:
1.2597e-04
4.9654e-05
5.8871e-05
1.1851e-04

Pytorch code:

import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import torch.backends.cudnn
torch.backends.cudnn.enabled = False

BATCH_SIZE = 1
INPUT_DIM = 1
OUTPUT_DIM = 1
DTYPE = np.float64

class Net(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, hidden_layers):
        super(Net, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.hidden_layers = hidden_layers

        self.rnn = nn.RNN(input_dim, hidden_dim, hidden_layers)
        self.h2o = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        h_0 = Variable(torch.zeros(self.hidden_layers, BATCH_SIZE, self.hidden_dim))
        if DTYPE == np.float32:
            h_0 = h_0.float()
        else:
            h_0 = h_0.double()

        output, h_t = self.rnn(x, h_0)
        output = self.h2o(output)
        return output


def weights_init(m):
    if isinstance(m, nn.RNN):
        nn.init.xavier_uniform(m.weight_ih_l0.data)
        nn.init.orthogonal(m.weight_hh_l0.data)
        nn.init.constant(m.bias_ih_l0.data, 0)
        nn.init.constant(m.bias_hh_l0.data, 0)
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform(m.weight.data)
        nn.init.constant(m.bias.data, 0)


data = np.loadtxt('data/mg17.csv', delimiter=',', dtype=DTYPE)
trX = torch.from_numpy(np.expand_dims(data[:4000, [0]], axis=1))
trY = torch.from_numpy(np.expand_dims(data[:4000, [1]], axis=1))

loss_fcn = nn.MSELoss()
model = Net(INPUT_DIM, 10, OUTPUT_DIM, 1)
if DTYPE == np.float32:
    model = model.float()
else:
    model = model.double()
model.apply(weights_init)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, betas=(0.9, 0.999), eps=2e-16, weight_decay=0)
for e in range(500):
    model.train()
    x = Variable(trX)
    y = Variable(trY)
    model.zero_grad()
    output = model(x)
    loss = loss_fcn(output, y)
    loss.backward()
    optimizer.step()

    print("Epoch", e + 1, "TR:", loss.cpu().data.numpy()[0])

open .keras/keras.json, set floatx to float64 and epsilon to 2e-16
Keras:

import keras
import numpy as np
from keras.models import Sequential
from keras.layers import Dense, SimpleRNN

DTYPE = np.float64

data = np.loadtxt('data/mg17.csv', delimiter=',', dtype=DTYPE)
X_data = np.expand_dims(data[:, [0]], axis=0)
Y_data = np.expand_dims(data[:, [1]], axis=0)

model = Sequential()
model.add(SimpleRNN(10, return_sequences=True, input_shape=(4000, 1)))
model.add(Dense(1, activation='linear'))

optimizer = keras.optimizers.Adam(lr=0.01, beta_1=0.9, beta_2=0.999, epsilon=2e-16, decay=0)
model.compile(loss='mean_squared_error',
              optimizer=optimizer)

model.fit(X_data[:, :4000, :], Y_data[:, :4000, :], batch_size=1, epochs=500, verbose=2, shuffle=False)

UPDATE: The situation is the same with RMSProp. Nevertheless, the issue does not appear either with SGD or with LSTM/GRU.

It looks like you have different learning rates for Kiera’s model lr=0.01 and pytorch model lr=0.001 so most likely your main cause for differing convergence rates

I’m sorry, it is a typo because I edited my code to post it here. The tests have the same learning rates (0.01)

Is this specifically a problem with Adam or do other, simpler optimisers such as SGD + momentum give a mismatch in results? Is the behaviour the same on CPU only? Although I don’t think there should be an issue with cuDNN and the RNN module, adding torch.backend.cudnn.enabled = False to the top of your PyTorch code should disable cuDNN and hence give you a sanity check on that.

Yes, it’s the same on CPU only. I updated the code with disabled cuDNN and CPU only.

Results are similar when using SGD. It seems a RMSProp/Adam problem but I didn’t try other Ada* algorithms.

edit: interesting fact: the issue does not appear with LSTM and GRU, on the contrary PyTorch performs even better!

@stefanonardo i ran your code with CPU mode, what exactly should I look for in terms of “non-convergence”?
The training loss seems to have gone down pretty well.

Here’s the code I ran: https://gist.github.com/soumith/ceb0d3de23585e676fd3b5e0402a45a3

Here’s the output log: https://gist.github.com/soumith/bfeeb1f5378030693b05231344c1c3f5

I run again my code 10 times and I got similar results between PyTorch and Keras. Maybe my original 4 runs were particularly unlucky for PyTorch. Or maybe there was some bug in my old 0.2 implementation. Sincerely I can’t figure it out but I’m glad (and sorry) this was a false alarm!

These are my new results:

PyTorch
1.2381e-05
2.8960e-04
1.2643e-04
7.6845e-05
2.5028e-05
1.1711e-04
1.2660e-04
1.5379e-04
1.1074e-04
4.8524e-04

Keras
3.7622e-05
2.4301e-04
9.6022e-05
1.2589e-04
3.5417e-04
6.3925e-05
8.8122e-05
2.6460e-04
2.0003e-04

1 Like

Now I’m getting very bad results with truncated back-propagation. Could anyone check if there are any bugs in my code please? I followed this tutorial to implement the truncated backprop. As I did before I used the same hyper-parameters and initializations both for PyTorch and Keras.
Here are the results after just 100 epochs of training and the links to minimal code:

PyTorch [code]
TR: 0.000440476183096 VL: 0.00169517311316
TR: 0.000462784681366 VL: 0.00128701637499
TR: 0.000823373540768 VL: 0.00211899834873
TR: 0.000430527156073 VL: 0.00167960980949
TR: 0.000533050970649 VL: 0.000932638757326

If you set TIMESTEPS to NaN it will apply the backprop through the entire sequence and you can see that it works good this way. The issue appears only when I truncate the sequence.

Keras [code]
TR: 1.60957323398e-05 VL: 3.12658933101e-06
TR: 1.97489706594e-05 VL: 3.44138302082e-06
TR: 2.47815147053e-05 VL: 5.84050205497e-06
TR: 2.54522322033e-05 VL: 2.236503277e-05
TR: 1.96936488671e-05 VL: 6.55356349568e-06

I just tried your pytorch code setting TIMESTEPS to NaN and it worked so much worse.

Timesteps NaN
Epoch 1 TR: 0.398600595038 VL: 0.251811799777
Epoch 2 TR: 0.256775490708 VL: 0.201037764696
Epoch 3 TR: 0.204928007695 VL: 0.148878562004
Epoch 4 TR: 0.152259112661 VL: 0.108047759769
Epoch 5 TR: 0.110495173812 VL: 0.0812407301001

Timesteps 50
Epoch 1 TR: 0.0158072562396 VL: 0.00216180915408
Epoch 2 TR: 0.00119101881924 VL: 0.00183749616851
Epoch 3 TR: 0.000887962538854 VL: 0.00180209117973
Epoch 4 TR: 0.000865832243852 VL: 0.00167242659527
Epoch 5 TR: 0.000841620917981 VL: 0.00159547644208

@jpeg729: you have to run it with more epochs. It is normal that with a larger batch (the entire sequence in this case) it converges slower. But it converges. Try with 500 epochs.

EDIT: I found the bug. It was just the reshape ordering that must be set to Fortran-like in PyTorch (because we reshape to TxBx* instead of BxTx*. :wink:

Larger batches will reduce the effective learning rate.

I’m glad you have found the bug. I was going to point out that your pytorch version wasn’t generalising anywhere near as well as the keras version.