Model.to("cpu") does not release GPU memory allocated by registered buffer

I have figured that registered_buffer does not release GPU memory when the model is moved back to CPU.

Here is the minimal code for reproducing the observation

import torch
from torch import nn
from subprocess import Popen, PIPE

class TestNet(nn.Module):
    def __init__(
        self
    ):
        super().__init__()
        self.register_buffer("test", torch.Tensor([1,2,3]).view(-1, 1, 1), False)

def log_nvidia_smi():
    print("--- nvidia-smi ---")
    for line in Popen('nvidia-smi', shell=True, stdout=PIPE).stdout:
        line = line.decode("utf-8")
        if "python" in line:
            print(line.rstrip())
    print("------------------")

def log_where_buffer_is(model):
    for buffer_name, buffer in model.named_buffers():
        print(f"buffer {buffer_name} is on {buffer.device}")

model = TestNet()

print("< before moving the model to gpu >")
log_where_buffer_is(model)
log_nvidia_smi()

# move model to gpu
model.to(torch.device("cuda"))

print("\n< after moving the model to gpu >")
log_where_buffer_is(model)
log_nvidia_smi()

# move model to cpu
model.to(torch.device("cpu"))

print("\n< after moving the model back to cpu >")
log_where_buffer_is(model)
log_nvidia_smi()

# delete model and clear cuda cache
del model
torch.cuda.empty_cache()

print("\n< after deleting the model >")
log_nvidia_smi()

The above code logs the following messages

< before moving the model to gpu >
buffer test is on cpu
--- nvidia-smi ---
------------------

< after moving the model to gpu >
buffer test is on cuda:0
--- nvidia-smi ---
|    0   N/A  N/A     2527      C   ...rtualenvs/nova/bin/python     1005MiB |
------------------

< after moving the model back to cpu >
buffer test is on cpu
--- nvidia-smi ---
|    0   N/A  N/A     2527      C   ...rtualenvs/nova/bin/python     1005MiB |
------------------

< after deleting the model >
--- nvidia-smi ---
|    0   N/A  N/A       2527      C   ...rtualenvs/nova/bin/python     1003MiB |
------------------

I have looked into pytorch documentations and figured that there isn’t a way to delete buffers.
I attempted manually deleting the tensors but it didn’t help.

Is there a way to properly clear all the GPU memory after moving the model back to CPU?

In this case the actual memory used by the model itself (it looks like it’s just a 6 element tensor) is miniscule compared to the memory used by the CUDA context (~300MiB+), and the PyTorch CUDA kernels (~600MiB+). What is happening is that when you invoke .cuda() on something for the first time or initialize a device tensor, this pulls in all of PyTorch’s CUDA kernels into GPU memory and creates a CUDA context. (If you had called a library function in cuDNN or cuBLAS you would expect this usage to go even higher when those kernels are loaded!) Unfortunately, just because there are no more GPU tensors doesn’t mean that this magically goes away.

If you want to see the effect of releasing GPU memory actually held by the model, you might want to increase the amount of memory used by the model (e.g., have it use up 1GiB+) of GPU memory.

1 Like

Thank you for your explanation.
I have replaced the line

self.register_buffer("test", torch.Tensor([1,2,3]).view(-1, 1, 1), False)

to

self.register_buffer("test", torch.rand([10000, 10000, 10]), False)

so that I allocate large GPU memory.

The results looks as follows:

< before moving the model to gpu >
buffer test is on cpu
--- nvidia-smi ---
------------------

< after moving the model to gpu >
--- nvidia-smi ---
|    0   N/A  N/A      3401      C   ...rtualenvs/nova/bin/python     4819MiB |
------------------

< after moving the model back to cpu >
--- nvidia-smi ---
|    0   N/A  N/A      3401      C   ...rtualenvs/nova/bin/python     4819MiB |
------------------

< after deleting the model >
--- nvidia-smi ---
|    0   N/A  N/A      3401      C   ...rtualenvs/nova/bin/python     1003MiB |
------------------

The memory 1003MiB seems unrelated to the model as you mention and it’s not being released properly even though I am allocating greater amount of memory

1 Like

I have done little more googling based on what I figured from your explanation.

one of the git issue I found says that deleting pytorch cuda context is not going to be supported.

Is this something relevant?

Yes, I don’t think there are plans to do something like this in pure PyTorch.

“you cannot delete the CUDA context while the PyTorch process is still running”

1 Like

I highly recommend adding this method.That’s really necessary when deploying with PyTorch.
Each time initialization a instance from the hard disk load model, it is time consuming. if you can del pytorch cuda context without exiting the process.
Thank you

CUDA 11.7+ ships with lazy module loading, which will only load the needed device kernels and will thus significantly reduce the startup time as well as CUDA context size. You could try it out and see if this works for you.

1 Like