Memory increase per epoch on CUDA but not CPU

Hi there! I am working on a custom GNN that is implemented in PyTorch. During training on GPU, I observed an increase in VRAM, main memory, and training time / epoch as well as a decrease in GPU utilization (down to 0%). I could not find anything in the forum or documentation that led to an improvement. Common approaches such as (a) avoiding appending tensors that are connected to the computational graph to a list without detaching, (b) using cuda.empty_cache() or gc.collect() with additional del statements, (c) optimizing data loading via DataLoader and a custom Dataset did only lead to minor improvements or seemed to be not present in my case. I have not found a good way to use the PyTorch profiler for my problem.

However, I found that I could resolve at least the increase in VRAM by deleting these lines:

# copy model
model_float32_cuda = copy.deepcopy(model)
model_float64_cuda = copy.deepcopy(model)
# trace & save model (cuda, float32)
model_float32_cuda.device = cuda
model_float32_cuda.dtype = torch.float32
model_scripted_float32_cuda = torch.jit.optimize_for_inference(torch.jit.script(model_float32_cuda.float()))["save_path"], "")), os.path.join(PARAMETERS["save_path"], "model_float32_cuda_state_dict.pth"))
print(f"Saving model at: {model_float32_cuda.device}, {model_float32_cuda.dtype}")

# trace & save model (cuda, float64)
model_float64_cuda.device = cuda
model_float64_cuda.dtype = torch.float64
model_scripted_float64_cuda = torch.jit.optimize_for_inference(torch.jit.script(model_float64_cuda.double()))["save_path"], "")), os.path.join(PARAMETERS["save_path"], "model_float64_cuda_state_dict.pth"))
print(f"Saving model at: {model_float64_cuda.device}, {model_float64_cuda.dtype}")

I suspect my VRAM increased here with every epoch because models are cached on the GPU? I never tried deleting the model explicitly and just resorted to saving the state dict only.

Regarding slow downs, main memory consumption, and reduced GPU load, I suspect they are connected and I have made the following observations: (a) main memory stays constant for (epochs * training_set_size) time, afterwards it increases linearly (e.g. after with 500 data points, memory is constant for 600 epochs. With 5’000 data points, memory increases after 60 epochs), (b) the problem seems to occur only with training on GPU and not on CPU, (c) the number of tensors and object stays constant with the exception of the first couple of epochs. The problem occurs on different machines with different versions of PyTorch.

My best hypothesis is that I do keep a reference to the computational graph (that I understand lives on the CPU?) somewhere but I cannot figure out where and also not what the best debugging strategy would be. I also don’t understand why the problem seems to be only related to CUDA and not CPU execution. Are different data structures used depending on which device is used? I could also imagine that during transfer from CPU to CUDA (see code below) a reference is kept somewhere. For example, initialization of the model is a bit clunky because I used torchlayers.

I am using:

pytorch 2.0.1 py3.11_cuda11.8_cudnn8.7.0_0 pytorch
pytorch-cuda 11.8 h7e8668a_5 pytorch

on 5.19.0-46-generic #47~22.04.1-Ubuntu Liunux

Here is a trimmed down version of my training script. I also add some graphs with the observed memory increase and training time slowdown.

Thank you already for your help! :slightly_smiling_face:

Let me know if I can add additional information, unfortunately I cannot give away too many details on the model because it is unpublished work.

#!/usr/bin/env python

# import statements

def instantiate_model(PARAMETERS: dict, training_data: MyDataset):
    # loader to instantiate model
    instantiation_loader = DataLoader(
        training_data, batch_size=PARAMETERS["batch_size"], shuffle=False

    # save E0 (and dE0)
    PARAMETERS["E0"] = training_data._e0.item()
    PARAMETERS["E0_IDX"] = training_data._e0_idx.item()

    # backup user selection
    user_device_name = PARAMETERS["device_name"]
    user_device = PARAMETERS["device"]

    # change device temporarily to CPU and instantiate model on CPU
    PARAMETERS['device_name'] = 'cpu'
    PARAMETERS['device'] = torch.device(PARAMETERS['device_name'])

    model = MyModel(**PARAMETERS)

    print(f"E0 saved in parameters: {PARAMETERS['E0']}")
    print(f"E0_IDX saved in parameters: {PARAMETERS['E0_IDX']}")

    # get a sample batch and show it to torchlayers
    sample_batch = next(iter(instantiation_loader))

    # check dtype
    for key in sample_batch:
        if sample_batch[key] is not None and sample_batch[key].dtype != torch.int64:
            if PARAMETERS["dtype"] == "float32":
                assert sample_batch[key].dtype == torch.float32
            elif PARAMETERS["dtype"] == "float64":
                assert sample_batch[key].dtype == torch.float64
                print(f"Unsupported dtype: {PARAMETERS['dtype']}")

    sample_input = (sample_batch["feature_1"][0], sample_batch["feature_2"], None, sample_batch["feature_4"], sample_batch["feature_5"])
    model =, sample_input)
    print("Sample batch on CPU:")
    prediction_cpu = model(sample_input)
    print("Prediction (CPU):")

    print("Prediction + E0 (CPU)")
    print(prediction_cpu.detach() + PARAMETERS["E0"])
    del prediction_cpu
    # move data back to GPU
    if PARAMETERS['device_name'] != user_device_name:
        PARAMETERS['device_name'] = user_device_name
        PARAMETERS['device'] = user_device["device"])
        model.device = PARAMETERS["device"]
        model.dtype = PARAMETERS["dtype"]   

        # sample input on device
        sample_input = [["device"]) if isinstance(input, torch.Tensor) else None for input in sample_input]
        prediction_cuda = model(sample_input)
        print("Prediction (Device):")
        print("Prediction + E0 (Device)")
        print(prediction_cuda.detach() + PARAMETERS["E0"])
        del prediction_cuda

    del sample_input
    del sample_batch
    del instantiation_loader

    return model

