Torch.fft.rfft consumes too high gpu memory


I have the following code:

import torch

def get_memory_mb(data):
    return data.numel() * data.element_size() / (1024*1024)

n_fft = 153999
with torch.no_grad():
    rir = torch.randn(64, 6, 10, 10000).float().cuda()
    rir_fft = torch.fft.rfft(rir, n=n_fft, dim=-1)
    print("rir memory should be: {} MB".format(get_memory_mb(rir)))
    print("rir_fft memory should be: {} MB".format(get_memory_mb(rir_fft)))

The output is as follows:

rir memory should be: 146.484375 MB
rir_fft memory should be: 2255.859375 MB

According to the output, I would expect the code consumes no more than 3GB GPU memory. But it actually consumes around 16 GB GPU memory. Is there anything wrong in my way of using torch.fft.rfft without auto grad? And is there anyway to reduce the memory usage of this function (I’m trying to use it with GPU as an efficient way to calculate 1d convolution between two large 1d tensors)?

I was able to alleviate the problem using torch.cuda.empty_cache() to release some memory after fft