How to implement gradient accumulation for GAN?

I want to take the advantage of gradient accumulation to train a GAN with a larger batch size. I understand that for a normal network we just do something like

output = net(input)
loss = criterion(output, target_var)
loss = loss / accumulate_steps
loss.backward()
if iterations % accumulate_steps == 0:
    optimizer.step()
    optimizer.zero_grad()

But how can I implement this for a GAN? Cause in training a GAN we need to iteratively update G and D. When calculate the gradient of G, the wrong gradient will be accumulated to D. So we normally clear Dā€™s gradient in each iteration which conflicts with using gradient accumulation strategy.

Here is a code of GAN without accumulate gradient:

#-----------
# Update G
#-----------
optimizer_G.zero_grad()
gen_imgs = generator(input_noise)
g_loss = adversarial_loss(discriminator(gen_imgs), label_real) 
g_loss.backward() 
optimizer_G.step() 

#----------
# Update D
#----------
optimizer_D.zero_grad()   # This step would clear the wrong gradient
real_loss = adversarial_loss(discriminator(real_imgs), label_real) 
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), label_fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()

You could try to set the .requires_grad of the parameters of the model, which should not get gradients, to False. This should still calculate the gradients for e.g. the generator, while no new gradients should be accumulated in the discriminator.
After this step, you would have to set these attribute to True again for the parameters.

@AbnerC @ptrblck Thanks for your question and solution.
@AbnerC I wonder if you have tried the mentioned solution. If I understand correctly, this means we need to set the parameters in the discriminator to False when training the generator, and set them back to True when training the discriminator, and we need to do this in every iteration.

1 Like

An easier solution for gradient accumulation:

#-----------
# Update G
#-----------
gen_imgs = generator(input_noise)
g_loss = adversarial_loss(discriminator(gen_imgs), label_real) 
g_loss.backward(inputs = list(generator.parameters())) 
if iter % accum_steps == 0:
    optimizer_G.step()
    optimizer_G.zero_grad()

#----------
# Update D
#----------
real_loss = adversarial_loss(discriminator(real_imgs), label_real) 
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), label_fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
if iter % accum_steps == 0:
    optimizer_D.step()
    optimizer_D.zero_grad()