How to properly use distributed.all_gather?

Hi, I am trying to use torch.distributed.all_gather function and I’m confused with the parameter ‘tensor_list’.

here’s my snippet.

# output: dictionary, e.g. {'key':'tensor', 'key':'tensor'}
output = self.model(input)

for key in output.keys():
    tensor_list = [torch.zeors_like(output[key]) for _ in range(dist.get_world_size())]
    output[key] = self.gather_tensor(tensor_list, output[key])

def gather_tensor(self, tensor_list, tensor):
    rt = tensor.clone()
    dist.all_gather(tensor_list, rt)
    return torch.stack(tensor_list)

I designed ‘gather_tensor’ following the function ‘reduce_tensor’ from apex example.

However, it shows RuntimeError as below

in gather_tensor, dist.all_gather(tensor_list, rt)
File “/opt/conda/lib/python3.6/site-packages/torch/distributed/”, line 1027, in all_gather
work = _default_pg.allgather([tensor_list], [tensor])
RuntimeError: Input tensor sequence should hsave the same number of tensors as the output tensor sequence

What’s wrong with my code?

output = self.model(input)
tensor_list = [torch.zeors_like(output['key1']) for _ in range(dist.get_world_size())]
for key in output.keys():
    output[key] = self.gather_tensor(tensor_list, output[key])

I put the tensor_list initialization in front of ‘for loop’ and it works but it seems like ‘deadlock’ occurs.