Need help with GPU memory issues

I have a loss function that requires multiple internal passes:

def my_loss_func(logits, sigma, labels, num_passes):
    total_loss = 0
    img_batch_size = logits.shape[0]
    logits_shape = list(logits.shape)
    vol_std = np.zeros((img_batch_size, num_passes))
    for fpass in range(num_passes):
        noise_array = torch.normal(mean=0.0, std=1.0, size=logits_shape, device=torch.device('cuda:0'))
        stochastic_output = logits + sigma * noise_array
        del noise_array
        temp_vol = torch.softmax(stochastic_output, dim=1)
        temp_vol = temp_vol[:, 0, ...]
        vol_std[:, fpass] = temp_vol.view(4, -1).sum(1).detach().cpu().numpy()
        del temp_vol
        exponent_B = torch.log(torch.sum(torch.exp(stochastic_output), dim=-1, keepdim=True))
        inner_logits = exponent_B - stochastic_output
        soft_inner_logits = labels * inner_logits
        total_loss += torch.exp(soft_inner_logits)
        del exponent_B, inner_logits, soft_inner_logits
    mean_loss = total_loss / num_passes
    actual_loss = torch.mean(torch.log(mean_loss))
    batch_std = np.std(vol_std, axis=1)
    return actual_loss, batch_std

Both logits and sigma are networks outputs and therefore have associated gradients. I run into memory issues when num_passes exceeds 50, are there any other ways in which I could fully optimise memory allocation to allow for a greater number of passes? I’m not at all concerned with readability/ ugly solutions, anything will do.

@PedsB, I am not exactly sure if this is the most efficient way to do the multiple internal passes but firstly have you checked for other running processes that are taking up the GPU memory?

Maybe, kill 'em all: kill -9 $(nvidia-smi | sed -n 's/|\s*[0-9]*\s*\([0-9]*\)\s*.*/\1/p' | sort | uniq | sed '/^$/d')

Secondly, the problem could be at this line of code: vol_std = np.zeros((img_batch_size, num_passes))

What’s your batch_size?

Maybe, use a smaller batch size.

Thanks

Hi @YASJAY, I should’ve clarified: I’m running this job on a GPU cluster so I’m confident memory is completely free prior to job submission. My batch size cannot be changed at all, but it is fairly small, four.

My issue is that I quickly run out of memory, likely because of the total += torch.exp(soft_inner_logits) step; I believe that keeps piling on computation graphs on top of each other, using up a lot of memory. I was more looking for ways to re-write what I’ve done in a more memory-efficient way: E.g.: backpropagating in the loop to free up intermediary memory (Which unfortunately I can’t do in this case since the loss cannot be calculated until I’ve gone through N passes).