PyTorch uses a memory caching mechanism, which reuses already allocated device memory to avoid the expensive memory allocation calls.
E.g. your training loop might use 7GB of memory to store the model parameters, input data, and the intermediate tensors, which are needed to calculate the gradients in the backward pass.
Once the backward pass is done and the gradients are calculated, the intermediates are not longer needed and can be freed.
However, to avoid free
and alloc
calls, this memory is pushed to the cache and reused later.