How to turn off gradient during GAN training

(Zhenlan Wang) #1

I am going through the DCGAN tutorials tutorials.

One question I have is how do you turn off the gradient history tracking for discriminator when you are training the generator. In the tutorial, it is not turned off as shown below.

# this part trains generator
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G

I see the grad tracking is turned off for generator when training the discriminator by calling detach on fake image. But not the other way around. Thanks in advance for your help :smiley:


Since the discriminator’s optimizer won’t be called for the generator updates, nothing bad will happen.
You could set the requires_grad attributes of the discriminator’s parameters to False and reset them to True after the generator updates, but this is not really necessary, as you are using different optimizers.

(Zhenlan Wang) #3

Thanks for your reply. But my understanding of how optimizer works is that they take .grad and current weight and then update them. So even through discriminator’s optimizer does not get called, .grad gets calculated for discriminator’s weight when you call .backward and on the forward pass, cache is saved as require_grad==True. Let me know if that is the case or I mis-understand something.


Your understanding is basically correct. Since netD.zero_grad() is called before updating the discriminator, these gradients will be cleared.

(Zhenlan Wang) #5

In other words, even though discriminator’s optimizer won’t be called for the generator updates, you will still save some time and memory by turning requires_grad off. This is because regardless of whether or not there will be optimizer call on weight, weight.grad get calculated (if requires_grad is on).


I’m not sure about the computation and memory usage, since you need the gradients to backpropagate to the generator.

(Zhenlan Wang) #7

I will benchmark it and report back.

(Zhenlan Wang) #8

There is a slight improvement in terms of time. I single out the part of the code that trains generator as shown below.

label = torch.full((batch_size,), real_label, device=device)

def train():
    noise = torch.randn(batch_size, nz, 1, 1, device=device)
    fake = netG(noise)
    output = netD(fake).view(-1)
    errG = criterion(output, label)
%timeit train()
# -> 44.3 ms ± 1.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
for p in netD.parameters():

%timeit train()
# -> 41 ms ± 253 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


Thanks for the debugging!
This post might explain the benefits you are seeing.

(Zhenlan Wang) #10

Thanks for the reference and your help. Much appreciated :smiley: