Torch.autograd keeps on increasing GPU usage on backward pass

Hello,
I am working on SinGAN and they use a gradient penalty loss which just keeps on increasing GPU usage to the extent that I can not train even on A100(40 GB). I am not sure where the problem could be or what are the ways around it. I have gone through previous related posts but they were either had a problem in their own logic or were talking about memory leak(which may happen if you set create_graph=False). I am also pasting the chunk of code below where the memory increase happens continuously.

LAMBDA=opt.lambda_grad
device=opt.device
alpha = torch.rand(1, 1)
alpha = alpha.expand(real.size())
alpha = alpha.to(device)#cuda() #gpu) #if use_cuda else alpha

interpolates = alpha * real + ((1 - alpha1) * fake)


interpolates = interpolates.to(device)#.cuda()
interpolates = torch.autograd.Variable(interpolates, requires_grad=True)
disc_interpolates = netD(interpolates)
            
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                      grad_outputs=torch.ones(disc_interpolates.size()).to(device),#.cuda(), #if use_cuda else torch.ones(
                                          #disc_interpolates.size()),
                                      create_graph=True, retain_graph=True, only_inputs=True)[0]
            #LAMBDA = 1            
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA

            # gradeint penalty calculated
            # gradient_penalty = functions.calc_gradient_penalty(netD, real, fake, opt.lambda_grad, opt.device)            
**gradient_penalty.backward()**

The increase in GPU usage happens precisely at gradient_penalty.backward() which I have also made bold. I have already tried detaching and deleting the gradients and gradient_pentalty variables. It doesn’t work.

The problem goes away as soon as I switch to pytorch 1.4.1. However, I am not sure why the problem exists in Pytorch 1.7.0.

@ptrblck any insights here would be great :smile: This seems to be a pytorch related issue. I just checked code works fine on 1.5.0 as well.