As the doc says, " If your network has multiple losses, you must call scaler.scale on each of them individually." And this it looks like:
scaler = torch.cuda.amp.GradScaler()
with autocast():
loss0 = some_loss
loss1 = another_loss
scaler.scale(loss0).backward(retain_graph=True)
scaler.scale(loss1).backward()
It is relatively inconvenient especially when we want to have many weighted losses. What if we just sum all the losses to one and backward once? It looks like this:
scaler = torch.cuda.amp.GradScaler()
with autocast():
loss0 = some_loss
loss1 = another_loss
loss_all = w0 * loss0 + w1 * loss1
scaler.scale(loss_all).backward()
Is this kind of implementation correct?