Optimizing custom RNN implementation

Hi,

I’m currently testing a variant of the LSTM architecture called subLSTM. I was trying to get an efficient implementation to speed up my tests since my PyTorch implemenation its still very slow compared to the library LSTM. I also tried using TorchScript but its still much slower than the LSTM version. Specifically I used jit.script to compile the inner loop of the RNN (similar to the LSTM implementation). So my questions are:

  1. Why does this not make it work as fast? Is it a problem with the way I’m using jit.script?
  2. I was trying to access the RNN implementations in _C._VariableFunctions, but I can’t see it. Is there any way of modifying that code? The modifications I need are trivial, and while I could implement it myself I may not get such an efficient code since I am not that experienced with CUDA (and my C++ is pretty rusty).

Here is a link to my code:

Thanks,

You could take the LLTM from the C+±extension tutorial as a blueprint. That should give very similar speedups as basing your work on the native code.
That said, in our own benchmarking JITed vanilla LSTM almost as fast as cudnn for the forward and roughly the same speed as PyTorch’s own C++ implementation for the backward (but slower by a factor of 2.25 than CuDNN) on current master (which is faster in the backward than 1.0 was at 3x).

Best regards

Thomas

1 Like

Thanks for the reply,

I assume the differences in speed are for the LSTM cell, so that doing the processing over longer sequences with the full LSTM architecture will increase the difference, especially during training?

Actually I’m getting my LSTM benchmarking from the PyTorch benchmark repository and optimize against that before submitting patches to PyTorch. That would seem much more useful to me than a somewhat artificial benchmark on cells. As far as I can tell, the default sequence length is 100 and you could pass a parameter if you want to benchmark different sequence lengths.

Best regards

Thomas

1 Like

Hi @tom,

Fair enough that benchmark seems more useful.

What would you say is more efficient, implementing the LSTM outerloop in Python and calling the optimized function for each cell (as the LLTM example would imply), or implement the whole thing in C++ as a single function or an ATen module? I guess I’m trying to pinpoint exactly how it is done for the current implementation of LSTM.

Thanks for the help

So I discussed optimization in general a bit in a talk I gave last december. Whenever I looked at it, the difference between Python and “same thing in C” was about 10%. That is obviously not true universally, but it seems to be a realistic ballpark number in many cases. The true gain is

  • when you “fuse” (pointwise, mainly) operations, i.e. have them in a single cuda kernel, but that is rather time-consuming to implement,
  • when you minimize the number of operations (e.g. by applying weight_ih on all inputs in one go) - this is what the premul variants in the PyTorch benchmark repository does.

To me this hints a bit to the for loop itself in Python or C++ not being the crucial bit, but rather whether you do things in parallel before and after and whether you’re efficient inside the for loop.

NVidia has a great writeup about how they optimized LSTM.

In the current PyTorch master, scripted LSTM is of similar speed than the ATen implementation.

Maybe you can get best results on current PyTorch by manually defining a backward in a , but using @torch.jit.scripted functions for both forward and backward. I must admit I haven’t looked into this too much, though, because I mainly approached things from a can we improve PyTorch to make this better perspective.

I’ll probably try to write a few things about my new LSTM optimizations - they actually get a bit better than advertised there to about 1.33x CuDNN wall clock time for the jit-premul LSTM backward - not too bad for less than 200 lines of code diff.

Best regards

Thomas

2 Likes

Hi @tom,

Thanks for the tips. Will see what I can get by doing that. Its just hard to compare models against LSTMs when they take half a day to train, and given the model I am comparing I would expect it to run faster than LSTMs (forward and backwards operations are simpler). Just one final question (for now :smiley:), at the end of the following sentence did you mean CUDA kernel?

Maybe you can get best results on current PyTorch by manually defining a backward in a ,

so that it would be CUDA kernel for the LSTM cell operation and scripted for the forward and backward loop?

Best

What I meant was, depending on what you want to do, you might get by with scripting the cell operations instead of writing CUDA.
The thing that’s not as fast is the backward, so writing your own scripted operation there might be an option - I’m guessing you don’t want to wait until the better backward fusing is merged.

Best regards

Thomas

1 Like

I put out a PR for my optimization work and wrote a bit about JIT and LSTM.
So with a bit of analysis of what’s slowing things down and a fair amount of bookkeeping, you can speed up things quite a bit.

Best regards

Thomas

Great! I’ll check it out. I did some changes to the way I wrote the functions in Python and I got better performance than LSTM with JIT version on CPU, by lot. But when I use GPU the difference is massive in favor of LSTM. So I guess I’ll have to code it by hand. It is time consuming but in the end I think it’s something I should know how to do.

It’s tricky because I tried using addmm instead of F.linear and pre-transposing the weights before the time loop (which should give better performance) in my Python code, but on GPU I get almost no improvements.

Thanks for all the help.

Best

PS: I noticed that LSTMs use two bias vectors instead of one, and I know that this is done for keeping it similar to cuDNN but I can’t find the reason why they defined the model in that way. What’s more interesting is that in the link you posted earlier, it seems to imply only one bias. Any ideas on this?

Just read that you don’t know why this is on your PR.

Well, I do know it’s for CuDNN compat and no, you would not need it for your own implementation.