I have multiple scalers that operate on losses which might overlap in the computation graph. Should I call scaler1.scale(loss)
and scaler2.scale(loss)
? I’m not sure how to do this. The different scalers correspond to different optimizers.
Maybe a minimal example (not tested):
scaler0 = torch.cuda.amp.GradScaler()
scaler1 = torch.cuda.amp.GradScaler()
for epoch in epochs:
for input, target in data:
optimizer0.zero_grad()
optimizer1.zero_grad()
with autocast(device_type='cuda', dtype=torch.float16):
output0 = model0(input)
output1 = model1(input)
loss = loss_fn(2 * output0 + 3 * output1, target)
# This is the part I don't know how to do
scaler0.scale(loss).backward(retain_graph=True)
scaler1.scale(loss).backward()
scaler0.step(optimizer0)
scaler1.step(optimizer1)
scaler0.update()
scaler1.update()