How to use amp in GAN

Generally speaking, the steps to use amp should be like this:

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

This is my GAN code (before using amp):

optimizer1=SGD(model.D.parameters()......)
optimizer2=SGD(model.G.parameters().....(
for I in range(D_training_steps):
     ........
     loss1.backward()
    optimizer1.step()
    ......

     loss2.backward()      ## after update the D above,I will compute loss2 and update again
     optimizer1.step()
    ....
for I in range(G_training_steps):
    loss3.backward()
    optimizer3.step()

So the question is when to use scaler.update() if using amp?
Thanks for your help!

1 Like

This example should work for your use case. You can use a single scaler for both losses or create separate ones, if you want.

1 Like