DistributedDataParallel hangs when there is autograd call in backward

Hi,

I have a model which has a custom autograd Function call. The backward method of this call has mutliple torch.autograd.grad calls in a loop. Somewhat like this -

class func1(Function):
    @staticmethod
    def forward(ctx, input1, input2, *args):
        ctx.save_for_backward(input1, input2)
        return input2

    @staticmethod
    def backward(ctx, grad):
        input1, input2 = ctx.saved_tensors
        
        for ii in range(10):
            new = torch.autograd.grad(input2, input1, grad_outputs=grad,
                                            retain_graph=True)
        return (None, new)


class MyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = torch.nn.Linear(10, 10)
        self.relu = torch.nn.ReLU()
        self.net2 = torch.nn.Linear(10, 5)

    def forward(self, x):
        out = self.net1(x)
        out = func1.apply(x, out)
        out2 = self.relu(out)
        out2 = func1.apply(x, out2)
        out3 = self.net2(out2)
        out3 = func1.apply(x, out3)
        return  out3

This works fine when I run with single GPU. But when I run with DDP it hangs in the loss.backward() call. I see that it hangs in the for loop in func1. Peculiarly all the workers hang in the same iteration of the for loop.

Edit: CPU utlization and GPU utilization stays high (100%) when it hangs for all processes. Code run with torch.distributed.launch
Any help would be appreciated!

Thanks!!

Hi,

This is a known limitation of DDP I’m afraid. You can see the issue tracking this here: https://github.com/pytorch/pytorch/issues/24005

We are planning on getting to this after the 1.5 release.

Thank you for your response.

After some testing, now I realised that the hang is due to syncBNs in the model - it works fine with normal BNs. The graph between input2 and input1 has syncBNs too and the many autograd.grad() calls give rise to many all_reduce calls in syncBNs’ backward which hang. I think this is what’s happening - https://github.com/pytorch/pytorch/pull/14267#discussion_r257051495

My model has many heads in parallel with syncBNs and those could be deadlocking too.

Do you see a solution/workaround for this?

Thanks again!

@alekhka as a temporary workaround, you could try doing allreduce manually after the backward pass. (see this comment). This will be slower as there will be no overlapping between computation and communication, but hopefully can avoid the hang problem.

That should be same as setting delay_reduce=True in apex right? That doesn’t fix it either.

I think the deadlock is between all_reduce calls in syncBNs’ backward() across the parallel heads and not between gradient all_reduce & syncBN all_reduce.

Removing the parallel head and having just 1 head works fine. Replacing syncBN with BN works fine for with parallel head model.

Thank you!