CUDA out of memory while inferencing

Hi everyone,
I’m currently working on training a PyTorch model for singing voice/music source separation. I am not getting out of memory problem while training the model, but when I use the following inference code I am getting CUDA out of memory error. I am using a GPU of 24 GB memory. Why does inference use so much more memory than training? I have a manual memory release in my code, what is the reason or operation for the excessive memory usage?
Any advice on how to resolve this would be greatly appreciated. Thank you!
Here’s my code:

def main():
sample_rate = 44100
device = torch.device(‘cuda:0’)
model_path=‘/data2/Will/DNN-based_source_separation-main/DNN-based_source_separation-main/src/tmp/model/best.pth’
model = MMDenseNet.build_model(model_path) ## the model
config = torch.load(model_path, map_location=lambda storage, loc: storage)
model.load_state_dict(config[‘state_dict’])
model = model.to(device)
print(f"Total number of parameters: {sum([p.numel() for p in model.parameters()])}")
model.eval()
channel = model.in_channels
n_fft = 2048
hop = 1024
window_fn = torch.hann_window(n_fft, periodic=True,device=device)
if os.path.isdir(‘estimated/’):
ABS_path = os.path.join(os.getcwd(),‘estimated’)
else:
os.mkdir(‘estimated/’)

for name in names:
    print(name)
    mixture_path = os.path.join(musdb18_root,'test',name,"mixture.wav")
    with torch.no_grad():
        source, sr = torchaudio.load(mixture_path)
        source = source.to(device)  
        source_duration = source.size(1)/44100       
        source_stft = torch.stft(source, n_fft=n_fft, hop_length=hop,window=window_fn,return_complex=True)
        source_stft = torch.unsqueeze(source_stft, dim=0)
        print(source_stft.shape)
        estimated = model(source_stft)
        print(estimated.shape)
        channels = estimated.size()[:-2] ## keep the B,C
        estimated = estimated.view(-1, *estimated.size()[-2:])
        estimated_out = torch.istft(estimated, n_fft=n_fft, hop_length=hop, window=window_fn, return_complex=False)
        estimated_out = estimated_out.view(*channels, -1).squeeze(0).cpu()
        print('TEST',estimated_out.shape,source.shape)
        est_path = os.path.join(ABS_path,'{}.wav'.format(name))
        torchaudio.save(est_path,estimated_out,sample_rate=sample_rate,channels_first=True,bits_per_sample=16)
        del source, source_stft, estimated, estimated_out  
        torch.cuda.empty_cache()

Your code is incomplete and not properly formatted. However, depending on the call ordering the training step might still keep some intermediates alive which would then cause the inference to run OOM. You could double check it by printing the allocated memory before running the inference code.
If that’s not the issue, check the input shapes of the training vs. inference code and make sure the inference does not use a significantly larger input.

1 Like

Thank you for your reply, I will double check this part.

def forward(self, input):
bands, sections = self.bands, self.sections
n_bins = input.size(2) #origin 2 because size = (batch, in_ch, nbin, n_frames,
if sum(sections) == n_bins:
x_valid, x_invalid = input, None
else:
sections = [sum(sections), n_bins - sum(sections)]
x_valid, x_invalid = torch.split(input, sections, dim=2)

    x_valid = self.transform_affine_in(x_valid)
    x = self.band_split(x_valid)
    x_f = torch.cat(x,dim=2)
    x_bands = []

    for band, x_band in zip(bands, x):
        x_band = self.net[band](x_band)
        x_bands.append(x_band)

    x_bands = torch.cat(x_bands, dim=2)
    x_full = self.net[FULL](x_f)
    x = torch.cat([x_bands, x_full], dim=1)
    x = self.dense_block(x)
    x = self.norm2d(x)
    real_part = x[:, :, 0::2, :]
    imag_part = x[:, :, 1::2, :]
    real_part = self.glu2dR(real_part)
    imag_part = self.glu2dI(imag_part)
    x = torch.complex(real_part, imag_part)        
    x = x * input 
    _, _, _, n_frames = x.size()
    _, _, _, n_frames_in = input.size()
    padding_width = n_frames - n_frames_in
    padding_left = padding_width // 2
    padding_right = padding_width - padding_left

    x = F.pad(x, (-padding_left, -padding_right))

    if x_invalid is None:
        output = x
    else:
        output = torch.cat([x, x_invalid], dim=2)

    return output

def transform_affine_in(self, input):
    eps = self.eps
    output = (input - self.bias_in.unsqueeze(dim=1)) / (torch.abs(self.scale_in.unsqueeze(dim=1)) + eps)

    return output

def transform_affine_out(self, input):
    output = self.scale_out.unsqueeze(dim=1) * input + self.bias_out.unsqueeze(dim=1)
    return output

Is it the mask multiplication at the end (x = x * input) that is using too much memory? I was letting the model output the result directly, and it was working fine. Will this dot-multiplication operation have a big impact on memory usage?

You could check if broadcasting is used which could increase the memory usage. If not, the resulting x will be allocated which would have the same shape as the input x tensor.
You can also check the memory usage e.g. via torch.cuda.memory_allocated().

2 Likes