The code for stargan is here. In line 244, the discriminator prediction for real data is computed. In line 250, the discriminator prediction for fake data is computed but with .detach(). Why do we need to use .detach() when using zero.grad() will eventually remove all previously computed gradients? Also, why do we use it only for fake data and not the real data?
The purpose of using detach and zero_grad differ.
detach is used to literally detach a tensor from its current computation graph.
x.detach() hence returns a new tensor that isn’t attached to any comp graph and doesn’t require gradient. Importantly note that the new tensor shares the same storage with the previous one.
zero_grad simply only zeros out the gradients of the parameters (of the model) it is called on. The tensors in the computation graph aren’t affected.
As for your question, detach is used when the previous calculation history of a tensor wouldn’t be required anymore for the calculation of any gradients and hence we are able to save memory by detaching it from its graph.
zero_grad is used to prevent accumulation of gradients from several backward calls.
Thank you for your reply.
I am a little confused as to why the discriminator uses .detach() with fake data only. Shouldn’t it use for both real and fake data distribution?
Hi, this is more of an architectural question and I am not into CV, but here’s what I think:
I think you are talking about this part of the code -
# Compute loss with fake images.
x_fake = self.G(x_real, c_trg)
out_src, out_cls = self.D(x_fake.detach())
d_loss_fake = torch.mean(out_src)
As I could understand it, fake data is generated using the generator part of the GAN network.
The goal of the discriminator is to act as a classifier to classify real and fake images. Now, since during the discriminator training only this classifier needs to be trained (and not the generator part), the fake data needs to be detached from its computation graph as its graph involves the parameters of the generator.
Also, for real images I don’t think there’s a need to call detach as those do not have any computation graph associated with them.
Thanks a lot. This makes things so much clearer.