I met a problem when I tried to call the backward function of two losses in sequential when using DDP.
It works fine without DDP.
I have a Function like this:
class WeightedSumFunc(torch.autograd.Function):
a_ij_require_grad=True
@staticmethod
def forward(ctx, a_ij, e_ijk):
ctx.save_for_backward(a_ij, e_ijk)
output = torch.einsum('bi,bij->bj', a_ij, e_ijk)
return output
@staticmethod
def backward(ctx, grad_output):
# (N, dim)
a_ij, e_ijk = ctx.saved_tensors
if WeightedSumFunc.a_ij_require_grad:
d_a_ij = torch.einsum('bj,bij->bi', grad_output, e_ijk)
else:
d_a_ij = None
d_e_ijk = torch.einsum('bi,bj->bij', a_ij, grad_output)
return d_a_ij, d_e_ijk
And I have two losses: loss1, and loss2. Here is my training code.
with torch.cuda.amp.autocast(dtype=amp_dtype):
loss1, loss2 = self.model(**inputs)
WeightedSumFunc.a_ij_require_grad = True
scaler.scale(loss1 / accumulation_steps).backward(retain_graph=True)
WeightedSumFunc.a_ij_require_grad = False
scaler.scale(loss2 / accumulation_steps).backward()
I hope a_ij will backpropagate gradients during the backward pass in loss1 but not in loss2. Some parameters participate in both loss1 and loss2 and some don’t. I expect gradients can be accumulated after calling backward functions and it works fine without DDP.
However, I receive a runtime error message when using DDP:
“RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the forward
function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple checkpoint
functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.”
I didn’t use a parameter outside the forward function nor multiple ‘checkpoint’ functions.
In my use-case, how can I accumulate gradients?
Or is there any more elegant way to achieve the same goal?
Updates:
I tried this
with model.no_sync():
result = self.model(**inputs, coeff=coeff, temperature=temperature)
if loss1 is not None:
WeightedSumFunc.a_ij_require_grad = True
scaler.scale(loss1 / accumulation_steps).backward(retain_graph=True)
WeightedSumFunc.a_ij_require_grad = False
scaler.scale(result.non_struct_loss / accumulation_steps).backward()
if (step + 1) % accumulation_steps == 0:
for p in model.parameters():
if p.grad is not None:
torch.distributed.all_reduce(p.grad, op=torch.distributed.ReduceOp.SUM)
p.grad /= torch.distributed.get_world_size()
Now no error is thrown, but I just find the loss is not as expected. The loss curve is still the same as training in a single card. How can I synchronize manually?