I am a college student majoring in astronomy, and would like to do FFT and de-dispersion in GPU with torch module. The original data is a time series of 4GB in float32 format, and the factor of de-dispersion is the same. But it takes more than 24GB which exceeds my gpu memory. I am confused about what takes so much memory and how I can improve it. Thank you!
def dedisperse(spectrum,fre=5e8,DM=0.0,fre0=1e9,device='cuda', log_memory=False): # do dedispersion for the signal
# Create frequency data
deltat = torch.arange(spectrum.shape[0], device='cuda:0', dtype = torch.float64)
deltat = deltat * fre / spectrum.shape[0] + fre0
print('deltat shape',deltat.shape)
print('spectrum.shape[0]', spectrum.shape[0])
phase=4.148808e15*(deltat**(-2)-fre0**(-2))*DM*deltat # culculate the phase caused by the time-delay
print('phase shape',phase.shape)
print('phase:',phase)
# DM
del deltat
torch.cuda.empty_cache()
gc.collect()
print('get_gpu_memory_del_deltat:',get_gpu_memory())
spectrum = spectrum * torch.exp(-2j * math.pi * phase)
spectrum_size_gb = spectrum.element_size() * spectrum.numel() / (1024 ** 3)
print('get_gpu_memory_dedispers:',get_gpu_memory(),spectrum_size_gb,'GB',spectrum.element_size(),spectrum.numel())
del phase
torch.cuda.empty_cache()
gc.collect()
print('get_gpu_memory_del_phase:',get_gpu_memory())
print('dedisperse completed')
return spectrum#ddm_data
When you compute spectrum * torch.exp(-2j * math.pi * phase), PyTorch converts your float32 data to complex64 (or complex128), effectively doubling the memory requirement
Float64 operations: Your deltat and phase are computed in float64, which uses twice the memory of float32
Intermediate tensors: The expression creates multiple intermediate tensors that aren’t immediately freed
Thank you for replying! In fact I made a mistake before: the program goes well when the data is in float32 format, but it warns me that the memory is not enough when it is float64. However, I calculate the total memory should be 16 GB while it takes more than 24 GB. I am confused what other thing takes so much memory.