Hi, Im trying to understand the code of WGAN-gp. ... create_graph=True, retain_graph=True,...
However, I did not understand what the argument create_graph is doing?
I’ve checked the doc however it’s not clear for me. Where did it calculate the second-order gradient?
If True , graph of the derivative will be constructed, allowing to compute higher order derivative products. Default: False .
When I change it to ... create_graph=True, retain_graph=True,...
or ... create_graph=True, retain_graph=None,...
the code still works, but the model generates same image (model collapse?)
in WGAN-gp, you want to calculate the gradient wrt the norm of your gradient, because you want to optimize that your norm of gradient is constant (this is the Lipschitz constraint that you apply).
With create_graph=True, we are declaring that we want to do further operations on gradients, so that the autograd engine can create a backpropable graph for operations done on gradients.
retain_graph=True declares that we will want to reuse the overall graph multiple times, so do not delete it after someone called .backward(). From looking at the code, we do not call .backward() on the same graph again, so retain_graph=True is not needed in this case.
When you changed retain_graph=None, a model mode collapse occuring is likely something that’s happening coincidentally, but the reason for the mode collapse is not because of changing retain_graph (from the 2 minutes I spend reading the code, I think so).
I actually made a mistake in the original post, what I changed the code to is
... create_graph=False, retain_graph=True,...
or ... create_graph=False, retain_graph=None,...
the code still works, but the model generates same image (model collapse?)
if create_graph=False, technically it should error out when you try to compute the gradient of gradient. It is a bit strange that it is silently ignoring it.
That’s interesting, I do think here retain_graph=False doesn’t work even if create_graph=True with normal discriminator loss i.e. gradient_penalty - d_real_out + d_fake_out
First of all, as per Pytorch docs, retain_graph=None defaults to create_graph.
So when we give following:
... create_graph=True, retain_graph=False,...
and use above test code with slight changes to only one line:
gp = ((gradients.norm(dim=1)-1)**2).mean() + y.mean()
it fails with:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
So it does seem like retain_graph=True (or None if create_graph=True) is necessary for normal discriminator loss.