yes that would work, and you could do the same without amp.
There’s no need for retain_graph=True on the first backward unless the two losses share backward graph. scaler.scale(loss0).backward(retain_graph=True) in the example only uses retain_graph=True because the two losses share some backward graph, it has nothing to do with amp.