Convergence issue with PyTorch distributed

Hi,
I’m trying to do distributed training using torch.distributed package while hosting different parts of the model on different nodes. These nodes are initialized independently, and aggregated outputs of the first part of the model are used as inputs for the second model. The backpropagation is done in the same way. So firstly, the second model calculates gradients w.r.t. inputs and these are sent back to the first part of the model on multiple clients that do the backward pass.

The code looks approximately like this:

  • Init top:
model_top = nn.Sequential([...])
opt_top = Adam()
loss_top = nn.BCELoss()
  • Init bottom:
model_bottom = nn.Sequential([...])
opt_bottom = Adam()

Training loop:

  • Bottom:
output = model_bottom(sample)
dist.send(output, 0, global_group)
gradient = torch.zeros_like(output).to(device)
dist.recv(gradient, 0, global_group)
output.backward(gradient)
self.optimizer_.step()
  • Top:
outputs = []
for j in range(1, global_size):
    output = torch.zeros((len(batch), 1000))
    dist.recv(output, j, global_group)
    outputs.append(output)

outputs = torch.autograd.Variable(
    torch.stack(outputs), requires_grad=True
)
outputs = outputs.to(torch.float32).mean(dim=0)

predictions = top_model(outputs)
loss = loss_top(predictions, data)
loss.backward(retain_graph=True)
if outputs.grad is not None:
    gradients = outputs.grad
else:
    logger.warning("Gradient not computed in backward pass!")
    gradients = torch.autograd.grad(loss, outputs)[0]
opt_top.step()

for j in range(1, global_size):
    dist.send(gradients[j - 1], j, global_group)

The questions here are:

  • Is this correct approach towards custom distributed training?
  • Is there a reason why the model should not converge? (It converges when running on single machine)
  • Is there a better way how to write this in PyTorch?

Thanks for your question. What you have described sounds like pipeline parallelism. Have you tried our solution: PiPPy? (Ref: https://github.com/pytorch/tau/blob/main/README.md) We do the split automatically for you. Also what’s your model in your use case?

cc: @kwen2501