What's the correct way of using AMP with multiple losses?

yes that would work, and you could do the same without amp.

There’s no need for retain_graph=True on the first backward unless the two losses share backward graph. scaler.scale(loss0).backward(retain_graph=True) in the example only uses retain_graph=True because the two losses share some backward graph, it has nothing to do with amp.

2 Likes