Memory blowup when computing high-order gradients

I’m using the new autograd.grad function to penalize gradient values for implementing improved WGANs. This function computes the gradient penalty:

def calc_gradient_penalty(D, real_data, fake_data, _lambda):
   eps = torch.rand(batch_size, 1).expand(real_data.size()).type(dtype)
   x_hat = eps * real_data + ((1 - eps) * fake_data)
   x_hat = Variable(x_hat, requires_grad=True)

   D_x_hat = D(x_hat)

   grad_params = autograd.grad(outputs=D_x_hat, inputs=x_hat,
                             grad_outputs=torch.ones(D_x_hat.size()).type(dtype),
                             create_graph=True, retain_graph=True, only_inputs=True)

   grad_norm = 0
   for grad in grad_params:
       grad_norm += grad.pow(2).sum()
   grad_norm = grad_norm.sqrt()

   gradient_penalty = ((grad_norm - 1) ** 2).mean() * _lambda
   return gradient_penalty

and later I call gradient_penalty.backward()
The memory usage gets larger with each iteration and exceeds my limit pretty fast. I think this is because I retain the graph.

How can I resolve this problem?

I think this issue is related to your problem https://github.com/pytorch/pytorch/issues/2287

1 Like

Thanks!
Problem solved after removing the batchnorm layers.