RNNCell uses much more memory than RNN

Hi everyone! I am working with somewhat complicated RNN architectures that receive inputs from multiple sources in such a way that requires me to process each RNN layer separately and in a sequential fashion. Because of this, I am utilizing the RNNCell block. However, I’ve noticed that when training these networks, the memory utilization far exceeds those of RNNs with the same architectures. As a simple experiment, I devised a 1-layer RNN using the classic RNN() and compared it to an RNNCell(). This is what my sample code looks like:

For RNN -

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, nonlinearity = 'relu', batch_first = True)
        self.fc = nn.Linear(hidden_size, num_classes)


    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        print(h0.shape)
        output, hidden = self.rnn(x, h0) # output shape (batch_size, seq_length, hidden_size)RNN)
        output = output[:, -1, :]
        output = self.fc(output)

        return output

For the analogous RNNCell -

## make the model - a recurrent network with one hidden layer, and a fully connected output layer
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn = nn.RNNCell(input_size, hidden_size, num_layers, nonlinearity = 'relu')
        self.fc = nn.Linear(hidden_size, num_classes)


    def forward(self, x):
        h = None
        for t in range(x.shape[1]):
            h = self.rnn(x[:,t,:], h)
        output = self.fc(h)
        return output

I’ve determined that during validation, the memory usage is quite similar, yet during training the RNN utilizes 3.43GB while the RNNCell uses 8.97GB. I’ve also taken a look at the computational graph using torchviz to determine that the RNN generates one that looks like the following:


while the RNNCell generates a much more complicated one, which scales with the sequence length:

I know that the RNN() class itself is built on RNNCell(), so I assume some kind of optimization is being done behind the scenes to avoid massive memory usage. I’m aware that PyTorch allows for gradient checkpointing and am wondering if this might be what’s happening within the actual RNN code.

So I guess my questions are:

  1. How does RNN manage to be memory efficient? I don’t believe it scales with sequence length, but I haven’t tested this extensively myself.
  2. How can I make my code more efficient/as efficient as what is possible with RNN? At the current moment, I am limited on how large my networks can be, even when running them on A100s.

Thanks in advance for any and all feedback!

The nn.RNN implementation should use cuDNN while I would guess the RNNCell implementation is using a native implementation and might thus create a lot of intermediates (the second picture is unreadable, so unsure what’s called).

Ah yes, it looks like the image got compressed. I’ve attached some screenshots, but as you’ve already said it does look like RNN utilizes cuDNN while RNNCell does not. I will look into enabling this for my approach. Thank you!