Hi,
I have read the other threads related to this, so forgive me if I’ve overlooked something obvious, but I’ve tried to implement according to the advice there given, and am still having a persistent memory leak.
Here’s the code, adapted from Improved Training of Wasserstein GANs
class WassGP(torch.nn.Module):
def __init__(self,model,config):
super().__init__()
self.model=model
self.config=config
def gradient_penalty(self,imgs,fake_inp,lyrs,dev):
dev=dev if dev=='cpu' else f'cuda:{dev}'
eps=torch.rand([imgs.size(0),1,1,1],device=dev)
x_hat=eps*imgs+(1-eps)*fake_inp
x_hat.requires_grad=True
outp=self.model(x_hat,lyrs,dev)
gradient,=torch.autograd.grad(outp,x_hat,torch.ones_like(outp),create_graph=True)
grad_loss = self.config.training.lambda_grad*(torch.linalg.norm(gradient.view(outp.size(0),-1),2,dim=1)-1)**2
grad_loss.mean().backward()
def forward(self,real_inpt,fake_inpt,layer,dev,res,alpha):
fake_out=self.model(fake_inpt,layer,dev,res=res,alpha=alpha)
real_out=self.model(real_inpt,layer,dev,res=res,alpha=alpha)
loss_real=self.config.training.lambda_disc_real * real_out
loss_fake=self.config.training.lambda_disc_fake * fake_out
drift_loss=self.config.training.epsilon_drift*real_out.pow(2) if self.config.training.drift_loss else 0.
if self.config.training.use_gradient_penalty:
self.gradient_penalty(real_inpt,fake_inpt,layer,dev)
return (fake_out-real_out+drift_loss).mean()
I implemented this as a module subclass because I was experimenting with some things, hooks etc. but the leak occurs whether gradient_penalty is a regular function, a method in a non-module class, or any other way I can think of. I have also tried returning grad_loss and propagating it backward with the other losses with the same result.
If I leave out create_graph=True in the call to autograd.grad(), there is no leak, but also not the intended behavior, as gradients don’t flow backward properly. There are threads related to memory leaks using create_graph=True in .backward(), and they usually recommend using .autograd.grad() with create_graph=True, so I’m very confused.
I looked at python’s garbage collector tensor list, as prescribed in https://discuss.pytorch.org/t/how-to-debug-causes-of-gpu-memory-leaks/6741/2
and it’s unclear what exactly is happening. I see the creation of a lot of non-Parameter tensors (as well as Parameter tensors), but that could be related to the many model submodules not being used in early training (StyleGAN, employing a ProGAN training strategy).
Using torch 1.7.1+cu101
Any help is much appreciated.