Why is custom RNN implementation so slow?

Hello, I am currently creating a generic custom RNN implementation in PyTorch, however it takes about 60x longer to train than if I were to just use nn.RNN.

def forward(self, idx, targets=None):
        emb = self.embedding(idx)

        B, T, C = emb.shape

        h = torch.zeros((B, C)).to(emb.device)

        logits = []

        for t in range(T):
            token = emb[:, t, :]
            h = F.tanh(self.ih(token) + self.hh(h))
            logits.append(self.head(h))

        logits = torch.stack(logits, dim=1)

        logits = self.head(out)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

The nn.RNN module dispatches to cuDNN by default (you could verify it by profiling the use case and check the kernel names), which is uses an optimized algorithm. Your custom implementation launches a lot of kernels internally and would thus suffer from this overhead. You could try to use torch.compile to see if it can improve the performance.

1 Like

Thanks, I have tried torch.compile however it gives minimal performance boost. Is it possible to get near cuDNN performance with optimizations? And what optimizations would I need to do?

You could go to the GitHub repo for PyTorch’s RNNBase class and see what calls are used to optimize internal functions:

Or modify that file in your torch install to what you need.

1 Like

Thanks! Where is the actual forward loop? It seems to lead to _VF.rnn_tanh but I checked files that are called _VF but I cant find anything to do with rnns.

RNNBase is a parent class called below in various RNN implementations in the same file such as RNN, GRU, LSTM, etc.

For example, go to line 467 where the RNN class is defined. And that has a forward method defined. Within that method, as you noticed, are called _VF.rnn_tanh, etc.

These are handled in C++. That sends you here:

Which is just a wrapper for files in ATen/native, such as for cudnn here:

Note, PyTorch backend functions are optimized for each device (i.e. CPU, Cuda). Also note that some functions may be called directly from the cudnn library, which is closed source.

1 Like