[solved] PyTorch LSTM 50x slower than Keras (TF) CuDNNLSTM

I want to train a model for a time series prediction task. I built my own model on PyTorch but I’m getting really bad performance compared to the same model implemented on Keras. Each epoch on PyTorch takes 50ms against 1ms on Keras. I want to show you my simple code because I’d like to know if I made any mistakes or it’s just PyTorch. Thank you in advance. :slight_smile:

This is my module:

class Net(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, hidden_layers):
        super(Net, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.hidden_layers = hidden_layers

        self.lstm = nn.LSTM(input_dim, hidden_dim, hidden_layers)
        self.h2o = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        h_t = Variable(torch.randn(self.hidden_layers, BATCH_SIZE, self.hidden_dim)).cuda()
        c_t = Variable(torch.randn(self.hidden_layers, BATCH_SIZE, self.hidden_dim)).cuda()

        h_t, c_t = self.lstm(x, (h_t, c_t))
        output = self.h2o(h_t)
        return output

And this is the training execution:

model = Net(INPUT_DIM, 40, OUTPUT_DIM, 1).cuda()
loss_fcn = MEDLoss()
optimizer = optim.RMSprop(model.parameters())

for epoch in range(EPOCHS):
    loss = 0
    start = time.time()
    for seq in range(11, 20):
        length = seq_lenghts[seq]
        x = Variable(X_data[:length, [seq], :]).cuda()
        y = Variable(Y_data[:length, [seq], :]).cuda()
        model.zero_grad()
        output = model(x)
        loss = loss_fcn(output, y)
        loss.backward()
        optimizer.step()
    print("Epoch", epoch + 1, "Loss:", loss.cpu().data.numpy(), "Time:", time.time() - start)
1 Like

Could I see the code with the dataset as well (can I get an example that I can run).
And at the same time, can I run your Keras / TF code as well?

I’m sorry, I cannot publish the dataset. Input dimension is (242 timesteps, 20 sequences, 10 input dim), output is (242, 20, 2). This is my full code with a random generated dataset so you can run it.

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.modules.loss import  _Loss, _assert_no_grad
import torch.optim as optim
import numpy as np
import time

BATCH_SIZE = 1
INPUT_DIM = 10
OUTPUT_DIM = 2
EPOCHS = 1000


def med_loss(input, target):
    return torch.mean(torch.sqrt(torch.sum(torch.pow(target - input, 2))))


class MEDLoss(_Loss):
    def __init__(self):
        super(MEDLoss, self).__init__()

    def forward(self, input, target):
        _assert_no_grad(target)
        return med_loss(input, target)


class Net(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, hidden_layers):
        super(Net, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.hidden_layers = hidden_layers

        self.lstm = nn.LSTM(input_dim, hidden_dim, hidden_layers)
        self.h2o = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        h_t = Variable(torch.randn(self.hidden_layers, BATCH_SIZE, self.hidden_dim)).cuda()
        c_t = Variable(torch.randn(self.hidden_layers, BATCH_SIZE, self.hidden_dim)).cuda()

        h_t, c_t = self.lstm(x, (h_t, c_t))
        output = self.h2o(h_t)
        return output


print("Loading data")
X_data = torch.randn((242, 20, INPUT_DIM))
Y_data = torch.rand((242, 20, OUTPUT_DIM)) * 10
seq_lenghts = np.ones(20) * 242

model = Net(INPUT_DIM, 40, OUTPUT_DIM, 1).cuda()
loss_fcn = MEDLoss()
optimizer = optim.RMSprop(model.parameters())

for epoch in range(EPOCHS):
    loss = 0
    start = time.time()
    for seq in range(11, 20):
        length = seq_lenghts[seq]
        x = Variable(X_data[:length, [seq], :]).cuda()
        y = Variable(Y_data[:length, [seq], :]).cuda()
        model.zero_grad()
        output = model(x)
        loss = loss_fcn(output, y)
        loss.backward()
        optimizer.step()
    print("Epoch", epoch + 1, "Loss:", loss.cpu().data.numpy(), "Time:", time.time() - start)

thanks, i could run it. Can I have the TF/Keras equivalent as well?

Sample output that I’m seeing

Epoch 218 Loss: [ 20.28092194] Time: 0.12861180305480957
Epoch 219 Loss: [ 21.20872307] Time: 0.12335944175720215
Epoch 220 Loss: [ 20.74290848] Time: 0.11568808555603027
Epoch 221 Loss: [ 20.72460365] Time: 0.11843180656433105
Epoch 222 Loss: [ 19.11690903] Time: 0.11834907531738281
Epoch 223 Loss: [ 22.12939262] Time: 0.11621856689453125
Epoch 224 Loss: [ 19.81811905] Time: 0.11482071876525879
Epoch 225 Loss: [ 19.9168148] Time: 0.137786865234375

Keras implementation:

import numpy as np
from keras.models import Sequential
from keras.layers import Dense, CuDNNLSTM
from keras.utils import print_summary
from keras import backend as K


def mean_eucl_dist(y_true, y_pred):
    return K.mean(K.sqrt(K.sum(K.square(y_true - y_pred), axis=-1, keepdims=True)))


X_data = np.random.randn((20, 242, 10))
Y_data = np.random.rand((20, 242, 2)) * 10

model = Sequential()
model.add(CuDNNLSTM(40, return_sequences=True, input_shape=(242, 10)))
model.add(Dense(2, activation='linear'))

model.compile(loss=mean_eucl_dist,
              optimizer='rmsprop',
              metrics=[mean_eucl_dist])

model.fit(X_data[11:, :, :], Y_data[11:, :, :], batch_size=242, epochs=10000, shuffle=True)

which version of TensorFlow has to be installed to get to run this? (also I dont think X_data = np.randn((20, 242, 10)) is valid, numpy doesn’t have randn).

TensorFlow 1.4, tensorflow-gpu package on pip.
As a side note, I’m not sure that batch_size=242 is correct in the Keras implementation, I’m a bit confused because it seems that PyTorch and Keras have a different semantic for this term. I edited my code, now it should work.

yes that looks suspect, batch_size should be 1. If you give batch_size=242, it is processing the entire dataset in one go (because dataset size here is 10 samples, it either takes min(dataset_size, batch_size) or entirely drops the training computation.

At batch_size=1 for both scripts, pytorch runs on my machine in about 15 seconds for 100 epochs, and the Keras one runs at about 20 seconds.

I’ve made some cosmetic changes to the PyTorch one in terms of best-practices for you to maybe learn from, nothing major. My changes get the pytorch script down to 13 seconds / 100 epochs.
The changes are:

  • we provide a mse_loss (mean square error), just use that instead of rolling your own
  • move the dataset to GPU fully before-hand, just like how TF/Keras would do it.

Here’s the modified PyTorch script:

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.modules.loss import  _Loss, _assert_no_grad
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import time

BATCH_SIZE = 1
INPUT_DIM = 10
OUTPUT_DIM = 2
EPOCHS = 100

class Net(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, hidden_layers):
        super(Net, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.hidden_layers = hidden_layers

        self.lstm = nn.LSTM(input_dim, hidden_dim, hidden_layers)
        self.h2o = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        h_t = Variable(x.data.new(self.hidden_layers, BATCH_SIZE, self.hidden_dim).normal_())
        c_t = Variable(x.data.new(self.hidden_layers, BATCH_SIZE, self.hidden_dim).normal_())

        h_t, c_t = self.lstm(x, (h_t, c_t))
        output = self.h2o(h_t)
        return output


print("Loading data")
X_data = torch.randn((242, 20, INPUT_DIM))
Y_data = torch.rand((242, 20, OUTPUT_DIM)) * 10
X_data = X_data.cuda()
Y_data = Y_data.cuda()

model = Net(INPUT_DIM, 40, OUTPUT_DIM, 1).cuda()
optimizer = optim.RMSprop(model.parameters())

for epoch in range(EPOCHS):
    loss = 0
    start = time.time()
    for seq in range(11, 20):
        x = Variable(X_data[:, [seq], :].cuda())
        y = Variable(Y_data[:, [seq], :].cuda())
        model.zero_grad()
        output = model(x)
        loss = F.mse_loss(output, y)
        loss.backward()
        optimizer.step()
    print("Epoch", epoch + 1, "Loss:", loss.data[0], "Time:", time.time() - start)

Here’s the runtime logs for pytorch: https://gist.github.com/994c3d50545d26229b9ac9c12c070b7e
Here’s the runtime logs for TF/Keras: https://gist.github.com/842a410b004a8fe7b773b1e9904befaa

And btw at this size of an LSTM, CuDNN kernels dont give any speedup really. The CuDNN kernels shine when there are multiple layers and large kernel sizes.

3 Likes

Thank you very much, now I get comparable results (6 seconds both with PyTorch and Keras for 100 epochs). So the problem was just my misunderstanding of batch_size’s meaning. I read some articles found on the web that are not clear about it and some of them are wrong!

I need a small clarification. Is it always necessary to have a batch size of 1 when training any of the sequence models (RNN,LSTM,GRU) in pytorch to give the maximum speed-up?

A larger batch is faster to compute but more iterations could be necessary to converge, so… it depends.

In that case why in your example a batch_size > 1 did not give a greater speedup?

Where did you see that?