When using torch.distributed.algorithms.join, I got the `gloo::EnforceNotMet` error

When I split the training set, the training data cannot be split evenly among all trainers. Without using torch.distributed.algorithms.join, the training script hangs. But when I start to use torch.distributed.algorithms.join, I got another error.

terminate called after throwing an instance of 'gloo::EnforceNotMet'
  what():  [enforce fail at ../third_party/gloo/gloo/transport/tcp/pair.cc:510] op.preamble.length <= op.nbytes. 24640 vs 4
Traceback (most recent call last):
  File "/home/dzzhen/.local/lib/python3.9/site-packages/torch/distributed/algorithms/join.py", line 274, in __exit__
  File "/home/dzzhen/m5-gnn/python/m5gnn/model/rgcn_node_base.py", line 179, in fit
    loss.backward()
  File "/home/dzzhen/.local/lib/python3.9/site-packages/torch/_tensor.py", line 307, in backward
    join_hook.main_hook()
  File "/home/dzzhen/.local/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 193, in main_hook
    ddp._match_all_reduce_for_bwd_pass()
  File "/home/dzzhen/.local/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1070, in _match_all_reduce_for_bwd_pass
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/dzzhen/.local/lib/python3.9/site-packages/torch/autograd/__init__.py", line 154, in backward
    work.wait()
RuntimeError: [../third_party/gloo/gloo/transport/tcp/pair.cc:589] Read error [10.2.23.254]:1439: Connection reset by peer
    Variable._execution_engine.run_backward(
RuntimeError: [../third_party/gloo/gloo/transport/tcp/pair.cc:589] Read error [10.2.23.254]:49550: Connection reset by peer

My code is a bit complex. It has multiple nn.module. Here is a snippet of the code.

            with Join([model, embed_layer] + list(bert_model.values())):
                for i, (input_nodes, seeds, blocks) in enumerate(loader):
                    total_steps += 1
                    blocks = [blk.to(device) for blk in blocks]

Is the error related to how many modules are added to Join? I tried different options and always got the same error.

cc @awgu for help with join context manager

Is model a DistributedDataParallel instance? Do the modules embed_layer and list(bert_model.values()) include any collective communications?

If the answers are yes and no, respectively, then can you try with Join([model]): (i.e. do not pass in the other modules)?