Memory leak when using autograd.grad

Hi,
When I use autograd.grad to get higher order derivatives there’s memory that remains allocated and eventually I get the ‘out of memory’ error.
Please find bellow a minimal example:

from __future__ import print_function
import argparse
import random
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.autograd import Variable

try:
    import gpustat
except ImportError:
    raise ImportError("pip install gpustat")


def show_memusage(device=1):
    gpu_stats = gpustat.GPUStatCollection.new_query()
    item = gpu_stats.jsonify()["gpus"][device]
    print('Used/total: ' + "{}/{}".format(item["memory.used"], item["memory.total"]))


def run_simple_experiment():

    class simpleNet(nn.Module):
        def __init__(self):
            super(simpleNet, self).__init__()
            self.main = nn.Sequential(
                nn.Linear(500, 5000),
                nn.BatchNorm1d(5000),
                nn.Linear(5000, 5000),
                nn.BatchNorm1d(5000),
                nn.Linear(5000, 5000),
                nn.BatchNorm1d(5000),
                nn.Linear(5000, 5000),
                nn.BatchNorm1d(5000),
                nn.Linear(5000, 5000),
                nn.BatchNorm1d(5000),
                nn.Linear(5000, 1000),
                nn.Linear(1000, 1),
            )

        def forward(self, x):
            return self.main(x)

    net = simpleNet()
    if opt.cuda:
        net.cuda()

    optimizer = optim.RMSprop(net.parameters(), lr=opt.lr)

    input = dtype(opt.bsize, 500)
    input2 = dtype(opt.bsize, 500)
    one = dtype([1])

    for _ in range(10):
        vinput = Variable(input, requires_grad=True)

        # whatever
        err = net(vinput).mean(0).view(1)
        err.backward(one, retain_graph=opt.use_grad_loss)

        if opt.use_grad_loss:
            # whatever
            vinput2 = Variable(input2, requires_grad=True)
            outputs = net(vinput2)
            gradients = torch.autograd.grad(outputs=outputs, inputs=vinput2,
                                            grad_outputs=torch.ones(outputs.size()).cuda()
                                            if opt.cuda else torch.ones(outputs.size()),
                                            create_graph=True, retain_graph=True, only_inputs=True)[0]
            grad_loss = gradients.view(opt.bsize, -1).sum(1).mean(0).view(1)
            grad_loss.backward()

        optimizer.step()


################################################################################################
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--manualSeed', type=int, default=1, help='seed, default %(default)s')
    parser.add_argument('--lr', type=float, default=1e-4, help='default: %(default)s')
    parser.add_argument('--cuda', action='store_true', help='enable gpu, default: %(default)s')
    parser.add_argument('--bsize', type=int, default=64, help='batch size %(default)s')
    parser.add_argument('--devID', type=int, default=1, help='GPU device ID %(default)s')
    parser.add_argument('--use_grad_loss', action='store_true', help='default: %(default)s')
    opt = parser.parse_args()
    print(opt)

    if torch.cuda.is_available() and not opt.cuda:
        print("Please run it on gpu, with --cuda")

    if opt.cuda:
        torch.cuda.manual_seed(opt.manualSeed)
        dtype = torch.cuda.FloatTensor
        cudnn.benchmark = True
    else:
        dtype = torch.FloatTensor

    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)

    for _ in range(10):
        run_simple_experiment()
        if opt.cuda:
            print('gpu-mem-usage after f-n call')
            show_memusage(opt.devID)
            print()

Please run it with: --devID <YOUR_GPU_DEVID> --cuda [–use_grad_loss]
to prinout mem.usage you can use ‘gpustat’.
(which for python 3 gives error. you can rm it by replacing iteritems with items in the line)

In the main, I call a function few times and I printout the used gpu memory after the call.

  1. with --use_grad_loss: the used gpu memory adds up;
  2. without --use_grad_loss: the used gpu memory stays constant.
    Output snippets of the two:

python sanity-check.py --cuda
Namespace(bsize=64, cuda=True, lr=0.0001, manualSeed=1, niters=5, use_grad_loss=False)
gpu-mem-usage after f-n call
Used/total: 1783/6082

gpu-mem-usage after f-n call
Used/total: 1783/6082

gpu-mem-usage after f-n call
Used/total: 1783/6082

gpu-mem-usage after f-n call
Used/total: 1783/6082

gpu-mem-usage after f-n call
Used/total: 1783/6082

gpu-mem-usage after f-n call
Used/total: 1783/6082

gpu-mem-usage after f-n call
Used/total: 1783/6082

gpu-mem-usage after f-n call
Used/total: 1783/6082

gpu-mem-usage after f-n call
Used/total: 1783/6082

gpu-mem-usage after f-n call
Used/total: 1783/6082

################################################################
python sanity-check.py --cuda --use_grad_loss
Namespace(bsize=64, cuda=True, lr=0.0001, manualSeed=1, niters=5, use_grad_loss=True)
gpu-mem-usage after f-n call
Used/total: 2379/6082

gpu-mem-usage after f-n call
Used/total: 2571/6082

gpu-mem-usage after f-n call
Used/total: 2764/6082

gpu-mem-usage after f-n call
Used/total: 2958/6082

gpu-mem-usage after f-n call
Used/total: 3151/6082

gpu-mem-usage after f-n call
Used/total: 3343/6082

gpu-mem-usage after f-n call
Used/total: 3537/6082

gpu-mem-usage after f-n call
Used/total: 3730/6082

gpu-mem-usage after f-n call
Used/total: 3923/6082

gpu-mem-usage after f-n call
Used/total: 4116/6082

What I tried:
both with source code built and conda built pytorch
w/o cudnn
different models and losses.

Above I keep more params on purpose so that there’s notable difference in GB.
In the above example, if I don’t use BatchNorm there’s no problem, which is listed in the nn layers that support higher order gradients. I am not sure if only BatchNorm gets ‘out of memory’, I remember without it I was getting nan-s from autograd.grad
please let me know if more details are needed
many thanks!

1 Like

It could be this: https://github.com/pytorch/pytorch/pull/2326

1 Like

Great! I had older source code I guess, I just re-built and it works! many thanks!