I’m trying to implement Wasserstein-GP with Pytorchs’s Automatic Mixed Precision. In Wasserstein GAN the critic networks gets updated more frequently than generator (e.g. 5 times per each iteration).
My code goes something like:
for i in range(5):
with torch.cuda.amp.autocast():
discriminator_optimizer.zero_grad()
# do some stuff
loss_discriminator = compute_discriminator_loss()
loss_gradient = compute_gradient_penalty_loss()
ld = loss_gradient + loss_discriminator
scaler.scale(ld).backward()
scaler.step(discriminator_optimizer)
with torch.cuda.amp.autocast():
generator_optimizer.zero_grad()
# do some stuff
loss_generator = compute_generator_loss()
scaler.scale(loss_generator ).backward()
scaler.step(generator_optimizer)
scaler.update()
but I get the error
RuntimeError: step() has already been called since the last update().
When I call the backward()
on discriminator loss for the second time.
The error does not appear if I update the scaler after each backward()
call, but the Pytorchs examples states that " scaler.update
should only be called once".
So what’s the correct way to update the scaler here?
(I’m using pytorch 1.9.0)