Out of memory if loop in custom loss function

Hi,
due to the nature of the loss that I implemented, it is only possible to compute it through a for-loop. When doing so, I run out of cuda memory. Here is the code:

    for i in range(n_predictions):
        for j in range(i,n_predictions):
            loss += multiplicator[j]/\
                    (1+torch.exp(-self.gamma*(prediction[i]-prediction[j])))

I assume that due to “loss += …” a new graph is constructed each time? How can I prevent this?

Thanks in advance!

1 Like
  1. What is the value of n_predictions?
  2. When does it run out of memory? (Are you accumulating history across iterations)? Can you post a link to the rest of your code?

loss += ... will combine the current graph of loss with the subgraph for multiplcator[j] / .... That’s unavoidable if you want to backpropagate through loss. If you don’t need to compute a derivative, use:

with torch.no_grad():
  <code>

You can rewrite your expression to avoid the for-loops. I’m not sure if this will help with memory usage. It’ll likely be faster.

diff = prediction.reshape(-1, 1) - prediction  # diff[i][j] = prediction[i] - prediction[j]
denom = 1 + torch.exp(-self.gamma * diff)
losses = multiplicator / denom
losses = torch.triu(losses)  # return upper triangular part of diff (i.e. zero out `i < j`)
loss = losses.sum()
1 Like

Thanks a lot, I will check your code. Multiplicator is part of my groundtruth, so it should not have a graph. I checked before and am fairly confident that this is not an issue due to accumulating graphs, but should occur within the first call, but I will verify that again. Currently I try to insure that by

del backprop_loss

in every iteration after the optimizer step etc.

The value of n_predictions is my batch size, which used to be 16, but I set it to 10 now and it seems to not run out of memory with that. But again I will confirm that. I cannot currently link my code unfortunately, it may be published in a while though, I will link the git here then.

Thanks a lot for the code, my implementation did in fact slow down everything noticably, which I attribute to the two loops. I will check again with a bigger batch size and see if your code fixes the issue.

Thanks again, it seems to not only have fixed my memory problem but is also faster, which is great. I will still link the repository here once the code is public.