25x slowdown when using a custom GRU layer

I’m trying to create a custom GRU layer, it effectively introduces one more gate. Rather than paste the entire module code here, I have uploaded a snippet to a Python file on Github. The file can be imported to get the BlockSparseGRU class. I’m trying to design the class to be as identical to nn.GRU as I can; it is in fact a subclass of nn.GRU.

I’m finding that my custom module is significantly slower to train compared to a standard GRU. Here’s a timeit script.

import timeit

setup_code = '''
import torch, torch.nn
from bsgru import BlockSparseGRU
x = torch.rand(13, 50, 511)
net1 = torch.nn.GRU(511, 128)
net2 = BlockSparseGRU(511, 128, num_blocks=4)'''

times1 = timeit.repeat('net1(x)', setup_code, repeat=5, number=10)
times2 = timeit.repeat('net2(x)', setup_code, repeat=5, number=10)
print('GRU time: {}'.format(min(times1)))
print('BSGRU time: {}'.format(min(times2)))

I then get the following runtimes:

>> GRU time: 0.06263355700002649
>> BSGRU time: 1.570136519000016

So as we can see, using my custom GRU layer causes a 25x slowdown. I’m wondering if anyone has tips on how to optimize the custom module. I know nn.GRU is written in C++, but I wasn’t expecting such a huge difference in run-times. I have started reading into the PyTorch C++ Frontend, but even there, I’d have to define the forward function manually, and I’m worried I would hit a similar slowdown there.

GRU has two gates: reset and update. LSTM has three gates: input, output and forget.

How is your modified GRU functionally different from an LSTM?

Hey Johnson, I’m trying to introduce a clustering gate—it takes the input at time step t​ and classifies it as one-in-K groups. Depending on which group the current input and the last input time step are in, I utilize the [i,j]th chunk within the hidden unit matrix. I have some equations written out for my thesis, let me know if providing those would help or if I got the idea across.

Rather than trying to explain that, though—suppose I hypothetically wanted to manually re-implement an LSTM, are there any tips for achieving similar runtime efficiency purely using the Python API? Or would it help to build the module in C++?

First thing I notice is you’re running over the inputs with a for loop (see line 181). If I am not mistaken, that can be made to be asynchronuous. One of the main advantages of tensorized operations is maximizing calculation capacity of multi-processing units/cuda cores.

Second point I notice is that you’re running “check”(for inf/nan) at nearly every step in the process. Why not just run that initially on the first input and then just throw in an occasional torch.clip or torch.clamp between layers? If you clamp the values to be between -100 to 100, there is basically no possible situation in which a matmul will turn that into an inf value. If you’re using a division somewhere that I am not seeing, you can mask the denominator for zeros and add a small epsilon to those. That should eliminate any possible inf/nan values.

Thanks for pointing this out! Cutting out the inf/nan checks altogether saved about 0.3 seconds of runtime. Once I have the module fully optimized, I’ll go back and add clamp statements.

But about the for-loop in line 181–that one is looping over the input sequence (over the timesteps), so is it safe/valid to do that asynchronously? Wouldn’t you need to have the values of h_{t-1} before you compute h_{t}?

As for the actual method of doing this, is it as simple as adding async and await keywords, or is there some Torch-specific API for asynchronous loops?

h_{t} represents the hidden state. And you are correct that you need the prior hidden state to calculate h_{t}.

But that’s not what you’re looping through on line 181. You’re looping through the input features.

Suppose we have a GRU chatbot. We write a sentence, the bot gives an output. We send in tokenized sentences as inputs. Individual words of the sentence represent input features. We can process each sentence asynchronously. Then we send out our chatbot’s reply and the hidden layer that contains the dialogue context. So the bot can “remember” what we’re talking about.

Another example of input features might be OCHLV price data. Each step contains 5 input features. What we can’t know, however, is what the next time step input will be.

So the input features should be calculated independently and thus can be asynchronous. Where the loop comes in is when we want to send in the second time sequence. But that is all handled outside of the module in the training loop. The data coming in at each time step can be processed asynchronuously.