def train_one_epoch(epoch, model, optimizer, loss_fn, training_loader, PARAMETERS):
    # put model in train mode
    training_start = time.time()

    prediction_times = list()
    loss_times = list()
    backward_times = list()
    optimizer_times = list()
    clean_up_times = list()

    for idx, batch in enumerate(training_loader):
        if idx % 100 == 0:
            print(f"Training: Epoch: {epoch}. Batch {idx} / {len(training_loader)}")

        # transfer batch to GPU and prepare input
        for key in batch:
            batch[key] = batch[key].to(PARAMETERS["device"])

        # check dtype
        for key in batch:
            if batch[key] is not None and batch[key].dtype != torch.int64:
                if PARAMETERS["dtype"] == "float32":
                    assert batch[key].dtype == torch.float32
                elif PARAMETERS["dtype"] == "float64":
                    assert batch[key].dtype == torch.float64
                    print(f"Unsupported dtype: {PARAMETERS['dtype']}")

        # Zero your gradients for every batch!

        # make prediction
        input = (sample_batch["feature_1"][0], sample_batch["feature_2"], None, sample_batch["feature_4"], sample_batch["feature_5"])
        prediction_start = time.time()
        prediction = model(input)
        prediction_stop = time.time()
        prediction_times.append(prediction_stop - prediction_start)
        # loss calculation
        loss_start = time.time()
        loss = loss_fn(prediction, batch['feature_1'])
        loss_end = time.time()
        loss_times.append(loss_end - loss_start)

        # backprop 
        backward_start = time.time()
        backward_end = time.time()
        backward_times.append(backward_end - backward_start)

        # gradient clipping
        for param in model.parameters():
            if torch.isnan(param.grad).any():
                print("nan gradient found")
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=PARAMETERS['max_grad_norm'])
        # optimizer
        optimizer_start = time.time()
        optimizer_end = time.time()
        optimizer_times.append(optimizer_end - optimizer_start)

        clean_up_start = time.time()
        del loss
        del prediction
        del input
        del batch
        clean_up_end = time.time()
        clean_up_times = clean_up_end - clean_up_start

    training_end = time.time()

    training_time = training_end - training_start
    prediction_time = np.sum(prediction_times)
    loss_time = np.sum(loss_times)
    backward_time = np.sum(backward_times)
    optimizer_time = np.sum(optimizer_times)
    cleanup_time = np.sum(clean_up_times)

    unaccounted_time = training_time - prediction_time - loss_time - backward_time - optimizer_time - cleanup_time

    print(f"Training time: {training_time:.6f}")
    print(f"Unaccounted time: {unaccounted_time:.6f}")
    print(f"Prediction time: {prediction_time:.6f}")
    print(f"Loss time: {loss_time:.6f}")
    print(f"Backward time: {backward_time:.6f}")
    print(f"Optimizer time: {optimizer_time:.6f}")
    print(f"Cleanup time: {cleanup_time:.6f}")

def print_memory_usage():
    tensors = [torch.is_tensor(obj) for obj in gc.get_objects()]
        f"Number of objects: {len(gc.get_objects())}", 
        f"Number of tensors: {len(tensors)}", 
        f"Current CUDA memory usage: {torch.cuda.memory_allocated() / 1024 ** 2 :.8f}", 
        f"Peak CUDA memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.8f}", 
        f"CUDA memory usage (percent max): {torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated() * 100:.8f}%",
        f"Current main memory usage: {psutil.Process(os.getpid()).memory_info().rss / 1024 ** 2:.8f}"

if __name__ == "__main__":
    usage = f"{sys.argv[0]} parameters.yaml"
    if len(sys.argv) != 2:

    # load parameters
    PARAMETERS = load_parameters(sys.argv[1])
    PARAMETERS["parameters_file"] = sys.argv[1]

    # set precision
    if PARAMETERS["dtype"] == "float32":
    elif PARAMETERS["dtype"] == "float64":
        print(f"Unsupported dtype: {PARAMETERS['dtype']}")

    # load training and validation data
    training_data = MyDataset(
    validation_data = MyDataset(
    assert training_data._e0 == validation_data._e0
    assert training_data._e0_idx == validation_data._e0_idx

    model = instantiate_model(PARAMETERS, training_data)

    # assert data and model have precision requested
    for param in model.parameters():
        if PARAMETERS["dtype"] == "float32":
            assert param.dtype == torch.float32
        elif PARAMETERS["dtype"] == "float64":
            assert param.dtype == torch.float64
            print(f"Unsupported dtype: {PARAMETERS['dtype']}")

    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=PARAMETERS['learning_rate'])

    # scheduler
    decay_rate = np.exp(np.log(PARAMETERS['decay_factor']) / (PARAMETERS["num_epochs"]))
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=decay_rate)
    print(f"ExponentialLR scheduler with decay rate {decay_rate}")
    # instantiate data loaders
    training_loader = DataLoader(training_data, batch_size=PARAMETERS["batch_size"], shuffle=True, drop_last=True)
    validation_loader = DataLoader(validation_data, batch_size=PARAMETERS["batch_size"], shuffle=False, drop_last=True)

    for epoch in range(PARAMETERS["num_epochs"]):
        train_one_epoch(epoch, model, optimizer, loss_fn, training_loader, PARAMETERS)


I fixed my problem. It turns out we were having the @torch.jit.script decorator in some functions inside and outside our model. Removing the decorators resolved the problem. The purpose of the decorator was to script the model in order to run it later from C++ in production, which I feel can be solved in another way.