Calling backward doubles memory

I am currently profiling the memory-consumption of my model. My code is essentially:

log_probs = model.log_prob(inp)
loss = (-1)*log_probs.mean()
loss.backward()

I have ~4 gig of GPU memory (3980MB) before calling backward and ~8 gigs (8882MB) after, but this is suprising to me. Pytorch maintains the whole computational graph before calling backward, which it should be able to free while computing it. I have 2 281 048 parameters, where for each parameter a gradient is computed. I am not sure why I now occupy ~8 gigs of GPU memory?

If I am correct, this means:
(8 gigabytes) / (4 bytes) = 2 000 000 000 floats
which are a lot of floats and more than I would expect.

If this is not expected, is there any way in which I could figure out what’s going on? For example, get the memory consumption per module?

EDIT: The reason for my investigation is that I get an OOM when scaling to a higher-dimensional dataset and I am currently looking into reducing the memory-requirement.

Hi,

How do you measure the gpu memory usage? You should be using torch.cuda.memory_allocated to get the memory actually used by Tensors. Note that this number only counts Tensors, and not the memory required by the cuda driver at initialization.
If you check via nvidia-smi, you will see a larger number because we have a special allocator for speed reasons that do not return the memory back to the driver when Tensors are freed.

Ok, thanks I have not looked at torch.cuda.memory_allocated.

This is off-topic, but do you by chance know what happens when cublas fails to allocate memory (or whether it uses pytorchs allocator)? I suspect that the caching allocator consumes all the memory and then cublas fails when i use it in optimizer.step() because the memory requirement after backward.step() is indeed very small.

I don’t think cublas manages any memory, it takes pointers to inputs. Do you think a of a particular function in there?

Hmm, I unfortunately didn’t save the error. It might have been magma? It was an assert error and not an pytorch OOM error. I will come back or open a new issue if I manage to reproduce it (or when torch.cuda.memory_allocated before optmizier.step fixes the issue).