What's the correct way of using AMP with multiple losses?

As the doc says, " If your network has multiple losses, you must call scaler.scale on each of them individually." And this it looks like:

scaler = torch.cuda.amp.GradScaler()
with autocast():
    loss0 = some_loss
    loss1 = another_loss

scaler.scale(loss0).backward(retain_graph=True)
scaler.scale(loss1).backward()

It is relatively inconvenient especially when we want to have many weighted losses. What if we just sum all the losses to one and backward once? It looks like this:

scaler = torch.cuda.amp.GradScaler()
with autocast():
    loss0 = some_loss
    loss1 = another_loss
    loss_all = w0 * loss0 + w1 * loss1

scaler.scale(loss_all).backward()

Is this kind of implementation correct?

1 Like

I have the same problem here, this use case is not very well explained in the docs. Can anyone provide an explanation?

yes that would work, and you could do the same without amp.

There’s no need for retain_graph=True on the first backward unless the two losses share backward graph. scaler.scale(loss0).backward(retain_graph=True) in the example only uses retain_graph=True because the two losses share some backward graph, it has nothing to do with amp.

2 Likes