Memory Leak in WGAN-GP autograd

I want to use WGAN-GP, and when I run the code, it gives me an error:

def calculate_gradient_penalty(real_images, fake_images):

    t = torch.rand(real_images.size(0), 1, 1, 1).to(real_images.device)
    t = t.expand(real_images.size())

    interpolates = t * real_images + (1 - t) * fake_images
    interpolates.requires_grad_(True)

    disc_interpolates = D(interpolates)

    grad = torch.autograd.grad(
        outputs=disc_interpolates, inputs=interpolates,
        grad_outputs=torch.ones_like(disc_interpolates),
        create_graph=True, retain_graph=True, allow_unused=True)[0]

    grad_norm = torch.norm(torch.flatten(grad, start_dim=1), dim=1)
    loss_gp = torch.mean((grad_norm - 1) ** 2) * lambda_term

    return loss_gp

RuntimeError Traceback (most recent call last)
in

/opt/conda/lib/python3.8/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
243 create_graph=create_graph,
244 inputs=inputs)
→ 245 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
246
247 def register_hook(self, hook):

/opt/conda/lib/python3.8/site-packages/torch/autograd/init.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
143 retain_graph = create_graph
144
→ 145 Variable.execution_engine.run_backward(
146 tensors, grad_tensors
, retain_graph, create_graph, inputs,
147 allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag

RuntimeError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 2; 15.75 GiB total capacity; 13.76 GiB already allocated; 2.75 MiB free; 14.50 GiB reserved in total by PyTorch)

Does anyone know how to solve this problem?

Could you check, if prob_interpolated has a valid .grad_fn attribute?
If print(prob_interpolated.grad_fn) returns None, the model output seems to be detached from the computation graph and this error would be raised.

@ptrblck Thx, I already solve this problem, but now seems the memory leak…Do you have any suggestions? The case it that, after every epoch, the memory increase…

In the grad call you are using retain_graph=True, so I would assume that the graph is not freed afterwards. Is this on purpose and if so, why do you need to retain the graph?

@ptrblck In WGAN-GP, I think without it, the gradient (of the GP part of loss) cannot backproprogation through the entire D network, also there will be an error, RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn…