After calling torch.nn.Module.cuda(), model doesn't seem to be freed from RAM

As per documentation here for cuda method in torch.nn.Module, the model parameters and buffers should have moved to GPU. SO I assume that the model in RAM will get wiped out, but when checking the Resident Set Size of the program, it was increasing 3x than the initial model in cpu.
Here is the code I’ve used to test,

import torch
import psutil
import gc
import time
import sys

class TestModel(torch.nn.Module):
    def __init__(self) -> None:
        super(TestModel, self).__init__()
        self.ip_l = torch.nn.Linear(1, 512)
        self.hidden = torch.nn.Sequential(*([torch.nn.Linear(512, 512)] * 1000))
        self.op_l = torch.nn.Linear(512, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        self.op_l(self.hidden(self.ip_l(x)))
        return x


if __name__ == "__main__":
    ovr_ram = psutil.virtual_memory().total / (1024 * 1024)

    proc = psutil.Process()
    time.sleep(1)
    print(f"Before model init, \n\tRSS(mb) - {(proc.memory_percent() / 100) * ovr_ram}")

    model = TestModel()
    time.sleep(1)
    print(f"Model of size {sys.getsizeof(model) / 1024} kb at {set([i.device for i in model.parameters()])}, \n\tRSS(mb) - {(proc.memory_percent() / 100) * ovr_ram}") 

    model.cuda(0)
    time.sleep(1)
    print(f"Model of size {sys.getsizeof(model) / 1024} kb at {set([i.device for i in model.parameters()])}, \n\tRSS(mb) - {(proc.memory_percent() / 100) * ovr_ram}")

    gc.collect()
    time.sleep(1)
    print(f"After garbage collect model of size {sys.getsizeof(model) / 1024} kb, \n\tRSS(mb) - {(proc.memory_percent() / 100) * ovr_ram}")

Is it the right way to profile the memory usage of the program and also, is the torch.nn.Module making a copy of model parameters to GPU.
Tried this in Ubuntu 22.04 x86_64 machine with torch:2.0.0+cu117

The model is tiny and I assume the lazily loaded driver uses the host RAM. Update to the most recent stable or nightly release with CUDA 11.8 or 12.1 and the host RAM usage should decrease.

1 Like

That worked, thanks @ptrblck