Speed of Custom RNN is SUPER SLOW

Hi,

Based on code here

pytorch/benchmarks/fastrnns/custom_lstms.py at main · pytorch/pytorch · GitHub

I write an example to compare the cumputation capability of native lstm and custom lstm.
But I found that the speed of custom lstm is 100 times slower than native lstm class.

here is my test code:

import torch.nn as nn
import time
from models.custom_lstm import LSTMLayer, LSTMCell, script_lstm, LSTMState

input_size = 1024
cell_size =2048
batch_size =20 
seq_len = 200 

native_lstm=nn.LSTM(input_size, cell_size,1).cuda()
custom_lstm=script_lstm(input_size, cell_size,1).cuda()
inp = torch.randn(seq_len,batch_size,input_size).cuda()
hx = inp.new_zeros(batch_size, cell_size, requires_grad=False)
cx = inp.new_zeros(batch_size, cell_size, requires_grad=False)

t1 = time.time()
out, hid = native_lstm(inp)
t2 = time.time()
out2, hid2 = custom_lstm(inp, [(hx, cx)])
t3 = time.time()

print ('lstm:{}\ncustom lstm:{}\n'.format(t2-t1, t3-t2))

And here is the result:
image

lstm:0.015676498413085938
custom lstm:1.0338680744171143

My torch version is 1.3.1, GPU is TITANV with Cuda10 and cuDNN7.4.1
I also tried on pytorch 1.1, GPU TITAN XP with CUDA9.1, which has same ratio of speed.

Any idea?
Thanks so much

The TorchScript runtime does some optimizations on the first pass (it assumes you will be running your compiled model’s inference many times), so this is likely why it looks much slower. Could you try running custom_lstm a couple times before you benchmark it and comparing?

I retest both class for 1000 times, and the result seems more reasonable.

lstm:49.758071184158325
custom lstm:55.80940389633179

Thanks for you answer.

1 Like

I met the same problem. is there any way to disable the optimization or choose the optimization level or after optimization we can save the model.because when I load the torchscript model in C++, the first pass takes about 20s while the others’ infer time is about 0.5s.

I think there’s a parameter called optimize for scripting in scripting using touchscript.

in the source code "optimize is deprecated and has no effect."
https://pytorch.org/docs/stable/_modules/torch/jit.html#script

The code suggested an alternative: warnings.warn("`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead")

you could try setting torch._C._jit_set_profiling_mode() to True and torch._C._jit_set_profiling_mode to False
This mode was specifically added for speeding up compilation times for inference.
You could also indeed try with torch.jit.optimized_execution() if compilation times are still high for you. The latter runs even fewer optimizations.

2 Likes

ok, thx. I will try this method

great, thanks. I will try to use these methods

when I use python the torch.jit.optimized_execution() could solve the problem, thanks
however, how should I solve this problem in C++?
thanks in andvanve

you could try this.

#include <torch/csrc/jit/update_graph_executor_opt.h>
//...
setGraphExecutorOptimize(false);
4 Likes

Great, thanks.I solved the problem by setGraphExecutorOptimize(false);

Hi, I still have some questions about the custom RNN:

  1. I am able to reproduce senmao’s results that lstm and custom lstm have similar performance in 1000 times, but this is partly due to the original lstm becomes worse. This can be seen in senmao’s results. The first run of the original lstm is 0.015. If the performance is consistent, 1000 runs would take 15 seconds instead of 49.758 as reported (I have verified this myself).

  2. Although I have no idea why the original lstm becomes worse, I get rid of the problem by changing the hyper parameters to:
    input_size = 37
    cell_size =256
    batch_size =128
    seq_len = 60
    Now the original lstm becomes stable. In this case, the custom lstm is 10 times slower than the original lstm. Here are the results of 1000 runs: lstm:1.54s, custom lstm:19.75s. Can anyone please suggest how the custom lstm can be modified to have comparable performance with the original lstm?

Thanks so much!

Hi senmao! I’m writting you because I tried to get your results by my own… but, I cound’t get them. I have the same as you posted and I’m getting the time by this code:

for i in range(1000):

t1 = time.time()

out, hid = native_lstm(inp)

t2 = time.time()

t_native += t2-t1

t3 = time.time()

out2, hid2 = custom_lstm(inp, [(hx, cx)])

t4 = time.time()

t_custom += t4-t3

Thanks for all!

Regards,
David.