I’m using the new autograd.grad function to penalize gradient values for implementing improved WGANs. This function computes the gradient penalty:
def calc_gradient_penalty(D, real_data, fake_data, _lambda): eps = torch.rand(batch_size, 1).expand(real_data.size()).type(dtype) x_hat = eps * real_data + ((1 - eps) * fake_data) x_hat = Variable(x_hat, requires_grad=True) D_x_hat = D(x_hat) grad_params = autograd.grad(outputs=D_x_hat, inputs=x_hat, grad_outputs=torch.ones(D_x_hat.size()).type(dtype), create_graph=True, retain_graph=True, only_inputs=True) grad_norm = 0 for grad in grad_params: grad_norm += grad.pow(2).sum() grad_norm = grad_norm.sqrt() gradient_penalty = ((grad_norm - 1) ** 2).mean() * _lambda return gradient_penalty
and later I call gradient_penalty.backward()
The memory usage gets larger with each iteration and exceeds my limit pretty fast. I think this is because I retain the graph.
How can I resolve this problem?