yes that looks suspect, batch_size should be 1. If you give batch_size=242, it is processing the entire dataset in one go (because dataset size here is 10 samples, it either takes min(dataset_size, batch_size) or entirely drops the training computation.
At batch_size=1 for both scripts, pytorch runs on my machine in about 15 seconds for 100 epochs, and the Keras one runs at about 20 seconds.
I’ve made some cosmetic changes to the PyTorch one in terms of best-practices for you to maybe learn from, nothing major. My changes get the pytorch script down to 13 seconds / 100 epochs.
The changes are:
- we provide a mse_loss (mean square error), just use that instead of rolling your own
- move the dataset to GPU fully before-hand, just like how TF/Keras would do it.
Here’s the modified PyTorch script:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.modules.loss import _Loss, _assert_no_grad
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import time
BATCH_SIZE = 1
INPUT_DIM = 10
OUTPUT_DIM = 2
EPOCHS = 100
class Net(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, hidden_layers):
super(Net, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.hidden_layers = hidden_layers
self.lstm = nn.LSTM(input_dim, hidden_dim, hidden_layers)
self.h2o = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
h_t = Variable(x.data.new(self.hidden_layers, BATCH_SIZE, self.hidden_dim).normal_())
c_t = Variable(x.data.new(self.hidden_layers, BATCH_SIZE, self.hidden_dim).normal_())
h_t, c_t = self.lstm(x, (h_t, c_t))
output = self.h2o(h_t)
return output
print("Loading data")
X_data = torch.randn((242, 20, INPUT_DIM))
Y_data = torch.rand((242, 20, OUTPUT_DIM)) * 10
X_data = X_data.cuda()
Y_data = Y_data.cuda()
model = Net(INPUT_DIM, 40, OUTPUT_DIM, 1).cuda()
optimizer = optim.RMSprop(model.parameters())
for epoch in range(EPOCHS):
loss = 0
start = time.time()
for seq in range(11, 20):
x = Variable(X_data[:, [seq], :].cuda())
y = Variable(Y_data[:, [seq], :].cuda())
model.zero_grad()
output = model(x)
loss = F.mse_loss(output, y)
loss.backward()
optimizer.step()
print("Epoch", epoch + 1, "Loss:", loss.data[0], "Time:", time.time() - start)
Here’s the runtime logs for pytorch: https://gist.github.com/994c3d50545d26229b9ac9c12c070b7e
Here’s the runtime logs for TF/Keras: https://gist.github.com/842a410b004a8fe7b773b1e9904befaa
And btw at this size of an LSTM, CuDNN kernels dont give any speedup really. The CuDNN kernels shine when there are multiple layers and large kernel sizes.