I think I might still have a GRU setup up in a spreadsheet. I can check tomorrow and clarify better. But I’m certain the inputs at each timestep are handled separately.

Had a closer look at my notes and, you’re correct. The dim=1 is sequence length. Which would have a time dependency.

One other issue I was considering is how your tensor is being fed to each. The vanilla nn.GRU takes (seq_len, batch_size, in_features) as batch_first defaults to False. So if the vanilla GRU is looping through 13 sequence length while your custom GRU defaults to batch_first=True and is looping through 50 sequence length. That might account for some of the slow down.

Thanks for pointing this out! Rectifying that typo above, along with removing the check() statements:

import timeit

setup_code = '''
import torch, torch.nn
from bsgru import BlockSparseGRU
x = torch.rand(13, 50, 511)
net1 = torch.nn.GRU(511, 128, batch_first=True)
net2 = BlockSparseGRU(511, 128, num_blocks=8)'''

times = timeit.repeat('net1(x)', setup_code, repeat=5, number=10)
print('GRU time: {}'.format(min(times)))
times = timeit.repeat('net2(x)', setup_code, repeat=5, number=10)
print('BSGRU time: {}'.format(min(times)))

Seems like I still get more or less the same runtime difference.

>> GRU time: 0.0769930789974751
>> BSGRU time: 1.578554071995313

I feel like the only option is to re-implement the model using the C++ front-end, but then I’m not sure how to get started with that, and then subsequently, if the rest of my training script would also need to be re-written or whether I can invoke the C++ module directly in Python code.

I ran your code. The sigmoid and tanh functions are definitely your main bottlenecks. It’s possible they opted to change those to more time-efficient activation functions, such as relu or some approximation with lower compute time, in C++.

Hi, have you considered this to be a result of custom RNNs lacking cuDNN optimization?

All RNN implementations provided by PyTorch are leveraging a cuDNN kernel, which is a highly optimized low-level implementation of the underlying algorithm that’s been provided by NVIDIA. The main problem with them is that they are inflexible, as anything not supported by the default kernel can’t use it. Here’s a quote from PyTorch:

However, many users want to implement their own custom RNNs, taking ideas from recent literature. Applying Layer Normalization to LSTMs is one such use case. Because the PyTorch CUDA LSTM implementation uses a fused kernel, it is difficult to insert normalizations or even modify the base LSTM implementation. Many users have turned to writing custom implementations using standard PyTorch operators, but such code suffers from high overhead: most PyTorch operations launch at least one kernel on the GPU and RNNs generally run many operations due to their recurrent nature.

Could you take a look at that article from PyTorch and let me know if your python code can be sped up using those recommendations? Specifically TorchScript.

Hi Aidan, thanks for suggesting I use TorchScript! I’m part of the way there in re-writing my custom GRU, though I found out that there’s a problem with TorchScript using getattr and setattr to utilize layer names. Here’s a minimum example to replicate the issue:

Example Code
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.jit as jit
import warnings
from collections import namedtuple
from typing import List, Tuple
from torch import Tensor
from torch.nn.functional import one_hot, softmax
import numpy as np
import numbers
import sys


