Custom RNN uses twice as much GPU memory due to torch.mm always making new CUDA allocations

Hi,

I’m using custom RNN kernels and am finding that training uses something like 2x more CUDA memory than with LSTMs. To investigate I recorded CUDA allocations and found that memory allocation grows unreasonably due to repeated calls to torch.mm inside the RNN loop. I was not expecting the same line in my model to repeatedly increase memory use. As a fix, I attempted the use of the “out” keyword to avoid new CUDA allocations, but got: RuntimeError: mm(): functions with out=… arguments don’t support automatic differentiation, but one of the arguments requires grad.

Is there any way to avoid new CUDA allocations when using torch.mm inside a loop? This is a really big problem for me as it severely limits my model size and/or truncated backprop window. Thanks in advance and fingers crossed! (If I have to learn CUDA in order to get lighter memory usage similar to built-in RNN layers that would really suck)

Model forward code below.

@torch.compile
def forward(
    self, input_seq: Tensor, hidden: Tensor) -> Tuple[Tensor, Tensor]:
    
    nseq, nbatch, nx = input_seq.shape

    epss = torch.randn((nseq, nbatch, self.hidden_size),device=input_seq.device)
    epss = epss.unbind(0)
    
    inputs = input_seq.unbind(0)
    outputs = torch.jit.annotate(List[Tensor], [])

    # x_results = torch.randn((nbatch, self.hidden_size))
    # z_results = torch.randn((nbatch, self.hidden_size))

    for i in range(len(input_seq)):
        x = inputs[i]
        eps = epss[i]
        
        predicted_distribution = torch.mm(hidden, self.weight_encoder) 
        mean_, z = predicted_distribution.chunk(2,1)
        
        z = mean_ + eps * torch.exp(0.5*z)

        # lines below allocate new CUDA memory for each i
        x_results = torch.mm(x, self.weight_ih) 
        z_results = torch.mm(z, self.weight_zh) 
        # lines below don't support autograd 
        # torch.mm(x, self.weight_ih, out=x_results) 
        # torch.mm(z, self.weight_zh, out=z_results) 

        r, z, n       = x_results.chunk(3, 1)
        z_r, z_z, z_n = z_results.chunk(3, 1)
        
        z = torch.sigmoid(z + z_z)
        n = torch.tanh(n + torch.sigmoid(r + z_r) * z_n)
            
        hidden = n + torch.mul(z, (hidden - n))

        outputs += [hidden]
        gc.collect()

    return torch.stack(outputs), hidden

It seems you are explicitly storing the matmul outputs:

x_results = torch.mm(x, self.weight_ih) 
z_results = torch.mm(z, self.weight_zh) 
...
r, z, n       = x_results.chunk(3, 1)
z_r, z_z, z_n = z_results.chunk(3, 1)
z = torch.sigmoid(z + z_z)
n = torch.tanh(n + torch.sigmoid(r + z_r) * z_n)
hidden = n + torch.mul(z, (hidden - n))
outputs += [hidden]

Since outputs now stores the processed matmul output in a differentiable way, these outputs would be stored to allow gradient computation during the backward call, which also explains the error you are seeing when trying to use the out argument.

If you are limited by the GPU memory you could try to e.g. offload the intermediate activations to the CPU as described here.

Right, so storing every iteration of x_results, z_results is actually needed for gradient computation, makes sense. But how come the built-in LSTM CUDA kernel uses much less GPU memory despite similar number of model parameters and activations? Is there any way to make use of similar tricks/optimizations without explicit CUDA code?