Gpu out of memory after I call the backward function

I use 32GB memory GPU to train the gpt2-xl and find every time I call the backward(), the memory will increase about 10GB. So the training will stop after 2 epochs because the memory use out. I try some methods like call the torch.cuda.empty_cache() or ‘del loss, output’ after optimizer.step() but it seems not work well. I wanna know it is because I didn’t delete the calculation graph of last epoch correctly? :tired_face:

Are you seeing the memory increase while (or after) calling backward() and is it stable afterwards?
If so, then note that the gradients would need to be stored on the device and will also take memory.
If you are seeing an increased memory usage of 10GB in each iteration, you are most likely storing the computation graph accidentally by e.g. appending the loss (without detaching it) to a list etc.

1 Like

Yeah it’s the second situation you mentioned, I actually append the model output to a list because the loss calculation need three times forward and I use this list to collect three times forward output, and then use torch.cat to concatenate the list to a Tensor, the loss is calculated based on this Tensor. But I actually call the ‘del loss’, couldn’t this step clear the calculation graph?

del loss would eventually delete the loss after the gradients were already calculated, but your use case would still store 3 full computation graphs (with their intermediate activations which are needed to compute the gradients), so I assume you are trying to delete the loss after the loss.backward() was called? If so, wouldn’t the OOM error be already raised at this point?

I think I didn’t state the problem clearly :hushed:. In my codes, there are three forward steps and one backward step in each epoch like this:

for epoch in range(n):
    list = []
    list.append(model(x1)['logit'][0])
    list.append(model(x2)['logit'][0])
    list.append(model(x3)['logit'][0])
    loss = loss_cal(torch.cat(list).reshape(a,b,c))
    loss.backward()  #increase about 10GB
    #... ( the code like step() and zero_grad())
    torch.cuda.empty_cache()  #decrease about 2GB
    del loss  # nothing change

And it can savely run the epoch 1 with the memory increasing about 10GB because I have enough memory right now(20GB/32GB), but the next epoch when I call the loss.backward(), OOM will happen.

You could set the gradients to None instead of resetting them to zero in order to save more memory, as I guess you are running into the OOM in the second step since the gradients are now allocated and use additional memory. Use optimzer.zero_grad(set_to_none=True) and rerun the code.