How can I speed up an RNN written from scratch?

Hi,

I am converting some of my old lua-torch codes into pytorch. I am having some problems with RNNs implemented from scratch. They seem much slower in pytorch than in lua-torch. I can not use pre-built modules such as nn.LSTM() or nn.GRU() because I need to implement rnn cells which are non-traditional.

Below are two codes (one in pytorch and one in lua-torch) in which an LSTM cell is built from scratch, then it is run forward and backward 1000 times with fake data. On a Titan X the computational time are:
pytorch: 4.6s
lua-torch: 1.4s

Am I doing something wrong in my pytorch code? Can I speed this up?

PYTORCH CODE:

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import time

input_size=500
hidden_size=500
batch_size=20

class LSTMCell(nn.Module):
    
    def __init__(self, input_size, hidden_size):
        super(LSTMCell, self).__init__()
        self.hidden_size=hidden_size
        self.lin = nn.Linear( input_size+hidden_size , 4*hidden_size )
         
    def forward(self, x, state0):
        h0,c0=state0
        x_and_h0 = torch.cat((x,h0), 1)
        u=self.lin(x_and_h0)
        i=F.sigmoid( u[ : , 0*self.hidden_size : 1*self.hidden_size ] )
        f=F.sigmoid( u[ : , 1*self.hidden_size : 2*self.hidden_size ] )
        g=F.tanh(    u[ : , 2*self.hidden_size : 3*self.hidden_size ] )
        o=F.sigmoid( u[ : , 3*self.hidden_size : 4*self.hidden_size ] )
        c= f*c0 + i*g
        h= o*F.tanh(c)
        return (h,c)

# construct LSTM Cell
rnn = LSTMCell(input_size, hidden_size)
rnn.cuda()

# generate fake data
x=torch.rand(batch_size,input_size).cuda()
h0=torch.rand(batch_size,hidden_size).cuda()
c0=torch.rand(batch_size,hidden_size).cuda()
grad_h=torch.rand(batch_size,hidden_size).cuda()
grad_c=torch.rand(batch_size,hidden_size).cuda()

# run the cell 1000 times forward and backward
t0=time.time()
for i in range(1000):
	xx=Variable(x,requires_grad=True)
	hh0=Variable(h0,requires_grad=True)
	cc0=Variable(c0,requires_grad=True)
	hh,cc=rnn(xx, (hh0,cc0)  ) 
	torch.autograd.backward(variables=(hh,cc) , grad_variables=(grad_h,grad_c) )
print('time in s : '+ str(time.time()-t0) )

LUA-TORCH CODE:

require('nngraph')
require('cunn')

input_size=500
hidden_size=500
batch_size=20

local function LSTMCell()

	local x=nn.Identity()()
	local h0=nn.Identity()()
	local c0=nn.Identity()()
	
	local x_and_h0=nn.JoinTable(2)({x,h0})
	local u=nn.Linear(input_size+hidden_size , 4*hidden_size )(x_and_h0)
	local u_reshaped=nn.Reshape(batch_size,4,hidden_size)(u)
	local tbl=nn.SplitTable(2)(u_reshaped)
  
  	local f  = nn.Sigmoid()(nn.SelectTable(1)(tbl))
	local i  = nn.Sigmoid()(nn.SelectTable(2)(tbl))
	local g  = nn.Tanh()(   nn.SelectTable(3)(tbl))
	local o  = nn.Sigmoid()(nn.SelectTable(4)(tbl))

	local c  = nn.CAddTable()({  nn.CMulTable()({f,c0}), nn.CMulTable()({i,g})  })
	local h  = nn.CMulTable()({  o, nn.Tanh()(c)   })
	   
	local mod = nn.gModule({x,h0,c0},{h,c})
	return mod

end

-- construct LSTM Cell
rnn= LSTMCell()
rnn:cuda()

-- generate fake data
x=torch.rand(batch_size,input_size):cuda()
h0=torch.rand(batch_size,hidden_size):cuda()
c0=torch.rand(batch_size,hidden_size):cuda()
grad_h=torch.rand(batch_size,hidden_size):cuda()
grad_c=torch.rand(batch_size,hidden_size):cuda()

-- run the cell 1000 times forward and backward
t0 = torch.tic()
for i=1,1000 do
	h, c = unpack( rnn:forward({x,h0,c0}) )
	tbl=rnn:backward({x,h0,c0}, {grad_h,grad_c})
end
print('time in s : ' .. torch.toc(t0))