I’m trying to implement Wasserstein-GP with Pytorchs’s Automatic Mixed Precision. In Wasserstein GAN the critic networks gets updated more frequently than generator (e.g. 5 times per each iteration).

My code goes something like:

```
for i in range(5):
with torch.cuda.amp.autocast():
discriminator_optimizer.zero_grad()
# do some stuff
loss_discriminator = compute_discriminator_loss()
loss_gradient = compute_gradient_penalty_loss()
ld = loss_gradient + loss_discriminator
scaler.scale(ld).backward()
scaler.step(discriminator_optimizer)
with torch.cuda.amp.autocast():
generator_optimizer.zero_grad()
# do some stuff
loss_generator = compute_generator_loss()
scaler.scale(loss_generator ).backward()
scaler.step(generator_optimizer)
scaler.update()
```

but I get the error

```
RuntimeError: step() has already been called since the last update().
```

When I call the `backward()`

on discriminator loss for the second time.

The error does not appear if I update the scaler after each `backward()`

call, but the Pytorchs examples states that " `scaler.update`

should only be called once".

So what’s the correct way to update the scaler here?

(I’m using pytorch 1.9.0)