Question: What takes so much gpu memory?

I am a college student majoring in astronomy, and would like to do FFT and de-dispersion in GPU with torch module. The original data is a time series of 4GB in float32 format, and the factor of de-dispersion is the same. But it takes more than 24GB which exceeds my gpu memory. I am confused about what takes so much memory and how I can improve it. Thank you!

def dedisperse(spectrum,fre=5e8,DM=0.0,fre0=1e9,device='cuda', log_memory=False):   # do dedispersion for the signal
    # Create frequency data
    deltat = torch.arange(spectrum.shape[0], device='cuda:0', dtype = torch.float64)
    deltat = deltat * fre / spectrum.shape[0] + fre0
    print('deltat shape',deltat.shape)
    print('spectrum.shape[0]', spectrum.shape[0])
    phase=4.148808e15*(deltat**(-2)-fre0**(-2))*DM*deltat    # culculate the phase caused by the time-delay
    print('phase shape',phase.shape)
    print('phase:',phase)
    # DM
    del deltat
    torch.cuda.empty_cache()
    gc.collect()
    print('get_gpu_memory_del_deltat:',get_gpu_memory())
    spectrum = spectrum * torch.exp(-2j * math.pi * phase)
    spectrum_size_gb = spectrum.element_size() * spectrum.numel() / (1024 ** 3)
    print('get_gpu_memory_dedispers:',get_gpu_memory(),spectrum_size_gb,'GB',spectrum.element_size(),spectrum.numel())
    del phase
    torch.cuda.empty_cache()
    gc.collect()
    print('get_gpu_memory_del_phase:',get_gpu_memory())
    print('dedisperse completed')
    return spectrum#ddm_data

When you compute spectrum * torch.exp(-2j * math.pi * phase), PyTorch converts your float32 data to complex64 (or complex128), effectively doubling the memory requirement

  • Float64 operations: Your deltat and phase are computed in float64, which uses twice the memory of float32
  • Intermediate tensors: The expression creates multiple intermediate tensors that aren’t immediately freed

Try this

def dedisperse(spectrum, fre=5e8, DM=0.0, fre0=1e9, device='cuda'):
    # Use float32 consistently
    deltat = torch.arange(spectrum.shape[0], device=device, dtype=torch.float32)
    deltat = deltat * (fre / spectrum.shape[0]) + fre0
    
    # Calculate phase in float32
    phase = 4.148808e15 * (deltat**(-2) - fre0**(-2)) * DM * deltat
    
    # In-place operation to save memory
    spectrum = spectrum.to(torch.complex64)  # Convert once
    torch.exp(-2j * math.pi * phase, out=spectrum)  # Reuse memory
    
    return spectrum

Thank you for replying! In fact I made a mistake before: the program goes well when the data is in float32 format, but it warns me that the memory is not enough when it is float64. However, I calculate the total memory should be 16 GB while it takes more than 24 GB. I am confused what other thing takes so much memory.