Using Pytorch's AMP with multiple scaler backwards per epoch

I’m trying to implement Wasserstein-GP with Pytorchs’s Automatic Mixed Precision. In Wasserstein GAN the critic networks gets updated more frequently than generator (e.g. 5 times per each iteration).
My code goes something like:

for i in range(5):
   with torch.cuda.amp.autocast():
     discriminator_optimizer.zero_grad()
         # do some stuff
      loss_discriminator = compute_discriminator_loss()
      loss_gradient = compute_gradient_penalty_loss()
      ld = loss_gradient +  loss_discriminator
   scaler.scale(ld).backward()
   scaler.step(discriminator_optimizer)
with torch.cuda.amp.autocast():
   generator_optimizer.zero_grad()
   # do some stuff
   loss_generator = compute_generator_loss()
scaler.scale(loss_generator ).backward()
scaler.step(generator_optimizer) 
scaler.update()

but I get the error

RuntimeError: step() has already been called since the last update().

When I call the backward() on discriminator loss for the second time.
The error does not appear if I update the scaler after each backward() call, but the Pytorchs examples states that " scaler.update should only be called once".
So what’s the correct way to update the scaler here?

(I’m using pytorch 1.9.0)

hi,
not an expert on amp, but it seems you cant repeat step() with the same scale factor. it has to be updated. i think it is by design because it makes sens that way. the scale factor is designed to follow the dynamics of gradients which change through step(). so, every time you step(), you need to update(). you cant step several times without update, because you will leave the scale factor behind.

i think you should use 2 gradient scalers because your both losses are progressing at different rates. so, each one needs its own scale factor that is updated using update. this will allow you to update the scale factor of the discriminator as often as you like. it does not seem that you retaining any gradient to share between disc and gen, so, you can scale their gradient separately.

i think the pytorch example works because both losses advance with the same rate.

1 Like

That makes sense. I just wasn’t sure if it’s ok to use two different Scalers. Thanx.