Hi,
I have a model which has a custom autograd Function call. The backward method of this call has mutliple torch.autograd.grad calls in a loop. Somewhat like this -
class func1(Function):
@staticmethod
def forward(ctx, input1, input2, *args):
ctx.save_for_backward(input1, input2)
return input2
@staticmethod
def backward(ctx, grad):
input1, input2 = ctx.saved_tensors
for ii in range(10):
new = torch.autograd.grad(input2, input1, grad_outputs=grad,
retain_graph=True)
return (None, new)
class MyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = torch.nn.Linear(10, 10)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(10, 5)
def forward(self, x):
out = self.net1(x)
out = func1.apply(x, out)
out2 = self.relu(out)
out2 = func1.apply(x, out2)
out3 = self.net2(out2)
out3 = func1.apply(x, out3)
return out3
This works fine when I run with single GPU. But when I run with DDP it hangs in the loss.backward() call. I see that it hangs in the for loop in func1. Peculiarly all the workers hang in the same iteration of the for loop.
Edit: CPU utlization and GPU utilization stays high (100%) when it hangs for all processes. Code run with torch.distributed.launch
Any help would be appreciated!
Thanks!!