Hi everyone,
I’m currently working on training a PyTorch model for singing voice/music source separation. I am not getting out of memory problem while training the model, but when I use the following inference code I am getting CUDA out of memory error. I am using a GPU of 24 GB memory. Why does inference use so much more memory than training? I have a manual memory release in my code, what is the reason or operation for the excessive memory usage?
Any advice on how to resolve this would be greatly appreciated. Thank you!
Here’s my code:
def main():
sample_rate = 44100
device = torch.device(‘cuda:0’)
model_path=‘/data2/Will/DNN-based_source_separation-main/DNN-based_source_separation-main/src/tmp/model/best.pth’
model = MMDenseNet.build_model(model_path) ## the model
config = torch.load(model_path, map_location=lambda storage, loc: storage)
model.load_state_dict(config[‘state_dict’])
model = model.to(device)
print(f"Total number of parameters: {sum([p.numel() for p in model.parameters()])}")
model.eval()
channel = model.in_channels
n_fft = 2048
hop = 1024
window_fn = torch.hann_window(n_fft, periodic=True,device=device)
if os.path.isdir(‘estimated/’):
ABS_path = os.path.join(os.getcwd(),‘estimated’)
else:
os.mkdir(‘estimated/’)
for name in names:
print(name)
mixture_path = os.path.join(musdb18_root,'test',name,"mixture.wav")
with torch.no_grad():
source, sr = torchaudio.load(mixture_path)
source = source.to(device)
source_duration = source.size(1)/44100
source_stft = torch.stft(source, n_fft=n_fft, hop_length=hop,window=window_fn,return_complex=True)
source_stft = torch.unsqueeze(source_stft, dim=0)
print(source_stft.shape)
estimated = model(source_stft)
print(estimated.shape)
channels = estimated.size()[:-2] ## keep the B,C
estimated = estimated.view(-1, *estimated.size()[-2:])
estimated_out = torch.istft(estimated, n_fft=n_fft, hop_length=hop, window=window_fn, return_complex=False)
estimated_out = estimated_out.view(*channels, -1).squeeze(0).cpu()
print('TEST',estimated_out.shape,source.shape)
est_path = os.path.join(ABS_path,'{}.wav'.format(name))
torchaudio.save(est_path,estimated_out,sample_rate=sample_rate,channels_first=True,bits_per_sample=16)
del source, source_stft, estimated, estimated_out
torch.cuda.empty_cache()