class BSGRUCell(jit.ScriptModule):
    def __init__(self, input_size, hidden_size, num_blocks, bias=True):
        super(BSGRUCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_blocks = num_blocks
        block_size = int(hidden_size / num_blocks)
        if (block_size * num_blocks) != hidden_size:
            raise ValueError("hidden_size must be divisible by num_blocks")
        self.block_size = block_size
        self.weight_ik = Parameter(torch.randn(input_size, num_blocks))
        self.weight_hk = Parameter(torch.randn(hidden_size, num_blocks))
        self.weight_ih = Parameter(torch.randn(
            1, num_blocks, input_size, 3 * block_size))
        self.weight_hh = Parameter(torch.randn(
            1, num_blocks, num_blocks, block_size, 3 * block_size))
        self.bias_ik = Parameter(torch.randn(num_blocks))
        self.bias_ih = Parameter(torch.randn(3 * hidden_size))
        self.bias_hh = Parameter(torch.randn(3 * hidden_size))

    @jit.script_method
    def forward(self, input, state):
        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
        # "x" and "y" variable names refer to previous and current timesteps
        hx, kx = state

        # predict the current active block for this time step
        ky = (torch.mm(input, self.weight_ik) +
              torch.mm(hx, self.weight_hk) +
              self.bias_ik)
        if self.training:
            beta = 5.
            ky = softmax(beta*ky, dim=-1)
        else:
            ky = one_hot(ky.argmax(-1), num_classes=self.num_blocks).float()

        # use the current block predictions to sparsely activate
        # the input to hidden (W-matrix) weights
        mat_W = torch.mul(ky.unsqueeze(-1).unsqueeze(-1), self.weight_ih)
        mat_W = mat_W.view(-1, 3 * self.hidden_size, self.input_size)
        mat_W = torch.transpose(mat_W, 1, 2)

        # use the previous + current block predictions to sparsely activate
        # the hidden to hidden (U-matrix) weights
        kk = torch.bmm(kx.unsqueeze(2), ky.unsqueeze(1))
        mat_U = torch.mul(kk.unsqueeze(-1).unsqueeze(-1), self.weight_hh)
        mat_U = mat_U.view(-1, 3 * self.hidden_size, self.hidden_size)
        mat_U = torch.transpose(mat_U, 1, 2)

        # compute all the gate values
        i_gates = torch.bmm(input.unsqueeze(1), mat_W) + self.bias_ih
        h_gates = torch.bmm(hx.unsqueeze(1), mat_U) + self.bias_hh
        i_z, i_r, i_n = i_gates.squeeze(1).chunk(3, -1)
        h_z, h_r, h_n = h_gates.squeeze(1).chunk(3, -1)
        
        updategate = torch.sigmoid(i_z + h_z)
        resetgate = torch.sigmoid(i_r + h_r)
        newgate = torch.tanh(i_n + h_n * resetgate)
        hy = newgate + updategate * (hx - newgate)

        return hy, (hy, ky)



class BSGRU(jit.ScriptModule):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        num_layers: int = 1,
        num_blocks: int = 1,
        bias: bool = True,
        batch_first: bool = False,
        bidirectional: bool = False,
        beta: float = 1.,
        dropout: float = 0.,
        device=None,
        dtype=None) -> None:

        super(BSGRU, self).__init__()
        self.num_layers = num_layers
        self.batch_first = batch_first
        for l in range(num_layers):
            setattr(self, "layer_{}".format(l), 
                BSGRUCell(input_size, hidden_size, num_blocks, bias=bias))

    @jit.script_method
    def forward(self, input, state):
        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
        
        # put seq_len as the first axes
        if self.batch_first:
            input = torch.transpose(input, 0, 1)

        in_tensor = input
        for l in range(self.num_layers):
            in_seq = in_tensor.unbind(0)
            out_seq = torch.jit.annotate(List[Tensor], [])
            for i in range(len(in_seq)):
                out, state = getattr(self, "layer_{}".format(l))(in_seq[i], state)
                out_seq += [out]
            out_seq = torch.stack(out_seq)
            in_tensor = out_seq
        return out_seq


seq_len, batch_size, input_size, hidden_size, num_blocks = 11, 5, 7, 4, 2
inp = torch.randn(seq_len, batch_size, input_size)
h0 = torch.rand(batch_size, hidden_size)
k0 = torch.from_numpy(np.tile([0,1],[batch_size, 1])).float()
rnn = BSGRU(input_size, hidden_size, 1, num_blocks)
out = rnn(inp, (h0, k0))
print(out)

The error I get:

RuntimeError: 
getattr's second argument must be a string literal:
  File "temp.py", line 114
            out_seq = torch.jit.annotate(List[Tensor], [])
            for i in range(len(in_seq)):
                out, state = getattr(self, "layer_{}".format(l))(in_seq[i], state)
                             ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
                out_seq += [out]
            out_seq = torch.stack(out_seq)

I was able to make the BSGRUCell successfully run, but how does one write modules that have a variable amount of layers if they aren’t supposed to use attributes?